diff --git a/factory.go b/factory.go new file mode 100644 index 0000000..758d5db --- /dev/null +++ b/factory.go @@ -0,0 +1,69 @@ +package injector + +import ( + "fmt" + "reflect" +) + +type FactoryFunction func() (interface{}, error) + +// TransientFactory is a function wrapper for a transient dependency +// to provide dependency injection inside the factory function +func TransientFactory(depFactory interface{}) (FactoryFunction, error) { + depFactoryType := reflect.TypeOf(depFactory) + depFactoryValue := reflect.ValueOf(depFactory) + + if depFactoryType.Kind() != reflect.Func { + return nil, ErrorDepFactoryNotAFunc + } + + if depFactoryType.NumOut() != 1 { + return nil, ErrorDepFactoryReturnCount + } + + // Dependency-Inject arguments for factory function + args, err := injectFuncArgs(depFactory) + if err != nil { + return nil, err + } + + // Create factory wrapper to inject dependencies + return func() (dep interface{}, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("dependency injection failed because factory paniced, recovered value: %v", r) + dep = nil + } + }() + returnVals := depFactoryValue.Call(args) + dep = returnVals[0].Interface() + return + }, nil +} + +// SingletonFactory is a function wrapper for a singleton dependency factory +// to provide dependency injection inside the factory function and to retain +// the singleton instance once instantiated. +func SingletonFactory(depFactory interface{}) (FactoryFunction, error) { + factory, err := TransientFactory(depFactory) + if err != nil { + return nil, err + } + + // Wrap factory wrapper to ensure existing singleton value is used if + // it already exists. + var singleton interface{} + return func() (interface{}, error) { + if singleton != nil { + return singleton, nil + } + + // Singleton is not ready, call transient factory to instantiate dependency + dep, err := factory() + if err != nil { + return nil, err + } + singleton = dep + return singleton, nil + }, nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..52af2df --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/infinytum/injector + +go 1.18 + +require github.com/infinytum/structures v0.0.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6f7ff5d --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/infinytum/structures v0.0.1 h1:DgwnAkvodCn2Zn07/PGM+vZZIlhCZJ+602KudwbQSt8= +github.com/infinytum/structures v0.0.1/go.mod h1:4vDl7BamOX2fF0B+h2n/xGIorGv1u9ZQ9wgF6IW/Kc8= diff --git a/injector_read.go b/injector_read.go new file mode 100644 index 0000000..4a4edee --- /dev/null +++ b/injector_read.go @@ -0,0 +1,115 @@ +package injector + +import ( + "fmt" + "reflect" +) + +// Inject tries to resolve a dependency by its type and optionally its name +// If the dependency is unknown, ErrorDepFactoryNotFound is returned +func Inject[T any](name ...string) (T, error) { + argType := reflect.TypeOf((*T)(nil)).Elem() + factory := depMap.GetOrDefault(reflectTypeKey(argType), nameOrDefault(name), nil) + if factory == nil { + return reflect.Zero(argType).Interface().(T), ErrorDepFactoryNotFound + } + + dep, err := factory() + if err != nil { + return reflect.Zero(argType).Interface().(T), err + } + + castDep, ok := dep.(T) + if !ok { + return reflect.Zero(argType).Interface().(T), ErrorDependencyTypeMismatch + } + return castDep, nil +} + +// InjectT tries to resolve a dependency by its type and optionally its name +// If the dependency is unknown, ErrorDepFactoryNotFound is returned +func InjectT(depType reflect.Type, name ...string) (interface{}, error) { + factory := depMap.GetOrDefault(reflectTypeKey(depType), nameOrDefault(name), nil) + if factory == nil { + return nil, ErrorDepFactoryNotFound + } + return factory() +} + +// InjectT tries to resolve a dependency by its type and optionally its name +// If the dependency is unknown, ErrorDepFactoryNotFound is returned +func InjectInto(out interface{}, name ...string) error { + argType := reflect.TypeOf(out) + if argType.Kind() != reflect.Pointer { + return ErrorDepNotAPointer + } + factory := depMap.GetOrDefault(reflectTypeKey(argType.Elem()), nameOrDefault(name), nil) + if factory == nil { + return ErrorDepFactoryNotFound + } + dep, err := factory() + if err != nil { + return err + } + reflect.ValueOf(out).Elem().Set(reflect.ValueOf(dep)) + return nil +} + +// MustInject tries to resolve a dependency by its type and optionally its name or panics +func MustInject[T any](name ...string) T { + dep, err := Inject[T](name...) + if err != nil { + panic(err) + } + return dep +} + +// MustInjectT tries to resolve a dependency by its type and optionally its name or panics +func MustInjectT(depType reflect.Type, name ...string) interface{} { + dep, err := InjectT(depType, name...) + if err != nil { + panic(err) + } + return dep +} + +// MustInjectInto tries to resolve a dependency by its type and optionally its name or panics +func MustInjectInto(out interface{}, name ...string) { + if err := InjectInto(out, name...); err != nil { + panic(err) + } +} + +// Call will attempt to resolve all arguments of the function and then call it +func Call(fn interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("dependency injection failed because factory paniced, recovered value: %v", r) + } + }() + args, err := injectFuncArgs(fn) + if err != nil { + return + } + reflect.ValueOf(fn).Call(args) + return +} + +// MustCall will attempt to resolve all arguments of the function and then call it or panic +func MustCall(fn interface{}) { + if err := Call(fn); err != nil { + panic(err) + } +} + +// Fill will attempt to resolve all tagged fields of a struct with their matching dependency +func Fill(strct interface{}) error { + return injectStructFields(strct) +} + +// MustFill will attempt to resolve all tagged fields of a struct with their matching dependency or panic +func MustFill(strct interface{}) { + if err := Fill(strct); err != nil { + panic(err) + } +} diff --git a/injector_write.go b/injector_write.go new file mode 100644 index 0000000..cc4399d --- /dev/null +++ b/injector_write.go @@ -0,0 +1,32 @@ +package injector + +import ( + "reflect" + + "github.com/infinytum/structures" +) + +var ( + depMap structures.Table[string, string, FactoryFunction] = structures.NewTable[string, string, FactoryFunction]() +) + +// Singleton will register a dependency that is only instantiated once, then re-used +// for all future resolve calls +func Singleton(resolver interface{}, name ...string) error { + factory, err := SingletonFactory(resolver) + if err != nil { + return err + } + depFactoryType := reflect.TypeOf(resolver) + return depMap.Set(reflectTypeKey(depFactoryType.Out(0)), nameOrDefault(name), factory) +} + +// Transient will register a dependency that is instantiated every time it's resolved. +func Transient(resolver interface{}, name ...string) error { + factory, err := TransientFactory(resolver) + if err != nil { + return err + } + depFactoryType := reflect.TypeOf(resolver) + return depMap.Set(reflectTypeKey(depFactoryType.Out(0)), nameOrDefault(name), factory) +} diff --git a/introspect.go b/introspect.go new file mode 100644 index 0000000..d98369f --- /dev/null +++ b/introspect.go @@ -0,0 +1,74 @@ +package injector + +import ( + "reflect" + "unsafe" +) + +func injectFuncArgs(fn interface{}) ([]reflect.Value, error) { + depFactoryType := reflect.TypeOf(fn) + + if depFactoryType.Kind() != reflect.Func { + return nil, ErrorDepFactoryNotAFunc + } + + // Retrieve dependencies for all factory arguments + args := make([]reflect.Value, depFactoryType.NumIn()) + for i := 0; i < depFactoryType.NumIn(); i++ { + inType := depFactoryType.In(i) + resolvedDep, err := InjectT(inType) + if err != nil { + // If the field was not resolvable, attempt to fill it as a struct + possibleStruct := reflect.New(inType) + fillStruct := possibleStruct + + // If the field is a pointer, the fill struct must be initialized and de-refereced, else we have a double-pointer **Type + if inType.Kind() == reflect.Pointer { + possibleStruct.Elem().Set(reflect.New(inType.Elem())) + fillStruct = possibleStruct.Elem() + } + + // Attempt to fill in the struct with dependencies + if err := injectStructFields(fillStruct.Interface()); err != nil { + return nil, err + } + args[i] = reflect.ValueOf(possibleStruct.Elem().Interface()) + } else { + args[i] = reflect.ValueOf(resolvedDep) + } + } + return args, nil +} + +func injectStructFields(strct interface{}) error { + depFactoryType := reflect.TypeOf(strct) + + if depFactoryType.Kind() != reflect.Ptr { + return ErrorDepNotAPointer + } + + if depFactoryType.Elem().Kind() != reflect.Struct { + return ErrorDepNotAStruct + } + + // Set struct values to injected dependencies + depFactoryValue := reflect.ValueOf(strct).Elem() + for i := 0; i < depFactoryType.Elem().NumField(); i++ { + field := depFactoryType.Elem().Field(i) + fieldVal := depFactoryValue.Field(i) + if lookupType, exists := field.Tag.Lookup("injector"); exists { + name := DefaultForType + switch lookupType { + case "name": + name = field.Name + } + resolvedDep, err := InjectT(field.Type, name) + if err != nil { + return err + } + ptr := reflect.NewAt(fieldVal.Type(), unsafe.Pointer(fieldVal.UnsafeAddr())).Elem() + ptr.Set(reflect.ValueOf(resolvedDep)) + } + } + return nil +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..181af7a --- /dev/null +++ b/util.go @@ -0,0 +1,45 @@ +package injector + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var ( + DefaultForType = "" + + ErrorDependencyTypeMismatch = errors.New("the resolved dependency does not match the generic type") + + ErrorDepFactoryNotAFunc = errors.New("the provided dependency factory is not a function") + ErrorDepFactoryNotFound = errors.New("the requested type/name combination is not a registered dependency") + ErrorDepFactoryReturnCount = errors.New("the provided dependency factory must return exactly 1 value") + + ErrorDepNotAPointer = errors.New("the provided value must be a pointer to the struct you want to inject into") + ErrorDepNotAStruct = errors.New("the provided value must be a struct") + + ErrorInvalidTag = errors.New("the provided injector tag is not valid") +) + +// nameOrDefault will return the default name if the name array is nil or empty +func nameOrDefault(name []string) string { + depName := DefaultForType + if name != nil && len(name) > 0 { + depName = strings.Join(name, "") + } + return depName +} + +func reflectTypeKey(t reflect.Type) string { + nameType := t + for nameType != nil && nameType.Kind() == reflect.Pointer { + nameType = nameType.Elem() + } + + pkg, name := "UNKNOWN_PACKAGE", t.String() + if nameType != nil { + pkg = nameType.PkgPath() + } + return fmt.Sprintf("%s/%s", pkg, name) +}