Skip to content

Commit

Permalink
Separate entity state from TaskEntityContext onto its own object
Browse files Browse the repository at this point in the history
  • Loading branch information
jviau committed Aug 25, 2023
1 parent ca8d7ad commit a41d258
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 212 deletions.
22 changes: 9 additions & 13 deletions src/Abstractions/Entities/TaskEntity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public abstract class TaskEntity<TState> : ITaskEntity
/// </para>
/// <para><b>Persistence</b></para>
/// <para>
/// The contents of this property will be persisted to <see cref="TaskEntityContext.SetState(object?)"/> at the end
/// The contents of this property will be persisted to <see cref="TaskEntityState.SetState(object?)"/> at the end
/// of the operation.
/// </para>
/// <para><b>Deletion</b></para>
Expand All @@ -105,29 +105,25 @@ public abstract class TaskEntity<TState> : ITaskEntity
/// </remarks>
protected TState State { get; set; } = default!;

/// <summary>
/// Gets the entity operation.
/// </summary>
protected TaskEntityOperation Operation { get; private set; } = null!;

/// <summary>
/// Gets the entity context.
/// </summary>
protected TaskEntityContext Context => this.Operation.Context;
protected TaskEntityContext Context { get; private set; } = null!;

/// <inheritdoc/>
public ValueTask<object?> RunAsync(TaskEntityOperation operation)
{
this.Operation = Check.NotNull(operation);
object? state = operation.Context.GetState(typeof(TState));
Check.NotNull(operation);
this.Context = operation.Context;
object? state = operation.State.GetState(typeof(TState));
this.State = state is null ? this.InitializeState() : (TState)state;
if (!operation.TryDispatch(this, out object? result, out Type returnType)
&& !this.TryDispatchState(out result, out returnType))
&& !this.TryDispatchState(operation, out result, out returnType))
{
throw new NotSupportedException($"No suitable method found for entity operation '{operation}'.");
}

return TaskEntityHelpers.UnwrapAsync(this.Context, () => this.State, result, returnType);
return TaskEntityHelpers.UnwrapAsync(operation.State, () => this.State, result, returnType);
}

/// <summary>
Expand All @@ -148,7 +144,7 @@ protected virtual TState InitializeState()
return Activator.CreateInstance<TState>();
}

bool TryDispatchState(out object? result, out Type returnType)
bool TryDispatchState(TaskEntityOperation operation, out object? result, out Type returnType)
{
if (!this.AllowStateDispatch)
{
Expand All @@ -162,6 +158,6 @@ bool TryDispatchState(out object? result, out Type returnType)
throw new InvalidOperationException("Attempting to dispatch to state, but entity state is null.");
}

return this.Operation.TryDispatch(this.State, out result, out returnType);
return operation.TryDispatch(this.State, out result, out returnType);
}
}
14 changes: 0 additions & 14 deletions src/Abstractions/Entities/TaskEntityContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,4 @@ public virtual void StartOrchestration(TaskName name, StartOrchestrationOptions
/// <param name="options">The options for starting the orchestration.</param>
public abstract void StartOrchestration(
TaskName name, object? input = null, StartOrchestrationOptions? options = null);

/// <summary>
/// Gets the current state for the entity this context is for. This will return <c>null</c> if no state is present,
/// regardless if <paramref name="type"/> is a value-type or not.
/// </summary>
/// <param name="type">The type to retrieve the state as.</param>
/// <returns>The entity state.</returns>
public abstract object? GetState(Type type);

/// <summary>
/// Sets the entity state. Setting of <c>null</c> will delete entity state.
/// </summary>
/// <param name="state">The state to set.</param>
public abstract void SetState(object? state);
}
32 changes: 16 additions & 16 deletions src/Abstractions/Entities/TaskEntityHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,47 @@ static class TaskEntityHelpers
/// <summary>
/// Unwraps a dispatched result for a <see cref="TaskEntityOperation"/> into a <see cref="ValueTask{Object}"/>.
/// </summary>
/// <param name="context">The entity context.</param>
/// <param name="state">Delegate to resolve new state for the entity.</param>
/// <param name="state">The entity state.</param>
/// <param name="stateCallback">Delegate to resolve new state for the entity.</param>
/// <param name="result">The result of the operation.</param>
/// <param name="resultType">The declared type of the result (may be different that actual type).</param>
/// <returns>A value task which holds the result of the operation and sets state before it completes.</returns>
public static ValueTask<object?> UnwrapAsync(
TaskEntityContext context, Func<object?> state, object? result, Type resultType)
TaskEntityState state, Func<object?> stateCallback, object? result, Type resultType)
{
// NOTE: Func<object?> is used for state so that we can lazily resolve it AFTER the operation has ran.
Check.NotNull(context);
Check.NotNull(state);
Check.NotNull(resultType);

if (typeof(Task).IsAssignableFrom(resultType))
{
// Task or Task<T>
// We assume a declared Task return type is never null.
return new(UnwrapTask(context, state, (Task)result!, resultType));
return new(UnwrapTask(state, stateCallback, (Task)result!, resultType));
}

if (resultType == typeof(ValueTask))
{
// ValueTask
// We assume a declared ValueTask return type is never null.
return UnwrapValueTask(context, state, (ValueTask)result!);
return UnwrapValueTask(state, stateCallback, (ValueTask)result!);
}

if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
// ValueTask<T>
// No inheritance, have to do purely via reflection.
return UnwrapValueTaskOfT(context, state, result!, resultType);
return UnwrapValueTaskOfT(state, stateCallback, result!, resultType);
}

