Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
bfarmer67 committed Oct 23, 2024
2 parents a2b09d4 + bf0e669 commit 59bc2a6
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/Hyperbee.Expressions/Transformation/AwaitBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ internal TResult AwaitResult<TAwaitable, TAwaiter, TResult>( ref TAwaitable awai

default:
var awaiter = GetAwaiter<TAwaitable, TAwaiter>( ref awaitable, configureAwait );
return GetResultValue<TAwaiter, TResult>( ref awaiter );
return GetResult<TAwaiter, TResult>( ref awaiter );
}
}

Expand Down Expand Up @@ -127,7 +127,7 @@ internal void GetResult<TAwaiter>( ref TAwaiter awaiter )
internal static TResult GetResult<TResult>( ref ConfiguredValueTaskAwaitable<TResult>.ConfiguredValueTaskAwaiter awaiter ) => awaiter.GetResult();

[MethodImpl( MethodImplOptions.AggressiveInlining )]
internal TResult GetResultValue<TAwaiter, TResult>( ref TAwaiter awaiter )
internal TResult GetResult<TAwaiter, TResult>( ref TAwaiter awaiter )
{
if ( GetResultImplDelegate == null )
throw new InvalidOperationException( $"The {nameof( GetResultImplDelegate )} is not set for {awaiter.GetType()}." );
Expand Down
26 changes: 22 additions & 4 deletions src/Hyperbee.Expressions/Transformation/AwaitBinderFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ internal static class AwaitBinderFactory
private static MethodInfo GetAwaiterTaskResultMethod;
private static MethodInfo GetAwaiterValueTaskMethod;
private static MethodInfo GetAwaiterValueTaskResultMethod;
private static MethodInfo GetAwaiterCustomMethod;

private static MethodInfo GetResultTaskMethod;
private static MethodInfo GetResultTaskResultMethod;
private static MethodInfo GetResultValueTaskMethod;
private static MethodInfo GetResultValueTaskResultMethod;
private static MethodInfo GetResultCustomMethod;
private static MethodInfo GetResultCustomResultMethod;

private static MethodInfo CreateGetAwaiterImplDelegateMethod;
private static MethodInfo CreateGetResultImplDelegateMethod;
Expand Down Expand Up @@ -157,14 +160,14 @@ private static AwaitBinder CreateAwaitableTypeAwaitBinder( Type awaitableType )
var awaiterResultType = awaiterType.GetGenericArguments()[0];

awaitMethod = AwaitResultMethod.MakeGenericMethod( awaitableType, awaiterType, awaiterResultType );
getAwaiterMethod = GetAwaiterTaskResultMethod.MakeGenericMethod( awaiterResultType );
getResultMethod = GetResultTaskResultMethod.MakeGenericMethod( awaiterResultType );
getAwaiterMethod = GetAwaiterCustomMethod.MakeGenericMethod( awaitableType, awaiterType );
getResultMethod = GetResultCustomResultMethod.MakeGenericMethod( awaiterType, awaiterResultType );
}
else
{
awaitMethod = AwaitMethod.MakeGenericMethod( awaitableType, awaiterType );
getAwaiterMethod = GetAwaiterTaskMethod.MakeGenericMethod();
getResultMethod = GetResultTaskMethod.MakeGenericMethod();
getAwaiterMethod = GetAwaiterCustomMethod.MakeGenericMethod( awaitableType, awaiterType );
getResultMethod = GetResultCustomMethod.MakeGenericMethod( awaiterType );
}

// Return the AwaitBinder
Expand Down Expand Up @@ -326,6 +329,11 @@ when matches( [typeof( ValueTask ).MakeByRefType(), typeof( bool )] ):
GetAwaiterValueTaskMethod = method;
break;

case nameof(AwaitBinder.GetAwaiter) // custom awaitable
when matches( [null, typeof(bool)], argCount: 2 ):
GetAwaiterCustomMethod = method;
break;

case nameof( AwaitBinder.GetResult )
when matches( [typeof( ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ).MakeByRefType()] ):
GetResultTaskMethod = method;
Expand All @@ -345,6 +353,16 @@ when matches( [typeof( ConfiguredTaskAwaitable<>.ConfiguredTaskAwaiter ).MakeByR
when matches( [typeof( ConfiguredValueTaskAwaitable<>.ConfiguredValueTaskAwaiter ).MakeByRefType()], argCount: 1 ):
GetResultValueTaskResultMethod = method;
break;

case nameof(AwaitBinder.GetResult) // custom awaitable
when matches( [null], argCount: 1 ):
GetResultCustomMethod = method;
break;

case nameof(AwaitBinder.GetResult) // custom awaitable
when matches( [null], argCount: 2 ):
GetResultCustomResultMethod = method;
break;
}
}
);
Expand Down
10 changes: 5 additions & 5 deletions src/Hyperbee.Expressions/Transformation/LoweringVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ protected override Expression VisitLoop( LoopExpression node )
sourceState.ResultVariable = resultVariable;
joinState.ResultValue = resultVariable;

// TODO: This seems wrong, I shouldn't have to cast to GotoTransition (maybe all types have a TargetNode?)
if ( _states.TailState.Transition is not GotoTransition gotoTransition )
throw new InvalidOperationException( "Loop must have a goto transition." );

