cosmos-sdk/depinject/container.go

563 lines
14 KiB
Go

package depinject
import (
"bytes"
"fmt"
"reflect"
"github.com/pkg/errors"
"github.com/cosmos/cosmos-sdk/depinject/internal/graphviz"
)
type container struct {
*debugConfig
resolvers map[string]resolver
interfaceBindings map[string]interfaceBinding
invokers []invoker
moduleKeys map[string]*moduleKey
resolveStack []resolveFrame
callerStack []Location
callerMap map[Location]bool
}
type invoker struct {
fn *ProviderDescriptor
modKey *moduleKey
}
type resolveFrame struct {
loc Location
typ reflect.Type
}
// interfaceBinding defines a type binding for interfaceName to type implTypeName when being provided as a
// dependency to the module identified by moduleKey. If moduleKey is nil then the type binding is applied globally,
// not module-scoped.
type interfaceBinding struct {
interfaceName string
implTypeName string
moduleKey *moduleKey
resolver resolver
}
func newContainer(cfg *debugConfig) *container {
return &container{
debugConfig: cfg,
resolvers: map[string]resolver{},
moduleKeys: map[string]*moduleKey{},
interfaceBindings: map[string]interfaceBinding{},
callerStack: nil,
callerMap: map[Location]bool{},
}
}
func (c *container) call(provider *ProviderDescriptor, moduleKey *moduleKey) ([]reflect.Value, error) {
loc := provider.Location
graphNode := c.locationGraphNode(loc, moduleKey)
markGraphNodeAsFailed(graphNode)
if c.callerMap[loc] {
return nil, errors.Errorf("cyclic dependency: %s -> %s", loc.Name(), loc.Name())
}
c.callerMap[loc] = true
c.callerStack = append(c.callerStack, loc)
c.logf("Resolving dependencies for %s", loc)
c.indentLogger()
inVals := make([]reflect.Value, len(provider.Inputs))
for i, in := range provider.Inputs {
val, err := c.resolve(in, moduleKey, loc)
if err != nil {
return nil, err
}
inVals[i] = val
}
c.dedentLogger()
c.logf("Calling %s", loc)
delete(c.callerMap, loc)
c.callerStack = c.callerStack[0 : len(c.callerStack)-1]
out, err := provider.Fn(inVals)
if err != nil {
return nil, errors.Wrapf(err, "error calling provider %s", loc)
}
markGraphNodeAsUsed(graphNode)
return out, nil
}
func (c *container) getResolver(typ reflect.Type, key *moduleKey) (resolver, error) {
pr, err := c.getExplicitResolver(typ, key)
if err != nil {
return nil, err
}
if pr != nil {
return pr, nil
}
if vr, ok := c.resolverByType(typ); ok {
return vr, nil
}
elemType := typ
if isManyPerContainerSliceType(elemType) || isOnePerModuleMapType(elemType) {
elemType = elemType.Elem()
}
var typeGraphNode *graphviz.Node
if isManyPerContainerType(elemType) {
c.logf("Registering resolver for many-per-container type %v", elemType)
sliceType := reflect.SliceOf(elemType)
typeGraphNode = c.typeGraphNode(sliceType)
typeGraphNode.SetComment("many-per-container")
r := &groupResolver{
typ: elemType,
sliceType: sliceType,
graphNode: typeGraphNode,
}
c.addResolver(elemType, r)
c.addResolver(sliceType, &sliceGroupResolver{r})
} else if isOnePerModuleType(elemType) {
c.logf("Registering resolver for one-per-module type %v", elemType)
mapType := reflect.MapOf(stringType, elemType)
typeGraphNode = c.typeGraphNode(mapType)
typeGraphNode.SetComment("one-per-module")
r := &onePerModuleResolver{
typ: elemType,
mapType: mapType,
providers: map[*moduleKey]*simpleProvider{},
idxMap: map[*moduleKey]int{},
graphNode: typeGraphNode,
}
c.addResolver(elemType, r)
c.addResolver(mapType, &mapOfOnePerModuleResolver{r})
}
res, found := c.resolverByType(typ)
if !found && typ.Kind() == reflect.Interface {
matches := map[reflect.Type]reflect.Type{}
var resolverType reflect.Type
for _, r := range c.resolvers {
if r.getType().Kind() != reflect.Interface && r.getType().Implements(typ) {
resolverType = r.getType()
matches[resolverType] = resolverType
}
}
if len(matches) == 1 {
res, _ = c.resolverByType(resolverType)
c.logf("Implicitly registering resolver %v for interface type %v", resolverType, typ)
c.addResolver(typ, res)
} else if len(matches) > 1 {
return nil, newErrMultipleImplicitInterfaceBindings(typ, matches)
}
}
return res, nil
}
func (c *container) getExplicitResolver(typ reflect.Type, key *moduleKey) (resolver, error) {
var pref interfaceBinding
var found bool
// module scoped binding takes precedence
pref, found = c.interfaceBindings[bindingKeyFromType(typ, key)]
// fallback to global scope binding
if !found {
pref, found = c.interfaceBindings[bindingKeyFromType(typ, nil)]
}
if !found {
return nil, nil
}
if pref.resolver != nil {
return pref.resolver, nil
}
res, ok := c.resolverByTypeName(pref.implTypeName)
if ok {
c.logf("Registering resolver %v for interface type %v by explicit binding", res.getType(), typ)
pref.resolver = res
return res, nil
}
return nil, newErrNoTypeForExplicitBindingFound(pref)
}
var stringType = reflect.TypeOf("")
func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (interface{}, error) {
providerGraphNode := c.locationGraphNode(provider.Location, key)
hasModuleKeyParam := false
hasOwnModuleKeyParam := false
for _, in := range provider.Inputs {
typ := in.Type
if typ == moduleKeyType {
hasModuleKeyParam = true
}
if typ == ownModuleKeyType {
hasOwnModuleKeyParam = true
}
if isManyPerContainerType(typ) {
return nil, fmt.Errorf("many-per-container type %v can't be used as an input parameter", typ)
} else if isOnePerModuleType(typ) {
return nil, fmt.Errorf("one-per-module type %v can't be used as an input parameter", typ)
}
vr, err := c.getResolver(typ, key)
if err != nil {
return nil, err
}
var typeGraphNode *graphviz.Node
if vr != nil {
typeGraphNode = vr.typeGraphNode()
} else {
typeGraphNode = c.typeGraphNode(typ)
if err != nil {
return nil, err
}
}
c.addGraphEdge(typeGraphNode, providerGraphNode)
}
if !hasModuleKeyParam {
c.logf("Registering %s", provider.Location.String())
c.indentLogger()
defer c.dedentLogger()
sp := &simpleProvider{
provider: provider,
moduleKey: key,
}
for i, out := range provider.Outputs {
typ := out.Type
// one-per-module maps can't be used as a return type
if isOnePerModuleMapType(typ) {
return nil, fmt.Errorf("%v cannot be used as a return type because %v is a one-per-module type",
typ, typ.Elem())
}
// many-per-container slices of many-per-container types
if isManyPerContainerSliceType(typ) {
typ = typ.Elem()
}
vr, err := c.getResolver(typ, key)
if err != nil {
return nil, err
}
if vr != nil {
c.logf("Found resolver for %v: %T", typ, vr)
err := vr.addNode(sp, i)
if err != nil {
return nil, err
}
} else {
c.logf("Registering resolver for simple type %v", typ)
typeGraphNode := c.typeGraphNode(typ)
vr = &simpleResolver{
node: sp,
typ: typ,
graphNode: typeGraphNode,
idxInValues: i,
}
c.addResolver(typ, vr)
}
c.addGraphEdge(providerGraphNode, vr.typeGraphNode())
}
return sp, nil
} else {
if hasOwnModuleKeyParam {
return nil, errors.Errorf("%T and %T must not be declared as dependencies on the same provided",
ModuleKey{}, OwnModuleKey{})
}
c.logf("Registering module-scoped provider: %s", provider.Location.String())
c.indentLogger()
defer c.dedentLogger()
node := &moduleDepProvider{
provider: provider,
calledForModule: map[*moduleKey]bool{},
valueMap: map[*moduleKey][]reflect.Value{},
}
for i, out := range provider.Outputs {
typ := out.Type
c.logf("Registering resolver for module-scoped type %v", typ)
existing, ok := c.resolverByType(typ)
if ok {
return nil, errors.Errorf("duplicate provision of type %v by module-scoped provider %s\n\talready provided by %s",
typ, provider.Location, existing.describeLocation())
}
typeGraphNode := c.typeGraphNode(typ)
c.addResolver(typ, &moduleDepResolver{
typ: typ,
idxInValues: i,
node: node,
valueMap: map[*moduleKey]reflect.Value{},
graphNode: typeGraphNode,
})
c.addGraphEdge(providerGraphNode, typeGraphNode)
}
return node, nil
}
}
func (c *container) supply(value reflect.Value, location Location) error {
typ := value.Type()
locGrapNode := c.locationGraphNode(location, nil)
markGraphNodeAsUsed(locGrapNode)
typeGraphNode := c.typeGraphNode(typ)
c.addGraphEdge(locGrapNode, typeGraphNode)
if existing, ok := c.resolverByType(typ); ok {
return duplicateDefinitionError(typ, location, existing.describeLocation())
}
c.addResolver(typ, &supplyResolver{
typ: typ,
value: value,
loc: location,
graphNode: typeGraphNode,
})
return nil
}
func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) error {
// make sure there are no outputs
if len(provider.Outputs) > 0 {
return fmt.Errorf("invoker function %s should not return any outputs", provider.Location)
}
c.invokers = append(c.invokers, invoker{
fn: provider,
modKey: key,
})
return nil
}
func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Location) (reflect.Value, error) {
c.resolveStack = append(c.resolveStack, resolveFrame{loc: caller, typ: in.Type})
typeGraphNode := c.typeGraphNode(in.Type)
if in.Type == moduleKeyType {
if moduleKey == nil {
return reflect.Value{}, errors.Errorf("trying to resolve %T for %s but not inside of any module's scope", moduleKey, caller)
}
c.logf("Providing ModuleKey %s", moduleKey.name)
markGraphNodeAsUsed(typeGraphNode)
return reflect.ValueOf(ModuleKey{moduleKey}), nil
}
if in.Type == ownModuleKeyType {
if moduleKey == nil {
return reflect.Value{}, errors.Errorf("trying to resolve %T for %s but not inside of any module's scope", moduleKey, caller)
}
c.logf("Providing OwnModuleKey %s", moduleKey.name)
markGraphNodeAsUsed(typeGraphNode)
return reflect.ValueOf(OwnModuleKey{moduleKey}), nil
}
vr, err := c.getResolver(in.Type, moduleKey)
if err != nil {
return reflect.Value{}, err
}
if vr == nil {
if in.Optional {
c.logf("Providing zero value for optional dependency %v", in.Type)
return reflect.Zero(in.Type), nil
}
markGraphNodeAsFailed(typeGraphNode)
return reflect.Value{}, errors.Errorf("can't resolve type %v for %s:\n%s",
fullyQualifiedTypeName(in.Type), caller, c.formatResolveStack())
}
res, err := vr.resolve(c, moduleKey, caller)
if err != nil {
markGraphNodeAsFailed(typeGraphNode)
return reflect.Value{}, err
}
markGraphNodeAsUsed(typeGraphNode)
c.resolveStack = c.resolveStack[:len(c.resolveStack)-1]
return res, nil
}
func (c *container) build(loc Location, outputs ...interface{}) error {
var providerIn []ProviderInput
for _, output := range outputs {
typ := reflect.TypeOf(output)
if typ.Kind() != reflect.Pointer {
return fmt.Errorf("output type must be a pointer, %s is invalid", typ)
}
providerIn = append(providerIn, ProviderInput{Type: typ.Elem()})
}
desc := ProviderDescriptor{
Inputs: providerIn,
Outputs: nil,
Fn: func(values []reflect.Value) ([]reflect.Value, error) {
if len(values) != len(outputs) {
return nil, fmt.Errorf("internal error, unexpected number of values")
}
for i, output := range outputs {
val := reflect.ValueOf(output)
val.Elem().Set(values[i])
}
return nil, nil
},
Location: loc,
}
callerGraphNode := c.locationGraphNode(loc, nil)
callerGraphNode.SetShape("hexagon")
desc, err := expandStructArgsProvider(desc)
if err != nil {
return err
}
c.logf("Registering outputs")
c.indentLogger()
node, err := c.addNode(&desc, nil)
if err != nil {
return err
}
c.dedentLogger()
sn, ok := node.(*simpleProvider)
if !ok {
return errors.Errorf("cannot run module-scoped provider as an invoker")
}
c.logf("Building container")
_, err = sn.resolveValues(c)
if err != nil {
return err
}
c.logf("Done building container")
c.logf("Calling invokers")
for _, inv := range c.invokers {
_, err := c.call(inv.fn, inv.modKey)
if err != nil {
return err
}
}
c.logf("Done calling invokers")
return nil
}
func (c container) createOrGetModuleKey(name string) *moduleKey {
if s, ok := c.moduleKeys[name]; ok {
return s
}
s := &moduleKey{name}
c.moduleKeys[name] = s
return s
}
func (c container) formatResolveStack() string {
buf := &bytes.Buffer{}
_, _ = fmt.Fprintf(buf, "\twhile resolving:\n")
n := len(c.resolveStack)
for i := n - 1; i >= 0; i-- {
rk := c.resolveStack[i]
_, _ = fmt.Fprintf(buf, "\t\t%v for %s\n", rk.typ, rk.loc)
}
return buf.String()
}
func fullyQualifiedTypeName(typ reflect.Type) string {
pkgType := typ
if typ.Kind() == reflect.Pointer || typ.Kind() == reflect.Slice || typ.Kind() == reflect.Map || typ.Kind() == reflect.Array {
pkgType = typ.Elem()
}
return fmt.Sprintf("%s/%v", pkgType.PkgPath(), typ)
}
func bindingKeyFromTypeName(typeName string, key *moduleKey) string {
if key == nil {
return fmt.Sprintf("%s;", typeName)
}
return fmt.Sprintf("%s;%s", typeName, key.name)
}
func bindingKeyFromType(typ reflect.Type, key *moduleKey) string {
return bindingKeyFromTypeName(fullyQualifiedTypeName(typ), key)
}
func (c *container) addBinding(p interfaceBinding) {
c.interfaceBindings[bindingKeyFromTypeName(p.interfaceName, p.moduleKey)] = p
}
func (c *container) addResolver(typ reflect.Type, r resolver) {
c.resolvers[fullyQualifiedTypeName(typ)] = r
}
func (c *container) resolverByType(typ reflect.Type) (resolver, bool) {
return c.resolverByTypeName(fullyQualifiedTypeName(typ))
}
func (c *container) resolverByTypeName(typeName string) (resolver, bool) {
res, found := c.resolvers[typeName]
return res, found
}
func markGraphNodeAsUsed(node *graphviz.Node) {
node.SetColor("black")
node.SetPenWidth("1.5")
node.SetFontColor("black")
}
func markGraphNodeAsFailed(node *graphviz.Node) {
node.SetColor("red")
node.SetFontColor("red")
}