context.SetState(state());
state.SetState(stateCallback());
return new(result);
}

static async Task<object?> UnwrapTask(TaskEntityContext context, Func<object?> state, Task task, Type declared)
static async Task<object?> UnwrapTask(TaskEntityState state, Func<object?> callback, Task task, Type declared)
{
await task;
context.SetState(state());
state.SetState(callback());
if (declared.IsGenericType && declared.GetGenericTypeDefinition() == typeof(Task<>))
{
return declared.GetProperty("Result", BindingFlags.Public | BindingFlags.Instance).GetValue(task);
Expand All @@ -62,39 +62,39 @@ static class TaskEntityHelpers
return null;
}

static ValueTask<object?> UnwrapValueTask(TaskEntityContext context, Func<object?> state, ValueTask t)
static ValueTask<object?> UnwrapValueTask(TaskEntityState state, Func<object?> callback, ValueTask t)
{
async Task<object?> Await(ValueTask t)
{
await t;
context.SetState(state());
state.SetState(callback());
return null;
}

if (t.IsCompletedSuccessfully)
{
context.SetState(state());
state.SetState(callback());
return default;
}

return new(Await(t));
}

static ValueTask<object?> UnwrapValueTaskOfT(
TaskEntityContext context, Func<object?> state, object result, Type type)
TaskEntityState state, Func<object?> callback, object result, Type type)
{
// Result and type here must be some form of ValueTask<T>.
// TODO: can this amount of reflection be avoided?
if ((bool)type.GetProperty("IsCompletedSuccessfully").GetValue(result))
{
context.SetState(state());
state.SetState(callback());
return new(type.GetProperty("Result").GetValue(result));
}
else
{
Task t = (Task)type.GetMethod("AsTask", BindingFlags.Instance | BindingFlags.Public).Invoke(result, null);
Type taskType = typeof(Task<>).MakeGenericType(type.GetGenericArguments()[0]);
return new(UnwrapTask(context, state, t, taskType));
return new(UnwrapTask(state, callback, t, taskType));
}
}
}
5 changes: 5 additions & 0 deletions src/Abstractions/Entities/TaskEntityOperation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ public abstract class TaskEntityOperation
/// </summary>
public abstract TaskEntityContext Context { get; }

/// <summary>
/// Gets the state of the entity.
/// </summary>
public abstract TaskEntityState State { get; }

