diff --git a/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs b/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs new file mode 100644 index 00000000..d749df24 --- /dev/null +++ b/src/Analyzers/Activities/MatchingInputOutputTypeActivityAnalyzer.cs @@ -0,0 +1,352 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.DurableTask.Analyzers.Activities; + +/// +/// Analyzer that checks for mismatches between the input and output types of Activities invocations and their definitions. +/// +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class MatchingInputOutputTypeActivityAnalyzer : DiagnosticAnalyzer +{ + /// + /// The diagnostic ID for the diagnostic that reports when the input argument type of an Activity invocation does not match the input parameter type of the Activity definition. + /// + public const string InputArgumentTypeMismatchDiagnosticId = "DURABLE2001"; + + /// + /// The diagnostic ID for the diagnostic that reports when the output argument type of an Activity invocation does not match the return type of the Activity definition. + /// + public const string OutputArgumentTypeMismatchDiagnosticId = "DURABLE2002"; + + static readonly LocalizableString InputArgumentTypeMismatchTitle = new LocalizableResourceString(nameof(Resources.InputArgumentTypeMismatchAnalyzerTitle), Resources.ResourceManager, typeof(Resources)); + static readonly LocalizableString InputArgumentTypeMismatchMessageFormat = new LocalizableResourceString(nameof(Resources.InputArgumentTypeMismatchAnalyzerMessageFormat), Resources.ResourceManager, typeof(Resources)); + + static readonly LocalizableString OutputArgumentTypeMismatchTitle = new LocalizableResourceString(nameof(Resources.OutputArgumentTypeMismatchAnalyzerTitle), Resources.ResourceManager, typeof(Resources)); + static readonly LocalizableString OutputArgumentTypeMismatchMessageFormat = new LocalizableResourceString(nameof(Resources.OutputArgumentTypeMismatchAnalyzerMessageFormat), Resources.ResourceManager, typeof(Resources)); + + static readonly DiagnosticDescriptor InputArgumentTypeMismatchRule = new( + InputArgumentTypeMismatchDiagnosticId, + InputArgumentTypeMismatchTitle, + InputArgumentTypeMismatchMessageFormat, + AnalyzersCategories.Activity, + DiagnosticSeverity.Warning, + customTags: [WellKnownDiagnosticTags.CompilationEnd], + isEnabledByDefault: true); + + static readonly DiagnosticDescriptor OutputArgumentTypeMismatchRule = new( + OutputArgumentTypeMismatchDiagnosticId, + OutputArgumentTypeMismatchTitle, + OutputArgumentTypeMismatchMessageFormat, + AnalyzersCategories.Activity, + DiagnosticSeverity.Warning, + customTags: [WellKnownDiagnosticTags.CompilationEnd], + isEnabledByDefault: true); + + /// + public override ImmutableArray SupportedDiagnostics => [InputArgumentTypeMismatchRule, OutputArgumentTypeMismatchRule]; + + /// + public override void Initialize(AnalysisContext context) + { + context.EnableConcurrentExecution(); + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); + + context.RegisterCompilationStartAction(context => + { + KnownTypeSymbols knownSymbols = new(context.Compilation); + + if (knownSymbols.ActivityTriggerAttribute == null || knownSymbols.FunctionNameAttribute == null || + knownSymbols.DurableTaskRegistry == null || knownSymbols.TaskActivityBase == null || + knownSymbols.Task == null || knownSymbols.TaskT == null) + { + // symbols not available in this compilation, skip analysis + return; + } + + IMethodSymbol taskActivityRunAsync = knownSymbols.TaskActivityBase.GetMembers("RunAsync").OfType().Single(); + INamedTypeSymbol voidSymbol = context.Compilation.GetSpecialType(SpecialType.System_Void); + + // Search for Activity invocations + ConcurrentBag invocations = []; + context.RegisterOperationAction( + ctx => + { + ctx.CancellationToken.ThrowIfCancellationRequested(); + + if (ctx.Operation is not IInvocationOperation invocationOperation) + { + return; + } + + IMethodSymbol targetMethod = invocationOperation.TargetMethod; + if (!targetMethod.IsEqualTo(knownSymbols.TaskOrchestrationContext, "CallActivityAsync")) + { + return; + } + + Debug.Assert(invocationOperation.Arguments.Length is 2 or 3, "CallActivityAsync has 2 or 3 parameters"); + Debug.Assert(invocationOperation.Arguments[0].Parameter?.Name == "name", "First parameter of CallActivityAsync is name"); + IArgumentOperation activityNameArgumentOperation = invocationOperation.Arguments[0]; + + // extracts the constant value from the argument (e.g.: it can be a nameof, string literal or const field) + Optional constant = ctx.Operation.SemanticModel!.GetConstantValue(activityNameArgumentOperation.Value.Syntax); + if (!constant.HasValue) + { + // not a constant value, we cannot correlate this invocation to an existent activity in compile time + return; + } + + string activityName = constant.Value!.ToString(); + + // Try to extract the input argument from the invocation + ITypeSymbol? inputType = null; + IArgumentOperation? inputArgumentParameter = invocationOperation.Arguments.SingleOrDefault(a => a.Parameter?.Name == "input"); + if (inputArgumentParameter != null && inputArgumentParameter.ArgumentKind != ArgumentKind.DefaultValue) + { + // if the argument is not null or a default value provided by the compiler, get the type before the conversion to object + TypeInfo inputTypeInfo = ctx.Operation.SemanticModel.GetTypeInfo(inputArgumentParameter.Value.Syntax, ctx.CancellationToken); + inputType = inputTypeInfo.Type; + } + + // If the CallActivityAsync method is used, we extract the output type from TypeArguments + ITypeSymbol? outputType = targetMethod.OriginalDefinition.Arity == 1 && targetMethod.TypeArguments.Length == 1 ? + targetMethod.TypeArguments[0] : null; + + invocations.Add(new ActivityInvocation() + { + Name = activityName, + InputType = inputType, + OutputType = outputType, + InvocationSyntaxNode = invocationOperation.Syntax, + }); + }, + OperationKind.Invocation); + + // Search for Durable Functions Activities definitions + ConcurrentBag activities = []; + context.RegisterSymbolAction( + ctx => + { + ctx.CancellationToken.ThrowIfCancellationRequested(); + + if (ctx.Symbol is not IMethodSymbol methodSymbol) + { + return; + } + + if (!methodSymbol.ContainsAttributeInAnyMethodArguments(knownSymbols.ActivityTriggerAttribute)) + { + return; + } + + if (!methodSymbol.TryGetSingleValueFromAttribute(knownSymbols.FunctionNameAttribute, out string functionName)) + { + return; + } + + IParameterSymbol? inputParam = methodSymbol.Parameters.SingleOrDefault( + p => p.GetAttributes().Any(a => knownSymbols.ActivityTriggerAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default))); + if (inputParam == null) + { + // Azure Functions Activity methods must have an input parameter + return; + } + + ITypeSymbol? inputType = inputParam.Type; + + ITypeSymbol? outputType = methodSymbol.ReturnType; + if (outputType.Equals(voidSymbol, SymbolEqualityComparer.Default) || + outputType.Equals(knownSymbols.Task, SymbolEqualityComparer.Default)) + { + // If the method returns void or Task, we consider it as having no output + outputType = null; + } + else if (outputType.OriginalDefinition.Equals(knownSymbols.TaskT, SymbolEqualityComparer.Default) && + outputType is INamedTypeSymbol outputNamedType) + { + // If the method is Task, we consider T as the output type + Debug.Assert(outputNamedType.TypeArguments.Length == 1, "Task has one type argument"); + outputType = outputNamedType.TypeArguments[0]; + } + + activities.Add(new ActivityDefinition() + { + Name = functionName, + InputType = inputType, + OutputType = outputType, + }); + }, + SymbolKind.Method); + + // Search for TaskActivity definitions + context.RegisterSyntaxNodeAction( + ctx => + { + ctx.CancellationToken.ThrowIfCancellationRequested(); + + if (ctx.ContainingSymbol is not INamedTypeSymbol classSymbol) + { + return; + } + + if (classSymbol.IsAbstract) + { + return; + } + + // Check if the class has a method that overrides TaskActivity.RunAsync + IMethodSymbol? methodOverridingRunAsync = null; + INamedTypeSymbol? baseType = classSymbol; // start from the current class + while (baseType != null) + { + foreach (IMethodSymbol method in baseType.GetMembers().OfType()) + { + if (SymbolEqualityComparer.Default.Equals(method.OverriddenMethod?.OriginalDefinition, taskActivityRunAsync)) + { + methodOverridingRunAsync = method.OverriddenMethod; + break; + } + } + + baseType = baseType.BaseType; + } + + // TaskActivity.RunAsync method not found in the class hierarchy + if (methodOverridingRunAsync == null) + { + return; + } + + // gets the closed constructed TaskActivity type, so we can extract TInput and TOutput + INamedTypeSymbol closedConstructedTaskActivity = methodOverridingRunAsync.ContainingType; + Debug.Assert(closedConstructedTaskActivity.TypeArguments.Length == 2, "TaskActivity has TInput and TOutput"); + + activities.Add(new ActivityDefinition() + { + Name = classSymbol.Name, + InputType = closedConstructedTaskActivity.TypeArguments[0], + OutputType = closedConstructedTaskActivity.TypeArguments[1], + }); + }, + SyntaxKind.ClassDeclaration); + + // Search for Func/Action activities directly registered through DurableTaskRegistry + context.RegisterOperationAction( + ctx => + { + ctx.CancellationToken.ThrowIfCancellationRequested(); + + if (ctx.Operation is not IInvocationOperation invocation) + { + return; + } + + if (!SymbolEqualityComparer.Default.Equals(invocation.Type, knownSymbols.DurableTaskRegistry)) + { + return; + } + + // there are 8 AddActivityFunc overloads, with combinations of Activity Name, TInput and TOutput + if (invocation.TargetMethod.Name != "AddActivityFunc") + { + return; + } + + // all overloads have the parameter 'name', either as an Action or a Func + IArgumentOperation? activityNameArgumentOperation = invocation.Arguments.SingleOrDefault(a => a.Parameter!.Name == "name"); + if (activityNameArgumentOperation == null) + { + return; + } + + // extracts the constant value from the argument (e.g.: it can be a nameof, string literal or const field) + Optional constant = ctx.Operation.SemanticModel!.GetConstantValue(activityNameArgumentOperation.Value.Syntax); + if (!constant.HasValue) + { + // not a constant value, we cannot correlate this invocation to an existent activity in compile time + return; + } + + string activityName = constant.Value!.ToString(); + + ITypeSymbol? inputType = invocation.TargetMethod.GetTypeArgumentByParameterName("TInput"); + ITypeSymbol? outputType = invocation.TargetMethod.GetTypeArgumentByParameterName("TOutput"); + + activities.Add(new ActivityDefinition() + { + Name = activityName, + InputType = inputType, + OutputType = outputType, + }); + }, + OperationKind.Invocation); + + // At the end of the compilation, we correlate the invocations with the definitions + context.RegisterCompilationEndAction(ctx => + { + // index by name for faster lookup + Dictionary activitiesByName = activities.ToDictionary(a => a.Name, a => a); + + foreach (ActivityInvocation invocation in invocations) + { + if (!activitiesByName.TryGetValue(invocation.Name, out ActivityDefinition activity)) + { + // Activity not found, we cannot correlate this invocation to an existent activity in compile time. + // We could add a diagnostic here if we want to enforce that, but while we experiment with this analyzer, + // we should prevent false positives. + continue; + } + + if (!SymbolEqualityComparer.Default.Equals(invocation.InputType, activity.InputType)) + { + string actual = invocation.InputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; + string expected = activity.InputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; + string activityName = invocation.Name; + + Diagnostic diagnostic = RoslynExtensions.BuildDiagnostic(InputArgumentTypeMismatchRule, invocation.InvocationSyntaxNode, actual, expected, activityName); + ctx.ReportDiagnostic(diagnostic); + } + + if (!SymbolEqualityComparer.Default.Equals(invocation.OutputType, activity.OutputType)) + { + string actual = invocation.OutputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; + string expected = activity.OutputType?.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat) ?? "none"; + string activityName = invocation.Name; + + Diagnostic diagnostic = RoslynExtensions.BuildDiagnostic(OutputArgumentTypeMismatchRule, invocation.InvocationSyntaxNode, actual, expected, activityName); + ctx.ReportDiagnostic(diagnostic); + } + } + }); + }); + } + + struct ActivityInvocation + { + public string Name { get; set; } + + public ITypeSymbol? InputType { get; set; } + + public ITypeSymbol? OutputType { get; set; } + + public SyntaxNode InvocationSyntaxNode { get; set; } + } + + struct ActivityDefinition + { + public string Name { get; set; } + + public ITypeSymbol? InputType { get; set; } + + public ITypeSymbol? OutputType { get; set; } + } +} diff --git a/src/Analyzers/AnalyzerReleases.Unshipped.md b/src/Analyzers/AnalyzerReleases.Unshipped.md index c2109dea..c9ecda5e 100644 --- a/src/Analyzers/AnalyzerReleases.Unshipped.md +++ b/src/Analyzers/AnalyzerReleases.Unshipped.md @@ -15,4 +15,6 @@ DURABLE0007 | Orchestration | Warning | CancellationTokenOrchestrationAnalyzer DURABLE0008 | Orchestration | Warning | OtherBindingsOrchestrationAnalyzer DURABLE1001 | Attribute Binding | Error | OrchestrationTriggerBindingAnalyzer DURABLE1002 | Attribute Binding | Error | DurableClientBindingAnalyzer -DURABLE1003 | Attribute Binding | Error | EntityTriggerBindingAnalyzer \ No newline at end of file +DURABLE1003 | Attribute Binding | Error | EntityTriggerBindingAnalyzer +DURABLE2001 | Activity | Warning | MatchingInputOutputTypeActivityAnalyzer +DURABLE2002 | Activity | Warning | MatchingInputOutputTypeActivityAnalyzer \ No newline at end of file diff --git a/src/Analyzers/AnalyzersCategories.cs b/src/Analyzers/AnalyzersCategories.cs index 8bdae885..4e57086c 100644 --- a/src/Analyzers/AnalyzersCategories.cs +++ b/src/Analyzers/AnalyzersCategories.cs @@ -17,4 +17,9 @@ static class AnalyzersCategories /// The category for the attribute binding related analyzers. /// public const string AttributeBinding = "Attribute Binding"; + + /// + /// The category for the activity related analyzers. + /// + public const string Activity = "Activity"; } diff --git a/src/Analyzers/KnownTypeSymbols.Durable.cs b/src/Analyzers/KnownTypeSymbols.Durable.cs index 454b8d6d..41d0719d 100644 --- a/src/Analyzers/KnownTypeSymbols.Durable.cs +++ b/src/Analyzers/KnownTypeSymbols.Durable.cs @@ -14,6 +14,7 @@ namespace Microsoft.DurableTask.Analyzers; public sealed partial class KnownTypeSymbols { INamedTypeSymbol? taskOrchestratorInterface; + INamedTypeSymbol? taskActivityBase; INamedTypeSymbol? durableTaskRegistry; INamedTypeSymbol? taskOrchestrationContext; INamedTypeSymbol? durableTaskClient; @@ -23,6 +24,11 @@ public sealed partial class KnownTypeSymbols /// public INamedTypeSymbol? TaskOrchestratorInterface => this.GetOrResolveFullyQualifiedType("Microsoft.DurableTask.ITaskOrchestrator", ref this.taskOrchestratorInterface); + /// + /// Gets a TaskActivity<TInput,TOutput> type symbol. + /// + public INamedTypeSymbol? TaskActivityBase => this.GetOrResolveFullyQualifiedType("Microsoft.DurableTask.TaskActivity`2", ref this.taskActivityBase); + /// /// Gets a DurableTaskRegistry type symbol. /// diff --git a/src/Analyzers/KnownTypeSymbols.Functions.cs b/src/Analyzers/KnownTypeSymbols.Functions.cs index a0c106ba..cd217b58 100644 --- a/src/Analyzers/KnownTypeSymbols.Functions.cs +++ b/src/Analyzers/KnownTypeSymbols.Functions.cs @@ -16,6 +16,7 @@ public sealed partial class KnownTypeSymbols INamedTypeSymbol? functionOrchestrationAttribute; INamedTypeSymbol? functionNameAttribute; INamedTypeSymbol? durableClientAttribute; + INamedTypeSymbol? activityTriggerAttribute; INamedTypeSymbol? entityTriggerAttribute; INamedTypeSymbol? taskEntityDispatcher; @@ -34,6 +35,11 @@ public sealed partial class KnownTypeSymbols /// public INamedTypeSymbol? DurableClientAttribute => this.GetOrResolveFullyQualifiedType("Microsoft.Azure.Functions.Worker.DurableClientAttribute", ref this.durableClientAttribute); + /// + /// Gets an ActivityTriggerAttribute type symbol. + /// + public INamedTypeSymbol? ActivityTriggerAttribute => this.GetOrResolveFullyQualifiedType("Microsoft.Azure.Functions.Worker.ActivityTriggerAttribute", ref this.activityTriggerAttribute); + /// /// Gets an EntityTriggerAttribute type symbol. /// diff --git a/src/Analyzers/Resources.resx b/src/Analyzers/Resources.resx index 6af3d76f..b8c51631 100644 --- a/src/Analyzers/Resources.resx +++ b/src/Analyzers/Resources.resx @@ -183,4 +183,16 @@ Thread and Task calls must be deterministic inside an orchestrator function + + CallActivityAsync is passing the incorrect type '{0}' instead of '{1}' to the activity function '{2}' + + + Activity function calls use the wrong argument type + + + CallActivityAsync is expecting the return type '{0}' and that does not match the return type '{1}' of the activity function '{2}' + + + Activity function call return type doesn't match the function definition return type + \ No newline at end of file diff --git a/src/Analyzers/RoslynExtensions.cs b/src/Analyzers/RoslynExtensions.cs index e1920524..10339f7a 100644 --- a/src/Analyzers/RoslynExtensions.cs +++ b/src/Analyzers/RoslynExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Collections.Immutable; +using System.Diagnostics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; @@ -85,6 +86,28 @@ public static bool BaseTypeIsConstructedFrom(this INamedTypeSymbol symbol, IType return methods.FirstOrDefault(m => m.OverriddenMethod != null && m.OverriddenMethod.OriginalDefinition.Equals(methodSymbol, SymbolEqualityComparer.Default)); } + /// + /// Gets the type argument of a method by its parameter name. + /// + /// Method symbol. + /// Type argument name. + /// The type argument symbol. + public static ITypeSymbol? GetTypeArgumentByParameterName(this IMethodSymbol method, string parameterName) + { + (ITypeParameterSymbol param, int idx) = method.TypeParameters + .Where(t => t.Name == parameterName) + .Select((t, i) => (t, i)) + .SingleOrDefault(); + + if (param != null) + { + Debug.Assert(idx >= 0, "parameter index is not negative"); + return method.TypeArguments[idx]; + } + + return null; + } + /// /// Gets the syntax nodes of a method symbol. /// diff --git a/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs b/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs new file mode 100644 index 00000000..2970f6ad --- /dev/null +++ b/test/Analyzers.Tests/Activities/MatchingInputOutputTypeActivityAnalyzerTests.cs @@ -0,0 +1,419 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.CodeAnalysis.Testing; +using Microsoft.DurableTask.Analyzers.Activities; +using VerifyCS = Microsoft.DurableTask.Analyzers.Tests.Verifiers.CSharpAnalyzerVerifier; + +namespace Microsoft.DurableTask.Analyzers.Tests.Activities; + +public class MatchingInputOutputTypeActivityAnalyzerTests +{ + [Fact] + public async Task DurableFunctionActivityInvocationWithMatchingInputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await context.CallActivityAsync(nameof(SayHello), ""Tokyo""); +} + +[Function(nameof(SayHello))] +void SayHello([ActivityTrigger] string name) +{ +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMismatchedInputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await {|#0:context.CallActivityAsync(nameof(SayHello), 123456)|}; +} + +[Function(nameof(SayHello))] +void SayHello([ActivityTrigger] string name) +{ +} +"); + DiagnosticResult expected = BuildInputDiagnostic().WithLocation(0).WithArguments("int", "string", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMissingInputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await {|#0:context.CallActivityAsync(nameof(SayHello))|}; +} + +[Function(nameof(SayHello))] +void SayHello([ActivityTrigger] string name) +{ +} +"); + DiagnosticResult expected = BuildInputDiagnostic().WithLocation(0).WithArguments("none", "string", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMatchingOutputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + int output = await {|#0:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; +} + +[Function(nameof(SayHello))] +int SayHello([ActivityTrigger] string name) +{ + return 42; +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMatchingTaskTOutputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + int output = await {|#0:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; +} + +[Function(nameof(SayHello))] +Task SayHello([ActivityTrigger] string name) +{ + return Task.FromResult(42); +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMatchingVoidOutputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await {|#0:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; +} + +[Function(nameof(SayHello))] +void SayHello([ActivityTrigger] string name) +{ +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMatchingTaskOutputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await {|#0:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; +} + +[Function(nameof(SayHello))] +Task SayHello([ActivityTrigger] string name) +{ + return Task.CompletedTask; +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task DurableFunctionActivityInvocationWithMismatchedOutputType() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + string output = await {|#0:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; +} + +[Function(nameof(SayHello))] +int SayHello([ActivityTrigger] string name) +{ + return 42; +} +"); + + DiagnosticResult expected = BuildOutputDiagnostic().WithLocation(0).WithArguments("string", "int", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + + [Fact] + public async Task TaskActivityInvocationWithMatchingInputTypeAndOutputType() + { + string code = Wrapper.WrapTaskOrchestrator(@" +public class Caller { + async Task Method(TaskOrchestrationContext context) + { + await context.CallActivityAsync(nameof(MyActivity), ""Tokyo""); + } +} + +public class MyActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, string cityName) + { + return Task.FromResult(cityName); + } +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task TaskActivityInvocationWithMismatchedInputType() + { + string code = Wrapper.WrapTaskOrchestrator(@" +public class Caller { + async Task Method(TaskOrchestrationContext context) + { + await {|#0:context.CallActivityAsync(nameof(MyActivity), ""Tokyo"")|}; + } +} + +public class MyActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, int cityCode) + { + return Task.FromResult(cityCode.ToString()); + } +} +"); + + DiagnosticResult expected = BuildInputDiagnostic().WithLocation(0).WithArguments("string", "int", "MyActivity"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task TaskActivityInvocationWithMismatchedOutputType() + { + string code = Wrapper.WrapTaskOrchestrator(@" +public class Caller { + async Task Method(TaskOrchestrationContext context) + { + await {|#0:context.CallActivityAsync(nameof(MyActivity), ""Tokyo"")|}; + } +} + +public class MyActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, string city) + { + return Task.FromResult(city.Length); + } +} +"); + + DiagnosticResult expected = BuildOutputDiagnostic().WithLocation(0).WithArguments("string", "int", "MyActivity"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task TaskActivityInvocationWithOneTypeParameterDefinedInAbstractClass() + { + string code = Wrapper.WrapTaskOrchestrator(@" +public class Caller { + async Task Method(TaskOrchestrationContext context) + { + await context.CallActivityAsync(nameof(AnotherActivity), 5); + } +} + +public class AnotherActivity : EchoActivity { } + +public abstract class EchoActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, T input) + { + return Task.FromResult(input); + } +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + + [Fact] + public async Task LambdaActivityInvocationWithMatchingInputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await context.CallActivityAsync(""SayHello"", ""Tokyo"")); + +tasks.AddActivityFunc(""SayHello"", (context, city) => { }); +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task LambdaActivityInvocationWithMatchingNoInputTypeAndNoOutputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await context.CallActivityAsync(""SayHello"")); + +tasks.AddActivityFunc(""SayHello"", (context) => { }); +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task LambdaActivityInvocationWithMismatchedInputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await {|#0:context.CallActivityAsync(""SayHello"", 42)|}); + +tasks.AddActivityFunc(""SayHello"", (context, city) => { }); +"); + + DiagnosticResult expected = BuildInputDiagnostic().WithLocation(0).WithArguments("int", "string", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task LambdaActivityInvocationWithMismatchedNoInputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await {|#0:context.CallActivityAsync(""SayHello"", ""Tokyo"")|}); + +tasks.AddActivityFunc(""SayHello"", (context) => { }); +"); + + DiagnosticResult expected = BuildInputDiagnostic().WithLocation(0).WithArguments("string", "none", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task LambdaActivityInvocationWithMatchingOutputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await context.CallActivityAsync(""SayHello"")); + +tasks.AddActivityFunc(""SayHello"", (context) => ""hello""); +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + [Fact] + public async Task LambdaActivityInvocationWithMismatchedOutputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await {|#0:context.CallActivityAsync(""SayHello"")|}); + +tasks.AddActivityFunc(""SayHello"", (context) => ""hello""); +"); + + DiagnosticResult expected = BuildOutputDiagnostic().WithLocation(0).WithArguments("int", "string", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task LambdaActivityInvocationWithMismatchedNoOutputType() + { + string code = Wrapper.WrapFuncOrchestrator(@" +tasks.AddOrchestratorFunc(""HelloSequence"", async context => + await {|#0:context.CallActivityAsync(""SayHello"")|}); + +tasks.AddActivityFunc(""SayHello"", (context) => { }); +"); + + DiagnosticResult expected = BuildOutputDiagnostic().WithLocation(0).WithArguments("int", "none", "SayHello"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + + [Fact] + public async Task ActivityInvocationWithConstantNameIsDiscovered() + { + string code = Wrapper.WrapDurableFunctionOrchestration(@" +const string name = ""SayHello""; + +async Task Method(TaskOrchestrationContext context) +{ + // the ones containing the output mismatch diagnostic mean they were discovered + await {|#0:context.CallActivityAsync(""SayHello"", ""Tokyo"")|}; + await {|#1:context.CallActivityAsync(nameof(SayHello), ""Tokyo"")|}; + await {|#2:context.CallActivityAsync(name, ""Tokyo"")|}; + + // not diagnostics here, because the name could not be determined (since it is not a constant) + string anotherName = ""SayHello""; + await context.CallActivityAsync(anotherName, ""Tokyo""); +} + +[Function(nameof(SayHello))] +int SayHello([ActivityTrigger] string name) => 42; +"); + + DiagnosticResult[] expected = Enumerable.Range(0, 3).Select(i => + BuildOutputDiagnostic().WithLocation(i).WithArguments("string", "int", "SayHello")).ToArray(); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code, expected); + } + + [Fact] + public async Task ActivityInvocationWithNonExistentActivity() + { + // When the Activity is not found, we cannot correlate this invocation to an existent activity in compile time + // or it is defined in another assembly. We could add a diagnostic here if we want to enforce that, + // but while we experiment with this analyzer, we will not report a diagnostic to prevent false positives. + + string code = Wrapper.WrapDurableFunctionOrchestration(@" +async Task Method(TaskOrchestrationContext context) +{ + await context.CallActivityAsync(""ActivityNotFound"", ""Tokyo""); +} +"); + + await VerifyCS.VerifyDurableTaskAnalyzerAsync(code); + } + + + static DiagnosticResult BuildInputDiagnostic() + { + return VerifyCS.Diagnostic(MatchingInputOutputTypeActivityAnalyzer.InputArgumentTypeMismatchDiagnosticId); + } + + static DiagnosticResult BuildOutputDiagnostic() + { + return VerifyCS.Diagnostic(MatchingInputOutputTypeActivityAnalyzer.OutputArgumentTypeMismatchDiagnosticId); + } +}