Skip to content

Commit

Permalink
Merge pull request #13 from westermo/7-support-ilifetimescope-for-asp…
Browse files Browse the repository at this point in the history
…net-compatibility

Add support for Lifetime Scoping
  • Loading branch information
carl-andersson-at-westermo authored Apr 29, 2024
2 parents cad8f51 + afb49e0 commit da17549
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 64 deletions.
6 changes: 6 additions & 0 deletions FactoryGenerator.Attributes/Attributes/ScopedAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using System;

namespace FactoryGenerator.Attributes;

[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Interface)]
public class ScopedAttribute : Attribute;
7 changes: 6 additions & 1 deletion FactoryGenerator.Attributes/IContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace FactoryGenerator;
#nullable enable
public interface IContainer : IDisposable
public interface ILifetimeScope : IDisposable
{
T Resolve<T>();
object Resolve(Type type);
Expand All @@ -16,4 +16,9 @@ public interface IContainer : IDisposable

bool IsRegistered(Type type) => TryResolve(type, out _);
bool IsRegistered<T>() => IsRegistered(typeof(T));
}

public interface IContainer : ILifetimeScope
{
public ILifetimeScope BeginLifetimeScope();
}
171 changes: 135 additions & 36 deletions FactoryGenerator/FactoryGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ private void MakeAutofacModule(SourceProductionContext context,
context.AddSource("DependencyInjectionContainer.Constructor.g.cs", source[1]);
context.AddSource("DependencyInjectionContainer.Declarations.g.cs", source[2]);
context.AddSource("DependencyInjectionContainer.EnumerableDeclarations.g.cs", source[3]);
context.AddSource("LifetimeScope.Lookup.g.cs", source[4]);
context.AddSource("LifetimeScope.Constructor.g.cs", source[5]);
context.AddSource("LifetimeScope.Declarations.g.cs", source[6]);
context.AddSource("LifetimeScope.EnumerableDeclarations.g.cs", source[7]);
}

private INamedTypeSymbol? ResolveTransformations(GeneratorSyntaxContext context, CancellationToken token)
Expand Down Expand Up @@ -126,6 +130,9 @@ private static INamespaceSymbol GetGlobalNamespace(Compilation compilation, Canc
return compilation.GlobalNamespace;
}

private const string ClassName = "DependencyInjectionContainer";
private const string LifetimeName = "LifetimeScope";

private static IEnumerable<string> GenerateCode(ImmutableArray<Injection> dataInjections,
Compilation compilation, ImmutableArray<INamedTypeSymbol?> usages, ILogger log)
{
Expand All @@ -142,8 +149,10 @@ namespace {compilation.Assembly.Name}.Generated;
yield return $@"{usingStatements}
[GeneratedCode(""{ToolName}"", ""{Version}"")]
#nullable disable
public partial class DependencyInjectionContainer : IContainer
public sealed partial class {ClassName} : IContainer
{{
private object m_lock = new();
private Dictionary<Type,Func<object>> m_lookup;
private readonly List<WeakReference<IDisposable>> resolvedInstances = new();
public T Resolve<T>()
Expand Down Expand Up @@ -195,13 +204,12 @@ public bool TryResolve<T>(out T resolved)
resolved = default;
return false;
}}
private Dictionary<Type,Func<object>> m_lookup;
private object m_lock = new();
}}";

var booleans = dataInjections.Select(inj => inj.BooleanInjection).Where(b => b is not null)
.Select(b => b!.Key).Distinct();
.Select(b => b!.Key).Distinct().ToArray();
var allArguments = booleans.Select(b => $"bool {b}").ToList();
var allParameters = booleans.Select(b => $"{b}").ToList();
var ordered = dataInjections.Reverse().ToList();
//Put all test-overrides at the end
foreach (var injection in ordered.ToArray())
Expand All @@ -228,14 +236,15 @@ public bool TryResolve<T>(out T resolved)
}

var declarations = new Dictionary<string, string>();
var scopedDeclarations = new Dictionary<string, string>();
var availableInterfaces = interfaceInjectors.Keys.ToImmutableArray();
var constructorParameters = new List<IParameterSymbol>();
var disposables = new List<Injection>();

