diff --git a/src/Hyperbee.Expressions/AsyncBlockExpression.cs b/src/Hyperbee.Expressions/AsyncBlockExpression.cs index 37cd4b7..c79be04 100644 --- a/src/Hyperbee.Expressions/AsyncBlockExpression.cs +++ b/src/Hyperbee.Expressions/AsyncBlockExpression.cs @@ -2,6 +2,7 @@ using System.Diagnostics; using System.Linq.Expressions; using Hyperbee.Expressions.CompilerServices; +using Hyperbee.Expressions.CompilerServices.Collections; namespace Hyperbee.Expressions; @@ -13,7 +14,7 @@ public class AsyncBlockExpression : Expression public ReadOnlyCollection Expressions { get; } public ReadOnlyCollection Variables { get; } - internal ReadOnlyCollection ExternVariables { get; set; } + internal LinkedDictionary ScopedVariables { get; set; } public Expression Result => Expressions[^1]; @@ -25,7 +26,7 @@ internal AsyncBlockExpression( ReadOnlyCollection variables internal AsyncBlockExpression( ReadOnlyCollection variables, ReadOnlyCollection expressions, - ReadOnlyCollection externVariables + LinkedDictionary scopedVariables ) { if ( expressions == null || expressions.Count == 0 ) @@ -33,7 +34,7 @@ ReadOnlyCollection externVariables Variables = variables; Expressions = expressions; - ExternVariables = externVariables; + ScopedVariables = scopedVariables; _taskType = GetTaskType( Result.Type ); } @@ -55,11 +56,14 @@ private LoweringInfo LoweringTransformer() { var visitor = new LoweringVisitor(); + var scope = ScopedVariables ?? []; + scope.Push(); + return visitor.Transform( Result.Type, [.. Variables], [.. Expressions], - ExternVariables != null ? [.. ExternVariables] : [] + scope ); } catch ( LoweringException ex ) @@ -76,7 +80,7 @@ protected override Expression VisitChildren( ExpressionVisitor visitor ) if ( Compare( newVariables, Variables ) && Compare( newExpressions, Expressions ) ) return this; - return new AsyncBlockExpression( newVariables, newExpressions, ExternVariables ); + return new AsyncBlockExpression( newVariables, newExpressions, ScopedVariables ); } internal static bool Compare( ICollection compare, IReadOnlyList current ) diff --git a/src/Hyperbee.Expressions/CompilerServices/Collections/LinkedDictionary.cs b/src/Hyperbee.Expressions/CompilerServices/Collections/LinkedDictionary.cs new file mode 100644 index 0000000..4ba9d2d --- /dev/null +++ b/src/Hyperbee.Expressions/CompilerServices/Collections/LinkedDictionary.cs @@ -0,0 +1,281 @@ +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Immutable; + +namespace Hyperbee.Expressions.CompilerServices.Collections; + +// a dictionary comprised of a stack of dictionaries + +public record LinkedDictionaryNode +{ + public string Name { get; init; } + public IDictionary Dictionary { get; init; } +} + +public enum KeyValueOptions +{ + None, + All, + Current, + First +} + +public interface ILinkedDictionary : IDictionary +{ + IEqualityComparer Comparer { get; } + + string Name { get; } + + IEnumerable> Nodes(); + IEnumerable> Items( KeyValueOptions options = KeyValueOptions.None ); + + TValue this[TKey key, KeyValueOptions options] { set; } // let and set support + void Clear( KeyValueOptions options ); + bool Remove( TKey key, KeyValueOptions options ); + + void Push( IEnumerable> collection = default ); + void Push( string name, IEnumerable> collection = default ); + LinkedDictionaryNode Pop(); +} + +public class LinkedDictionary : ILinkedDictionary +{ + public IEqualityComparer Comparer { get; } + + public ImmutableStack> _nodes = []; + + // ctors + + public LinkedDictionary() + { + } + + public LinkedDictionary( IEqualityComparer comparer ) + : this( null, comparer ) + { + } + + public LinkedDictionary( IEnumerable> collection ) + : this( collection, null ) + { + } + + public LinkedDictionary( IEnumerable> collection, IEqualityComparer comparer ) + { + Comparer = comparer; + + if ( collection != null ) + Push( collection ); + } + + public LinkedDictionary( ILinkedDictionary inner ) + : this( inner, null ) + { + } + + public LinkedDictionary( ILinkedDictionary inner, IEnumerable> collection ) + { + Comparer = inner.Comparer; + _nodes = ImmutableStack.CreateRange( inner.Nodes() ); + + if ( collection != null ) + Push( collection ); + } + + // Stack + + public void Push( IEnumerable> collection = default ) + { + Push( null, collection ); + } + + public void Push( string name, IEnumerable> collection = default ) + { + var dictionary = collection == null + ? new ConcurrentDictionary( Comparer ) + : new ConcurrentDictionary( collection, Comparer ); + + _nodes = _nodes.Push( new LinkedDictionaryNode + { + Name = name ?? Guid.NewGuid().ToString(), + Dictionary = dictionary + } ); + } + + public LinkedDictionaryNode Pop() + { + _nodes = _nodes.Pop( out var node ); + return node; + } + + // ILinkedDictionary + + public string Name => _nodes.PeekRef().Name; + + public TValue this[TKey key, KeyValueOptions options] + { + set + { + // support both 'let' and 'set' style assignments + // + // 'set' will assign value to the nearest existing key, or to the current node if no key is found. + // 'let' will assign value to the current node dictionary. + + if ( options != KeyValueOptions.Current ) + { + // find and set if exists in an inner node + foreach ( var scope in _nodes.Where( scope => scope.Dictionary.ContainsKey( key ) ) ) + { + scope.Dictionary[key] = value; + return; + } + } + + // set in current node + _nodes.PeekRef().Dictionary[key] = value; + } + } + + public IEnumerable> Nodes() => _nodes; + + public IEnumerable> Items( KeyValueOptions options = KeyValueOptions.None ) + { + var keys = options == KeyValueOptions.First ? new HashSet( Comparer ) : null; + + foreach ( var scope in _nodes ) + { + foreach ( var pair in scope.Dictionary ) + { + if ( options == KeyValueOptions.First ) + { + if ( keys!.Contains( pair.Key ) ) + continue; + + keys.Add( pair.Key ); + } + + yield return pair; + } + + if ( options == KeyValueOptions.Current ) + break; + } + } + + public void Clear( KeyValueOptions options ) + { + if ( options != KeyValueOptions.Current && options != KeyValueOptions.First ) + { + _nodes.Pop( out var node ); + _nodes = [node]; + } + + _nodes.PeekRef().Dictionary.Clear(); + } + + public bool Remove( TKey key, KeyValueOptions options ) + { + var result = false; + + foreach ( var _ in _nodes.Where( scope => scope.Dictionary.Remove( key ) ) ) + { + result = true; + + if ( options == KeyValueOptions.First ) + break; + } + + return result; + } + + // IDictionary + + public TValue this[TKey key] + { + get + { + if ( !TryGetValue( key, out var result ) ) + throw new KeyNotFoundException(); + + return result; + } + + set => this[key, KeyValueOptions.First] = value; + } + + public bool IsReadOnly => false; + public int Count => _nodes.Count(); + + public void Add( TKey key, TValue value ) + { + if ( ContainsKey( key ) ) + throw new ArgumentException( "Key already exists." ); + + this[key, KeyValueOptions.First] = value; + } + + public void Clear() => Clear( KeyValueOptions.All ); + + public bool ContainsKey( TKey key ) + { + return _nodes.Any( scope => scope.Dictionary.ContainsKey( key ) ); + } + + public bool Remove( TKey key ) => Remove( key, KeyValueOptions.First ); + + public bool TryGetValue( TKey key, out TValue value ) + { + foreach ( var scope in _nodes ) + { + if ( scope.Dictionary.TryGetValue( key, out value ) ) + return true; + } + + value = default; + return false; + } + + // ICollection + + void ICollection>.Add( KeyValuePair item ) + { + var (key, value) = item; + Add( key, value ); + } + + bool ICollection>.Contains( KeyValuePair item ) + { + return _nodes.Any( scope => scope.Dictionary.Contains( item ) ); + } + + void ICollection>.CopyTo( KeyValuePair[] array, int arrayIndex ) + { + ArgumentNullException.ThrowIfNull( array, nameof( array ) ); + + if ( (uint) arrayIndex > (uint) array.Length ) + throw new IndexOutOfRangeException(); + + if ( array.Length - arrayIndex < Count ) + throw new IndexOutOfRangeException( "Array plus offset is out of range." ); + + foreach ( var current in _nodes.Select( scope => scope.Dictionary ).Where( current => current.Count != 0 ) ) + { + current.CopyTo( array, arrayIndex ); + arrayIndex += current.Count; + } + } + + bool ICollection>.Remove( KeyValuePair item ) + { + return _nodes.Any( scope => scope.Dictionary.Remove( item ) ); + } + + ICollection IDictionary.Keys => Items( KeyValueOptions.First ).Select( pair => pair.Key ).ToArray(); + ICollection IDictionary.Values => Items( KeyValueOptions.First ).Select( pair => pair.Value ).ToArray(); + + // Enumeration + + public IEnumerator> GetEnumerator() => Items( KeyValueOptions.All ).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); +} diff --git a/src/Hyperbee.Expressions/CompilerServices/LoweringVisitor.cs b/src/Hyperbee.Expressions/CompilerServices/LoweringVisitor.cs index 0a49b8f..2ddb583 100644 --- a/src/Hyperbee.Expressions/CompilerServices/LoweringVisitor.cs +++ b/src/Hyperbee.Expressions/CompilerServices/LoweringVisitor.cs @@ -1,5 +1,6 @@ using System.Linq.Expressions; using System.Runtime.CompilerServices; +using Hyperbee.Expressions.CompilerServices.Collections; using Hyperbee.Expressions.CompilerServices.Transitions; using Hyperbee.Expressions.Visitors; @@ -19,12 +20,16 @@ internal class LoweringVisitor : ExpressionVisitor private VariableResolver _variableResolver; - public LoweringInfo Transform( Type resultType, ParameterExpression[] variables, Expression[] expressions, ParameterExpression[] externVariables ) + public LoweringInfo Transform( + Type resultType, + ParameterExpression[] variables, + Expression[] expressions, + LinkedDictionary externScopes = null ) { ArgumentNullException.ThrowIfNull( expressions, nameof( expressions ) ); ArgumentOutOfRangeException.ThrowIfZero( expressions.Length, nameof( expressions ) ); - _variableResolver = new VariableResolver( variables, _states ); + _variableResolver = new VariableResolver( variables, externScopes, _states ); _finalResultVariable = CreateFinalResultVariable( resultType, _variableResolver ); VisitExpressions( expressions ); @@ -39,7 +44,7 @@ public LoweringInfo Transform( Type resultType, ParameterExpression[] variables, HasFinalResultVariable = _hasFinalResultVariable, AwaitCount = _awaitCount, Variables = _variableResolver.GetMappedVariables(), - ExternVariables = externVariables + ExternScopes = externScopes }; // helpers diff --git a/src/Hyperbee.Expressions/CompilerServices/StateMachineBuilder.cs b/src/Hyperbee.Expressions/CompilerServices/StateMachineBuilder.cs index ea9117f..37b1bee 100644 --- a/src/Hyperbee.Expressions/CompilerServices/StateMachineBuilder.cs +++ b/src/Hyperbee.Expressions/CompilerServices/StateMachineBuilder.cs @@ -84,10 +84,10 @@ public Expression CreateStateMachine( LoweringTransformer loweringTransformer, i }; bodyExpression.AddRange( // Assign extern variables to state-machine - loweringInfo.ExternVariables.Select( externVariable => + loweringInfo.ExternScopes.Items().Select( externVariable => Assign( - Field( stateMachineVariable, fields.First( field => field.Name == externVariable.Name ) ), - externVariable + Field( stateMachineVariable, fields.First( field => field.Name == externVariable.Value.Name ) ), + externVariable.Value ) ) ); @@ -163,7 +163,7 @@ private Type CreateStateMachineType( StateMachineContext context, out FieldInfo[ // variables from other state-machines - foreach ( var parameterExpression in context.LoweringInfo.ExternVariables ) + foreach ( var parameterExpression in context.LoweringInfo.ExternScopes.Items().Select( x => x.Value ) ) { typeBuilder.DefineField( parameterExpression.Name ?? parameterExpression.ToString(), diff --git a/src/Hyperbee.Expressions/CompilerServices/StateMachineContext.cs b/src/Hyperbee.Expressions/CompilerServices/StateMachineContext.cs index 5df8b95..cb40c9d 100644 --- a/src/Hyperbee.Expressions/CompilerServices/StateMachineContext.cs +++ b/src/Hyperbee.Expressions/CompilerServices/StateMachineContext.cs @@ -1,5 +1,6 @@  using System.Linq.Expressions; +using Hyperbee.Expressions.CompilerServices.Collections; namespace Hyperbee.Expressions.CompilerServices; @@ -22,7 +23,9 @@ internal record LoweringInfo { public IReadOnlyList Scopes { get; init; } public IReadOnlyCollection Variables { get; init; } - public IReadOnlyCollection ExternVariables { get; init; } + public LinkedDictionary ExternScopes { get; init; } + public int AwaitCount { get; init; } public bool HasFinalResultVariable { get; init; } + } diff --git a/src/Hyperbee.Expressions/CompilerServices/VariableResolver.cs b/src/Hyperbee.Expressions/CompilerServices/VariableResolver.cs index 3478575..b758659 100644 --- a/src/Hyperbee.Expressions/CompilerServices/VariableResolver.cs +++ b/src/Hyperbee.Expressions/CompilerServices/VariableResolver.cs @@ -1,6 +1,7 @@ using System.Collections.ObjectModel; using System.Linq.Expressions; using System.Runtime.CompilerServices; +using Hyperbee.Expressions.CompilerServices.Collections; using static System.Linq.Expressions.Expression; namespace Hyperbee.Expressions.CompilerServices; @@ -33,20 +34,28 @@ internal static class VariableName private const int InitialCapacity = 8; - private readonly Dictionary _variableMap = new( InitialCapacity ); - private readonly Dictionary _externVariableMap = new( InitialCapacity ); + private readonly Dictionary _awaiters = new( InitialCapacity ); + private readonly HashSet _variables; - private readonly Stack> _variableBlockScope = new( InitialCapacity ); + private readonly StateContext _states; private readonly Dictionary _labels = []; + + private readonly LinkedDictionary _scopedVariables; private int _variableId; - public VariableResolver( ParameterExpression[] variables, StateContext states ) + private readonly Dictionary _variableMap = new( InitialCapacity ); + + public VariableResolver( + ParameterExpression[] variables, + LinkedDictionary scopedVariables, + StateContext states ) { - _states = states; _variables = [.. variables]; + _scopedVariables = scopedVariables ?? []; + _states = states; } // Helpers @@ -63,7 +72,13 @@ public Expression GetResultVariable( Expression node, int stateId ) [MethodImpl( MethodImplOptions.AggressiveInlining )] public Expression GetAwaiterVariable( Type type, int stateId ) { - return AddVariable( Variable( type, VariableName.Awaiter( stateId, ref _variableId ) ) ); + if( _awaiters.ContainsKey( type ) ) + return _awaiters[type]; + + var awaiter = AddVariable( Variable( type, VariableName.Awaiter( stateId, ref _variableId ) ) ); + _awaiters[type] = awaiter; + + return AddVariable( awaiter ); } [MethodImpl( MethodImplOptions.AggressiveInlining )] @@ -104,13 +119,12 @@ internal bool TryResolveLabel( GotoExpression node, out Expression label ) protected override Expression VisitBlock( BlockExpression node ) { - var newVars = CreateLocalVariables( node.Variables ); - - _variableBlockScope.Push( newVars ); + _scopedVariables.Push(); + var newVars = CreateExternVariables( node.Variables ); var returnNode = base.VisitBlock( node.Update( newVars, node.Expressions ) ); - _variableBlockScope.Pop(); + _scopedVariables.Pop(); return returnNode; } @@ -118,14 +132,12 @@ protected override Expression VisitBlock( BlockExpression node ) #if FAST_COMPILER protected override Expression VisitLambda( Expression node ) { - // Add Params to Externals - var newVars = CreateLocalVariables( node.Parameters ); + _scopedVariables.Push(); - _variableBlockScope.Push( newVars ); + var newParams = CreateExternVariables( node.Parameters ); + var returnNode = base.VisitLambda( node.Update( node.Body, newParams ) ); - var returnNode = base.VisitLambda( node ); - - _variableBlockScope.Pop(); + _scopedVariables.Pop(); return returnNode; } @@ -142,10 +154,7 @@ protected override Expression VisitExtension( Expression node ) { if ( node is AsyncBlockExpression asyncBlockExpression ) { - asyncBlockExpression.ExternVariables = _variableBlockScope - .SelectMany( x => x ) - .ToList() - .AsReadOnly(); + asyncBlockExpression.ScopedVariables = _scopedVariables; } return base.VisitExtension( node ); @@ -200,16 +209,14 @@ private ParameterExpression AddVariable( ParameterExpression variable ) return variable; } - List CreateLocalVariables( ReadOnlyCollection parameters ) + private IEnumerable CreateExternVariables( ReadOnlyCollection parameters ) { - var vars = new List(); - foreach ( var variable in parameters ) { if ( variable.Name!.StartsWith( "__extern." ) ) { - vars.Add( variable ); - _externVariableMap.TryAdd( variable, variable ); + _scopedVariables.TryAdd( variable, variable ); + yield return variable; continue; } @@ -218,11 +225,9 @@ List CreateLocalVariables( ReadOnlyCollectionfalse true - + all