Skip to content

Commit

Permalink
feat: support nullable as optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
skarllot committed Nov 21, 2023
1 parent 53ac7bc commit 6396f1b
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 10 deletions.
33 changes: 33 additions & 0 deletions src/Jab.FunctionalTests.Common/ConstructorSelectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public void PassesOptionalParametersWhenAvailable()
Assert.NotNull(service.Parameter1);
Assert.Null(service.Parameter2);
Assert.NotNull(service.Parameter3);
Assert.False(typeof(IServiceProvider<IService2>).IsAssignableFrom(typeof(PassesOptionalParametersWhenAvailableContainer)));
}

[ServiceProvider]
Expand All @@ -83,5 +84,37 @@ public void IgnoresNonReferenceTypedParameters()
[Transient(typeof(IService3), typeof(ServiceImplementation))]
[Transient(typeof(IService), typeof(ServiceImplementationWithParameter<IService1, int, IService3>))]
internal partial class IgnoresNonReferenceTypedParametersContainer { }

[Fact]
public void IgnoresNullableParametersWhenNotAvailable()
{
IgnoresNullableParametersWhenNotAvailableContainer c = new();
var service = Assert.IsType<ServiceImplementationWithNullable>(c.GetService<IService>());
Assert.NotNull(service.Parameter1);
Assert.Null(service.Parameter2);
Assert.Empty(service.Parameter3!);
Assert.False(typeof(IServiceProvider<IService2>).IsAssignableFrom(typeof(IgnoresNullableOptionalParametersWhenNotAvailableContainer)));
}

[ServiceProvider]
[Transient(typeof(IService1), typeof(ServiceImplementation))]
[Transient(typeof(IService), typeof(ServiceImplementationWithNullable))]
internal partial class IgnoresNullableParametersWhenNotAvailableContainer { }

[Fact]
public void IgnoresNullableOptionalParametersWhenNotAvailable()
{
IgnoresNullableOptionalParametersWhenNotAvailableContainer c = new();
var service = Assert.IsType<ServiceImplementationWithNullableOptional>(c.GetService<IService>());
Assert.NotNull(service.Parameter1);
Assert.Null(service.Parameter2);
Assert.Empty(service.Parameter3!);
Assert.False(typeof(IServiceProvider<IService2>).IsAssignableFrom(typeof(IgnoresNullableOptionalParametersWhenNotAvailableContainer)));
}

[ServiceProvider]
[Transient(typeof(IService1), typeof(ServiceImplementation))]
[Transient(typeof(IService), typeof(ServiceImplementationWithNullableOptional))]
internal partial class IgnoresNullableOptionalParametersWhenNotAvailableContainer { }
}
}
1 change: 1 addition & 0 deletions src/Jab.FunctionalTests.Common/Jab.FunctionalTest.props
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
<LangVersion>preview</LangVersion>
<RootNamespace>JabTests</RootNamespace>
<IsTestProject Condition="'$(TargetFramework)' == 'netstandard2.0'">false</IsTestProject>
<NoWarn>$(NoWarn);JAB0013;JAB0014</NoWarn>
</PropertyGroup>

<Target Name="_SetProperties">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Collections.Generic;

namespace JabTests;

#nullable enable

