diff --git a/Benchmarking/Benchmarks/Program.cs b/Benchmarking/Benchmarks/Program.cs index 3c32307..71e07b6 100644 --- a/Benchmarking/Benchmarks/Program.cs +++ b/Benchmarking/Benchmarks/Program.cs @@ -12,7 +12,7 @@ namespace Benchmarks; [JsonExporterAttribute.FullCompressed] public class ResolveBenchmarks { - private readonly IContainer m_container = new DependencyInjectionContainer(default, default, default!); + private readonly IContainer m_container = new DependencyInjectionContainer(default, default, new NonInjectedClass()); [Benchmark] public ChainA ResolveChain() => m_container.Resolve(); @@ -31,6 +31,9 @@ public class ResolveBenchmarks [Benchmark] public IContainer Create() => new DependencyInjectionContainer(default, default, default!); + + [Benchmark] + public IContainer CreateFromSelf() => new DependencyInjectionContainer(m_container); } internal static class Program diff --git a/FactoryGenerator.Attributes/IContainer.cs b/FactoryGenerator.Attributes/IContainer.cs index abf990b..6cc3b9b 100644 --- a/FactoryGenerator.Attributes/IContainer.cs +++ b/FactoryGenerator.Attributes/IContainer.cs @@ -19,4 +19,6 @@ public interface ILifetimeScope : IDisposable public interface IContainer : ILifetimeScope { + IContainer? Base { get; } + IContainer? Inheritor { get; set; } } \ No newline at end of file diff --git a/FactoryGenerator/FactoryGenerator.cs b/FactoryGenerator/FactoryGenerator.cs index 9117c6c..7919e6c 100644 --- a/FactoryGenerator/FactoryGenerator.cs +++ b/FactoryGenerator/FactoryGenerator.cs @@ -154,22 +154,43 @@ namespace {compilation.Assembly.Name}.Generated; yield return $@"{usingStatements} [GeneratedCode(""{ToolName}"", ""{Version}"")] -#nullable disable +#nullable enable public sealed partial class {ClassName} : IContainer {{ + + bool Reentrant = false; + private IContainer GetRoot() + {{ + IContainer root = this; + while(root.Base != null) + {{ + root = root.Base; + }} + return root; + }} + private IContainer GetTop() + {{ + IContainer top = this; + while(top.Inheritor != null) + {{ + top = top.Inheritor; + }} + return top; + }} + public IContainer? Base {{ get; }} + public IContainer? Inheritor {{ get; set; }} private object m_lock = new(); private Dictionary> m_lookup; private readonly List> resolvedInstances = new(); public T Resolve() {{ - return (T)Resolve(typeof(T)); + return TryResolve(out var resolved) ? resolved! : throw new KeyNotFoundException($""The type {{typeof(T)}} has not been registered, and thus cannot be resolved""); }} public object Resolve(Type type) {{ - var instance = m_lookup[type](); - return instance; + return TryResolve(type, out var resolved) ? resolved! : throw new KeyNotFoundException($""The type {{type}} has not been registered, and thus cannot be resolved""); }} public void Dispose() @@ -182,22 +203,28 @@ public void Dispose() }} }} resolvedInstances.Clear(); + Base?.Dispose(); }} - public bool TryResolve(Type type, out object resolved) + public bool TryResolve(Type type, out object? resolved) {{ + resolved = default; if(m_lookup.TryGetValue(type, out var factory)) {{ resolved = factory(); return true; }} - resolved = default; + if(Base is not null) + {{ + return Base.TryResolve(type, out resolved); + }} return false; }} - public bool TryResolve(out T resolved) + public bool TryResolve(out T? resolved) {{ + resolved = default; if(m_lookup.TryGetValue(typeof(T), out var factory)) {{ var value = factory(); @@ -207,12 +234,15 @@ public bool TryResolve(out T resolved) return true; }} }} - resolved = default; + if(Base is not null) + {{ + return Base.TryResolve(out resolved); + }} return false; }} public bool IsRegistered(Type type) {{ - return m_lookup.ContainsKey(type); + return m_lookup.ContainsKey(type) ? true : Base?.IsRegistered(type) == true; }} public bool IsRegistered() => IsRegistered(typeof(T)); }}"; @@ -220,6 +250,7 @@ public bool IsRegistered(Type type) var booleans = dataInjections.Select(inj => inj.BooleanInjection).Where(b => b is not null) .Select(b => b!.Key).Distinct().ToArray(); var allArguments = booleans.Select(b => $"bool {b}").ToList(); + var justBooleans = allArguments.ToList(); var allParameters = booleans.Select(b => $"{b}").ToList(); var ordered = dataInjections.Reverse().ToList(); //Put all test-overrides at the end @@ -384,20 +415,45 @@ public bool IsRegistered(Type type) var constructorFields = string.Join("\n\t", allArguments.Select(arg => arg + ";")); var constructorAssignments = string.Join("\n\t\t", allArguments.Select(arg => arg.Split(' ').Last()).Select(arg => $"this.{arg} = {arg};")); + var resolvedConstructorAssignments = string.Join("\n\t\t", allArguments.Select(a => a.Split(' ')).Where(a => a[0] != "bool").Select(a => $"this.{a[1]} = Base.Resolve<{a[0]}>();")); var dictSize = interfaceInjectors.Count + localizedParameters.Count + requested.Count + + constructorParameters.Count; yield return Constructor(usingStatements, constructorFields, constructor, constructorAssignments, dictSize, interfaceInjectors.Keys, localizedParameters, requested, - constructorParameters, true, ClassName, lifetimeParameters); + constructorParameters, true, ClassName, lifetimeParameters, + resolvingConstructorAssignments: resolvedConstructorAssignments); yield return Declarations(usingStatements, declarations, ClassName); yield return ArrayDeclarations(usingStatements, arrayDeclarations, ClassName); yield return $@"{usingStatements} [GeneratedCode(""{ToolName}"", ""{Version}"")] -#nullable disable +#nullable enable public sealed partial class LifetimeScope : IContainer {{ + private bool Reentrant = false; + private IContainer GetRoot() + {{ + IContainer root = this; + while(root.Base != null) + {{ + root = root.Base; + }} + return root; + }} + private IContainer GetTop() + {{ + IContainer top = this; + while(top.Inheritor != null) + {{ + top = top.Inheritor; + }} + return top; + }} + + public IContainer? Base {{ get; }} + public IContainer? Inheritor {{ get; set; }} public ILifetimeScope BeginLifetimeScope() {{ var scope = m_fallback.BeginLifetimeScope(); @@ -409,15 +465,14 @@ public ILifetimeScope BeginLifetimeScope() private Dictionary> m_lookup; private readonly List> resolvedInstances = new(); - public T Resolve() + public T Resolve() {{ - return (T)Resolve(typeof(T)); + return TryResolve(out var resolved) ? resolved! : throw new KeyNotFoundException($""The type {{typeof(T)}} has not been registered, and thus cannot be resolved""); }} public object Resolve(Type type) {{ - var instance = m_lookup[type](); - return instance; + return TryResolve(type, out var resolved) ? resolved! : throw new KeyNotFoundException($""The type {{type}} has not been registered, and thus cannot be resolved""); }} public void Dispose() @@ -430,22 +485,28 @@ public void Dispose() }} }} resolvedInstances.Clear(); + Base?.Dispose(); }} - public bool TryResolve(Type type, out object resolved) + public bool TryResolve(Type type, out object? resolved) {{ + resolved = default; if(m_lookup.TryGetValue(type, out var factory)) {{ resolved = factory(); return true; }} - resolved = default; + else if(Base is not null) + {{ + return Base.TryResolve(type, out resolved); + }} return false; }} - public bool TryResolve(out T resolved) + public bool TryResolve(out T? resolved) {{ + resolved = default; if(m_lookup.TryGetValue(typeof(T), out var factory)) {{ var value = factory(); @@ -455,12 +516,15 @@ public bool TryResolve(out T resolved) return true; }} }} - resolved = default; + else if(Base is not null) + {{ + return Base.TryResolve(out resolved); + }} return false; }} public bool IsRegistered(Type type) {{ - return m_lookup.ContainsKey(type); + return m_lookup.ContainsKey(type) ? true : Base?.IsRegistered(type) == true; }} public bool IsRegistered() => IsRegistered(typeof(T)); }} @@ -469,7 +533,8 @@ public bool IsRegistered(Type type) lifetimeConstructor, constructorAssignments, dictSize, interfaceInjectors.Keys, localizedParameters, requested, - constructorParameters, false, LifetimeName); + constructorParameters, false, LifetimeName, + resolvingConstructorAssignments: resolvedConstructorAssignments, addMergingConstructor: false); yield return Declarations(usingStatements, scopedDeclarations, LifetimeName); yield return ArrayDeclarations(usingStatements, arrayDeclarations, LifetimeName); } @@ -513,7 +578,7 @@ private static void CheckForCycles(ImmutableArray dataInjections) } } - if (parameter.Type is not IArrayTypeSymbol {ElementType: INamedTypeSymbol arrType}) continue; + if (parameter.Type is not IArrayTypeSymbol { ElementType: INamedTypeSymbol arrType }) continue; { node.Add(arrType); if (!tree.TryGetValue(arrType, out var list)) continue; @@ -529,7 +594,8 @@ private static void CheckForCycles(ImmutableArray dataInjections) private static string Constructor(string usingStatements, string constructorFields, string constructor, string constructorAssignments, int dictSize, IEnumerable interfaceInjectors, List localizedParameters, List requested, - List constructorParameters, bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null) + List constructorParameters, bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null, + string? fromConstructor = null, string? resolvingConstructorAssignments = null, bool addMergingConstructor = true) { var lifetimeScopeFunction = addLifetimeScopeFunction ? $@" @@ -538,8 +604,24 @@ public ILifetimeScope BeginLifetimeScope() var scope = new {LifetimeName}({lifetimeParameters}); resolvedInstances.Add(new WeakReference(scope)); return scope; -}}" - : string.Empty; +}}" : string.Empty; + + var mergingConstructor = addMergingConstructor ? $@" +public {className}(IContainer Base{fromConstructor}) +{{ + this.Base = Base; + Base.Inheritor = this; + {resolvingConstructorAssignments} + + m_lookup = new({dictSize}) {{ +{MakeDictionary(interfaceInjectors)} +{MakeDictionary(localizedParameters)} +{MakeDictionary(requested)} +{MakeDictionary(constructorParameters)} + }}; +}}" : string.Empty; + + var extraConstruction = addLifetimeScopeFunction ? string.Empty : "m_fallback = fallback;"; return $@"{usingStatements} public partial class {className} @@ -549,15 +631,15 @@ public partial class {className} {{ {extraConstruction} {constructorAssignments} - m_lookup = new({dictSize}) - {{ + + m_lookup = new({dictSize}) {{ {MakeDictionary(interfaceInjectors)} {MakeDictionary(localizedParameters)} {MakeDictionary(requested)} {MakeDictionary(constructorParameters)} }}; }} - + {mergingConstructor} {lifetimeScopeFunction} }}"; @@ -586,57 +668,53 @@ private static void MakeArray(Dictionary declarations, string na { var factoryName = $"new {type}[0]"; var factory = string.Empty; + var functionString = function ? "()" : string.Empty; + var starter = function ? string.Empty : "get {"; + var ender = function ? string.Empty : "}"; if (interfaceInjectors.TryGetValue(type, out var injections)) { factoryName = $"Create{name}()".Replace("_", ""); factory = @$" IEnumerable<{type}> {factoryName} {{ + if(Reentrant) return Array.Empty<{type}>(); + Reentrant = true; List<{type}> source = new List<{type}> {{ {string.Join(",\n\t\t\t", injections.Where(i => i.BooleanInjection == null).Select(i => i.Name))} }}; {string.Join("\n\t\t\t", injections.Where(b => b.BooleanInjection != null).Select(i => $"if({i.BooleanInjection!.Key}) source.Add({i.Name});"))} + var b = Base; + while(b is not null) + {{ + if(b.TryResolve>(out var additional)) source.AddRange(additional!); + b = b.Base; + }} + b = Inheritor; + while(b is not null) + {{ + if(b.TryResolve>(out var additional)) source.AddRange(additional!); + b = b.Inheritor; + }} + Reentrant = false; return source; }}"; } - - if (function) - { - declarations[name] = $@" - internal IEnumerable<{type}> {name}() + declarations[name] = $@" + internal IEnumerable<{type}> {name}{functionString} {{ + {starter} if (m_{name} != null) return m_{name}; - + lock (m_lock) {{ if (m_{name} != null) return m_{name}; return m_{name} = {factoryName}; }} + {ender} }} internal IEnumerable<{type}>? m_{name};" + factory; - } - else - { - declarations[name] = $@" - internal IEnumerable<{type}> {name} - {{ - get - {{ - if (m_{name} != null) - return m_{name}; - - lock (m_lock) - {{ - if (m_{name} != null) - return m_{name}; - return m_{name} = {factoryName}; - }} - }} - }} - internal IEnumerable<{type}>? m_{name};" + factory; - } } private static string MakeDictionary(IEnumerable types) diff --git a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs index 3116e7b..355b861 100644 --- a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs +++ b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs @@ -219,4 +219,79 @@ public void ClassesInsideOtherClassesCanBeInjected() { m_container.Resolve(); } + [Fact] + public void ContainerMayCreateItself() + { + var newContainer = new DependencyInjectionContainer(m_container); + var resolved = m_container.Resolve>(); + resolved.Count().ShouldBe(6); + var nonInjected = m_container.Resolve(); + } + [Fact] + public void HierarchicalContainersResolveArraysProperly() + { + var newContainer = new DependencyInjectionContainer(m_container); + newContainer.Resolve().Arrays.Count().ShouldBe(6); + } + [Fact] + public void HierarchicalContainersResolveUsesFallBackIfItCannotFindImplementation() + { + var newContainer = new DependencyInjectionContainer(new DummyContainer()); + newContainer.Resolve().ShouldBe(DummyContainer.DummyText); + } + private class DummyContainer : IContainer + { + public const string DummyText = "I am a bit of text"; + + public static NonInjectedClass m_dummy = new(); + public IContainer? Base => null; + + public IContainer? Inheritor { get; set; } = null; + + public ILifetimeScope BeginLifetimeScope() + { + return this; + } + + public void Dispose() + { + } + + public bool IsRegistered(System.Type type) + { + return true; + } + + public bool IsRegistered() + { + return true; + } + + public T Resolve() + { + if (typeof(T) == typeof(string)) return (T) (object) DummyText; + return (T) (object) m_dummy; + } + + public object Resolve(System.Type type) + { + if (type == typeof(string)) return DummyText; + return m_dummy; + } + + public bool TryResolve(System.Type type, out object? resolved) + { + resolved = null; + if (type == typeof(string)) resolved = DummyText; + return resolved != null; + } + + public bool TryResolve(out T? resolved) + { + resolved = default; + if (typeof(T) == typeof(string)) resolved = (T) (object) DummyText; + return resolved != null; + } + + } } \ No newline at end of file diff --git a/Tests/TestData/Inheritor/Inheritor.csproj b/Tests/TestData/Inheritor/Inheritor.csproj index d650a88..8dd149a 100644 --- a/Tests/TestData/Inheritor/Inheritor.csproj +++ b/Tests/TestData/Inheritor/Inheritor.csproj @@ -1,16 +1,18 @@  - - - - + + + + net8.0 + true enable enable - + \ No newline at end of file