Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP switching to LinkDictionary to track scopes
Browse files Browse the repository at this point in the history
MattEdwardsWaggleBee committed Dec 17, 2024
1 parent 8192ee1 commit 1d104e1
Showing 7 changed files with 343 additions and 45 deletions.
14 changes: 9 additions & 5 deletions src/Hyperbee.Expressions/AsyncBlockExpression.cs
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
using System.Diagnostics;
using System.Linq.Expressions;
using Hyperbee.Expressions.CompilerServices;
using Hyperbee.Expressions.CompilerServices.Collections;

namespace Hyperbee.Expressions;

@@ -13,7 +14,7 @@ public class AsyncBlockExpression : Expression

public ReadOnlyCollection<Expression> Expressions { get; }
public ReadOnlyCollection<ParameterExpression> Variables { get; }
internal ReadOnlyCollection<ParameterExpression> ExternVariables { get; set; }
internal LinkedDictionary<ParameterExpression, ParameterExpression> ScopedVariables { get; set; }

public Expression Result => Expressions[^1];

@@ -25,15 +26,15 @@ internal AsyncBlockExpression( ReadOnlyCollection<ParameterExpression> variables
internal AsyncBlockExpression(
ReadOnlyCollection<ParameterExpression> variables,
ReadOnlyCollection<Expression> expressions,
ReadOnlyCollection<ParameterExpression> externVariables
LinkedDictionary<ParameterExpression, ParameterExpression> scopedVariables
)
{
if ( expressions == null || expressions.Count == 0 )
throw new ArgumentException( $"{nameof( AsyncBlockExpression )} must contain at least one expression.", nameof( expressions ) );

Variables = variables;
Expressions = expressions;
ExternVariables = externVariables;
ScopedVariables = scopedVariables;

_taskType = GetTaskType( Result.Type );
}
@@ -55,11 +56,14 @@ private LoweringInfo LoweringTransformer()
{
var visitor = new LoweringVisitor();

var scope = ScopedVariables ?? [];
scope.Push();

return visitor.Transform(
Result.Type,
[.. Variables],
[.. Expressions],
ExternVariables != null ? [.. ExternVariables] : []
scope
);
}
catch ( LoweringException ex )
@@ -76,7 +80,7 @@ protected override Expression VisitChildren( ExpressionVisitor visitor )
if ( Compare( newVariables, Variables ) && Compare( newExpressions, Expressions ) )
return this;

return new AsyncBlockExpression( newVariables, newExpressions, ExternVariables );
return new AsyncBlockExpression( newVariables, newExpressions, ScopedVariables );
}

internal static bool Compare<T>( ICollection<T> compare, IReadOnlyList<T> current )
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Immutable;

namespace Hyperbee.Expressions.CompilerServices.Collections;

// a dictionary comprised of a stack of dictionaries

public record LinkedDictionaryNode<TKey, TValue>
{
public string Name { get; init; }
public IDictionary<TKey, TValue> Dictionary { get; init; }
}

public enum KeyValueOptions
{
None,
All,
Current,
First
}

public interface ILinkedDictionary<TKey, TValue> : IDictionary<TKey, TValue>
{
IEqualityComparer<TKey> Comparer { get; }

string Name { get; }

IEnumerable<LinkedDictionaryNode<TKey, TValue>> Nodes();
IEnumerable<KeyValuePair<TKey, TValue>> Items( KeyValueOptions options = KeyValueOptions.None );

TValue this[TKey key, KeyValueOptions options] { set; } // let and set support
void Clear( KeyValueOptions options );
bool Remove( TKey key, KeyValueOptions options );

void Push( IEnumerable<KeyValuePair<TKey, TValue>> collection = default );
void Push( string name, IEnumerable<KeyValuePair<TKey, TValue>> collection = default );
LinkedDictionaryNode<TKey, TValue> Pop();
}