internal class ServiceImplementationWithNullable : IService
{
public IService1 Parameter1 { get; }
public IService2? Parameter2 { get; }
public IEnumerable<IService3>? Parameter3 { get; }

public ServiceImplementationWithNullable(
IService1 parameter1,
IService2? parameter2,
IEnumerable<IService3>? parameter3)
{
Parameter1 = parameter1;
Parameter2 = parameter2;
Parameter3 = parameter3;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Collections.Generic;

namespace JabTests;

#nullable enable

internal class ServiceImplementationWithNullableOptional : IService
{
public IService1 Parameter1 { get; }
public IService2? Parameter2 { get; }
public IEnumerable<IService3>? Parameter3 { get; }

public ServiceImplementationWithNullableOptional(
IService1 parameter1,
IService2? parameter2 = null,
IEnumerable<IService3>? parameter3 = null)
{
Parameter1 = parameter1;
Parameter2 = parameter2;
Parameter3 = parameter3;
}
}
41 changes: 41 additions & 0 deletions src/Jab.Tests/DiagnosticsTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Testing;
using Xunit;
using Verify = JabTests.GeneratorAnalyzerVerifier<Jab.ContainerGenerator>;
Expand Down Expand Up @@ -279,5 +280,45 @@ await Verify.VerifyAnalyzerAsync(testCode,
.WithLocation(1)
.WithArguments("IService"));
}

[Fact]
public async Task ProducesJAB0013WhenNullableNonOptionalDependencyNotFound()
{
string testCode = $@"
#nullable enable
interface IDependency {{ }}
class Service {{ public Service(IDependency? dep) {{}} }}
[ServiceProvider]
[{{|#1:Transient(typeof(Service))|}}]
public partial class Container {{}}
";
await Verify.VerifyAnalyzerAsync(testCode,
DiagnosticResult
.CompilerError("JAB0013")
.WithSeverity(DiagnosticSeverity.Warning)
.WithLocation(1)
.WithArguments("IDependency?", "Service"));
}

[Fact]
public async Task ProducesJAB0014WhenNullableNonOptionalDependencyFound()
{
string testCode = $@"
#nullable enable
interface IDependency {{ }}
class Dependency : IDependency {{ }}
class Service {{ public Service(IDependency? dep) {{}} }}
[ServiceProvider]
[{{|#1:Transient(typeof(Service))|}}]
[{{|#2:Transient(typeof(IDependency), typeof(Dependency))|}}]
public partial class Container {{}}
";
await Verify.VerifyAnalyzerAsync(testCode,
DiagnosticResult
.CompilerError("JAB0014")
.WithSeverity(DiagnosticSeverity.Warning)
.WithLocation(1)
.WithArguments("IDependency?", "Service"));
}
}
}
2 changes: 1 addition & 1 deletion src/Jab/CodeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private void AppendType(INamedTypeSymbol namedTypeSymbol)
{
if (!_typeNameCache.TryGetValue(namedTypeSymbol, out var name))
{
name = _typeNameCache[namedTypeSymbol] = namedTypeSymbol.ToDisplayString();
name = _typeNameCache[namedTypeSymbol] = namedTypeSymbol.ToDisplayString(NullableFlowState.NotNull);
}

AppendRaw(name);
Expand Down
22 changes: 14 additions & 8 deletions src/Jab/ContainerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen

private void WriteResolutionCall(CodeWriter codeWriter, ServiceCallSite other, string reference)
{
if (other.IsMainImplementation)
if (other is DefaultValueCallSite)
{
codeWriter.Append($"{reference}.GetService<{other.ServiceType}>()");
codeWriter.Append($"default({other.ServiceTypeString})");
}
else if (other.IsMainImplementation)
{
codeWriter.Append($"{reference}.GetService<{other.ServiceTypeString}>()");
}
else
{
Expand Down Expand Up @@ -208,7 +212,7 @@ private void Execute(GeneratorContext context)

foreach (var rootService in root.RootCallSites)
{
var rootServiceType = rootService.ServiceType;
string rootServiceType = rootService.ServiceTypeString;
if (rootService.IsMainImplementation)
{
codeWriter.Append($"{rootServiceType} IServiceProvider<{rootServiceType}>.GetService()");
Expand Down Expand Up @@ -267,7 +271,7 @@ private void Execute(GeneratorContext context)
{
codeWriter.Line($" ||");
}
codeWriter.Append($"typeof({rootService.ServiceType}) == service");
codeWriter.Append($"typeof({rootService.ServiceType.ToDisplayString(NullableFlowState.NotNull)}) == service");
}
if (first)
{
Expand Down Expand Up @@ -296,7 +300,7 @@ private void Execute(GeneratorContext context)

foreach (var rootService in root.RootCallSites)
{
var rootServiceType = rootService.ServiceType;
string rootServiceType = rootService.ServiceTypeString;

using (rootService.IsMainImplementation ?
codeWriter.Scope($"{rootServiceType} IServiceProvider<{rootServiceType}>.GetService()") :
Expand Down Expand Up @@ -358,7 +362,7 @@ private void WriteServiceProvider(CodeWriter codeWriter, ServiceProvider root)
{
if (rootRootCallSite.IsMainImplementation)
{
codeWriter.Append($"if (type == typeof({rootRootCallSite.ServiceType})) return ");
codeWriter.Append($"if (type == typeof({rootRootCallSite.ServiceType.ToDisplayString(NullableFlowState.NotNull)})) return ");
WriteResolutionCall(codeWriter, rootRootCallSite, "this");
codeWriter.Line($";");
}
Expand Down Expand Up @@ -494,7 +498,7 @@ private static void WriteInterfaces(CodeWriter codeWriter, ServiceProvider root,
{
if (serviceCallSite.IsMainImplementation)
{
codeWriter.Line($" IServiceProvider<{serviceCallSite.ServiceType}>,");
codeWriter.Line($" IServiceProvider<{serviceCallSite.ServiceTypeString}>,");
}
}

Expand All @@ -510,7 +514,7 @@ private void WriteCacheLocations(ServiceProvider root, CodeWriter codeWriter, bo
(rootService.Lifetime == ServiceLifetime.Scoped && !isScope) ||
rootService.Lifetime == ServiceLifetime.Transient) continue;

codeWriter.Line($"private {rootService.ImplementationType}? {GetCacheLocation(rootService)};");
codeWriter.Line($"private {rootService.ImplementationType.ToDisplayString(NullableFlowState.NotNull)}? {GetCacheLocation(rootService)};");
}
codeWriter.Line();
}
Expand Down Expand Up @@ -589,6 +593,8 @@ public override void Initialize(AnalysisContext context)
DiagnosticDescriptors.NoServiceTypeRegistered,
DiagnosticDescriptors.ImplementationTypeAndFactoryNotAllowed,
DiagnosticDescriptors.FactoryMemberMustBeAMethodOrHaveDelegateType,
DiagnosticDescriptors.NullableServiceNotRegistered,
DiagnosticDescriptors.NullableServiceRegistered,
}.ToImmutableArray();

private static string ReadAttributesFile()
Expand Down
10 changes: 10 additions & 0 deletions src/Jab/DefaultValueCallSite.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Jab;

internal record DefaultValueCallSite: ServiceCallSite
{
public DefaultValueCallSite(ITypeSymbol serviceType) : base(serviceType, serviceType, ServiceLifetime.Transient, 0, false)
{
}

public override string ServiceTypeString => ServiceType.ToDisplayString(NullableFlowState.MaybeNull);
}
8 changes: 8 additions & 0 deletions src/Jab/DiagnosticDescriptors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,12 @@ internal static class DiagnosticDescriptors
"The factory member has to be a method or have a delegate type",
"The factory member '{0}' has to be a method of have a delegate type, for service '{1}'", "Usage", DiagnosticSeverity.Error, true);

public static readonly DiagnosticDescriptor NullableServiceNotRegistered = new("JAB0013",
"Not registered nullable dependency without a default value",
"'{0}' parameter to construct '{1}' will always be null when constructing using a service provider. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Warning, true);

public static readonly DiagnosticDescriptor NullableServiceRegistered = new("JAB0014",
"Nullable dependency without a default value",
"'{0}' parameter to construct '{1}' will never be null when constructing using a service provider. Add a default value to make the service reference optional", "Usage", DiagnosticSeverity.Warning, true);

}
1 change: 1 addition & 0 deletions src/Jab/ServiceCallSite.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ internal abstract record ServiceCallSite(ITypeSymbol ServiceType, ITypeSymbol Im
public int ReverseIndex { get; } = ReverseIndex;
public bool? IsDisposable { get; } = IsDisposable;
public bool IsMainImplementation => ReverseIndex == 0;
public virtual string ServiceTypeString => ServiceType.ToDisplayString(NullableFlowState.NotNull);
}
13 changes: 12 additions & 1 deletion src/Jab/ServiceProviderBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,18 @@ private ServiceCallSite CreateConstructorCallSite(
}
else
{
if (parameterCallSite == null)
if (parameterSymbol.Type.NullableAnnotation == NullableAnnotation.Annotated)
{
var diagnostic = Diagnostic.Create(
parameterCallSite is null ? DiagnosticDescriptors.NullableServiceNotRegistered : DiagnosticDescriptors.NullableServiceRegistered,
registrationLocation,
parameterSymbol.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat),
implementationType.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat));

_context.ReportDiagnostic(diagnostic);
callSites.Add(parameterCallSite ?? new DefaultValueCallSite(parameterSymbol.Type));
}
else if (parameterCallSite == null)
{
var diagnostic = Diagnostic.Create(DiagnosticDescriptors.ServiceRequiredToConstructNotRegistered,
registrationLocation,
Expand Down

0 comments on commit 6396f1b

Please sign in to comment.