From bf0e669a0a7f09155bc4004554cc78d2c46f7c55 Mon Sep 17 00:00:00 2001 From: Brenton Farmer Date: Wed, 23 Oct 2024 10:35:22 -0700 Subject: [PATCH] fix custom awaiters --- .../Transformation/AwaitBinder.cs | 4 +-- .../Transformation/AwaitBinderFactory.cs | 26 +++++++++++++++--- .../Transformation/LoweringVisitor.cs | 10 +++---- .../Transitions/AwaitResultTransition.cs | 10 ++++--- .../Transitions/AwaitTransition.cs | 13 +++++---- .../CustomAwaiterTests.cs | 27 +++++++++++++++++-- 6 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/Hyperbee.Expressions/Transformation/AwaitBinder.cs b/src/Hyperbee.Expressions/Transformation/AwaitBinder.cs index 269a4b2..cbd2e70 100644 --- a/src/Hyperbee.Expressions/Transformation/AwaitBinder.cs +++ b/src/Hyperbee.Expressions/Transformation/AwaitBinder.cs @@ -62,7 +62,7 @@ internal TResult AwaitResult( ref TAwaitable awai default: var awaiter = GetAwaiter( ref awaitable, configureAwait ); - return GetResultValue( ref awaiter ); + return GetResult( ref awaiter ); } } @@ -127,7 +127,7 @@ internal void GetResult( ref TAwaiter awaiter ) internal static TResult GetResult( ref ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter awaiter ) => awaiter.GetResult(); [MethodImpl( MethodImplOptions.AggressiveInlining )] - internal TResult GetResultValue( ref TAwaiter awaiter ) + internal TResult GetResult( ref TAwaiter awaiter ) { if ( GetResultImplDelegate == null ) throw new InvalidOperationException( $"The {nameof( GetResultImplDelegate )} is not set for {awaiter.GetType()}." ); diff --git a/src/Hyperbee.Expressions/Transformation/AwaitBinderFactory.cs b/src/Hyperbee.Expressions/Transformation/AwaitBinderFactory.cs index 49a0712..227f0f8 100644 --- a/src/Hyperbee.Expressions/Transformation/AwaitBinderFactory.cs +++ b/src/Hyperbee.Expressions/Transformation/AwaitBinderFactory.cs @@ -22,11 +22,14 @@ internal static class AwaitBinderFactory private static MethodInfo GetAwaiterTaskResultMethod; private static MethodInfo GetAwaiterValueTaskMethod; private static MethodInfo GetAwaiterValueTaskResultMethod; + private static MethodInfo GetAwaiterCustomMethod; private static MethodInfo GetResultTaskMethod; private static MethodInfo GetResultTaskResultMethod; private static MethodInfo GetResultValueTaskMethod; private static MethodInfo GetResultValueTaskResultMethod; + private static MethodInfo GetResultCustomMethod; + private static MethodInfo GetResultCustomResultMethod; private static MethodInfo CreateGetAwaiterImplDelegateMethod; private static MethodInfo CreateGetResultImplDelegateMethod; @@ -157,14 +160,14 @@ private static AwaitBinder CreateAwaitableTypeAwaitBinder( Type awaitableType ) var awaiterResultType = awaiterType.GetGenericArguments()[0]; awaitMethod = AwaitResultMethod.MakeGenericMethod( awaitableType, awaiterType, awaiterResultType ); - getAwaiterMethod = GetAwaiterTaskResultMethod.MakeGenericMethod( awaiterResultType ); - getResultMethod = GetResultTaskResultMethod.MakeGenericMethod( awaiterResultType ); + getAwaiterMethod = GetAwaiterCustomMethod.MakeGenericMethod( awaitableType, awaiterType ); + getResultMethod = GetResultCustomResultMethod.MakeGenericMethod( awaiterType, awaiterResultType ); } else { awaitMethod = AwaitMethod.MakeGenericMethod( awaitableType, awaiterType ); - getAwaiterMethod = GetAwaiterTaskMethod.MakeGenericMethod(); - getResultMethod = GetResultTaskMethod.MakeGenericMethod(); + getAwaiterMethod = GetAwaiterCustomMethod.MakeGenericMethod( awaitableType, awaiterType ); + getResultMethod = GetResultCustomMethod.MakeGenericMethod( awaiterType ); } // Return the AwaitBinder @@ -326,6 +329,11 @@ when matches( [typeof( ValueTask ).MakeByRefType(), typeof( bool )] ): GetAwaiterValueTaskMethod = method; break; + case nameof(AwaitBinder.GetAwaiter) // custom awaitable + when matches( [null, typeof(bool)], argCount: 2 ): + GetAwaiterCustomMethod = method; + break; + case nameof( AwaitBinder.GetResult ) when matches( [typeof( ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ).MakeByRefType()] ): GetResultTaskMethod = method; @@ -345,6 +353,16 @@ when matches( [typeof( ConfiguredTaskAwaitable<>.ConfiguredTaskAwaiter ).MakeByR when matches( [typeof( ConfiguredValueTaskAwaitable<>.ConfiguredValueTaskAwaiter ).MakeByRefType()], argCount: 1 ): GetResultValueTaskResultMethod = method; break; + + case nameof(AwaitBinder.GetResult) // custom awaitable + when matches( [null], argCount: 1 ): + GetResultCustomMethod = method; + break; + + case nameof(AwaitBinder.GetResult) // custom awaitable + when matches( [null], argCount: 2 ): + GetResultCustomResultMethod = method; + break; } } ); diff --git a/src/Hyperbee.Expressions/Transformation/LoweringVisitor.cs b/src/Hyperbee.Expressions/Transformation/LoweringVisitor.cs index 9f0497f..881e511 100644 --- a/src/Hyperbee.Expressions/Transformation/LoweringVisitor.cs +++ b/src/Hyperbee.Expressions/Transformation/LoweringVisitor.cs @@ -173,10 +173,10 @@ protected override Expression VisitLoop( LoopExpression node ) sourceState.ResultVariable = resultVariable; joinState.ResultValue = resultVariable; - // TODO: This seems wrong, I shouldn't have to cast to GotoTransition (maybe all types have a TargetNode?) + if ( _states.TailState.Transition is not GotoTransition gotoTransition ) + throw new InvalidOperationException( "Loop must have a goto transition." ); - if ( _states.TailState.Transition is GotoTransition gotoTransition ) - gotoTransition.TargetNode = loopTransition.BodyNode; + gotoTransition.TargetNode = loopTransition.BodyNode; _states.ExitGroup( sourceState, loopTransition ); @@ -320,7 +320,7 @@ protected Expression VisitAwait( AwaitExpression node ) TargetNode = joinState, AwaiterVariable = awaiterVariable, ResultVariable = resultVariable, - GetResultMethod = awaitBinder.GetResultMethod + AwaitBinder = awaitBinder }; _states.AddJumpCase( completionState.NodeLabel, joinState.NodeLabel, sourceState.StateId ); @@ -331,7 +331,7 @@ protected Expression VisitAwait( AwaitExpression node ) StateId = sourceState.StateId, AwaiterVariable = awaiterVariable, CompletionNode = completionState, - GetAwaiterMethod = awaitBinder.GetAwaiterMethod, + AwaitBinder = awaitBinder, ConfigureAwait = node.ConfigureAwait }; diff --git a/src/Hyperbee.Expressions/Transformation/Transitions/AwaitResultTransition.cs b/src/Hyperbee.Expressions/Transformation/Transitions/AwaitResultTransition.cs index 7876795..550c6b0 100644 --- a/src/Hyperbee.Expressions/Transformation/Transitions/AwaitResultTransition.cs +++ b/src/Hyperbee.Expressions/Transformation/Transitions/AwaitResultTransition.cs @@ -9,13 +9,15 @@ public class AwaitResultTransition : Transition public ParameterExpression AwaiterVariable { get; set; } public ParameterExpression ResultVariable { get; set; } public NodeExpression TargetNode { get; set; } - public MethodInfo GetResultMethod { get; set; } + public AwaitBinder AwaitBinder { get; set; } internal override Expression Reduce( int order, NodeExpression expression, IHoistingSource resolverSource ) { - var getResultCall = GetResultMethod.IsStatic - ? Call( GetResultMethod, AwaiterVariable ) - : Call( AwaiterVariable, GetResultMethod ); + var getResultMethod = AwaitBinder.GetResultMethod; + + var getResultCall = getResultMethod.IsStatic + ? Call( getResultMethod, AwaiterVariable ) + : Call( Constant( AwaitBinder ), getResultMethod, AwaiterVariable ); if ( ResultVariable == null ) { diff --git a/src/Hyperbee.Expressions/Transformation/Transitions/AwaitTransition.cs b/src/Hyperbee.Expressions/Transformation/Transitions/AwaitTransition.cs index 606cfce..abed484 100644 --- a/src/Hyperbee.Expressions/Transformation/Transitions/AwaitTransition.cs +++ b/src/Hyperbee.Expressions/Transformation/Transitions/AwaitTransition.cs @@ -1,4 +1,5 @@ -using System.Linq.Expressions; +using System.Diagnostics; +using System.Linq.Expressions; using System.Reflection; using static System.Linq.Expressions.Expression; @@ -11,16 +12,18 @@ public class AwaitTransition : Transition public Expression Target { get; set; } public ParameterExpression AwaiterVariable { get; set; } public NodeExpression CompletionNode { get; set; } - public MethodInfo GetAwaiterMethod { get; set; } + public AwaitBinder AwaitBinder { get; set; } public bool ConfigureAwait { get; set; } internal override Expression Reduce( int order, NodeExpression expression, IHoistingSource resolverSource ) { var awaitable = Variable( Target.Type, "awaitable" ); - var getAwaiterCall = GetAwaiterMethod.IsStatic - ? Call( GetAwaiterMethod, awaitable, Constant( ConfigureAwait ) ) - : Call( awaitable, GetAwaiterMethod, Constant( ConfigureAwait ) ); + var getAwaiterMethod = AwaitBinder.GetAwaiterMethod; + + var getAwaiterCall = getAwaiterMethod.IsStatic + ? Call( getAwaiterMethod, awaitable, Constant( ConfigureAwait ) ) + : Call( Constant( AwaitBinder ), getAwaiterMethod, awaitable, Constant( ConfigureAwait ) ); // Get AwaitUnsafeOnCompleted( ref awaiter, ref state-machine ) var awaitUnsafeOnCompleted = resolverSource.BuilderField.Type diff --git a/test/Hyperbee.Expressions.Tests/CustomAwaiterTests.cs b/test/Hyperbee.Expressions.Tests/CustomAwaiterTests.cs index 6b4bbad..2a4986e 100644 --- a/test/Hyperbee.Expressions.Tests/CustomAwaiterTests.cs +++ b/test/Hyperbee.Expressions.Tests/CustomAwaiterTests.cs @@ -3,7 +3,7 @@ namespace Hyperbee.Expressions.Tests; -internal readonly struct LazyAwaiter : INotifyCompletion +public readonly struct LazyAwaiter : ICriticalNotifyCompletion //INotifyCompletion { private readonly Lazy _lazy; @@ -12,9 +12,10 @@ namespace Hyperbee.Expressions.Tests; public T GetResult() => _lazy.Value; public bool IsCompleted => true; public void OnCompleted( Action continuation ) { } + public void UnsafeOnCompleted( Action continuation ) { } } -internal static class LazyAwaiterExtensions +public static class LazyAwaiterExtensions { public static LazyAwaiter GetAwaiter( this Lazy lazy ) { @@ -44,4 +45,26 @@ public void TestCustomAwaiter_Await() Assert.AreEqual( 42, result, "The result should be 42." ); } + + [TestMethod] + public async Task TestCustomAwaiter_AsyncBlock() + { + // var lazy = new Lazy( () => 42 ); + // var result = await lazy; + + Expression> valueExpression = () => 42; + var lazyConstructor = typeof(Lazy).GetConstructor( [typeof(Func)] ); + var lazyExpression = Expression.New( lazyConstructor!, valueExpression ); + + var block = ExpressionExtensions.BlockAsync( + ExpressionExtensions.Await( lazyExpression, configureAwait: false ) + ); + + var lambda = Expression.Lambda>>( block ); + var compiledLambda = lambda.Compile(); + + var result = await compiledLambda(); + + Assert.AreEqual( 42, result, "The result should be 42." ); + } }