/// <summary>
/// Gets a value indicating whether this operation has input or not.
/// </summary>
Expand Down
7 changes: 0 additions & 7 deletions src/Abstractions/Entities/TaskEntityOperationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ internal static bool TryDispatch(
int i = 0;
ParameterInfo? inputResolved = null;
ParameterInfo? contextResolved = null;
ParameterInfo? operationResolved = null;
foreach (ParameterInfo parameter in parameters)
{
if (parameter.ParameterType == typeof(TaskEntityContext))
Expand All @@ -56,12 +55,6 @@ internal static bool TryDispatch(
inputs[i] = operation.Context;
contextResolved = parameter;
}
else if (parameter.ParameterType == typeof(TaskEntityOperation))
{
ThrowIfDuplicateBinding(operationResolved, parameter, "operation", operation);
inputs[i] = operation;
operationResolved = parameter;
}
else
{
ThrowIfDuplicateBinding(inputResolved, parameter, "input", operation);
Expand Down
24 changes: 24 additions & 0 deletions src/Abstractions/Entities/TaskEntityState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Microsoft.DurableTask.Entities;

/// <summary>
/// Represents the persisted state of an entity.
/// </summary>
public abstract class TaskEntityState
{
/// <summary>
/// Gets the current state for the entity this context is for. This will return <c>null</c> if no state is present,
/// regardless if <paramref name="type"/> is a value-type or not.
/// </summary>
/// <param name="type">The type to retrieve the state as.</param>
/// <returns>The entity state.</returns>
public abstract object? GetState(Type type);

/// <summary>
/// Sets the entity state. Setting of <c>null</c> will delete entity state.
/// </summary>
/// <param name="state">The state to set.</param>
public abstract void SetState(object? state);
}
75 changes: 15 additions & 60 deletions test/Abstractions.Tests/Entities/EntityTaskEntityTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ public async Task TaskOperation_Success(

[Theory]
[CombinatorialData]
public async Task Add_Success([CombinatorialRange(0, 14)] int method, bool lowercase)
public async Task Add_Success([CombinatorialRange(0, 2)] int method, bool lowercase)
{
int start = Random.Shared.Next(0, 10);
int toAdd = Random.Shared.Next(1, 10);
string opName = lowercase ? "add" : "Add";
TestEntityContext context = new(start);
TestEntityOperation operation = new($"{opName}{method}", context, toAdd);
TestEntityState state = new(start);
TestEntityOperation operation = new($"{opName}{method}", state, toAdd);
TestEntity entity = new();

object? result = await entity.RunAsync(operation);

int expected = start + toAdd;
context.GetState(typeof(int)).Should().BeOfType<int>().Which.Should().Be(expected);
state.GetState(typeof(int)).Should().BeOfType<int>().Which.Should().Be(expected);
result.Should().BeOfType<int>().Which.Should().Be(expected);
}

Expand All @@ -60,20 +60,20 @@ public async Task Get_Success([CombinatorialRange(0, 2)] int method, bool lowerc
{
int expected = Random.Shared.Next(0, 10);
string opName = lowercase ? "get" : "Get";
TestEntityContext context = new(expected);
TestEntityOperation operation = new($"{opName}{method}", context, default);
TestEntityState state = new(expected);
TestEntityOperation operation = new($"{opName}{method}", state, default);
TestEntity entity = new();

object? result = await entity.RunAsync(operation);

context.GetState(typeof(int)).Should().BeOfType<int>().Which.Should().Be(expected);
state.GetState(typeof(int)).Should().BeOfType<int>().Which.Should().Be(expected);
result.Should().BeOfType<int>().Which.Should().Be(expected);
}

[Fact]
public async Task Add_NoInput_Fails()
{
TestEntityOperation operation = new("add0", new TestEntityContext(null), default);
TestEntityOperation operation = new("add0", new TestEntityState(null), default);
TestEntity entity = new();

Func<Task<object?>> action = () => entity.RunAsync(operation).AsTask();
Expand All @@ -85,7 +85,7 @@ public async Task Add_NoInput_Fails()
[CombinatorialData]
public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int method)
{
TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityContext(null), 10);
TestEntityOperation operation = new($"ambiguousArgs{method}", new TestEntityState(null), 10);
TestEntity entity = new();

Func<Task<object?>> action = () => entity.RunAsync(operation).AsTask();
Expand All @@ -96,7 +96,7 @@ public async Task Dispatch_AmbiguousArgs_Fails([CombinatorialRange(0, 3)] int me
[Fact]
public async Task Dispatch_AmbiguousMatch_Fails()
{
TestEntityOperation operation = new("ambiguousMatch", new TestEntityContext(null), 10);
TestEntityOperation operation = new("ambiguousMatch", new TestEntityState(null), 10);
TestEntity entity = new();

Func<Task<object?>> action = () => entity.RunAsync(operation).AsTask();
Expand All @@ -106,7 +106,7 @@ public async Task Dispatch_AmbiguousMatch_Fails()
[Fact]
public async Task DefaultValue_NoInput_Succeeds()
{
TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), default);
TestEntityOperation operation = new("defaultValue", new TestEntityState(null), default);
TestEntity entity = new();

object? result = await entity.RunAsync(operation);
Expand All @@ -117,7 +117,7 @@ public async Task DefaultValue_NoInput_Succeeds()
[Fact]
public async Task DefaultValue_Input_Succeeds()
{
TestEntityOperation operation = new("defaultValue", new TestEntityContext(null), "not-default");
TestEntityOperation operation = new("defaultValue", new TestEntityState(null), "not-default");
TestEntity entity = new();

object? result = await entity.RunAsync(operation);
Expand All @@ -131,44 +131,9 @@ class TestEntity : TaskEntity<int>
{
public static string StaticMethod() => throw new NotImplementedException();

// All possible permutations of the 3 inputs we support: object, context, operation
// 14 via Add, 2 via Get: 16 total.
public int Add0(int value) => this.Add(value, default, default);
public int Add0(int value) => this.Add(value, default);

public int Add1(int value, TaskEntityContext context) => this.Add(value, context, default);

public int Add2(int value, TaskEntityOperation operation) => this.Add(value, default, operation);

public int Add3(int value, TaskEntityContext context, TaskEntityOperation operation)
=> this.Add(value, context, operation);

public int Add4(int value, TaskEntityOperation operation, TaskEntityContext context)
=> this.Add(value, context, operation);

public int Add5(TaskEntityOperation operation) => this.Add(default, default, operation);

public int Add6(TaskEntityOperation operation, int value) => this.Add(value, default, operation);

public int Add7(TaskEntityOperation operation, TaskEntityContext context)
=> this.Add(default, context, operation);

public int Add8(TaskEntityOperation operation, int value, TaskEntityContext context)
=> this.Add(value, context, operation);

public int Add9(TaskEntityOperation operation, TaskEntityContext context, int value)
=> this.Add(value, context, operation);

public int Add10(TaskEntityContext context, int value)
=> this.Add(value, context, default);

public int Add11(TaskEntityContext context, TaskEntityOperation operation)
=> this.Add(default, context, operation);

public int Add12(TaskEntityContext context, int value, TaskEntityOperation operation)
=> this.Add(value, context, operation);

public int Add13(TaskEntityContext context, TaskEntityOperation operation, int value)
=> this.Add(value, context, operation);
public int Add1(int value, TaskEntityContext context) => this.Add(value, context);

public int Get0() => this.Get(default);

Expand Down Expand Up @@ -229,23 +194,13 @@ static async Task<string> Slow()
return sync ? new("success") : new(Slow());
}

int Add(int? value, Optional<TaskEntityContext> context, Optional<TaskEntityOperation> operation)
int Add(int? value, Optional<TaskEntityContext> context)
{
if (context.HasValue)
{
context.Value.Should().NotBeNull();
}

if (operation.HasValue)
{
operation.Value.Should().NotBeNull();
}

if (!value.HasValue && operation.TryGet(out TaskEntityOperation? op))
{
value = (int)op.GetInput(typeof(int))!;
}

value.HasValue.Should().BeTrue();
return this.State += value!.Value;
}
Expand Down
Loading

0 comments on commit a41d258

Please sign in to comment.