foreach (var injection in ordered)
{
if (injection.Disposable && injection.Singleton) disposables.Add(injection);
declarations[injection.Name] = injection.Declaration(availableInterfaces);
declarations[injection.Name] = injection.Declaration(availableInterfaces, false);
scopedDeclarations[injection.Name] = injection.Declaration(availableInterfaces, true);

HashSet<IParameterSymbol>? missing = null;
injection.GetBestConstructor(availableInterfaces, ref missing);
if (missing is null) continue;
Expand All @@ -256,7 +265,9 @@ public bool TryResolve<T>(out T resolved)
{
log.Log(LogLevel.Information, $"Selecting {chosen.Name} for {interfaceSymbol}");
declarations[SymbolUtility.MemberName(interfaceSymbol)] =
$"private {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};";
$"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};";
scopedDeclarations[SymbolUtility.MemberName(interfaceSymbol)] =
$"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {chosen.Name};";
}
}
}
Expand All @@ -272,21 +283,16 @@ public bool TryResolve<T>(out T resolved)
var trueValue = possibilities.LastOrDefault(p =>
p.BooleanInjection?.Value == true && p.BooleanInjection?.Key == key);
trueValue ??= fallback;
if (key == last)
{
ternary.Append($"{key} ? {trueValue?.Name ?? "null!"} : {fallback?.Name ?? "null!"}");
}
else
{
ternary.Append($"{key} ? {trueValue?.Name ?? "null!"} : ");
}
ternary.Append(key == last ? $"{key} ? {trueValue?.Name ?? "null!"} : {fallback?.Name ?? "null!"}" : $"{key} ? {trueValue?.Name ?? "null!"} : ");
}

if (!declarations.ContainsKey(SymbolUtility.MemberName(interfaceSymbol)))
{
log.Log(LogLevel.Information, $"Selecting {ternary} for {interfaceSymbol}");
declarations[SymbolUtility.MemberName(interfaceSymbol)] =
$"private {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};";
$"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};";
scopedDeclarations[SymbolUtility.MemberName(interfaceSymbol)] =
$"internal {interfaceSymbol} {SymbolUtility.MemberName(interfaceSymbol)} => {ternary};";
}
}
}
Expand Down Expand Up @@ -343,15 +349,22 @@ public bool TryResolve<T>(out T resolved)
{
log.Log(LogLevel.Debug, $"Registering {parameter.Name} as Self");
declarations[parameter.Name] = $"private IContainer {parameter.Name} => this;";
scopedDeclarations[parameter.Name] = $"private IContainer {parameter.Name} => this;";
constructorParameters.Remove(parameter);
}
}

var arguments = constructorParameters.OrderBy(p => p.Type.ToString()).Select(p => $"{p.Type} {p.Name}")
.Distinct();

var parameters = constructorParameters.OrderBy(p => p.Type.ToString()).Select(p => $"{p.Name}")
.Distinct();
allArguments.AddRange(arguments);
allParameters.AddRange(parameters);

var constructor = "(" + string.Join(", ", allArguments) + ")";
var lifetimeConstructor = "(" + $"{ClassName} fallback, " + string.Join(", ", allArguments) + ")";
var lifetimeParameters = "this, " + string.Join(", ", allParameters);

log.Log(LogLevel.Debug, $"Resulting Constructor: {constructor}");
var constructorFields = string.Join("\n\t", allArguments.Select(arg => arg + ";"));
Expand All @@ -361,11 +374,83 @@ public bool TryResolve<T>(out T resolved)
constructorParameters.Count;
yield return Constructor(usingStatements, constructorFields,
constructor, constructorAssignments,
dictSize, interfaceInjectors,
dictSize, interfaceInjectors.Keys,
localizedParameters, requested,
constructorParameters);
yield return Declarations(usingStatements, declarations, disposables);
yield return ArrayDeclarations(usingStatements, arrayDeclarations);
constructorParameters, true, ClassName, lifetimeParameters);
yield return Declarations(usingStatements, declarations, ClassName);
yield return ArrayDeclarations(usingStatements, arrayDeclarations, ClassName);
yield return $@"{usingStatements}
[GeneratedCode(""{ToolName}"", ""{Version}"")]
#nullable disable
public sealed partial class LifetimeScope : IContainer
{{
public ILifetimeScope BeginLifetimeScope()
{{
return m_fallback.BeginLifetimeScope();
}}
private object m_lock = new();
private {ClassName} m_fallback;
private Dictionary<Type,Func<object>> m_lookup;
private readonly List<WeakReference<IDisposable>> resolvedInstances = new();
public T Resolve<T>()
{{
return (T)Resolve(typeof(T));
}}
public object Resolve(Type type)
{{
var instance = m_lookup[type]();
return instance;
}}
public void Dispose()
{{
foreach (var weakReference in resolvedInstances)
{{
if(weakReference.TryGetTarget(out var disposable))
{{
disposable.Dispose();
}}
}}
resolvedInstances.Clear();
}}
public bool TryResolve(Type type, out object resolved)
{{
if(m_lookup.TryGetValue(type, out var factory))
{{
resolved = factory();
return true;
}}
resolved = default;
return false;
}}
public bool TryResolve<T>(out T resolved)
{{
if(m_lookup.TryGetValue(typeof(T), out var factory))
{{
var value = factory();
if(value is T t)
{{
resolved = t;
return true;
}}
}}
resolved = default;
return false;
}}
}}
";
yield return Constructor(usingStatements, constructorFields,
lifetimeConstructor, constructorAssignments,
dictSize, interfaceInjectors.Keys,
localizedParameters, requested,
constructorParameters, false, LifetimeName);
yield return Declarations(usingStatements, scopedDeclarations, LifetimeName);
yield return ArrayDeclarations(usingStatements, arrayDeclarations, LifetimeName);
}

