Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplifying ITaskOrchestration detection logic #305

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/Analyzers/KnownTypeSymbols.Durable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace Microsoft.DurableTask.Analyzers;
public sealed partial class KnownTypeSymbols
{
INamedTypeSymbol? taskOrchestratorInterface;
INamedTypeSymbol? taskOrchestratorBaseClass;
INamedTypeSymbol? durableTaskRegistry;
INamedTypeSymbol? taskOrchestrationContext;
INamedTypeSymbol? durableTaskClient;
Expand All @@ -24,11 +23,6 @@ public sealed partial class KnownTypeSymbols
/// </summary>
public INamedTypeSymbol? TaskOrchestratorInterface => this.GetOrResolveFullyQualifiedType("Microsoft.DurableTask.ITaskOrchestrator", ref this.taskOrchestratorInterface);

/// <summary>
/// Gets a TaskOrchestrator type symbol.
/// </summary>
public INamedTypeSymbol? TaskOrchestratorBaseClass => this.GetOrResolveFullyQualifiedType("Microsoft.DurableTask.TaskOrchestrator`2", ref this.taskOrchestratorBaseClass);

/// <summary>
/// Gets a DurableTaskRegistry type symbol.
/// </summary>
Expand All @@ -39,7 +33,6 @@ public sealed partial class KnownTypeSymbols
/// </summary>
public INamedTypeSymbol? TaskOrchestrationContext => this.GetOrResolveFullyQualifiedType("Microsoft.DurableTask.TaskOrchestrationContext", ref this.taskOrchestrationContext);


/// <summary>
/// Gets a DurableTaskClient type symbol.
/// </summary>
Expand Down
84 changes: 13 additions & 71 deletions src/Analyzers/Orchestration/OrchestrationAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,13 @@ public override void Initialize(AnalysisContext context)
KnownTypeSymbols knownSymbols = new(context.Compilation);

if (knownSymbols.FunctionOrchestrationAttribute == null || knownSymbols.FunctionNameAttribute == null ||
knownSymbols.TaskOrchestratorInterface == null || knownSymbols.TaskOrchestratorBaseClass == null ||
knownSymbols.TaskOrchestratorInterface == null ||
knownSymbols.DurableTaskRegistry == null)
{
// symbols not available in this compilation, skip analysis
return;
}

IMethodSymbol? runAsyncTaskOrchestratorInterface = knownSymbols.TaskOrchestratorInterface.GetMembers("RunAsync").OfType<IMethodSymbol>().FirstOrDefault();
IMethodSymbol? runAsyncTaskOrchestratorBase = knownSymbols.TaskOrchestratorBaseClass.GetMembers("RunAsync").OfType<IMethodSymbol>().FirstOrDefault();
if (runAsyncTaskOrchestratorInterface == null || runAsyncTaskOrchestratorBase == null)
{
return;
}

TOrchestrationVisitor visitor = new();
if (!visitor.Initialize(context.Compilation, knownSymbols))
{
Expand Down Expand Up @@ -75,7 +68,7 @@ public override void Initialize(AnalysisContext context)
},
SyntaxKind.MethodDeclaration);

// look for TaskOrchestrator`2 Orchestrations
// look for ITaskOrchestrator/TaskOrchestrator`2 Orchestrations
context.RegisterSyntaxNodeAction(
ctx =>
{
Expand All @@ -86,57 +79,24 @@ public override void Initialize(AnalysisContext context)
return;
}

if (!classSymbol.BaseTypeIsConstructedFrom(knownSymbols.TaskOrchestratorBaseClass))
bool implementsITaskOrchestrator = classSymbol.AllInterfaces.Any(i => i.Equals(knownSymbols.TaskOrchestratorInterface, SymbolEqualityComparer.Default));
if (!implementsITaskOrchestrator)
{
return;
}

// Get the method that overrides TaskOrchestrator.RunAsync
IMethodSymbol? methodSymbol = classSymbol.GetOverridenMethod(runAsyncTaskOrchestratorBase);
if (methodSymbol == null)
{
return;
}
IEnumerable<IMethodSymbol> orchestrationMethods = classSymbol.GetMembers().OfType<IMethodSymbol>()
.Where(m => m.Parameters.Any(p => p.Type.Equals(knownSymbols.TaskOrchestrationContext, SymbolEqualityComparer.Default)));

string functionName = classSymbol.Name;

IEnumerable<MethodDeclarationSyntax> methodSyntaxes = methodSymbol.GetSyntaxNodes();
foreach (MethodDeclarationSyntax rootMethodSyntax in methodSyntaxes)
foreach (IMethodSymbol? methodSymbol in orchestrationMethods)
{
visitor.VisitTaskOrchestrator(ctx.SemanticModel, rootMethodSyntax, methodSymbol, functionName, ctx.ReportDiagnostic);
}
},
SyntaxKind.ClassDeclaration);

