diff --git a/examples/EdgeDB.Examples.CSharp/Examples/QueryBuilder.cs b/examples/EdgeDB.Examples.CSharp/Examples/QueryBuilder.cs index b59856a..6310f87 100644 --- a/examples/EdgeDB.Examples.CSharp/Examples/QueryBuilder.cs +++ b/examples/EdgeDB.Examples.CSharp/Examples/QueryBuilder.cs @@ -35,6 +35,10 @@ public async Task ExecuteAsync(EdgeDBClient client) { try { + var test = QueryBuilder + .SelectExpression(ctx => EdgeQL.Count(ctx.SubQuery(QueryBuilder.Select()))) + .Compile(true); + await QueryBuilderDemo(client); } catch (Exception x) diff --git a/src/EdgeDB.Net.QueryBuilder/Extensions/IEnumerableExtensions.cs b/src/EdgeDB.Net.QueryBuilder/Extensions/IEnumerableExtensions.cs new file mode 100644 index 0000000..d1bf372 --- /dev/null +++ b/src/EdgeDB.Net.QueryBuilder/Extensions/IEnumerableExtensions.cs @@ -0,0 +1,24 @@ +namespace EdgeDB; + +public static class EnumerableExtensions +{ + public static Dictionary> ToBucketedDictionary(this IEnumerable collection, + Func selectKey, Func selectValue) + where T: notnull + { + var dict = new Dictionary>(); + + foreach (var item in collection) + { + var key = selectKey(item); + var value = selectValue(item); + + if (!dict.TryGetValue(key, out var bucket)) + dict[key] = bucket = new(); + + bucket.AddLast(value); + } + + return dict; + } +} diff --git a/src/EdgeDB.Net.QueryBuilder/Extensions/RangeExtensions.cs b/src/EdgeDB.Net.QueryBuilder/Extensions/RangeExtensions.cs index 444f226..0fda7fd 100644 --- a/src/EdgeDB.Net.QueryBuilder/Extensions/RangeExtensions.cs +++ b/src/EdgeDB.Net.QueryBuilder/Extensions/RangeExtensions.cs @@ -11,4 +11,7 @@ public static Range Normalize(this Range range) { return range.Start..(range.End.Value + range.Start.Value); } + + public static bool Contains(this Range range, int point) + => range.Start.Value <= point && range.End.Value >= point; } diff --git a/src/EdgeDB.Net.QueryBuilder/Extensions/TypeExtensions.cs b/src/EdgeDB.Net.QueryBuilder/Extensions/TypeExtensions.cs index 11b8731..5e43336 100644 --- a/src/EdgeDB.Net.QueryBuilder/Extensions/TypeExtensions.cs +++ b/src/EdgeDB.Net.QueryBuilder/Extensions/TypeExtensions.cs @@ -10,6 +10,25 @@ namespace EdgeDB { internal static class TypeExtensions { + public static bool References(this Type type, Type other) + => References(type, other, true, []); + + private static bool References(Type type, Type other, bool checkInterfaces, HashSet hasChecked) + { + if (!hasChecked.Add(type)) + return false; + + if (type == other) + return true; + + return type switch + { + { IsArray: true } => References(type.GetElementType()!, other, true, hasChecked), + { IsGenericType: true } => type.GetGenericArguments().Any(x => References(x, other, true, hasChecked)), + _ => (type.BaseType?.References(other) ?? false) || (checkInterfaces && type.GetInterfaces().Any(x => References(x, other, false, hasChecked))) + }; + } + public static IEnumerable GetEdgeDBTargetProperties(this Type type, bool excludeId = false) => type.GetProperties().Where(x => x.GetCustomAttribute() == null && !(excludeId && x.Name == "Id" && (x.PropertyType == typeof(Guid) || x.PropertyType == typeof(Guid?)))); diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/GlobalReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/GlobalReducer.cs index 3e89f4a..aa0cbfd 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/GlobalReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/GlobalReducer.cs @@ -1,10 +1,11 @@ using EdgeDB.QueryNodes; +using System.Diagnostics.CodeAnalysis; namespace EdgeDB; internal sealed class GlobalReducer : IReducer { - public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue shouldRunAfter) + public void Reduce(IQueryBuilder builder, QueryWriter writer) { if (!writer.Markers.MarkersByType.TryGetValue(MarkerType.QueryNode, out var nodes)) return; @@ -42,16 +43,31 @@ public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue sh withNode.Remove(); withNode.Kill(); } - - if(reducedCount > 0) - shouldRunAfter.Enqueue(QueryReducer.Get()); } - private bool CanReduceWithNestedTypeSafety(QueryGlobal global, Marker marker, QueryWriter writer) + private static bool CanReduceWithNestedTypeSafety(QueryGlobal global, Marker marker, QueryWriter writer) { - // TODO: - // we cant reduce a global when: - // - is a query builder inside of a nested query that selects the same type. + var bannedTypes = global switch + { + {Reference: IQueryBuilder builder} => builder.Nodes.Select(x => x.GetOperatingType()).ToHashSet(), + {Value: IQueryBuilder builder} => builder.Nodes.Select(x => x.GetOperatingType()).ToHashSet(), + _ => null + }; + + if (bannedTypes is null) + return false; + + var nodes = writer.Markers.GetParents(marker).Where(x => x.Type is MarkerType.QueryNode); + + + foreach (var node in nodes) + { + if (node.Metadata is not QueryNodeMetadata nodeMetadata) + return false; + + if (bannedTypes.Contains(nodeMetadata.Node.OperatingType)) + return false; + } return true; } diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/IReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/IReducer.cs index 3acc98d..ca79f70 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/IReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/IReducer.cs @@ -2,5 +2,5 @@ internal interface IReducer { - void Reduce(IQueryBuilder builder, QueryWriter writer, Queue shouldRunAfter); + void Reduce(IQueryBuilder builder, QueryWriter writer); } diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/NestedSelectReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/NestedSelectReducer.cs index 6dfebf3..1f9a9f3 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/NestedSelectReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/NestedSelectReducer.cs @@ -13,8 +13,7 @@ internal sealed class NestedSelectReducer : IReducer /// /// /// - /// - public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue shouldRunAfter) + public void Reduce(IQueryBuilder builder, QueryWriter writer) { if (!writer.Markers.MarkersByType.TryGetValue(MarkerType.QueryNode, out var nodes)) return; diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/QueryReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/QueryReducer.cs index 99bd966..0bc4172 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/QueryReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/QueryReducer.cs @@ -4,44 +4,19 @@ namespace EdgeDB; internal static class QueryReducer { - private static readonly Dictionary _reducers; - - private static readonly Type[] ExcludedReducers = + // important: order matters here + private static readonly IReducer[] _reducers = [ - typeof(WhitespaceReducer) + new NestedSelectReducer(), + new GlobalReducer(), + new TypeCastReducer(), + new SelectShapeReducer(), + new WhitespaceReducer() ]; - static QueryReducer() - { - _reducers = typeof(QueryReducer).Assembly.GetTypes() - .Where(x => x.IsAssignableTo(typeof(IReducer)) && x.IsClass && !ExcludedReducers.Contains(x)) - .ToDictionary(x => x, x => (IReducer)Activator.CreateInstance(x)!); - } - - public static T Get() where T : IReducer - { - if (!_reducers.TryGetValue(typeof(T), out var reducer)) - throw new KeyNotFoundException($"Could not find an instance of the reducer {typeof(T).Name}"); - - if (reducer is not T asType) - throw new InvalidCastException( - $"Expected reducer {reducer?.GetType().Name ?? "null"} to be of type {typeof(T).Name}"); - - return asType; - } - - public static void Apply(IQueryBuilder builder, QueryWriter writer) { - var shouldRunAfter = new Queue(); - foreach (var (_, reducer) in _reducers) - { - reducer.Reduce(builder, writer, shouldRunAfter); - - while(shouldRunAfter.TryDequeue(out var subReducer)) - subReducer.Reduce(builder, writer, shouldRunAfter); - } - - WhitespaceReducer.Instance.Reduce(builder, writer, shouldRunAfter); + for(var i = 0; i != _reducers.Length; i++) + _reducers[i].Reduce(builder, writer); } } diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/SelectShapeReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/SelectShapeReducer.cs new file mode 100644 index 0000000..b707272 --- /dev/null +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/SelectShapeReducer.cs @@ -0,0 +1,65 @@ +using EdgeDB.QueryNodes; + +namespace EdgeDB; + +internal sealed class SelectShapeReducer : IReducer +{ + public void Reduce(IQueryBuilder builder, QueryWriter writer) + { + // return early if theres no query nodes + if (!writer.Markers.MarkersByType.TryGetValue(MarkerType.QueryNode, out var selects)) + return; + + foreach (var select in selects.Where(x => x.Metadata is QueryNodeMetadata {Node: SelectNode})) + { + var shape = writer.Markers.GetDirectChildrenOfType(select, MarkerType.Shape).FirstOrDefault(); + + if (shape is null) + continue; + + var parents = writer.Markers.GetParents(select).ToBucketedDictionary(x => x.Type, x => x); + + // shapes are non-persistent in with statements + if (parents.TryGetValue(MarkerType.GlobalDeclaration, out _)) + RemoveShape(writer, shape); + // shapes are not used in functions that don't return the provided input + else if (parents.TryGetValue(MarkerType.Function, out var functions)) + { + // if the function contains no args, return early + if (!parents.TryGetValue(MarkerType.FunctionArg, out var argMarkers)) + continue; + + foreach (var function in functions) + { + // pull the argument marker that represents our query node + var ourArgument = argMarkers.MinBy(x => x.SizeDistance(function)); + + if (ourArgument?.Metadata is not FunctionArgumentMetadata argumentMetadata || + function.Metadata is not FunctionMetadata functionMetadata) + continue; + + // get all the arguments of the function + var args = writer.Markers.GetDirectChildrenOfType(function, MarkerType.FunctionArg).ToList(); + + // resolve the method info for the function + if (!functionMetadata.TryResolveExactFunctionInfo(args, out var methodInfo)) + continue; + + // remove the shape if the return type of the function doesn't include the result of the select + if (!methodInfo.ReturnType.References(methodInfo.GetParameters()[argumentMetadata.Index] + .ParameterType)) + RemoveShape(writer, shape); + } + } + } + } + + private static void RemoveShape(QueryWriter writer, Marker marker) + { + // remove whitespace around the shape + WhitespaceReducer.TrimWhitespaceAround(writer, marker); + + marker.Remove(); + marker.Kill(); + } +} diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/TypeCastReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/TypeCastReducer.cs index 19c6fe4..090c62a 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/TypeCastReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/TypeCastReducer.cs @@ -6,7 +6,7 @@ namespace EdgeDB; internal sealed class TypeCastReducer : IReducer { - public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue shouldRunAfter) + public void Reduce(IQueryBuilder builder, QueryWriter writer) { foreach (var marker in writer.Markers) { diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/WhitespaceReducer.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/WhitespaceReducer.cs index 17487c1..85a6871 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/WhitespaceReducer.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Reducers/WhitespaceReducer.cs @@ -4,7 +4,7 @@ internal sealed class WhitespaceReducer : IReducer { public static readonly WhitespaceReducer Instance = new(); - public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue shouldRunAfter) + public void Reduce(IQueryBuilder builder, QueryWriter writer) { TrimStart(writer); TrimEnd(writer); @@ -12,43 +12,49 @@ public void Reduce(IQueryBuilder builder, QueryWriter writer, Queue sh private void TrimEnd(QueryWriter writer) { - var token = writer.Tokens.Last; + if (writer.Tokens.Last is null) + return; - var count = 0; - while (token is not null && IsWhitespace(token.Value)) - { - count++; - - if (token.Previous is null) - break; + Trim(writer, writer.TailIndex, writer.Tokens.Last, false); + } - token = token.Previous; - } + private void TrimStart(QueryWriter writer) + { + if (writer.Tokens.First is null) + return; - if (count > 0) - writer.Remove(writer.Tokens.Count - count, token!, count); + Trim(writer, 0, writer.Tokens.First, true); } - private void TrimStart(QueryWriter writer) + private static void Trim(QueryWriter writer, int position, LooseLinkedList.Node node, bool dir) { - var token = writer.Tokens.First; + var token = node; + var lastValidNode = node; var count = 0; while (token is not null && IsWhitespace(token.Value)) { count++; - token = token.Next; + lastValidNode = token; + token = dir ? token.Next : token.Previous; } - if (count > 0) - writer.Remove(0, writer.Tokens.First!, count); + writer.Remove(position, dir ? node : lastValidNode, count); } - private bool IsWhitespace(in Value value) + public static bool IsWhitespace(in Value value) { if (value.CharValue.HasValue) return char.IsWhiteSpace(value.CharValue.Value); return value.StringValue is not null && string.IsNullOrWhiteSpace(value.StringValue); } + + public static void TrimWhitespaceAround(QueryWriter writer, Marker marker) + { + if (marker.Slice.Head?.Previous is not null) + Trim(writer, marker.Position - 1, marker.Slice.Head.Previous, false); + if(marker.Slice.Tail?.Next is not null) + Trim(writer, marker.Position + marker.Size + 1, marker.Slice.Tail.Next, true); + } } diff --git a/src/EdgeDB.Net.QueryBuilder/Grammar/Terms.cs b/src/EdgeDB.Net.QueryBuilder/Grammar/Terms.cs index d042da1..81006f3 100644 --- a/src/EdgeDB.Net.QueryBuilder/Grammar/Terms.cs +++ b/src/EdgeDB.Net.QueryBuilder/Grammar/Terms.cs @@ -154,6 +154,7 @@ public static QueryWriter Function(this QueryWriter writer, string name, Deferra MarkerType.FunctionArg, $"func_{name}_arg_{i}", null, + metadata: new FunctionArgumentMetadata(checked((uint)i - 1), name, arg.Named), Value.Of( writer => { diff --git a/src/EdgeDB.Net.QueryBuilder/Lexical/Marker.cs b/src/EdgeDB.Net.QueryBuilder/Lexical/Marker.cs index 81160e0..5ebd3cc 100644 --- a/src/EdgeDB.Net.QueryBuilder/Lexical/Marker.cs +++ b/src/EdgeDB.Net.QueryBuilder/Lexical/Marker.cs @@ -56,6 +56,17 @@ internal Marker(string name, MarkerType type, QueryWriter writer, int size, int Metadata = metadata; } + public bool IsChildOf(Marker marker) + => marker.Position <= Position && marker.Size >= Size; + + public int SizeDistance(Marker marker) + { + var a = Position - marker.Position; + var b = marker.Size; + + return a + b; + } + internal int UpdatePosition(int delta) { if (delta != 0) diff --git a/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerCollection.cs b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerCollection.cs index 5159c7a..a85b7dd 100644 --- a/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerCollection.cs +++ b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerCollection.cs @@ -57,7 +57,7 @@ public void Remove(Range range) { marker.UpdateSize(-range.End.Value); } - else if (markerLower > rangeUpper) + else if (markerLower >= rangeUpper) { // move the position marker.UpdatePosition(-range.End.Value); @@ -243,7 +243,25 @@ public IEnumerable GetParents(Marker marker) => _markers.Where(x => x.Position < marker.Position && x.Position + x.Size > marker.Position + marker.Size); public IEnumerable GetChildren(Marker marker) - => _markers.Where(x => x.Position > marker.Position && x.Position + x.Size < marker.Position + marker.Size); + => _markers.Where(x => + x.Position != marker.Position && x.Size != marker.Size && + x.Position >= marker.Position && x.Position + x.Size <= marker.Position + marker.Size + ); + + public IEnumerable GetChildrenOfType(Marker marker, MarkerType type) + => MarkersByType.TryGetValue(type, out var candidates) + ? candidates.Where(x => + x.Position != marker.Position && x.Size != marker.Size && x.Position >= marker.Position && + x.Position + x.Size <= marker.Position + marker.Size) + : Array.Empty(); + + public IEnumerable GetDirectChildrenOfType(Marker marker, MarkerType type) + { + var children = GetChildrenOfType(marker, type).ToList(); + + return children.Where(child => !children.Any(x => x != child && x.Range.Contains(child.Position))); + } + public bool TryGetNextNeighbours(Marker marker, [MaybeNullWhen(false)] out LinkedList neighbours) => _markersByPosition.TryGetValue(marker.Position + marker.Size, out neighbours); diff --git a/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionArgumentMetadata.cs b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionArgumentMetadata.cs new file mode 100644 index 0000000..2597b9d --- /dev/null +++ b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionArgumentMetadata.cs @@ -0,0 +1,3 @@ +namespace EdgeDB; + +internal sealed record FunctionArgumentMetadata(uint Index, string FunctionName, string? NamedParameter) : IMarkerMetadata; diff --git a/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionMetadata.cs b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionMetadata.cs index 16c92d7..c7d3607 100644 --- a/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionMetadata.cs +++ b/src/EdgeDB.Net.QueryBuilder/Lexical/MarkerMetadata/FunctionMetadata.cs @@ -1,5 +1,58 @@ -using System.Reflection; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; namespace EdgeDB; -public sealed record FunctionMetadata(string FunctionName, MethodInfo? Function = null) : IMarkerMetadata; +internal sealed record FunctionMetadata(string FunctionName, MethodInfo? Function = null) : IMarkerMetadata +{ + public bool TryResolveExactFunctionInfo(List arguments, [MaybeNullWhen(false)] out MethodInfo methodInfo) + => (methodInfo = null) is null && + TryResolveFunctionInfos(out var infos) && + TryResolveExactFunctionInfo(infos, arguments, out methodInfo); + + public bool TryResolveExactFunctionInfo(List potentials, List arguments, [MaybeNullWhen(false)] out MethodInfo methodInfo) + { + if (potentials.Count == 1) + { + methodInfo = potentials[0]; + return true; + } + + foreach (var potential in potentials) + { + var parameters = potential.GetParameters(); + var optionalParamsCount = parameters.Count(x => x.IsOptional); + var shouldBeIn = (parameters.Length - optionalParamsCount)..parameters.Length; + + if(!shouldBeIn.Contains(arguments.Count)) + continue; + + methodInfo = potential; + return true; + } + + methodInfo = null; + return false; + } + + public bool TryResolveFunctionInfos([MaybeNullWhen(false)] out List infos) + { + if (Function is not null) + { + infos = [Function]; + return true; + } + + if (FunctionName.Contains("::")) + { + var functionNameModule = FunctionName.Split("::"); + var functionName = functionNameModule[^1]; + var functionModule = string.Join("::", functionNameModule[..^1]); + + return EdgeQL.TryGetMethods(functionName, functionModule, out infos); + } + + infos = EdgeQL.SearchMethods(FunctionName); + return infos.Count > 0; + } +}