private static void CheckForCycles(ImmutableArray<Injection> dataInjections)
Expand All @@ -380,6 +465,7 @@ private static void CheckForCycles(ImmutableArray<Injection> dataInjections)
tree[iface] = new List<INamedTypeSymbol>();
}
}

foreach (var constructor in injection.Type.Constructors)
{
foreach (var parameter in constructor.Parameters)
Expand All @@ -391,11 +477,12 @@ private static void CheckForCycles(ImmutableArray<Injection> dataInjections)
if (named.TypeArguments.Length != 1) continue;
named = (INamedTypeSymbol) named.TypeArguments[0];
}

foreach (var iface in injection.Interfaces)
{
tree[iface].Add(named);

}

if (tree.TryGetValue(named, out var list))
{
foreach (var iface in injection.Interfaces)
Expand All @@ -404,6 +491,7 @@ private static void CheckForCycles(ImmutableArray<Injection> dataInjections)
}
}
}

if (parameter.Type is IArrayTypeSymbol array)
{
if (array.ElementType is INamedTypeSymbol arrType)
Expand All @@ -424,41 +512,52 @@ private static void CheckForCycles(ImmutableArray<Injection> dataInjections)
}

private static string Constructor(string usingStatements, string constructorFields, string constructor, string constructorAssignments, int dictSize,
Dictionary<INamedTypeSymbol, List<Injection>> interfaceInjectors, List<IParameterSymbol> localizedParameters, List<INamedTypeSymbol> requested,
List<IParameterSymbol> constructorParameters)
IEnumerable<INamedTypeSymbol> interfaceInjectors, List<IParameterSymbol> localizedParameters, List<INamedTypeSymbol> requested,
List<IParameterSymbol> constructorParameters, bool addLifetimeScopeFunction, string className, string? lifetimeParameters = null)
{
var lifetimeScopeFunction = addLifetimeScopeFunction
? $@"
public ILifetimeScope BeginLifetimeScope()
{{
return new {LifetimeName}({lifetimeParameters});
}}"
: string.Empty;
var extraConstruction = addLifetimeScopeFunction ? string.Empty : "m_fallback = fallback;";
return $@"{usingStatements}
public partial class DependencyInjectionContainer
public partial class {className}
{{
{constructorFields}
public DependencyInjectionContainer{constructor}
public {className}{constructor}
{{
{extraConstruction}
{constructorAssignments}
m_lookup = new({dictSize})
{{
{MakeDictionary(interfaceInjectors.Keys)}
{MakeDictionary(interfaceInjectors)}
{MakeDictionary(localizedParameters)}
{MakeDictionary(requested)}
{MakeDictionary(constructorParameters)}
}};
}}
}}
{lifetimeScopeFunction}
}}";
}

private static string ArrayDeclarations(string usingStatements, Dictionary<string, string> arrayDeclarations)
private static string ArrayDeclarations(string usingStatements, Dictionary<string, string> arrayDeclarations, string className)
{
return $@"{usingStatements}
public partial class DependencyInjectionContainer
public partial class {className}
{{
{string.Join("\n\t", arrayDeclarations.Values)}
}}";
}

private static string Declarations(string usingStatements, Dictionary<string, string> declarations, List<Injection> disposables)
private static string Declarations(string usingStatements, Dictionary<string, string> declarations, string className)
{
return $@"{usingStatements}
public partial class DependencyInjectionContainer
public partial class {className}
{{
{string.Join("\n\t", declarations.Values)}
}}";
Expand All @@ -485,7 +584,7 @@ private static void MakeArray(Dictionary<string, string> declarations, string na
if (function)
{
declarations[name] = $@"
private {type}[] {name}()
internal {type}[] {name}()
{{
if (m_{name} != null)
return m_{name};
Expand All @@ -497,12 +596,12 @@ private static void MakeArray(Dictionary<string, string> declarations, string na
return m_{name} = {factoryName};
}}
}}
private {type}[]? m_{name};" + factory;
internal {type}[]? m_{name};" + factory;
}
else
{
declarations[name] = $@"
private {type}[] {name}
internal {type}[] {name}
{{
get
{{
Expand All @@ -517,7 +616,7 @@ private static void MakeArray(Dictionary<string, string> declarations, string na
}}
}}
}}
private {type}[]? m_{name};" + factory;
internal {type}[]? m_{name};" + factory;
}
}

Expand Down
Loading

0 comments on commit da17549

Please sign in to comment.