package container import ( "reflect" "github.com/pkg/errors" ) // In can be embedded in another struct to inform the container that the // fields of the struct should be treated as dependency inputs. // This allows a struct to be used to specify dependencies rather than // positional parameters. // // Fields of the struct may support the following tags: // optional if set to true, the dependency is optional and will // be set to its default value if not found, rather than causing // an error type In struct{} func (In) isIn() {} type isIn interface{ isIn() } var isInType = reflect.TypeOf((*isIn)(nil)).Elem() // Out can be embedded in another struct to inform the container that the // fields of the struct should be treated as dependency outputs. // This allows a struct to be used to specify outputs rather than // positional return values. type Out struct{} func (Out) isOut() {} type isOut interface{ isOut() } var isOutType = reflect.TypeOf((*isOut)(nil)).Elem() func expandStructArgsConstructor(constructor ProviderDescriptor) (ProviderDescriptor, error) { var foundStructArgs bool var newIn []ProviderInput for _, in := range constructor.Inputs { if in.Type.AssignableTo(isInType) { foundStructArgs = true inTypes, err := structArgsInTypes(in.Type) if err != nil { return ProviderDescriptor{}, err } newIn = append(newIn, inTypes...) } else { newIn = append(newIn, in) } } var newOut []ProviderOutput for _, out := range constructor.Outputs { if out.Type.AssignableTo(isOutType) { foundStructArgs = true newOut = append(newOut, structArgsOutTypes(out.Type)...) } else { newOut = append(newOut, out) } } if foundStructArgs { return ProviderDescriptor{ Inputs: newIn, Outputs: newOut, Fn: expandStructArgsFn(constructor), Location: constructor.Location, }, nil } return constructor, nil } func expandStructArgsFn(constructor ProviderDescriptor) func(inputs []reflect.Value) ([]reflect.Value, error) { fn := constructor.Fn inParams := constructor.Inputs outParams := constructor.Outputs return func(inputs []reflect.Value) ([]reflect.Value, error) { j := 0 inputs1 := make([]reflect.Value, len(inParams)) for i, in := range inParams { if in.Type.AssignableTo(isInType) { v, n := buildIn(in.Type, inputs[j:]) inputs1[i] = v j += n } else { inputs1[i] = inputs[j] j++ } } outputs, err := fn(inputs1) if err != nil { return nil, err } var outputs1 []reflect.Value for i, out := range outParams { if out.Type.AssignableTo(isOutType) { outputs1 = append(outputs1, extractFromOut(out.Type, outputs[i])...) } else { outputs1 = append(outputs1, outputs[i]) } } return outputs1, nil } } func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { n := typ.NumField() var res []ProviderInput for i := 0; i < n; i++ { f := typ.Field(i) if f.Type.AssignableTo(isInType) { continue } var optional bool optTag, found := f.Tag.Lookup("optional") if found { if optTag == "true" { optional = true } else { return nil, errors.Errorf("bad optional tag %q (should be \"true\") in %v", optTag, typ) } } res = append(res, ProviderInput{ Type: f.Type, Optional: optional, }) } return res, nil } func structArgsOutTypes(typ reflect.Type) []ProviderOutput { n := typ.NumField() var res []ProviderOutput for i := 0; i < n; i++ { f := typ.Field(i) if f.Type.AssignableTo(isOutType) { continue } res = append(res, ProviderOutput{ Type: f.Type, }) } return res } func buildIn(typ reflect.Type, values []reflect.Value) (reflect.Value, int) { numFields := typ.NumField() j := 0 res := reflect.New(typ) for i := 0; i < numFields; i++ { f := typ.Field(i) if f.Type.AssignableTo(isInType) { continue } res.Elem().Field(i).Set(values[j]) j++ } return res.Elem(), j } func extractFromOut(typ reflect.Type, value reflect.Value) []reflect.Value { numFields := typ.NumField() var res []reflect.Value for i := 0; i < numFields; i++ { f := typ.Field(i) if f.Type.AssignableTo(isOutType) { continue } res = append(res, value.Field(i)) } return res }