Skip to content

Commit

Permalink
Initial updates to support 'Contains' operator after JSON changes in …
Browse files Browse the repository at this point in the history
…EFCore 8.
  • Loading branch information
StevenRasmussen committed May 28, 2024
1 parent bb65717 commit e59b869
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -193,5 +194,19 @@ public async Task LocalDate_DateDiff_Day()

Assert.Equal(6, raceResults.Count);
}

[Fact]
public async Task LocalDate_Contains()
{
var dates = new[]
{
new LocalDate(2024, 04, 22),
new LocalDate(2024, 04, 23)
};

var results = await this.Db.Race
.Where(x => dates.Contains(x.Date))
.ToArrayAsync();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand All @@ -15,21 +18,27 @@ public abstract class BaseNodaTimeMethodCallTranslator : IMethodCallTranslator
private readonly Dictionary<MethodInfo, string> _methodInfoDatePartExtensionMapping;
private readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping;
private readonly Dictionary<MethodInfo, string> _methodInfoDateDiffBigMapping;
private readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping;

protected static MethodInfo ContainsMethod { get; } = typeof(Enumerable).GetMethods()
.First(x => x.Name == nameof(Enumerable.Contains) && x.GetParameters().Count() == 2);

public BaseNodaTimeMethodCallTranslator(
ISqlExpressionFactory sqlExpressionFactory,
Dictionary<MethodInfo, string> methodInfoDateAddMapping,
Dictionary<MethodInfo, string> methodInfoDateAddExtensionMapping,
Dictionary<MethodInfo, string> methodInfoDatePartExtensionMapping,
Dictionary<MethodInfo, string> methodInfoDateDiffMapping,
Dictionary<MethodInfo, string> methodInfoDateDiffBigMapping)
Dictionary<MethodInfo, string> methodInfoDateDiffBigMapping,
Dictionary<MethodInfo, string> methodInfoContainsMapping)
{
_sqlExpressionFactory = sqlExpressionFactory;
_methodInfoDateAddMapping = methodInfoDateAddMapping;
_methodInfoDateAddExtensionMapping = methodInfoDateAddExtensionMapping;
_methodInfoDatePartExtensionMapping = methodInfoDatePartExtensionMapping;
_methodInfoDateDiffMapping = methodInfoDateDiffMapping;
_methodInfoDateDiffBigMapping = methodInfoDateDiffBigMapping;
_methodInfoContainsMapping = methodInfoContainsMapping;
}

public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments, IDiagnosticsLogger<DbLoggerCategory.Query> logger)
Expand Down Expand Up @@ -120,7 +129,58 @@ public SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadO
method.ReturnType,
null);
}
else if ((_methodInfoContainsMapping?.TryGetValue(method, out var containsMapping) ?? false)
&& ValidateValues(arguments[0]))
{
// Note that almost all forms of Contains are queryable (e.g. over inline/parameter collections), and translated in
// RelationalQueryableMethodTranslatingExpressionVisitor.TranslateContains.
// This enumerable Contains translation is still needed for entity Contains (#30712)
SqlExpression itemExpression = null, valuesExpression = null;

// Identify static Enumerable.Contains and instance List.Contains
if (method.IsGenericMethod
&& ValidateValues(arguments[0]))
{
(itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[1]), arguments[0]);
}

if (arguments.Count == 1
&& instance != null
&& ValidateValues(instance))
{
(itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[0]), instance);
}

if (itemExpression is not null && valuesExpression is not null)
{
switch (valuesExpression)
{
case SqlParameterExpression parameter:
return _sqlExpressionFactory.In(itemExpression, parameter);

case SqlConstantExpression { Value: IEnumerable values }:
var valuesExpressions = new List<SqlExpression>();

foreach (var value in values)
{
valuesExpressions.Add(_sqlExpressionFactory.Constant(value));
}

return _sqlExpressionFactory.In(itemExpression, valuesExpressions);
}
}
}

return null;
}

