From 824fd708e186c6ffcaf8aa7d54c88fa3c764c059 Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Fri, 11 Aug 2023 14:30:27 -0700 Subject: [PATCH 1/9] Refactor TaskEntity to have State as a property --- src/Abstractions/Entities/TaskEntity.cs | 181 +++--------------- .../Entities/TaskEntityContext.cs | 16 +- .../Entities/TaskEntityHelpers.cs | 105 ++++++++++ .../Entities/TaskEntityOperation.cs | 9 + .../Entities/TaskEntityOperationExtensions.cs | 143 ++++++++++++++ .../Entities/TaskEntityTests.cs | 79 ++++++-- 6 files changed, 348 insertions(+), 185 deletions(-) create mode 100644 src/Abstractions/Entities/TaskEntityHelpers.cs create mode 100644 src/Abstractions/Entities/TaskEntityOperationExtensions.cs diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index ec40154c..f7336f57 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System.Reflection; -using System.Threading.Tasks; namespace Microsoft.DurableTask.Entities; @@ -30,6 +29,7 @@ public interface ITaskEntity /// /// An which dispatches its operations to public instance methods or properties. /// +/// The state type held by this entity. /// /// Method Binding /// @@ -75,172 +75,35 @@ public interface ITaskEntity /// completes. /// /// -public abstract class TaskEntity : ITaskEntity +public abstract class TaskEntity : ITaskEntity { - /** - * TODO: - * 1. Consider caching a compiled delegate for a given operation name. - */ - static readonly BindingFlags InstanceBindingFlags - = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; + /// + /// Gets or sets the state for this entity. + /// + protected TState State { get; set; } = default!; // leave null-checks to end implementation. + + /// + /// Gets the entity operation. + /// + protected TaskEntityOperation Operation { get; private set; } = null!; + + /// + /// Gets the entity context. + /// + protected TaskEntityContext Context => this.Operation.Context; /// public ValueTask RunAsync(TaskEntityOperation operation) { - Check.NotNull(operation); - if (!this.TryDispatchMethod(operation, out object? result, out Type returnType)) + this.Operation = Check.NotNull(operation); + object? state = operation.Context.GetState(typeof(TState)); + this.State = state is null ? default! : (TState)state; + if (!operation.TryDispatch(this, out object? result, out Type returnType) + && !operation.TryDispatch(this.State, out result, out returnType)) { throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); } - if (typeof(Task).IsAssignableFrom(returnType)) - { - // Task or Task - return new(AsGeneric((Task)result!, returnType)); // we assume a declared Task return type is never null. - } - - if (returnType == typeof(ValueTask)) - { - // ValueTask - return AsGeneric((ValueTask)result!); // we assume a declared ValueTask return type is never null. - } - - if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) - { - // ValueTask - return AsGeneric(result!, returnType); // No inheritance, have to do purely via reflection. - } - - return new(result); - } - - static bool TryGetInput(ParameterInfo parameter, TaskEntityOperation operation, out object? input) - { - if (!operation.HasInput) - { - if (parameter.HasDefaultValue) - { - input = parameter.DefaultValue; - return true; - } - - input = null; - return false; - } - - input = operation.GetInput(parameter.ParameterType); - return true; - } - - static async Task AsGeneric(Task task, Type declared) - { - await task; - if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>)) - { - return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); - } - - return null; - } - - static ValueTask AsGeneric(ValueTask t) - { - static async Task Await(ValueTask t) - { - await t; - return null; - } - - if (t.IsCompletedSuccessfully) - { - return default; - } - - return new(Await(t)); - } - - static ValueTask AsGeneric(object result, Type type) - { - // result and type here must be some form of ValueTask. - if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) - { - return new(type.GetProperty("Result").GetValue(result)); - } - else - { - Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public) - .Invoke(result, null); - return new(t.ToGeneric()); - } - } - - bool TryDispatchMethod(TaskEntityOperation operation, out object? result, out Type returnType) - { - Type t = this.GetType(); - - // Will throw AmbiguousMatchException if more than 1 overload for the method name exists. - MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags); - if (method is null) - { - result = null; - returnType = typeof(void); - return false; - } - - ParameterInfo[] parameters = method.GetParameters(); - object?[] inputs = new object[parameters.Length]; - - int i = 0; - ParameterInfo? inputResolved = null; - ParameterInfo? contextResolved = null; - ParameterInfo? operationResolved = null; - foreach (ParameterInfo parameter in parameters) - { - if (parameter.ParameterType == typeof(TaskEntityContext)) - { - ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation); - inputs[i] = operation.Context; - contextResolved = parameter; - } - else if (parameter.ParameterType == typeof(TaskEntityOperation)) - { - ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation); - inputs[i] = operation; - operationResolved = parameter; - } - else - { - ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation); - if (TryGetInput(parameter, operation, out object? input)) - { - inputs[i] = input; - inputResolved = parameter; - } - else - { - throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" + - $"There was an error binding parameter '{parameter}'. The operation expected an input value, " + - "but no input was provided by the caller."); - } - } - - i++; - } - - result = method.Invoke(this, inputs); - returnType = method.ReturnType; - return true; - - static void ThrowIfDuplicateBinding( - ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation) - { - if (existing is not null) - { - throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" + - $"Unable to bind {bindingConcept} to '{parameter}' because it has " + - $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " + - $"'{parameter.Member}'.\nEntity operation: {operation}."); - } - } + return TaskEntityHelpers.UnwrapAsync(this.Context, () => this.State, result, returnType); } } diff --git a/src/Abstractions/Entities/TaskEntityContext.cs b/src/Abstractions/Entities/TaskEntityContext.cs index 1b551256..cf35d0c7 100644 --- a/src/Abstractions/Entities/TaskEntityContext.cs +++ b/src/Abstractions/Entities/TaskEntityContext.cs @@ -53,11 +53,15 @@ public abstract void StartOrchestration( TaskName name, object? input = null, StartOrchestrationOptions? options = null); /// - /// Deletes the state of this entity after the current operation completes. + /// Gets the current state for the entity this context is for. /// - /// - /// The state deletion only takes effect after the current operation completes. Any state changes made during the - /// current operation will be ignored in favor of the deletion. - /// - public abstract void DeleteState(); + /// The type to retrieve the state as. + /// The entity state. + public abstract object? GetState(Type type); + + /// + /// Sets the entity state. Setting of null will clear entity state. + /// + /// The state to set. + public abstract void SetState(object? state); } diff --git a/src/Abstractions/Entities/TaskEntityHelpers.cs b/src/Abstractions/Entities/TaskEntityHelpers.cs new file mode 100644 index 00000000..14a03cf6 --- /dev/null +++ b/src/Abstractions/Entities/TaskEntityHelpers.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; + +namespace Microsoft.DurableTask.Entities; + +/// +/// Helpers for task entities. +/// +static class TaskEntityHelpers +{ + /// + /// Unwraps a dispatched result for a into a + /// + /// The entity context. + /// Delegate to resolve new state for the entity. + /// The result of the operation. + /// The declared type of the result (may be different that actual type). + /// A value task which holds the result of the operation and sets state before it completes. + public static ValueTask UnwrapAsync( + TaskEntityContext context, Func state, object? result, Type resultType) + { + // NOTE: Func is used for state so that we can lazily resolve it AFTER the operation has ran. + Check.NotNull(context); + Check.NotNull(resultType); + + if (typeof(Task).IsAssignableFrom(resultType)) + { + // Task or Task + // We assume a declared Task return type is never null. + return new(UnwrapTask(context, state, (Task)result!, resultType)); + } + + if (resultType == typeof(ValueTask)) + { + // ValueTask + // We assume a declared ValueTask return type is never null. + return UnwrapValueTask(context, state, (ValueTask)result!); + } + + if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + // ValueTask + // No inheritance, have to do purely via reflection. + return UnwrapValueTaskOfT(context, state, result!, resultType); + } + + context.SetState(state()); + return new(result); + } + + static async Task UnwrapTask(TaskEntityContext context, Func state, Task task, Type declared) + { + await task; + context.SetState(state()); + if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>)) + { + return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + } + + return null; + } + + static ValueTask UnwrapValueTask(TaskEntityContext context, Func state, ValueTask t) + { + async Task Await(ValueTask t) + { + await t; + context.SetState(state()); + return null; + } + + if (t.IsCompletedSuccessfully) + { + context.SetState(state()); + return default; + } + + return new(Await(t)); + } + + static ValueTask UnwrapValueTaskOfT(TaskEntityContext context, Func state, object result, Type type) + { + async Task Await(Task t) + { + await t; + context.SetState(state()); + return null; + } + + // result and type here must be some form of ValueTask. + if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) + { + context.SetState(state()); + return new(type.GetProperty("Result").GetValue(result)); + } + else + { + Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public) + .Invoke(result, null); + return new(Await(t)); + } + } +} diff --git a/src/Abstractions/Entities/TaskEntityOperation.cs b/src/Abstractions/Entities/TaskEntityOperation.cs index 80f89461..0ae38796 100644 --- a/src/Abstractions/Entities/TaskEntityOperation.cs +++ b/src/Abstractions/Entities/TaskEntityOperation.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Reflection; + namespace Microsoft.DurableTask.Entities; /// @@ -8,6 +10,13 @@ namespace Microsoft.DurableTask.Entities; /// public abstract class TaskEntityOperation { + /** + * TODO: + * 1. Consider caching a compiled delegate for a given operation name. + */ + static readonly BindingFlags InstanceBindingFlags + = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; + /// /// Gets the name of the operation. /// diff --git a/src/Abstractions/Entities/TaskEntityOperationExtensions.cs b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs new file mode 100644 index 00000000..9a0423bf --- /dev/null +++ b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; + +namespace Microsoft.DurableTask.Entities; + +/// +/// Extensions for . +/// +public static class TaskEntityOperationExtensions +{ + /** + * TODO: + * 1. Consider caching a compiled delegate for a given operation name. + */ + static readonly BindingFlags InstanceBindingFlags + = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; + + /// + /// Dispatches an operation to a type . Entity state will be resolved as + /// and then the operation dispatch will be attempted on the constructed instance. + /// + /// The type to dispatch to. + /// The operation to dispatch. + /// The result returned by the method on the operation is dispatched to. + /// + /// If no suitable method is found on for dispatch. + /// + public static ValueTask DispatchAsync(this TaskEntityOperation operation) + { + // NOTE: when dispatching this way, we do not support value types as we have no way of capturing the changed + // value due to defensive copies. + Check.NotNull(operation); + object target = operation.GetInput(typeof(T)) ?? Activator.CreateInstance(typeof(T)); + if (!operation.TryDispatch(target, out object? result, out Type returnType)) + { + throw new NotSupportedException( + $"No suitable method on {typeof(T)} found for entity operation '{operation}'."); + } + + return TaskEntityHelpers.UnwrapAsync(operation.Context, () => target, result, returnType); + } + + /// + /// Try to dispatch this operation via reflection to a method on . + /// + /// The operation that is being dispatched. + /// The target to dispatch to. + /// The result of the dispatch. + /// The declared return type of the dispatched method. + /// True if dispatch successful, false otherwise. + internal static bool TryDispatch( + this TaskEntityOperation operation, object target, out object? result, out Type returnType) + { + Check.NotNull(operation); + Check.NotNull(target); + Type t = target.GetType(); + + // Will throw AmbiguousMatchException if more than 1 overload for the method name exists. + MethodInfo? method = t.GetMethod(operation.Name, InstanceBindingFlags); + if (method is null) + { + result = null; + returnType = typeof(void); + return false; + } + + ParameterInfo[] parameters = method.GetParameters(); + object?[] inputs = new object[parameters.Length]; + + int i = 0; + ParameterInfo? inputResolved = null; + ParameterInfo? contextResolved = null; + ParameterInfo? operationResolved = null; + foreach (ParameterInfo parameter in parameters) + { + if (parameter.ParameterType == typeof(TaskEntityContext)) + { + ThrowIfDuplicateBinding(contextResolved, parameter, "context", operation); + inputs[i] = operation.Context; + contextResolved = parameter; + } + else if (parameter.ParameterType == typeof(TaskEntityOperation)) + { + ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation); + inputs[i] = operation; + operationResolved = parameter; + } + else + { + ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation); + if (operation.TryGetInput(parameter, out object? input)) + { + inputs[i] = input; + inputResolved = parameter; + } + else + { + throw new InvalidOperationException($"Error dispatching {operation} to '{method}'.\n" + + $"There was an error binding parameter '{parameter}'. The operation expected an input value, " + + "but no input was provided by the caller."); + } + } + + i++; + } + + result = method.Invoke(target, inputs); + returnType = method.ReturnType; + return true; + + static void ThrowIfDuplicateBinding( + ParameterInfo? existing, ParameterInfo parameter, string bindingConcept, TaskEntityOperation operation) + { + if (existing is not null) + { + throw new InvalidOperationException($"Error dispatching {operation} to '{parameter.Member}'.\n" + + $"Unable to bind {bindingConcept} to '{parameter}' because it has " + + $"already been bound to parameter '{existing}'. Please remove the duplicate parameter in method " + + $"'{parameter.Member}'.\nEntity operation: {operation}."); + } + } + } + + static bool TryGetInput(this TaskEntityOperation operation, ParameterInfo parameter, out object? input) + { + if (!operation.HasInput) + { + if (parameter.HasDefaultValue) + { + input = parameter.DefaultValue; + return true; + } + + input = null; + return false; + } + + input = operation.GetInput(parameter.ParameterType); + return true; + } +} diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/TaskEntityTests.cs index ea5129c8..901230a4 100644 --- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/TaskEntityTests.cs @@ -14,7 +14,7 @@ public class TaskEntityTests [InlineData("staticMethod")] // public static methods are not supported. public async Task OperationNotSupported_Fails(string name) { - Operation operation = new(name, Mock.Of(), 10); + Operation operation = new(name, new Context(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -28,7 +28,7 @@ public async Task TaskOperation_Success( [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) { object? expected = name.Contains("OfString") ? "success" : null; - Operation operation = new(name, Mock.Of(), sync); + Operation operation = new(name, new Context(null), sync); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -43,13 +43,14 @@ public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lower int start = Random.Shared.Next(0, 10); int toAdd = Random.Shared.Next(0, 10); string opName = lowercase ? "add" : "Add"; - Operation operation = new($"{opName}{method}", Mock.Of(), toAdd); - TestEntity entity = new() { Value = start }; + Context context = new(start); + Operation operation = new($"{opName}{method}", context, toAdd); + TestEntity entity = new(); object? result = await entity.RunAsync(operation); int expected = start + toAdd; - entity.Value.Should().Be(expected); + context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected); result.Should().BeOfType().Which.Should().Be(expected); } @@ -59,19 +60,20 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc { int expected = Random.Shared.Next(0, 10); string opName = lowercase ? "get" : "Get"; - Operation operation = new($"{opName}{method}", Mock.Of(), default); - TestEntity entity = new() { Value = expected }; + Context context = new(expected); + Operation operation = new($"{opName}{method}", context, default); + TestEntity entity = new(); object? result = await entity.RunAsync(operation); - entity.Value.Should().Be(expected); + context.GetState(typeof(int)).Should().BeOfType().Which.Should().Be(expected); result.Should().BeOfType().Which.Should().Be(expected); } [Fact] public async Task Add_NoInput_Fails() { - Operation operation = new("add0", Mock.Of(), default); + Operation operation = new("add0", new Context(null), default); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -83,7 +85,7 @@ public async Task Add_NoInput_Fails() [CombinatorialData] public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) { - Operation operation = new($"ambiguousArgs{method}", Mock.Of(), 10); + Operation operation = new($"ambiguousArgs{method}", new Context(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -94,7 +96,7 @@ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int me [Fact] public async Task Dispatch_AmbiguousMatch_Fails() { - Operation operation = new("ambiguousMatch", Mock.Of(), 10); + Operation operation = new("ambiguousMatch", new Context(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -104,7 +106,7 @@ public async Task Dispatch_AmbiguousMatch_Fails() [Fact] public async Task DefaultValue_NoInput_Succeeds() { - Operation operation = new("defaultValue", Mock.Of(), default); + Operation operation = new("defaultValue", new Context(null), default); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -115,7 +117,7 @@ public async Task DefaultValue_NoInput_Succeeds() [Fact] public async Task DefaultValue_Input_Succeeds() { - Operation operation = new("defaultValue", Mock.Of(), "not-default"); + Operation operation = new("defaultValue", new Context(null), "not-default"); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -123,6 +125,45 @@ public async Task DefaultValue_Input_Succeeds() result.Should().BeOfType().Which.Should().Be("not-default"); } + class Context : TaskEntityContext + { + public Context(object? state) + { + this.State = state; + } + + public object? State { get; private set; } + + public override EntityInstanceId Id { get; } + + public override object? GetState(Type type) + { + return this.State switch + { + null => null, + _ when type.IsAssignableFrom(this.State.GetType()) => this.State, + _ => throw new InvalidCastException() + }; + } + + public override void SetState(object? state) + { + this.State = state; + } + + public override void SignalEntity( + EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) + { + throw new NotImplementedException(); + } + + public override void StartOrchestration( + TaskName name, object? input = null, StartOrchestrationOptions? options = null) + { + throw new NotImplementedException(); + } + } + class Operation : TaskEntityOperation { readonly Optional input; @@ -161,10 +202,8 @@ public Operation(string name, TaskEntityContext context, Optional input } } - class TestEntity : TaskEntity + class TestEntity : TaskEntity { - public int Value { get; set; } - public static string StaticMethod() => throw new NotImplementedException(); // All possible permutations of the 3 inputs we support: object, context, operation @@ -210,9 +249,9 @@ public int Add13(TaskEntityContext context, TaskEntityOperation operation, int v public int Get1(TaskEntityContext context) => this.Get(context); - public int AmbiguousMatch(TaskEntityContext context) => this.Value; + public int AmbiguousMatch(TaskEntityContext context) => this.State; - public int AmbiguousMatch(TaskEntityOperation operation) => this.Value; + public int AmbiguousMatch(TaskEntityOperation operation) => this.State; public int AmbiguousArgs0(int value, object other) => this.Add0(value); @@ -283,7 +322,7 @@ int Add(int? value, Optional context, Optional context) @@ -293,7 +332,7 @@ int Get(Optional context) context.Value.Should().NotBeNull(); } - return this.Value; + return this.State; } } } From 644b15e6c64477f5136e4a3a01d20d3396a1a285 Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Fri, 11 Aug 2023 15:32:21 -0700 Subject: [PATCH 2/9] Allow for disabling state dispatch on TaskEntity --- src/Abstractions/Entities/TaskEntity.cs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index f7336f57..32ec1363 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -77,6 +77,12 @@ public interface ITaskEntity /// public abstract class TaskEntity : ITaskEntity { + /// + /// Gets a value indicating whether dispatching operations to is allowed. State dispatch will + /// only be attempted if entity-level dispatch does not succeed. Default is true. + /// + protected bool AllowStateDispatch => true; + /// /// Gets or sets the state for this entity. /// @@ -99,7 +105,7 @@ public abstract class TaskEntity : ITaskEntity object? state = operation.Context.GetState(typeof(TState)); this.State = state is null ? default! : (TState)state; if (!operation.TryDispatch(this, out object? result, out Type returnType) - && !operation.TryDispatch(this.State, out result, out returnType)) + && (this.AllowStateDispatch && !operation.TryDispatch(this.State, out result, out returnType))) { throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); } From 36062ceda61bb3b88e936566ac1e85db0a62e961 Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Fri, 11 Aug 2023 16:41:40 -0700 Subject: [PATCH 3/9] Add more test cases, fix ValueTask result --- .../Entities/TaskEntityHelpers.cs | 19 +- .../Entities/TaskEntityOperation.cs | 9 - .../Entities/TaskEntityOperationExtensions.cs | 25 -- .../Entities/Mocks/TestEntityContext.cs | 43 ++++ .../Entities/Mocks/TestEntityOperation.cs | 54 ++++ .../Entities/TaskEntityHelpersTests.cs | 238 ++++++++++++++++++ .../Entities/TaskEntityTests.cs | 101 +------- 7 files changed, 354 insertions(+), 135 deletions(-) create mode 100644 test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs create mode 100644 test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs create mode 100644 test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs diff --git a/src/Abstractions/Entities/TaskEntityHelpers.cs b/src/Abstractions/Entities/TaskEntityHelpers.cs index 14a03cf6..14a39505 100644 --- a/src/Abstractions/Entities/TaskEntityHelpers.cs +++ b/src/Abstractions/Entities/TaskEntityHelpers.cs @@ -80,16 +80,11 @@ static class TaskEntityHelpers return new(Await(t)); } - static ValueTask UnwrapValueTaskOfT(TaskEntityContext context, Func state, object result, Type type) + static ValueTask UnwrapValueTaskOfT( + TaskEntityContext context, Func state, object result, Type type) { - async Task Await(Task t) - { - await t; - context.SetState(state()); - return null; - } - - // result and type here must be some form of ValueTask. + // Result and type here must be some form of ValueTask. + // TODO: can this amount of reflection be avoided? if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result)) { context.SetState(state()); @@ -97,9 +92,9 @@ static class TaskEntityHelpers } else { - Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public) - .Invoke(result, null); - return new(Await(t)); + Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public).Invoke(result, null); + Type taskType = typeof(Task<>).MakeGenericType(type.GetGenericArguments()[0]); + return new(UnwrapTask(context, state, t, taskType)); } } } diff --git a/src/Abstractions/Entities/TaskEntityOperation.cs b/src/Abstractions/Entities/TaskEntityOperation.cs index 0ae38796..80f89461 100644 --- a/src/Abstractions/Entities/TaskEntityOperation.cs +++ b/src/Abstractions/Entities/TaskEntityOperation.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.Reflection; - namespace Microsoft.DurableTask.Entities; /// @@ -10,13 +8,6 @@ namespace Microsoft.DurableTask.Entities; /// public abstract class TaskEntityOperation { - /** - * TODO: - * 1. Consider caching a compiled delegate for a given operation name. - */ - static readonly BindingFlags InstanceBindingFlags - = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; - /// /// Gets the name of the operation. /// diff --git a/src/Abstractions/Entities/TaskEntityOperationExtensions.cs b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs index 9a0423bf..303032da 100644 --- a/src/Abstractions/Entities/TaskEntityOperationExtensions.cs +++ b/src/Abstractions/Entities/TaskEntityOperationExtensions.cs @@ -17,31 +17,6 @@ public static class TaskEntityOperationExtensions static readonly BindingFlags InstanceBindingFlags = BindingFlags.Public | BindingFlags.Instance | BindingFlags.IgnoreCase; - /// - /// Dispatches an operation to a type . Entity state will be resolved as - /// and then the operation dispatch will be attempted on the constructed instance. - /// - /// The type to dispatch to. - /// The operation to dispatch. - /// The result returned by the method on the operation is dispatched to. - /// - /// If no suitable method is found on for dispatch. - /// - public static ValueTask DispatchAsync(this TaskEntityOperation operation) - { - // NOTE: when dispatching this way, we do not support value types as we have no way of capturing the changed - // value due to defensive copies. - Check.NotNull(operation); - object target = operation.GetInput(typeof(T)) ?? Activator.CreateInstance(typeof(T)); - if (!operation.TryDispatch(target, out object? result, out Type returnType)) - { - throw new NotSupportedException( - $"No suitable method on {typeof(T)} found for entity operation '{operation}'."); - } - - return TaskEntityHelpers.UnwrapAsync(operation.Context, () => target, result, returnType); - } - /// /// Try to dispatch this operation via reflection to a method on . /// diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs new file mode 100644 index 00000000..010cc70e --- /dev/null +++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityContext.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TestEntityContext : TaskEntityContext +{ + public TestEntityContext(object? state) + { + this.State = state; + } + + public object? State { get; private set; } + + public override EntityInstanceId Id { get; } + + public override object? GetState(Type type) + { + return this.State switch + { + null => null, + _ when type.IsAssignableFrom(this.State.GetType()) => this.State, + _ => throw new InvalidCastException() + }; + } + + public override void SetState(object? state) + { + this.State = state; + } + + public override void SignalEntity( + EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) + { + throw new NotImplementedException(); + } + + public override void StartOrchestration( + TaskName name, object? input = null, StartOrchestrationOptions? options = null) + { + throw new NotImplementedException(); + } +} diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs new file mode 100644 index 00000000..912b96e6 --- /dev/null +++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DotNext; + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TestEntityOperation : TaskEntityOperation +{ + readonly Optional input; + + public TestEntityOperation(string name, Optional input) + : this(name, new TestEntityContext(null), input) + { + } + + public TestEntityOperation(string name, object? state, Optional input) + : this(name, new TestEntityContext(state), input) + { + } + + public TestEntityOperation(string name, TaskEntityContext context, Optional input) + { + this.Name = name; + this.Context = context; + this.input = input; + } + + public override string Name { get; } + + public override TaskEntityContext Context { get; } + + public override bool HasInput => this.input.IsPresent; + + public override object? GetInput(Type inputType) + { + if (!this.input.IsPresent) + { + throw new InvalidOperationException("No input available."); + } + + if (this.input.Value is null) + { + return null; + } + + if (!inputType.IsAssignableFrom(this.input.Value.GetType())) + { + throw new InvalidCastException("Cannot convert input type."); + } + + return this.input.Value; + } +} diff --git a/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs new file mode 100644 index 00000000..2521ddb3 --- /dev/null +++ b/test/Abstractions.Tests/Entities/TaskEntityHelpersTests.cs @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Entities.Tests; + +public class TaskEntityHelpersTests +{ + [Fact] + public async Task UnwrapAsync_Void() + { + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, null, typeof(void)); + + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Fact] + public async Task UnwrapAsync_Object() + { + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + object? result = await TaskEntityHelpers.UnwrapAsync(context, () => state, value, typeof(int)); + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_Task(bool async) + { + TaskCompletionSource tcs = new(); + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(); + } + + object? result = await task; + + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_Task_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync(context, () => 0, tcs.Task, typeof(Task)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_TaskOfInt(bool async) + { + TaskCompletionSource tcs = new(); + + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(value); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync(context, () => state, tcs.Task, typeof(Task)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(value); + } + + object? result = await task; + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_TaskOfInt_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, tcs.Task, typeof(Task)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTask(bool async) + { + TaskCompletionSource tcs = new(); + + int state = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync( + context, () => state, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(); + } + + object? result = await task; + result.Should().BeNull(); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTask_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } + + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTaskOfInt(bool async) + { + TaskCompletionSource tcs = new(); + int state = Random.Shared.Next(1, 10); + int value = Random.Shared.Next(1, 10); + TestEntityContext context = new(null); + + if (!async) + { + tcs.TrySetResult(value); + } + + ValueTask task = TaskEntityHelpers.UnwrapAsync( + context, () => state, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + state++; // Make sure state changes are captured + tcs.TrySetResult(value); + } + + object? result = await task; + + result.Should().BeOfType().Which.Should().Be(value); + context.State.Should().BeOfType().Which.Should().Be(state); + } + + [Theory] + [CombinatorialData] + public async Task UnwrapAsync_ValueTaskOfInt_Throws(bool async) + { + TaskCompletionSource tcs = new(); + TestEntityContext context = new(null); + + if (!async) + { + tcs.SetException(new OperationCanceledException()); + } + + Func throws = async () => await TaskEntityHelpers.UnwrapAsync( + context, () => 0, new ValueTask(tcs.Task), typeof(ValueTask)); + + if (async) + { + tcs.SetException(new OperationCanceledException()); + } + + await throws.Should().ThrowExactlyAsync(); + } +} diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/TaskEntityTests.cs index 901230a4..ceaef970 100644 --- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/TaskEntityTests.cs @@ -6,7 +6,7 @@ namespace Microsoft.DurableTask.Entities.Tests; -public class TaskEntityTests +public partial class TaskEntityTests { [Theory] [InlineData("doesNotExist")] // method does not exist. @@ -14,7 +14,7 @@ public class TaskEntityTests [InlineData("staticMethod")] // public static methods are not supported. public async Task OperationNotSupported_Fails(string name) { - Operation operation = new(name, new Context(null), 10); + TestEntityOperation operation = new(name, 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -28,7 +28,7 @@ public async Task TaskOperation_Success( [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) { object? expected = name.Contains("OfString") ? "success" : null; - Operation operation = new(name, new Context(null), sync); + TestEntityOperation operation = new(name, sync); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -43,8 +43,8 @@ public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lower int start = Random.Shared.Next(0, 10); int toAdd = Random.Shared.Next(0, 10); string opName = lowercase ? "add" : "Add"; - Context context = new(start); - Operation operation = new($"{opName}{method}", context, toAdd); + TestEntityContext context = new(start); + TestEntityOperation operation = new($"{opName}{method}", start, toAdd); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -60,8 +60,8 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc { int expected = Random.Shared.Next(0, 10); string opName = lowercase ? "get" : "Get"; - Context context = new(expected); - Operation operation = new($"{opName}{method}", context, default); + TestEntityContext context = new(expected); + TestEntityOperation operation = new($"{opName}{method}", context, default); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -73,7 +73,7 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc [Fact] public async Task Add_NoInput_Fails() { - Operation operation = new("add0", new Context(null), default); + TestEntityOperation operation = new("add0", new TestEntityContext(null), default); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -85,7 +85,7 @@ public async Task Add_NoInput_Fails() [CombinatorialData] public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) { - Operation operation = new($"ambiguousArgs{method}", new Context(null), 10); + TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -96,7 +96,7 @@ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int me [Fact] public async Task Dispatch_AmbiguousMatch_Fails() { - Operation operation = new("ambiguousMatch", new Context(null), 10); + TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10); TestEntity entity = new(); Func> action = () => entity.RunAsync(operation).AsTask(); @@ -106,7 +106,7 @@ public async Task Dispatch_AmbiguousMatch_Fails() [Fact] public async Task DefaultValue_NoInput_Succeeds() { - Operation operation = new("defaultValue", new Context(null), default); + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -117,7 +117,7 @@ public async Task DefaultValue_NoInput_Succeeds() [Fact] public async Task DefaultValue_Input_Succeeds() { - Operation operation = new("defaultValue", new Context(null), "not-default"); + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default"); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -125,83 +125,6 @@ public async Task DefaultValue_Input_Succeeds() result.Should().BeOfType().Which.Should().Be("not-default"); } - class Context : TaskEntityContext - { - public Context(object? state) - { - this.State = state; - } - - public object? State { get; private set; } - - public override EntityInstanceId Id { get; } - - public override object? GetState(Type type) - { - return this.State switch - { - null => null, - _ when type.IsAssignableFrom(this.State.GetType()) => this.State, - _ => throw new InvalidCastException() - }; - } - - public override void SetState(object? state) - { - this.State = state; - } - - public override void SignalEntity( - EntityInstanceId id, string operationName, object? input = null, SignalEntityOptions? options = null) - { - throw new NotImplementedException(); - } - - public override void StartOrchestration( - TaskName name, object? input = null, StartOrchestrationOptions? options = null) - { - throw new NotImplementedException(); - } - } - - class Operation : TaskEntityOperation - { - readonly Optional input; - - public Operation(string name, TaskEntityContext context, Optional input) - { - this.Name = name; - this.Context = context; - this.input = input; - } - - public override string Name { get; } - - public override TaskEntityContext Context { get; } - - public override bool HasInput => this.input.IsPresent; - - public override object? GetInput(Type inputType) - { - if (!this.input.IsPresent) - { - throw new InvalidOperationException("No input available."); - } - - if (this.input.Value is null) - { - return null; - } - - if (!inputType.IsAssignableFrom(this.input.Value.GetType())) - { - throw new InvalidCastException("Cannot convert input type."); - } - - return this.input.Value; - } - } - class TestEntity : TaskEntity { public static string StaticMethod() => throw new NotImplementedException(); From 330f90c6878236fad27fb45a7e51bbd1c3b27e8d Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Fri, 11 Aug 2023 16:51:39 -0700 Subject: [PATCH 4/9] update state comments --- src/Abstractions/Entities/TaskEntity.cs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index 32ec1363..c248cd8e 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -11,9 +11,7 @@ namespace Microsoft.DurableTask.Entities; /// /// Entity State /// -/// All entity implementations are required to be serializable by the configured . An entity -/// will have its state deserialized before executing an operation, and then the new state will be the serialized value -/// of the implementation instance post-operation. +/// The state of an entity can be retrieved and updated via . /// /// public interface ITaskEntity @@ -71,8 +69,9 @@ public interface ITaskEntity /// /// Entity State /// -/// Unchanged from . Entity state is the serialized value of the entity after an operation -/// completes. +/// Entity state will be hydrated into the property. The contents of this +/// property will be persisted to when the operation has completed. +/// Deleting entity state can be accomplished by setting to default(). /// /// public abstract class TaskEntity : ITaskEntity @@ -86,6 +85,11 @@ public abstract class TaskEntity : ITaskEntity /// /// Gets or sets the state for this entity. /// + /// + /// This will be hydrated as part of . The contents of this property + /// will be persisted to when the operation completes. Deleting + /// entity state can be accomplished by setting this to default(). + /// protected TState State { get; set; } = default!; // leave null-checks to end implementation. /// From 2b10bd5a887d02473b9e0a2bab76d5807fa6e28f Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Mon, 14 Aug 2023 10:27:04 -0700 Subject: [PATCH 5/9] Fix package ref and tests --- .../Entities/Mocks/TestEntityOperation.cs | 8 ++++---- test/Abstractions.Tests/Entities/TaskEntityTests.cs | 12 ++++++------ test/Directory.Build.targets | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs index 912b96e6..63c277a4 100644 --- a/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs +++ b/test/Abstractions.Tests/Entities/Mocks/TestEntityOperation.cs @@ -30,21 +30,21 @@ public TestEntityOperation(string name, TaskEntityContext context, Optional this.input.IsPresent; + public override bool HasInput => this.input.HasValue; public override object? GetInput(Type inputType) { - if (!this.input.IsPresent) + if (this.input.IsUndefined) { throw new InvalidOperationException("No input available."); } - if (this.input.Value is null) + if (this.input.IsNull) { return null; } - if (!inputType.IsAssignableFrom(this.input.Value.GetType())) + if (!inputType.IsAssignableFrom(this.input.Value!.GetType())) { throw new InvalidCastException("Cannot convert input type."); } diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/TaskEntityTests.cs index ceaef970..bbfbbd38 100644 --- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/TaskEntityTests.cs @@ -41,10 +41,10 @@ public async Task TaskOperation_Success( public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase) { int start = Random.Shared.Next(0, 10); - int toAdd = Random.Shared.Next(0, 10); + int toAdd = Random.Shared.Next(1, 10); string opName = lowercase ? "add" : "Add"; TestEntityContext context = new(start); - TestEntityOperation operation = new($"{opName}{method}", start, toAdd); + TestEntityOperation operation = new($"{opName}{method}", context, toAdd); TestEntity entity = new(); object? result = await entity.RunAsync(operation); @@ -229,17 +229,17 @@ static async Task Slow() int Add(int? value, Optional context, Optional operation) { - if (context.IsPresent) + if (context.HasValue) { context.Value.Should().NotBeNull(); } - if (operation.IsPresent) + if (operation.HasValue) { operation.Value.Should().NotBeNull(); } - if (!value.HasValue && operation.TryGet(out TaskEntityOperation op)) + if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op)) { value = (int)op.GetInput(typeof(int))!; } @@ -250,7 +250,7 @@ int Add(int? value, Optional context, Optional context) { - if (context.IsPresent) + if (context.HasValue) { context.Value.Should().NotBeNull(); } diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets index 294f47da..c0558551 100644 --- a/test/Directory.Build.targets +++ b/test/Directory.Build.targets @@ -4,7 +4,7 @@ Condition=" '$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory)../, $(_DirectoryBuildTargetsFile)))' != '' " /> - + runtime; build; native; contentfiles; analyzers; buildtransitive From 315d7f9acbb22839cac32a3d449a78d55d6dd316 Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Mon, 21 Aug 2023 14:20:06 -0700 Subject: [PATCH 6/9] Address PR comments --- src/Abstractions/Entities/TaskEntity.cs | 35 +- .../Entities/TaskEntityContext.cs | 3 +- .../Entities/TaskEntityHelpers.cs | 2 +- src/Shared/Core/TaskExtensions.cs | 2 +- ...ntityTests.cs => EntityTaskEntityTests.cs} | 6 +- .../Entities/StateTaskEntityTests.cs | 325 ++++++++++++++++++ 6 files changed, 363 insertions(+), 10 deletions(-) rename test/Abstractions.Tests/Entities/{TaskEntityTests.cs => EntityTaskEntityTests.cs} (97%) create mode 100644 test/Abstractions.Tests/Entities/StateTaskEntityTests.cs diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index c248cd8e..f5c128b0 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -77,10 +77,10 @@ public interface ITaskEntity public abstract class TaskEntity : ITaskEntity { /// - /// Gets a value indicating whether dispatching operations to is allowed. State dispatch will - /// only be attempted if entity-level dispatch does not succeed. Default is true. + /// Gets a value indicating whether dispatching operations to is allowed. State dispatch + /// will only be attempted if entity-level dispatch does not succeed. Default is false. /// - protected bool AllowStateDispatch => true; + protected virtual bool AllowStateDispatch => false; /// /// Gets or sets the state for this entity. @@ -90,7 +90,7 @@ public abstract class TaskEntity : ITaskEntity /// will be persisted to when the operation completes. Deleting /// entity state can be accomplished by setting this to default(). /// - protected TState State { get; set; } = default!; // leave null-checks to end implementation. + protected TState? State { get; set; } /// /// Gets the entity operation. @@ -107,13 +107,36 @@ public abstract class TaskEntity : ITaskEntity { this.Operation = Check.NotNull(operation); object? state = operation.Context.GetState(typeof(TState)); - this.State = state is null ? default! : (TState)state; + this.State = state is null ? this.InitializeState() : (TState)state; if (!operation.TryDispatch(this, out object? result, out Type returnType) - && (this.AllowStateDispatch && !operation.TryDispatch(this.State, out result, out returnType))) + && !this.TryDispatchState(out result, out returnType)) { throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); } return TaskEntityHelpers.UnwrapAsync(this.Context, () => this.State, result, returnType); } + + /// + /// Initializes the entity state. + /// + /// The entity state. + protected virtual TState? InitializeState() => default; + + bool TryDispatchState(out object? result, out Type returnType) + { + if (!this.AllowStateDispatch) + { + result = null; + returnType = typeof(void); + return false; + } + + if (this.State is null) + { + throw new InvalidOperationException("Attempting to dispatch to state, but entity state is null."); + } + + return this.Operation.TryDispatch(this.State, out result, out returnType); + } } diff --git a/src/Abstractions/Entities/TaskEntityContext.cs b/src/Abstractions/Entities/TaskEntityContext.cs index cf35d0c7..8491e45b 100644 --- a/src/Abstractions/Entities/TaskEntityContext.cs +++ b/src/Abstractions/Entities/TaskEntityContext.cs @@ -53,7 +53,8 @@ public abstract void StartOrchestration( TaskName name, object? input = null, StartOrchestrationOptions? options = null); /// - /// Gets the current state for the entity this context is for. + /// Gets the current state for the entity this context is for. This will return null if no state is present, + /// regardless if is a value-type or not. /// /// The type to retrieve the state as. /// The entity state. diff --git a/src/Abstractions/Entities/TaskEntityHelpers.cs b/src/Abstractions/Entities/TaskEntityHelpers.cs index 14a39505..4437d981 100644 --- a/src/Abstractions/Entities/TaskEntityHelpers.cs +++ b/src/Abstractions/Entities/TaskEntityHelpers.cs @@ -11,7 +11,7 @@ namespace Microsoft.DurableTask.Entities; static class TaskEntityHelpers { /// - /// Unwraps a dispatched result for a into a + /// Unwraps a dispatched result for a into a . /// /// The entity context. /// Delegate to resolve new state for the entity. diff --git a/src/Shared/Core/TaskExtensions.cs b/src/Shared/Core/TaskExtensions.cs index 116c6623..69c9356f 100644 --- a/src/Shared/Core/TaskExtensions.cs +++ b/src/Shared/Core/TaskExtensions.cs @@ -24,7 +24,7 @@ static class TaskExtensions Type t = task.GetType(); if (t.IsGenericType) { - return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task); + return (T)t.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance)!.GetValue(task)!; } return default; diff --git a/test/Abstractions.Tests/Entities/TaskEntityTests.cs b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs similarity index 97% rename from test/Abstractions.Tests/Entities/TaskEntityTests.cs rename to test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs index bbfbbd38..ecc5541f 100644 --- a/test/Abstractions.Tests/Entities/TaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs @@ -6,7 +6,7 @@ namespace Microsoft.DurableTask.Entities.Tests; -public partial class TaskEntityTests +public class EntityTaskEntityTests { [Theory] [InlineData("doesNotExist")] // method does not exist. @@ -125,6 +125,8 @@ public async Task DefaultValue_Input_Succeeds() result.Should().BeOfType().Which.Should().Be("not-default"); } +#pragma warning disable CA1822 // Mark members as static +#pragma warning disable IDE0060 // Remove unused parameter class TestEntity : TaskEntity { public static string StaticMethod() => throw new NotImplementedException(); @@ -258,4 +260,6 @@ int Get(Optional context) return this.State; } } +#pragma warning restore IDE0060 // Remove unused parameter +#pragma warning restore CA1822 // Mark members as static } diff --git a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs new file mode 100644 index 00000000..c65ee1bf --- /dev/null +++ b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs @@ -0,0 +1,325 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using DotNext; + +namespace Microsoft.DurableTask.Entities.Tests; + +public class StateTaskEntityTests +{ + [Fact] + public async Task Precedence_ChoosesEntity() + { + TestEntityOperation operation = new("Precedence", default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().Be(20); + } + + [Fact] + public async Task StateDispatchDisallowed_Throws() + { + TestEntityOperation operation = new("add0", 10); + TestEntity entity = new(false); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task StateDispatch_NullState_Throws() + { + TestEntityOperation operation = new("add0", 10); + NullStateEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [InlineData("doesNotExist")] // method does not exist. + [InlineData("add")] // private method, should not work. + [InlineData("staticMethod")] // public static methods are not supported. + public async Task OperationNotSupported_Fails(string name) + { + TestEntityOperation operation = new(name, 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task TaskOperation_Success( + [CombinatorialValues("TaskOp", "TaskOfStringOp", "ValueTaskOp", "ValueTaskOfStringOp")] string name, bool sync) + { + object? expected = name.Contains("OfString") ? "success" : null; + TestEntityOperation operation = new(name, sync); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase) + { + int start = Random.Shared.Next(0, 10); + int toAdd = Random.Shared.Next(1, 10); + string opName = lowercase ? "add" : "Add"; + TestEntityContext context = new(State(start)); + TestEntityOperation operation = new($"{opName}{method}", context, toAdd); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + int expected = start + toAdd; + context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Theory] + [CombinatorialData] + public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowercase) + { + int expected = Random.Shared.Next(0, 10); + string opName = lowercase ? "get" : "Get"; + TestEntityContext context = new(State(expected)); + TestEntityOperation operation = new($"{opName}{method}", context, default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + context.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(expected); + result.Should().BeOfType().Which.Should().Be(expected); + } + + [Fact] + public async Task Add_NoInput_Fails() + { + TestEntityOperation operation = new("add0", new TestEntityContext(null), default); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Theory] + [CombinatorialData] + public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method) + { + TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task Dispatch_AmbiguousMatch_Fails() + { + TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10); + TestEntity entity = new(); + + Func> action = () => entity.RunAsync(operation).AsTask(); + await action.Should().ThrowAsync(); + } + + [Fact] + public async Task DefaultValue_NoInput_Succeeds() + { + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("default"); + } + + [Fact] + public async Task DefaultValue_Input_Succeeds() + { + TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default"); + TestEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be("not-default"); + } + + static TestState State(int value) => new() { Value = value }; + + class NullStateEntity : TestEntity + { + protected override TestState InitializeState() => null!; + } + + class TestEntity : TaskEntity + { + readonly bool allowStateDispatch; + + public TestEntity(bool allowStateDispatch = true) + { + this.allowStateDispatch = allowStateDispatch; + } + + protected override bool AllowStateDispatch => this.allowStateDispatch; + + public int Precedence() => this.State!.Precedence() * 2; + + protected override TestState InitializeState() => new(); + } + +#pragma warning disable CA1822 // Mark members as static +#pragma warning disable IDE0060 // Remove unused parameter + class TestState + { + public int Value { get; set; } + + public static string StaticMethod() => throw new NotImplementedException(); + + public int Precedence() => 10; + + // All possible permutations of the 3 inputs we support: object, context, operation + // 14 via Add, 2 via Get: 16 total. + public int Add0(int value) => this.Add(value, default, default); + + public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default); + + public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation); + + public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation); + + public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation); + + public int Add7(TaskEntityOperation operation, TaskEntityContext context) + => this.Add(default, context, operation); + + public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context) + => this.Add(value, context, operation); + + public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value) + => this.Add(value, context, operation); + + public int Add10(TaskEntityContext context, int value) + => this.Add(value, context, default); + + public int Add11(TaskEntityContext context, TaskEntityOperation operation) + => this.Add(default, context, operation); + + public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation) + => this.Add(value, context, operation); + + public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value) + => this.Add(value, context, operation); + + public int Get0() => this.Get(default); + + public int Get1(TaskEntityContext context) => this.Get(context); + + public int AmbiguousMatch(TaskEntityContext context) => this.Value; + + public int AmbiguousMatch(TaskEntityOperation operation) => this.Value; + + public int AmbiguousArgs0(int value, object other) => this.Add0(value); + + public int AmbiguousArgs1(int value, TaskEntityContext context, TaskEntityContext context2) => this.Add0(value); + + public int AmbiguousArgs2(int value, TaskEntityOperation operation, TaskEntityOperation operation2) + => this.Add0(value); + + public string DefaultValue(string toReturn = "default") => toReturn; + + public Task TaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? Task.CompletedTask : Slow(); + } + + public Task TaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? Task.FromResult("success") : Slow(); + } + + public ValueTask ValueTaskOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + } + + return sync ? default : new(Slow()); + } + + public ValueTask ValueTaskOfStringOp(bool sync) + { + static async Task Slow() + { + await Task.Yield(); + return "success"; + } + + return sync ? new("success") : new(Slow()); + } + + int Add(int? value, Optional context, Optional operation) + { + if (context.HasValue) + { + context.Value.Should().NotBeNull(); + } + + if (operation.HasValue) + { + operation.Value.Should().NotBeNull(); + } + + if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op)) + { + value = (int)op.GetInput(typeof(int))!; + } + + value.HasValue.Should().BeTrue(); + return this.Value += value!.Value; + } + + int Get(Optional context) + { + if (context.HasValue) + { + context.Value.Should().NotBeNull(); + } + + return this.Value; + } + } +#pragma warning restore IDE0060 // Remove unused parameter +#pragma warning restore CA1822 // Mark members as static +} From 7a791028373af924526c6977e62982a586b7c0ad Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Mon, 21 Aug 2023 14:35:03 -0700 Subject: [PATCH 7/9] Default to non-null state --- src/Abstractions/Entities/TaskEntity.cs | 5 +++-- test/Abstractions.Tests/Entities/StateTaskEntityTests.cs | 2 -- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index f5c128b0..aefae3aa 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -118,10 +118,11 @@ public abstract class TaskEntity : ITaskEntity } /// - /// Initializes the entity state. + /// Initializes the entity state. This is only called when there is no current state for this entity. /// /// The entity state. - protected virtual TState? InitializeState() => default; + /// The default implementation uses . + protected virtual TState? InitializeState() => Activator.CreateInstance(); bool TryDispatchState(out object? result, out Type returnType) { diff --git a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs index c65ee1bf..3227a5ae 100644 --- a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs @@ -177,8 +177,6 @@ public TestEntity(bool allowStateDispatch = true) protected override bool AllowStateDispatch => this.allowStateDispatch; public int Precedence() => this.State!.Precedence() * 2; - - protected override TestState InitializeState() => new(); } #pragma warning disable CA1822 // Mark members as static From a76657140afa54c6780ded7c1f09c71e1556cabc Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Wed, 23 Aug 2023 12:51:03 -0700 Subject: [PATCH 8/9] TState? -> TState. Updated comments and InitializeState --- src/Abstractions/Entities/TaskEntity.cs | 44 +++++++++++++++++++------ 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index aefae3aa..891bb385 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -69,16 +69,16 @@ public interface ITaskEntity /// /// Entity State /// -/// Entity state will be hydrated into the property. The contents of this -/// property will be persisted to when the operation has completed. -/// Deleting entity state can be accomplished by setting to default(). +/// Entity state will be hydrated into the property. See for +/// more information. /// /// public abstract class TaskEntity : ITaskEntity { /// /// Gets a value indicating whether dispatching operations to is allowed. State dispatch - /// will only be attempted if entity-level dispatch does not succeed. Default is false. + /// will only be attempted if entity-level dispatch does not succeed. Default is false. Dispatching to state + /// follows the same rules as dispatching to this entity. /// protected virtual bool AllowStateDispatch => false; @@ -86,11 +86,24 @@ public abstract class TaskEntity : ITaskEntity /// Gets or sets the state for this entity. /// /// - /// This will be hydrated as part of . The contents of this property - /// will be persisted to when the operation completes. Deleting - /// entity state can be accomplished by setting this to default(). + /// Initialization + /// + /// This will be hydrated as part of . will + /// be called when state is null at the start of an operation only. + /// + /// Persistence + /// + /// The contents of this property will be persisted to at the end + /// of the operation. + /// + /// Deletion + /// + /// Deleting entity state is possible by setting this to null. Setting to default of a value-type will + /// not delete state. This means deleting entity state is only possible for reference types or using ? + /// on a value-type (ie: TaskEntity<int?>). + /// /// - protected TState? State { get; set; } + protected TState State { get; set; } = default!; /// /// Gets the entity operation. @@ -121,8 +134,19 @@ public abstract class TaskEntity : ITaskEntity /// Initializes the entity state. This is only called when there is no current state for this entity. /// /// The entity state. - /// The default implementation uses . - protected virtual TState? InitializeState() => Activator.CreateInstance(); + /// The default implementation uses . + protected virtual TState InitializeState() + { + if (Nullable.GetUnderlyingType(typeof(TState)) is Type t) + { + // Activator.CreateInstance>() returns null. To avoid this, we will instantiate via underlying + // type if it is Nullable. This keeps the experience consistent between value and reference type. If an + // implementation wants null, they must override this method and explicitly provide null. + return (TState)Activator.CreateInstance(t); + } + + return Activator.CreateInstance(); + } bool TryDispatchState(out object? result, out Type returnType) { From 35bd3c2b7c2fdec95ea7b608fc8a1d4cc0f180bb Mon Sep 17 00:00:00 2001 From: Jacob Viau Date: Fri, 25 Aug 2023 09:52:59 -0700 Subject: [PATCH 9/9] Reword comments --- src/Abstractions/Entities/TaskEntity.cs | 2 +- src/Abstractions/Entities/TaskEntityContext.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index 891bb385..837df7ab 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -99,7 +99,7 @@ public abstract class TaskEntity : ITaskEntity /// Deletion /// /// Deleting entity state is possible by setting this to null. Setting to default of a value-type will - /// not delete state. This means deleting entity state is only possible for reference types or using ? + /// not perform a delete. This means deleting entity state is only possible for reference types or using ? /// on a value-type (ie: TaskEntity<int?>). /// /// diff --git a/src/Abstractions/Entities/TaskEntityContext.cs b/src/Abstractions/Entities/TaskEntityContext.cs index 8491e45b..8ce86fb9 100644 --- a/src/Abstractions/Entities/TaskEntityContext.cs +++ b/src/Abstractions/Entities/TaskEntityContext.cs @@ -61,7 +61,7 @@ public abstract void StartOrchestration( public abstract object? GetState(Type type); /// - /// Sets the entity state. Setting of null will clear entity state. + /// Sets the entity state. Setting of null will delete entity state. /// /// The state to set. public abstract void SetState(object? state);