// look for ITaskOrchestrator Orchestrations
context.RegisterSyntaxNodeAction(
ctx =>
{
ctx.CancellationToken.ThrowIfCancellationRequested();

if (ctx.ContainingSymbol is not INamedTypeSymbol classSymbol)
{
return;
}

// Gets the method that implements ITaskOrchestrator.RunAsync
if (classSymbol.FindImplementationForInterfaceMember(runAsyncTaskOrchestratorInterface) is not IMethodSymbol methodSymbol)
{
return;
}

// Skip if the found method is implemented in TaskOrchestrator<TInput, TOutput>
if (methodSymbol.ContainingType.ConstructedFrom.Equals(knownSymbols.TaskOrchestratorBaseClass, SymbolEqualityComparer.Default))
{
return;
}

string functionName = classSymbol.Name;

IEnumerable<MethodDeclarationSyntax> methodSyntaxes = methodSymbol.GetSyntaxNodes();
foreach (MethodDeclarationSyntax rootMethodSyntax in methodSyntaxes)
{
visitor.VisitITaskOrchestrator(ctx.SemanticModel, rootMethodSyntax, methodSymbol, functionName, ctx.ReportDiagnostic);
IEnumerable<MethodDeclarationSyntax> methodSyntaxes = methodSymbol.GetSyntaxNodes();
foreach (MethodDeclarationSyntax rootMethodSyntax in methodSyntaxes)
{
visitor.VisitTaskOrchestrator(ctx.SemanticModel, rootMethodSyntax, methodSymbol, functionName, ctx.ReportDiagnostic);
}
}
},
SyntaxKind.ClassDeclaration);
Expand Down Expand Up @@ -256,7 +216,7 @@ public virtual void VisitDurableFunction(SemanticModel semanticModel, MethodDecl
}

/// <summary>
/// Visits a TaskOrchestrator&lt;T1,T2&gt; orchestration.
/// Visits a strongly typed Task Orchestrator that implements an ITaskOrchestrator orchestration.
/// </summary>
/// <param name="semanticModel">Semantic Model.</param>
/// <param name="methodSyntax">Method Syntax Node.</param>
Expand All @@ -267,18 +227,6 @@ public virtual void VisitTaskOrchestrator(SemanticModel semanticModel, MethodDec
{
}

/// <summary>
/// Visits an ITaskOrchestrator orchestration.
/// </summary>
/// <param name="semanticModel">Semantic Model.</param>
/// <param name="methodSyntax">Method Syntax Node.</param>
/// <param name="methodSymbol">Method Symbol.</param>
/// <param name="orchestrationName">Class name.</param>
/// <param name="reportDiagnostic">Function that can be used to report diagnostics.</param>
public virtual void VisitITaskOrchestrator(SemanticModel semanticModel, MethodDeclarationSyntax methodSyntax, IMethodSymbol methodSymbol, string orchestrationName, Action<Diagnostic> reportDiagnostic)
{
}

/// <summary>
/// Visits an Orchestrator Func orchestration.
/// </summary>
Expand Down Expand Up @@ -311,12 +259,6 @@ public override void VisitTaskOrchestrator(SemanticModel semanticModel, MethodDe
this.FindInvokedMethods(semanticModel, methodSyntax, methodSymbol, orchestrationName, reportDiagnostic);
}

/// <inheritdoc/>
public override void VisitITaskOrchestrator(SemanticModel semanticModel, MethodDeclarationSyntax methodSyntax, IMethodSymbol methodSymbol, string orchestrationName, Action<Diagnostic> reportDiagnostic)
{
this.FindInvokedMethods(semanticModel, methodSyntax, methodSymbol, orchestrationName, reportDiagnostic);
}

/// <inheritdoc/>
public override void VisitFuncOrchestrator(SemanticModel semanticModel, SyntaxNode methodSyntax, IMethodSymbol methodSymbol, string orchestrationName, Action<Diagnostic> reportDiagnostic)
{
Expand Down
Loading