public class LinkedDictionary<TKey, TValue> : ILinkedDictionary<TKey, TValue>
{
public IEqualityComparer<TKey> Comparer { get; }

public ImmutableStack<LinkedDictionaryNode<TKey, TValue>> _nodes = [];

// ctors

public LinkedDictionary()
{
}

public LinkedDictionary( IEqualityComparer<TKey> comparer )
: this( null, comparer )
{
}

public LinkedDictionary( IEnumerable<KeyValuePair<TKey, TValue>> collection )
: this( collection, null )
{
}

public LinkedDictionary( IEnumerable<KeyValuePair<TKey, TValue>> collection, IEqualityComparer<TKey> comparer )
{
Comparer = comparer;

if ( collection != null )
Push( collection );
}

public LinkedDictionary( ILinkedDictionary<TKey, TValue> inner )
: this( inner, null )
{
}

public LinkedDictionary( ILinkedDictionary<TKey, TValue> inner, IEnumerable<KeyValuePair<TKey, TValue>> collection )
{
Comparer = inner.Comparer;
_nodes = ImmutableStack.CreateRange( inner.Nodes() );

if ( collection != null )
Push( collection );
}

// Stack

public void Push( IEnumerable<KeyValuePair<TKey, TValue>> collection = default )
{
Push( null, collection );
}

public void Push( string name, IEnumerable<KeyValuePair<TKey, TValue>> collection = default )
{
var dictionary = collection == null
? new ConcurrentDictionary<TKey, TValue>( Comparer )
: new ConcurrentDictionary<TKey, TValue>( collection, Comparer );

_nodes = _nodes.Push( new LinkedDictionaryNode<TKey, TValue>
{
Name = name ?? Guid.NewGuid().ToString(),
Dictionary = dictionary
} );
}

public LinkedDictionaryNode<TKey, TValue> Pop()
{
_nodes = _nodes.Pop( out var node );
return node;
}

// ILinkedDictionary

public string Name => _nodes.PeekRef().Name;

public TValue this[TKey key, KeyValueOptions options]
{
set
{
// support both 'let' and 'set' style assignments
//
// 'set' will assign value to the nearest existing key, or to the current node if no key is found.
// 'let' will assign value to the current node dictionary.

if ( options != KeyValueOptions.Current )
{
// find and set if exists in an inner node
foreach ( var scope in _nodes.Where( scope => scope.Dictionary.ContainsKey( key ) ) )
{
scope.Dictionary[key] = value;
return;
}
}

// set in current node
_nodes.PeekRef().Dictionary[key] = value;
}
}

public IEnumerable<LinkedDictionaryNode<TKey, TValue>> Nodes() => _nodes;

public IEnumerable<KeyValuePair<TKey, TValue>> Items( KeyValueOptions options = KeyValueOptions.None )
{
var keys = options == KeyValueOptions.First ? new HashSet<TKey>( Comparer ) : null;

foreach ( var scope in _nodes )
{
foreach ( var pair in scope.Dictionary )
{
if ( options == KeyValueOptions.First )
{
if ( keys!.Contains( pair.Key ) )
continue;

keys.Add( pair.Key );
}

yield return pair;
}

if ( options == KeyValueOptions.Current )
break;
}
}

public void Clear( KeyValueOptions options )
{
if ( options != KeyValueOptions.Current && options != KeyValueOptions.First )
{
_nodes.Pop( out var node );
_nodes = [node];
}

_nodes.PeekRef().Dictionary.Clear();
}

public bool Remove( TKey key, KeyValueOptions options )
{
var result = false;

foreach ( var _ in _nodes.Where( scope => scope.Dictionary.Remove( key ) ) )
{
result = true;

if ( options == KeyValueOptions.First )
break;
}

return result;
}

// IDictionary

public TValue this[TKey key]
{
get
{
if ( !TryGetValue( key, out var result ) )
throw new KeyNotFoundException();

return result;
}

set => this[key, KeyValueOptions.First] = value;
}

public bool IsReadOnly => false;
public int Count => _nodes.Count();

public void Add( TKey key, TValue value )
{
if ( ContainsKey( key ) )
throw new ArgumentException( "Key already exists." );

this[key, KeyValueOptions.First] = value;
}

public void Clear() => Clear( KeyValueOptions.All );

public bool ContainsKey( TKey key )
{
return _nodes.Any( scope => scope.Dictionary.ContainsKey( key ) );
}

public bool Remove( TKey key ) => Remove( key, KeyValueOptions.First );

public bool TryGetValue( TKey key, out TValue value )
{
foreach ( var scope in _nodes )
{
if ( scope.Dictionary.TryGetValue( key, out value ) )
return true;
}

value = default;
return false;
}

// ICollection

void ICollection<KeyValuePair<TKey, TValue>>.Add( KeyValuePair<TKey, TValue> item )
{
var (key, value) = item;
Add( key, value );
}

bool ICollection<KeyValuePair<TKey, TValue>>.Contains( KeyValuePair<TKey, TValue> item )
{
return _nodes.Any( scope => scope.Dictionary.Contains( item ) );
}

void ICollection<KeyValuePair<TKey, TValue>>.CopyTo( KeyValuePair<TKey, TValue>[] array, int arrayIndex )
{
ArgumentNullException.ThrowIfNull( array, nameof( array ) );

if ( (uint) arrayIndex > (uint) array.Length )
throw new IndexOutOfRangeException();

if ( array.Length - arrayIndex < Count )
throw new IndexOutOfRangeException( "Array plus offset is out of range." );

foreach ( var current in _nodes.Select( scope => scope.Dictionary ).Where( current => current.Count != 0 ) )
{
current.CopyTo( array, arrayIndex );
arrayIndex += current.Count;
}
}

bool ICollection<KeyValuePair<TKey, TValue>>.Remove( KeyValuePair<TKey, TValue> item )
{
return _nodes.Any( scope => scope.Dictionary.Remove( item ) );
}

ICollection<TKey> IDictionary<TKey, TValue>.Keys => Items( KeyValueOptions.First ).Select( pair => pair.Key ).ToArray();
ICollection<TValue> IDictionary<TKey, TValue>.Values => Items( KeyValueOptions.First ).Select( pair => pair.Value ).ToArray();

// Enumeration

public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() => Items( KeyValueOptions.All ).GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
11 changes: 8 additions & 3 deletions src/Hyperbee.Expressions/CompilerServices/LoweringVisitor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using Hyperbee.Expressions.CompilerServices.Collections;
using Hyperbee.Expressions.CompilerServices.Transitions;
using Hyperbee.Expressions.Visitors;

@@ -19,12 +20,16 @@ internal class LoweringVisitor : ExpressionVisitor

private VariableResolver _variableResolver;

public LoweringInfo Transform( Type resultType, ParameterExpression[] variables, Expression[] expressions, ParameterExpression[] externVariables )
public LoweringInfo Transform(
Type resultType,
ParameterExpression[] variables,
Expression[] expressions,
LinkedDictionary<ParameterExpression, ParameterExpression> externScopes = null )
{
ArgumentNullException.ThrowIfNull( expressions, nameof( expressions ) );
ArgumentOutOfRangeException.ThrowIfZero( expressions.Length, nameof( expressions ) );

_variableResolver = new VariableResolver( variables, _states );
_variableResolver = new VariableResolver( variables, externScopes, _states );
_finalResultVariable = CreateFinalResultVariable( resultType, _variableResolver );

VisitExpressions( expressions );
@@ -39,7 +44,7 @@ public LoweringInfo Transform( Type resultType, ParameterExpression[] variables,
HasFinalResultVariable = _hasFinalResultVariable,
AwaitCount = _awaitCount,
Variables = _variableResolver.GetMappedVariables(),
ExternVariables = externVariables
ExternScopes = externScopes
};

// helpers
Original file line number Diff line number Diff line change
@@ -84,10 +84,10 @@ public Expression CreateStateMachine( LoweringTransformer loweringTransformer, i
};

bodyExpression.AddRange( // Assign extern variables to state-machine
loweringInfo.ExternVariables.Select( externVariable =>
loweringInfo.ExternScopes.Items().Select( externVariable =>
Assign(
Field( stateMachineVariable, fields.First( field => field.Name == externVariable.Name ) ),
externVariable
Field( stateMachineVariable, fields.First( field => field.Name == externVariable.Value.Name ) ),
externVariable.Value
)
)
);
@@ -163,7 +163,7 @@ private Type CreateStateMachineType( StateMachineContext context, out FieldInfo[

// variables from other state-machines

foreach ( var parameterExpression in context.LoweringInfo.ExternVariables )
foreach ( var parameterExpression in context.LoweringInfo.ExternScopes.Items().Select( x => x.Value ) )
{
typeBuilder.DefineField(
parameterExpression.Name ?? parameterExpression.ToString(),
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

using System.Linq.Expressions;
using Hyperbee.Expressions.CompilerServices.Collections;

namespace Hyperbee.Expressions.CompilerServices;

@@ -22,7 +23,9 @@ internal record LoweringInfo
{
public IReadOnlyList<StateContext.Scope> Scopes { get; init; }
public IReadOnlyCollection<Expression> Variables { get; init; }
public IReadOnlyCollection<ParameterExpression> ExternVariables { get; init; }
public LinkedDictionary<ParameterExpression, ParameterExpression> ExternScopes { get; init; }

public int AwaitCount { get; init; }
public bool HasFinalResultVariable { get; init; }

}
65 changes: 35 additions & 30 deletions src/Hyperbee.Expressions/CompilerServices/VariableResolver.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.ObjectModel;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using Hyperbee.Expressions.CompilerServices.Collections;
using static System.Linq.Expressions.Expression;

namespace Hyperbee.Expressions.CompilerServices;
@@ -33,20 +34,28 @@ internal static class VariableName

private const int InitialCapacity = 8;

private readonly Dictionary<ParameterExpression, ParameterExpression> _variableMap = new( InitialCapacity );
private readonly Dictionary<ParameterExpression, ParameterExpression> _externVariableMap = new( InitialCapacity );
private readonly Dictionary<Type, ParameterExpression> _awaiters = new( InitialCapacity );

private readonly HashSet<ParameterExpression> _variables;
private readonly Stack<ICollection<ParameterExpression>> _variableBlockScope = new( InitialCapacity );

private readonly StateContext _states;

private readonly Dictionary<LabelTarget, Expression> _labels = [];

private readonly LinkedDictionary<ParameterExpression, ParameterExpression> _scopedVariables;

private int _variableId;

public VariableResolver( ParameterExpression[] variables, StateContext states )
private readonly Dictionary<ParameterExpression, ParameterExpression> _variableMap = new( InitialCapacity );

public VariableResolver(
ParameterExpression[] variables,
LinkedDictionary<ParameterExpression, ParameterExpression> scopedVariables,
StateContext states )
{
_states = states;
_variables = [.. variables];
_scopedVariables = scopedVariables ?? [];
_states = states;
}

// Helpers
@@ -63,7 +72,13 @@ public Expression GetResultVariable( Expression node, int stateId )
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public Expression GetAwaiterVariable( Type type, int stateId )
{
return AddVariable( Variable( type, VariableName.Awaiter( stateId, ref _variableId ) ) );
if( _awaiters.ContainsKey( type ) )
return _awaiters[type];

var awaiter = AddVariable( Variable( type, VariableName.Awaiter( stateId, ref _variableId ) ) );
_awaiters[type] = awaiter;

return AddVariable( awaiter );
}

[MethodImpl( MethodImplOptions.AggressiveInlining )]
@@ -104,28 +119,25 @@ internal bool TryResolveLabel( GotoExpression node, out Expression label )

protected override Expression VisitBlock( BlockExpression node )
{
var newVars = CreateLocalVariables( node.Variables );

_variableBlockScope.Push( newVars );
_scopedVariables.Push();

var newVars = CreateExternVariables( node.Variables );
var returnNode = base.VisitBlock( node.Update( newVars, node.Expressions ) );

_variableBlockScope.Pop();
_scopedVariables.Pop();

return returnNode;
}

#if FAST_COMPILER
protected override Expression VisitLambda<T>( Expression<T> node )
{
// Add Params to Externals
var newVars = CreateLocalVariables( node.Parameters );
_scopedVariables.Push();

_variableBlockScope.Push( newVars );
var newParams = CreateExternVariables( node.Parameters );
var returnNode = base.VisitLambda( node.Update( node.Body, newParams ) );

var returnNode = base.VisitLambda( node );

_variableBlockScope.Pop();
_scopedVariables.Pop();

return returnNode;
}
@@ -142,10 +154,7 @@ protected override Expression VisitExtension( Expression node )
{
if ( node is AsyncBlockExpression asyncBlockExpression )
{
asyncBlockExpression.ExternVariables = _variableBlockScope
.SelectMany( x => x )
.ToList()
.AsReadOnly();
asyncBlockExpression.ScopedVariables = _scopedVariables;
}

return base.VisitExtension( node );
@@ -200,16 +209,14 @@ private ParameterExpression AddVariable( ParameterExpression variable )
return variable;
}

List<ParameterExpression> CreateLocalVariables( ReadOnlyCollection<ParameterExpression> parameters )
private IEnumerable<ParameterExpression> CreateExternVariables( ReadOnlyCollection<ParameterExpression> parameters )
{
var vars = new List<ParameterExpression>();

foreach ( var variable in parameters )
{
if ( variable.Name!.StartsWith( "__extern." ) )
{
vars.Add( variable );
_externVariableMap.TryAdd( variable, variable );
_scopedVariables.TryAdd( variable, variable );
yield return variable;
continue;
}

@@ -218,11 +225,9 @@ List<ParameterExpression> CreateLocalVariables( ReadOnlyCollection<ParameterExpr
VariableName.ExternVariable( variable.Name, _states.TailState.StateId, ref _variableId )
);

_externVariableMap.TryAdd( variable, newVar );
vars.Add( newVar );
_scopedVariables.TryAdd( variable, newVar );
yield return newVar;
}

return vars;
}

private bool TryAddVariable(
@@ -232,7 +237,7 @@ out ParameterExpression updatedParameterExpression
)
{
if ( _variableMap.TryGetValue( parameter, out updatedParameterExpression ) ||
_externVariableMap.TryGetValue( parameter, out updatedParameterExpression ) )
_scopedVariables.TryGetValue( parameter, out updatedParameterExpression ) )
{
return true;
}
Original file line number Diff line number Diff line change
@@ -7,11 +7,11 @@
<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
</PropertyGroup>

<!--
<PropertyGroup>
<DefineConstants>$(DefineConstants);FAST_COMPILER</DefineConstants>
</PropertyGroup>
-->
<ItemGroup>
<PackageReference Include="coverlet.collector" Version="6.0.2">
<PrivateAssets>all</PrivateAssets>

0 comments on commit 1d104e1

Please sign in to comment.