From afb49e03c2df9362543df693b82f697e09858955 Mon Sep 17 00:00:00 2001 From: caran Date: Mon, 29 Apr 2024 13:59:32 +0200 Subject: [PATCH] Lifetime scope is working. --- .../Attributes/ScopedAttribute.cs | 6 + FactoryGenerator.Attributes/IContainer.cs | 7 +- FactoryGenerator/FactoryGenerator.cs | 171 ++++++++++++++---- FactoryGenerator/Injection.cs | 30 ++- FactoryGenerator/SymbolUtility.cs | 34 ++-- .../InjectionDetectionTests.cs | 36 ++++ Tests/TestData/Inherited/Types.cs | 16 ++ 7 files changed, 236 insertions(+), 64 deletions(-) create mode 100644 FactoryGenerator.Attributes/Attributes/ScopedAttribute.cs diff --git a/FactoryGenerator.Attributes/Attributes/ScopedAttribute.cs b/FactoryGenerator.Attributes/Attributes/ScopedAttribute.cs new file mode 100644 index 0000000..35bb910 --- /dev/null +++ b/FactoryGenerator.Attributes/Attributes/ScopedAttribute.cs @@ -0,0 +1,6 @@ +using System; + +namespace FactoryGenerator.Attributes; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Interface)] +public class ScopedAttribute : Attribute; \ No newline at end of file diff --git a/FactoryGenerator.Attributes/IContainer.cs b/FactoryGenerator.Attributes/IContainer.cs index 185e218..326ecf6 100644 --- a/FactoryGenerator.Attributes/IContainer.cs +++ b/FactoryGenerator.Attributes/IContainer.cs @@ -4,7 +4,7 @@ namespace FactoryGenerator; #nullable enable -public interface IContainer : IDisposable +public interface ILifetimeScope : IDisposable { T Resolve(); object Resolve(Type type); @@ -16,4 +16,9 @@ public interface IContainer : IDisposable bool IsRegistered(Type type) => TryResolve(type, out _); bool IsRegistered() => IsRegistered(typeof(T)); +} + +public interface IContainer : ILifetimeScope +{ + public ILifetimeScope BeginLifetimeScope(); } \ No newline at end of file diff --git a/FactoryGenerator/FactoryGenerator.cs b/FactoryGenerator/FactoryGenerator.cs index 4253fc9..4b818b9 100644 --- a/FactoryGenerator/FactoryGenerator.cs +++ b/FactoryGenerator/FactoryGenerator.cs @@ -66,6 +66,10 @@ private void MakeAutofacModule(SourceProductionContext context, context.AddSource("DependencyInjectionContainer.Constructor.g.cs", source[1]); context.AddSource("DependencyInjectionContainer.Declarations.g.cs", source[2]); context.AddSource("DependencyInjectionContainer.EnumerableDeclarations.g.cs", source[3]); + context.AddSource("LifetimeScope.Lookup.g.cs", source[4]); + context.AddSource("LifetimeScope.Constructor.g.cs", source[5]); + context.AddSource("LifetimeScope.Declarations.g.cs", source[6]); + context.AddSource("LifetimeScope.EnumerableDeclarations.g.cs", source[7]); } private INamedTypeSymbol? ResolveTransformations(GeneratorSyntaxContext context, CancellationToken token) @@ -126,6 +130,9 @@ private static INamespaceSymbol GetGlobalNamespace(Compilation compilation, Canc return compilation.GlobalNamespace; } + private const string ClassName = "DependencyInjectionContainer"; + private const string LifetimeName = "LifetimeScope"; + private static IEnumerable GenerateCode(ImmutableArray dataInjections, Compilation compilation, ImmutableArray usages, ILogger log) { @@ -142,8 +149,10 @@ namespace {compilation.Assembly.Name}.Generated; yield return $@"{usingStatements} [GeneratedCode(""{ToolName}"", ""{Version}"")] #nullable disable -public partial class DependencyInjectionContainer : IContainer +public sealed partial class {ClassName} : IContainer {{ + private object m_lock = new(); + private Dictionary> m_lookup; private readonly List> resolvedInstances = new(); public T Resolve() @@ -195,13 +204,12 @@ public bool TryResolve(out T resolved) resolved = default; return false; }} - private Dictionary> m_lookup; - private object m_lock = new(); }}"; var booleans = dataInjections.Select(inj => inj.BooleanInjection).Where(b => b is not null) - .Select(b => b!.Key).Distinct(); + .Select(b => b!.Key).Distinct().ToArray(); var allArguments = booleans.Select(b => $"bool {b}").ToList(); + var allParameters = booleans.Select(b => $"{b}").ToList(); var ordered = dataInjections.Reverse().ToList(); //Put all test-overrides at the end foreach (var injection in ordered.ToArray()) @@ -228,14 +236,15 @@ public bool TryResolve(out T resolved) } var declarations = new Dictionary(); + var scopedDeclarations = new Dictionary(); var availableInterfaces = interfaceInjectors.Keys.ToImmutableArray(); var constructorParameters = new List(); - var disposables = new List(); foreach (var injection in ordered) { - if (injection.Disposable && injection.Singleton) disposables.Add(injection); - declarations[injection.Name] = injection.Declaration(availableInterfaces); + declarations[injection.Name] = injection.Declaration(availableInterfaces, false); + scopedDeclarations[injection.Name] = injection.Declaration(availableInterfaces, true); + HashSet? missing = null; injection.GetBestConstructor(availableInterfaces, ref missing); if (missing is null) continue; @@ -256,7 +265,9 @@ public bool TryResolve(out T resolved) { log.Log(LogLevel.Information, $"Selecting {chosen.Name} for {interfaceSymbol}"); declarations[SymbolUtility.MemberName(interfaceSymbol)] = - $"private {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};"; + $"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};"; + scopedDeclarations[SymbolUtility.MemberName(interfaceSymbol)] = + $"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};"; } } } @@ -272,21 +283,16 @@ public bool TryResolve(out T resolved) var trueValue = possibilities.LastOrDefault(p => p.BooleanInjection?.Value == true && p.BooleanInjection?.Key == key); trueValue ??= fallback; - if (key == last) - { - ternary.Append($"{key} ? {trueValue?.Name ?? "null!"} : {fallback?.Name ?? "null!"}"); - } - else - { - ternary.Append($"{key} ? {trueValue?.Name ?? "null!"} : "); - } + ternary.Append(key == last ? $"{key} ? {trueValue?.Name ?? "null!"} : {fallback?.Name ?? "null!"}" : $"{key} ? {trueValue?.Name ?? "null!"} : "); } if (!declarations.ContainsKey(SymbolUtility.MemberName(interfaceSymbol))) { log.Log(LogLevel.Information, $"Selecting {ternary} for {interfaceSymbol}"); declarations[SymbolUtility.MemberName(interfaceSymbol)] = - $"private {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};"; + $"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};"; + scopedDeclarations[SymbolUtility.MemberName(interfaceSymbol)] = + $"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};"; } } } @@ -343,15 +349,22 @@ public bool TryResolve(out T resolved) { log.Log(LogLevel.Debug, $"Registering {parameter.Name} as Self"); declarations[parameter.Name] = $"private IContainer {parameter.Name} => this;"; + scopedDeclarations[parameter.Name] = $"private IContainer {parameter.Name} => this;"; constructorParameters.Remove(parameter); } } var arguments = constructorParameters.OrderBy(p => p.Type.ToString()).Select(p => $"{p.Type} {p.Name}") .Distinct(); + + var parameters = constructorParameters.OrderBy(p => p.Type.ToString()).Select(p => $"{p.Name}") + .Distinct(); allArguments.AddRange(arguments); + allParameters.AddRange(parameters); var constructor = "(" + string.Join(", ", allArguments) + ")"; + var lifetimeConstructor = "(" + $"{ClassName} fallback, " + string.Join(", ", allArguments) + ")"; + var lifetimeParameters = "this, " + string.Join(", ", allParameters); log.Log(LogLevel.Debug, $"Resulting Constructor: {constructor}"); var constructorFields = string.Join("\n\t", allArguments.Select(arg => arg + ";")); @@ -361,11 +374,83 @@ public bool TryResolve(out T resolved) constructorParameters.Count; yield return Constructor(usingStatements, constructorFields, constructor, constructorAssignments, - dictSize, interfaceInjectors, + dictSize, interfaceInjectors.Keys, localizedParameters, requested, - constructorParameters); - yield return Declarations(usingStatements, declarations, disposables); - yield return ArrayDeclarations(usingStatements, arrayDeclarations); + constructorParameters, true, ClassName, lifetimeParameters); + yield return Declarations(usingStatements, declarations, ClassName); + yield return ArrayDeclarations(usingStatements, arrayDeclarations, ClassName); + yield return $@"{usingStatements} +[GeneratedCode(""{ToolName}"", ""{Version}"")] +#nullable disable +public sealed partial class LifetimeScope : IContainer +{{ + public ILifetimeScope BeginLifetimeScope() + {{ + return m_fallback.BeginLifetimeScope(); + }} + private object m_lock = new(); + private {ClassName} m_fallback; + private Dictionary> m_lookup; + private readonly List> resolvedInstances = new(); + + public T Resolve() + {{ + return (T)Resolve(typeof(T)); + }} + + public object Resolve(Type type) + {{ + var instance = m_lookup[type](); + return instance; + }} + + public void Dispose() + {{ + foreach (var weakReference in resolvedInstances) + {{ + if(weakReference.TryGetTarget(out var disposable)) + {{ + disposable.Dispose(); + }} + }} + resolvedInstances.Clear(); + }} + + public bool TryResolve(Type type, out object resolved) + {{ + if(m_lookup.TryGetValue(type, out var factory)) + {{ + resolved = factory(); + return true; + }} + resolved = default; + return false; + }} + + + public bool TryResolve(out T resolved) + {{ + if(m_lookup.TryGetValue(typeof(T), out var factory)) + {{ + var value = factory(); + if(value is T t) + {{ + resolved = t; + return true; + }} + }} + resolved = default; + return false; + }} +}} +"; + yield return Constructor(usingStatements, constructorFields, + lifetimeConstructor, constructorAssignments, + dictSize, interfaceInjectors.Keys, + localizedParameters, requested, + constructorParameters, false, LifetimeName); + yield return Declarations(usingStatements, scopedDeclarations, LifetimeName); + yield return ArrayDeclarations(usingStatements, arrayDeclarations, LifetimeName); } private static void CheckForCycles(ImmutableArray dataInjections) @@ -380,6 +465,7 @@ private static void CheckForCycles(ImmutableArray dataInjections) tree[iface] = new List(); } } + foreach (var constructor in injection.Type.Constructors) { foreach (var parameter in constructor.Parameters) @@ -391,11 +477,12 @@ private static void CheckForCycles(ImmutableArray dataInjections) if (named.TypeArguments.Length != 1) continue; named = (INamedTypeSymbol) named.TypeArguments[0]; } + foreach (var iface in injection.Interfaces) { tree[iface].Add(named); - } + if (tree.TryGetValue(named, out var list)) { foreach (var iface in injection.Interfaces) @@ -404,6 +491,7 @@ private static void CheckForCycles(ImmutableArray dataInjections) } } } + if (parameter.Type is IArrayTypeSymbol array) { if (array.ElementType is INamedTypeSymbol arrType) @@ -424,41 +512,52 @@ private static void CheckForCycles(ImmutableArray dataInjections) } private static string Constructor(string usingStatements, string constructorFields, string constructor, string constructorAssignments, int dictSize, - Dictionary> interfaceInjectors, List localizedParameters, List requested, - List constructorParameters) + IEnumerable interfaceInjectors, List localizedParameters, List requested, + List constructorParameters, bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null) { + var lifetimeScopeFunction = addLifetimeScopeFunction + ? $@" +public ILifetimeScope BeginLifetimeScope() +{{ + return new {LifetimeName}({lifetimeParameters}); +}}" + : string.Empty; + var extraConstruction = addLifetimeScopeFunction ? string.Empty : "m_fallback = fallback;"; return $@"{usingStatements} -public partial class DependencyInjectionContainer +public partial class {className} {{ {constructorFields} - public DependencyInjectionContainer{constructor} + public {className}{constructor} {{ + {extraConstruction} {constructorAssignments} m_lookup = new({dictSize}) {{ -{MakeDictionary(interfaceInjectors.Keys)} +{MakeDictionary(interfaceInjectors)} {MakeDictionary(localizedParameters)} {MakeDictionary(requested)} {MakeDictionary(constructorParameters)} }}; - }} + }} + + {lifetimeScopeFunction} }}"; } - private static string ArrayDeclarations(string usingStatements, Dictionary arrayDeclarations) + private static string ArrayDeclarations(string usingStatements, Dictionary arrayDeclarations, string className) { return $@"{usingStatements} -public partial class DependencyInjectionContainer +public partial class {className} {{ {string.Join("\n\t", arrayDeclarations.Values)} }}"; } - private static string Declarations(string usingStatements, Dictionary declarations, List disposables) + private static string Declarations(string usingStatements, Dictionary declarations, string className) { return $@"{usingStatements} -public partial class DependencyInjectionContainer +public partial class {className} {{ {string.Join("\n\t", declarations.Values)} }}"; @@ -485,7 +584,7 @@ private static void MakeArray(Dictionary declarations, string na if (function) { declarations[name] = $@" - private {type}[] {name}() + internal {type}[] {name}() {{ if (m_{name} != null) return m_{name}; @@ -497,12 +596,12 @@ private static void MakeArray(Dictionary declarations, string na return m_{name} = {factoryName}; }} }} - private {type}[]? m_{name};" + factory; + internal {type}[]? m_{name};" + factory; } else { declarations[name] = $@" - private {type}[] {name} + internal {type}[] {name} {{ get {{ @@ -517,7 +616,7 @@ private static void MakeArray(Dictionary declarations, string na }} }} }} - private {type}[]? m_{name};" + factory; + internal {type}[]? m_{name};" + factory; } } diff --git a/FactoryGenerator/Injection.cs b/FactoryGenerator/Injection.cs index ad41f2d..fd5b01f 100644 --- a/FactoryGenerator/Injection.cs +++ b/FactoryGenerator/Injection.cs @@ -12,16 +12,23 @@ public class Injection public INamedTypeSymbol Type { get; } public ImmutableArray Interfaces { get; } public bool Singleton { get; } + public bool Scoped { get; } public BooleanInjection? BooleanInjection { get; } private ISymbol? Lambda { get; } private string LazyName => "m_" + Name.Replace("()", ""); public string Name => SymbolUtility.MemberName(Type).Replace("()", "") + (Lambda is not null ? Lambda.Name : string.Empty) + "()"; public bool Disposable { get; } - public string Declaration(ImmutableArray availableParameters) + public string Declaration(ImmutableArray availableParameters, bool forLifetimeScope) { var creationCall = CreationCall(availableParameters); - return Singleton ? SymbolUtility.SingletonFactory(Type, Name, LazyName, creationCall, Disposable) : Disposable ? SymbolUtility.DisposableFactory(Type, Name, creationCall) : $"private {Type} {Name} => {creationCall};"; + if (forLifetimeScope && Singleton) + { + return $"internal {Type} {Name} => m_fallback.{Name};"; + } + + return (Singleton || Scoped) ? SymbolUtility.SingletonFactory(Type, Name, LazyName, creationCall, Disposable) : + Disposable ? SymbolUtility.DisposableFactory(Type, Name, creationCall) : $"internal {Type} {Name} => {creationCall};"; } public string DisposeCall => LazyName + "?.Dispose();"; @@ -94,6 +101,7 @@ private string MakeMethodCall(IMethodSymbol? constructor, HashSet interfaces, bool singleton, bool disposable, - BooleanInjection? booleanInjection, ISymbol? lambda) + BooleanInjection? booleanInjection, ISymbol? lambda, bool scoped) { Type = type; Interfaces = interfaces; @@ -143,6 +151,7 @@ private Injection(INamedTypeSymbol type, ImmutableArray interf Disposable = disposable; BooleanInjection = booleanInjection; Lambda = lambda; + Scoped = scoped; } @@ -177,15 +186,13 @@ private Injection(INamedTypeSymbol type, ImmutableArray interf var singleInstance = false; var acquireChildInterfaces = false; - var asSelf = false; - if(namedTypeSymbol.Interfaces.Length == 0) - { - asSelf = true; - } + var asSelf = namedTypeSymbol.Interfaces.Length == 0; + var scoped = false; if (namedTypeSymbol.TypeKind == TypeKind.Interface) { asSelf = true; } + BooleanInjection? boolean = null; HashSet attributedInterfaces = new(SymbolEqualityComparer.Default); HashSet preventedInterfaces = new(SymbolEqualityComparer.Default); @@ -221,6 +228,9 @@ private Injection(INamedTypeSymbol type, ImmutableArray interf case "BooleanAttribute": boolean = HandleBoolean(attributeData); break; + case "ScopedAttribute": + scoped = true; + break; default: continue; } @@ -244,7 +254,7 @@ private Injection(INamedTypeSymbol type, ImmutableArray interf interfaces = interfaces.RemoveRange(preventedInterfaces).Distinct(SymbolEqualityComparer.Default).Cast().ToImmutableArray(); - return new Injection(namedTypeSymbol, interfaces, singleInstance, isDisposable, boolean, lambda); + return new Injection(namedTypeSymbol, interfaces, singleInstance, isDisposable, boolean, lambda, scoped); } private static BooleanInjection? HandleBoolean(AttributeData attributeData) @@ -257,4 +267,4 @@ private Injection(INamedTypeSymbol type, ImmutableArray interf return null; } } -} +} \ No newline at end of file diff --git a/FactoryGenerator/SymbolUtility.cs b/FactoryGenerator/SymbolUtility.cs index ea5036a..da32a59 100644 --- a/FactoryGenerator/SymbolUtility.cs +++ b/FactoryGenerator/SymbolUtility.cs @@ -14,11 +14,11 @@ public static IEnumerable GetAllTypes(INamespaceSymbol root) switch (namespaceOrTypeSymbol) { case INamespaceSymbol @namespace: - { - foreach (var nested in GetAllTypes(@namespace)) - yield return nested; - break; - } + { + foreach (var nested in GetAllTypes(@namespace)) + yield return nested; + break; + } case INamedTypeSymbol type: foreach (var nested in GetAllTypes(type)) @@ -28,6 +28,7 @@ public static IEnumerable GetAllTypes(INamespaceSymbol root) } } } + public static IEnumerable GetAllTypes(INamedTypeSymbol root) { foreach (var namespaceOrTypeSymbol in root.GetMembers()) @@ -35,11 +36,11 @@ public static IEnumerable GetAllTypes(INamedTypeSymbol root) switch (namespaceOrTypeSymbol) { case INamespaceSymbol @namespace: - { - foreach (var nested in GetAllTypes(@namespace)) - yield return nested; - break; - } + { + foreach (var nested in GetAllTypes(@namespace)) + yield return nested; + break; + } case INamedTypeSymbol type: foreach (var nested in GetAllTypes(type)) @@ -49,6 +50,7 @@ public static IEnumerable GetAllTypes(INamedTypeSymbol root) } } } + internal static bool IsEnumerable(ITypeSymbol symbol) { return symbol.Name.Contains("IEnumerable") || ImplementsInterface(symbol, "IEnumerable"); @@ -76,9 +78,8 @@ public static string SingletonFactory(INamedTypeSymbol type, string name, string { if (disposable) { - return $@" - private {type} {name} + internal {type} {name} {{ if ({lazyName} != null) return {lazyName}; @@ -92,11 +93,11 @@ public static string SingletonFactory(INamedTypeSymbol type, string name, string return {lazyName} = value; }} }} - private {type}? {lazyName};"; + internal {type}? {lazyName};"; } return $@" - private {type} {name} + internal {type} {name} {{ if ({lazyName} != null) return {lazyName}; @@ -108,19 +109,18 @@ public static string SingletonFactory(INamedTypeSymbol type, string name, string return {lazyName} = {creation}; }} }} - private {type}? {lazyName};"; + internal {type}? {lazyName};"; } internal static string DisposableFactory(INamedTypeSymbol type, string name, string creationCall) { return $@" - private {type} {name} + internal {type} {name} {{ var value = {creationCall}; resolvedInstances.Add(new WeakReference(value)); return value; }}"; } - } } \ No newline at end of file diff --git a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs index e137b16..3116e7b 100644 --- a/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs +++ b/Tests/FactoryGenerator.Tests/InjectionDetectionTests.cs @@ -132,6 +132,38 @@ public void DisposingContainerDisposesSingletons() ((DisposableSingleton) singleton).WasDisposed.ShouldBeTrue(); } + [Fact] + public void DisposingLifetimeContainerDoesNotDisposeSingletons() + { + ISingletonDisposer singleton; + using (var myContainer = new DependencyInjectionContainer(false, default, default!)) + { + using (var lifetime = myContainer.BeginLifetimeScope()) + { + singleton = lifetime.Resolve(); + } + + singleton.ShouldBeOfType(); + ((DisposableSingleton) singleton).WasDisposed.ShouldBeFalse(); + } + + ((DisposableSingleton) singleton).WasDisposed.ShouldBeTrue(); + } + + [Fact] + public void DisposingLifetimeContainerDisposesScoped() + { + IScoped singleton; + using var myContainer = new DependencyInjectionContainer(false, default, default!); + using (var lifetime = myContainer.BeginLifetimeScope()) + { + singleton = lifetime.Resolve(); + } + + singleton.ShouldBeOfType(); + singleton.WasDisposed.ShouldBeTrue(); + } + [Fact] public void DisposingContainerDoesNotDisposeUntrackedInstances() { @@ -161,23 +193,27 @@ public void RequestedArraysArePresent() { Program.Method().Count().ShouldBe(3); } + [Fact] public void BooleanFallbackIsOverriden() { m_container.Resolve().ShouldBeOfType(); } + [Fact] public void TryResolveWithTypeArgumentsWorks() { m_container.TryResolve(out var type).ShouldBeTrue(); type.ShouldBeOfType(); } + [Fact] public void TryResolveWithTypeParameterWorks() { m_container.TryResolve(typeof(IType), out var type).ShouldBeTrue(); type.ShouldBeOfType(); } + [Fact] public void ClassesInsideOtherClassesCanBeInjected() { diff --git a/Tests/TestData/Inherited/Types.cs b/Tests/TestData/Inherited/Types.cs index 8a4358d..42b1e87 100644 --- a/Tests/TestData/Inherited/Types.cs +++ b/Tests/TestData/Inherited/Types.cs @@ -149,4 +149,20 @@ public class Containing { [Inject] public class Containee; +} + +public interface IScoped +{ + bool WasDisposed { get; } +} + +[Inject, Scoped] +public class Scoped : IDisposable, IScoped +{ + public bool WasDisposed { get; private set; } + + public void Dispose() + { + WasDisposed = true; + } } \ No newline at end of file