private static bool ValidateValues(SqlExpression values)
=> values is SqlConstantExpression or SqlParameterExpression;

private static SqlExpression RemoveObjectConvert(SqlExpression expression)
=> expression is SqlUnaryExpression { OperatorType: ExpressionType.Convert } sqlUnaryExpression
&& sqlUnaryExpression.Type == typeof(object)
? sqlUnaryExpression.Operand
: expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand All @@ -23,6 +24,11 @@ public class DurationMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(DurationExtensions).GetRuntimeMethod(nameof(DurationExtensions.Microseconds), new[] { typeof(Duration) }), "microsecond" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(Duration)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
{
Expand Down Expand Up @@ -152,7 +158,7 @@ public class DurationMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public DurationMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, null, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping)
: base(sqlExpressionFactory, null, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping, _methodInfoContainsMapping)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand Down Expand Up @@ -38,6 +39,11 @@ internal class InstantMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(InstantExtensions).GetRuntimeMethod(nameof(InstantExtensions.IsoWeek), new[] { typeof(Instant), }), "iso_week" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(Instant)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
{
Expand Down Expand Up @@ -215,7 +221,7 @@ internal class InstantMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public InstantMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, null, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping)
: base(sqlExpressionFactory, null, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping, _methodInfoContainsMapping)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand All @@ -27,6 +28,11 @@ internal class LocalDateMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(LocalDateExtensions).GetRuntimeMethod(nameof(LocalDateExtensions.Week), new[] { typeof(LocalDate) }), "week" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(LocalDate)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
// Local Date
Expand Down Expand Up @@ -81,7 +87,7 @@ internal class LocalDateMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public LocalDateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, null)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, null, _methodInfoContainsMapping)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand Down Expand Up @@ -34,6 +35,11 @@ internal class LocalDateTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(LocalDateTimeExtensions).GetRuntimeMethod(nameof(LocalDateTimeExtensions.Microsecond), new[] { typeof(LocalDateTime) }), "microsecond" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(LocalDateTime)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
{
Expand Down Expand Up @@ -211,7 +217,7 @@ internal class LocalDateTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public LocalDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping, _methodInfoContainsMapping)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand All @@ -27,6 +28,11 @@ internal class LocalTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(LocalTimeExtensions).GetRuntimeMethod(nameof(LocalTimeExtensions.Microsecond), new[] { typeof(LocalTime) }), "microsecond" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(LocalTime)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
{
Expand Down Expand Up @@ -156,7 +162,7 @@ internal class LocalTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public LocalTimeMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping, _methodInfoContainsMapping)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore.SqlServer.NodaTime.Extensions;
using NodaTime;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.ExpressionTranslators
Expand Down Expand Up @@ -37,6 +38,11 @@ internal class OffsetDateTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
{ typeof(OffsetDateTimeExtensions).GetRuntimeMethod(nameof(OffsetDateTimeExtensions.Microsecond), new[] { typeof(OffsetDateTime) }), "microsecond" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoContainsMapping = new Dictionary<MethodInfo, string>
{
{ BaseNodaTimeMethodCallTranslator.ContainsMethod.MakeGenericMethod(typeof(OffsetDateTime)) , "contains" },
};

private static readonly Dictionary<MethodInfo, string> _methodInfoDateDiffMapping = new Dictionary<MethodInfo, string>
{
// Offset Date Time
Expand Down Expand Up @@ -215,7 +221,7 @@ internal class OffsetDateTimeMethodTranslator : BaseNodaTimeMethodCallTranslator
};

public OffsetDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping)
: base(sqlExpressionFactory, _methodInfoDateAddMapping, _methodInfoDateAddExtensionMapping, _methodInfoDatePartExtensionMapping, _methodInfoDateDiffMapping, _methodInfoDateDiffBigMapping, _methodInfoContainsMapping)
{
}
}
Expand Down

0 comments on commit e59b869

Please sign in to comment.