if ( _states.TailState.Transition is GotoTransition gotoTransition )
gotoTransition.TargetNode = loopTransition.BodyNode;
gotoTransition.TargetNode = loopTransition.BodyNode;

_states.ExitGroup( sourceState, loopTransition );

Expand Down Expand Up @@ -320,7 +320,7 @@ protected Expression VisitAwait( AwaitExpression node )
TargetNode = joinState,
AwaiterVariable = awaiterVariable,
ResultVariable = resultVariable,
GetResultMethod = awaitBinder.GetResultMethod
AwaitBinder = awaitBinder
};

_states.AddJumpCase( completionState.NodeLabel, joinState.NodeLabel, sourceState.StateId );
Expand All @@ -331,7 +331,7 @@ protected Expression VisitAwait( AwaitExpression node )
StateId = sourceState.StateId,
AwaiterVariable = awaiterVariable,
CompletionNode = completionState,
GetAwaiterMethod = awaitBinder.GetAwaiterMethod,
AwaitBinder = awaitBinder,
ConfigureAwait = node.ConfigureAwait
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ public class AwaitResultTransition : Transition
public ParameterExpression AwaiterVariable { get; set; }
public ParameterExpression ResultVariable { get; set; }
public NodeExpression TargetNode { get; set; }
public MethodInfo GetResultMethod { get; set; }
public AwaitBinder AwaitBinder { get; set; }

internal override Expression Reduce( int order, NodeExpression expression, IHoistingSource resolverSource )
{
var getResultCall = GetResultMethod.IsStatic
? Call( GetResultMethod, AwaiterVariable )
: Call( AwaiterVariable, GetResultMethod );
var getResultMethod = AwaitBinder.GetResultMethod;

var getResultCall = getResultMethod.IsStatic
? Call( getResultMethod, AwaiterVariable )
: Call( Constant( AwaitBinder ), getResultMethod, AwaiterVariable );

if ( ResultVariable == null )
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using System.Diagnostics;
using System.Linq.Expressions;
using System.Reflection;
using static System.Linq.Expressions.Expression;

Expand All @@ -11,16 +12,18 @@ public class AwaitTransition : Transition
public Expression Target { get; set; }
public ParameterExpression AwaiterVariable { get; set; }
public NodeExpression CompletionNode { get; set; }
public MethodInfo GetAwaiterMethod { get; set; }
public AwaitBinder AwaitBinder { get; set; }
public bool ConfigureAwait { get; set; }

internal override Expression Reduce( int order, NodeExpression expression, IHoistingSource resolverSource )
{
var awaitable = Variable( Target.Type, "awaitable" );

var getAwaiterCall = GetAwaiterMethod.IsStatic
? Call( GetAwaiterMethod, awaitable, Constant( ConfigureAwait ) )
: Call( awaitable, GetAwaiterMethod, Constant( ConfigureAwait ) );
var getAwaiterMethod = AwaitBinder.GetAwaiterMethod;

var getAwaiterCall = getAwaiterMethod.IsStatic
? Call( getAwaiterMethod, awaitable, Constant( ConfigureAwait ) )
: Call( Constant( AwaitBinder ), getAwaiterMethod, awaitable, Constant( ConfigureAwait ) );

// Get AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>( ref awaiter, ref state-machine )
var awaitUnsafeOnCompleted = resolverSource.BuilderField.Type
Expand Down
27 changes: 25 additions & 2 deletions test/Hyperbee.Expressions.Tests/CustomAwaiterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Hyperbee.Expressions.Tests;

internal readonly struct LazyAwaiter<T> : INotifyCompletion
public readonly struct LazyAwaiter<T> : ICriticalNotifyCompletion //INotifyCompletion
{
private readonly Lazy<T> _lazy;

Expand All @@ -12,9 +12,10 @@ namespace Hyperbee.Expressions.Tests;
public T GetResult() => _lazy.Value;
public bool IsCompleted => true;
public void OnCompleted( Action continuation ) { }
public void UnsafeOnCompleted( Action continuation ) { }
}

internal static class LazyAwaiterExtensions
public static class LazyAwaiterExtensions
{
public static LazyAwaiter<T> GetAwaiter<T>( this Lazy<T> lazy )
{
Expand Down Expand Up @@ -44,4 +45,26 @@ public void TestCustomAwaiter_Await()

Assert.AreEqual( 42, result, "The result should be 42." );
}

[TestMethod]
public async Task TestCustomAwaiter_AsyncBlock()
{
// var lazy = new Lazy<int>( () => 42 );
// var result = await lazy;

Expression<Func<int>> valueExpression = () => 42;
var lazyConstructor = typeof(Lazy<int>).GetConstructor( [typeof(Func<int>)] );
var lazyExpression = Expression.New( lazyConstructor!, valueExpression );

var block = ExpressionExtensions.BlockAsync(
ExpressionExtensions.Await( lazyExpression, configureAwait: false )
);

var lambda = Expression.Lambda<Func<Task<int>>>( block );
var compiledLambda = lambda.Compile();

var result = await compiledLambda();

Assert.AreEqual( 42, result, "The result should be 42." );
}
}

0 comments on commit 59bc2a6

Please sign in to comment.