diff --git a/Source/EntityFramework.Extended/Batch/OracleBatchRunner.cs b/Source/EntityFramework.Extended/Batch/OracleBatchRunner.cs new file mode 100644 index 0000000..d777f7e --- /dev/null +++ b/Source/EntityFramework.Extended/Batch/OracleBatchRunner.cs @@ -0,0 +1,470 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.Entity.Core.EntityClient; +using System.Data.Entity.Core.Objects; +using System.Globalization; +using System.Linq; +using System.Linq.Dynamic; +using System.Linq.Expressions; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using EntityFramework.Batch; +using EntityFramework.Extensions; +using EntityFramework.Mapping; +using EntityFramework.Reflection; + + +namespace EntityFramework.Batch +{ + /// + /// A batch execution runner for Oracle. + /// + public class OracleBatchRunner : IBatchRunner + { + /// + /// Create and run a batch delete statement. + /// + /// The type of the entity. + /// The to get connection and metadata information from. + /// The for . + /// The query to create the where clause from. + /// + /// The number of rows deleted. + /// + public int Delete(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query) where TEntity : class + { +#if NET45 + return InternalDelete(objectContext, entityMap, query).Result; +#else + return InternalDelete(objectContext, entityMap, query); +#endif + } + +#if NET45 + /// + /// Create and run a batch delete statement asynchronously. + /// + /// The type of the entity. + /// The to get connection and metadata information from. + /// The for . + /// The query to create the where clause from. + /// + /// The number of rows deleted. + /// + public Task DeleteAsync(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query) where TEntity : class + { + return InternalDelete(objectContext, entityMap, query, true); + } +#endif + +#if NET45 + private async Task InternalDelete(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query, bool async = false) + where TEntity : class +#else + private int InternalDelete(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query) + where TEntity : class +#endif + { + DbConnection deleteConnection = null; + DbTransaction deleteTransaction = null; + bool ownConnection = false; + bool ownTransaction = false; + + try + { + InitializeConnectionAndTransaction(objectContext, ref deleteConnection, ref deleteTransaction, ref ownConnection, ref ownTransaction); + + using (var deleteCommand = CreateCommand(objectContext, deleteConnection, deleteTransaction)) + { + deleteCommand.Transaction = deleteTransaction; + if (objectContext.CommandTimeout.HasValue) + { + deleteCommand.CommandTimeout = objectContext.CommandTimeout.Value; + } + + var innerSelect = GetSelectSql(query, entityMap, deleteCommand); + var sqlBuilder = new StringBuilder(); + sqlBuilder.Append("DELETE "); + sqlBuilder.AppendLine(entityMap.TableName.Replace('[', '\"').Replace(']', '\"')); + sqlBuilder.AppendLine("WHERE ROWID IN"); + sqlBuilder.AppendLine("("); + sqlBuilder.AppendLine("SELECT \"Extent1\".ROWID"); + sqlBuilder.AppendLine(innerSelect.Substring(innerSelect.IndexOf("FROM"))); + sqlBuilder.AppendLine(")"); + + deleteCommand.CommandText = sqlBuilder.ToString(); + +#if NET45 + int result = async + ? await deleteCommand.ExecuteNonQueryAsync().ConfigureAwait(false) + : deleteCommand.ExecuteNonQuery(); +#else + int result = deleteCommand.ExecuteNonQuery(); +#endif + + return result; + } + } + finally + { + ReleaseConnectionAndTransaction(deleteConnection, deleteTransaction, ownConnection, ownTransaction); + } + } + + /// + /// Create and run a batch update statement. + /// + /// The type of the entity. + /// The to get connection and metadata information from. + /// The for . + /// The query to create the where clause from. + /// The update expression. + /// + /// The number of rows updated. + /// + public int Update(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query, Expression> updateExpression) where TEntity : class + { +#if NET45 + return InternalUpdate(objectContext, entityMap, query, updateExpression, false).Result; +#else + return InternalUpdate(objectContext, entityMap, query, updateExpression); +#endif + } + +#if NET45 + /// + /// Create and run a batch update statement asynchronously. + /// + /// The type of the entity. + /// The to get connection and metadata information from. + /// The for . + /// The query to create the where clause from. + /// The update expression. + /// + /// The number of rows updated. + /// + public Task UpdateAsync(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query, Expression> updateExpression) where TEntity : class + { + return InternalUpdate(objectContext, entityMap, query, updateExpression, true); + } +#endif +#if NET45 + private async Task InternalUpdate(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query, Expression> updateExpression, bool async = false) + where TEntity : class +#else + private int InternalUpdate(ObjectContext objectContext, EntityMap entityMap, ObjectQuery query, Expression> updateExpression, bool async = false) + where TEntity : class +#endif + { + DbConnection updateConnection = null; + DbTransaction updateTransaction = null; + bool ownConnection = false; + bool ownTransaction = false; + + try + { + InitializeConnectionAndTransaction(objectContext, ref updateConnection, ref updateTransaction, ref ownConnection, ref ownTransaction); + + using (var updateCommand = CreateCommand(objectContext, updateConnection, updateTransaction)) + { + var memberInitExpression = updateExpression.Body as MemberInitExpression; + if (memberInitExpression == null) + { + throw new ArgumentException("The update expression must be of type MemberInitExpression.", "updateExpression"); + } + + var innerSelect = GetSelectSql(query, entityMap, updateCommand); + var sqlBuilder = BuildUpdateSql(objectContext, entityMap, updateCommand, innerSelect, memberInitExpression); + updateCommand.CommandText = sqlBuilder.ToString(); + +#if NET45 + int result = async + ? await updateCommand.ExecuteNonQueryAsync().ConfigureAwait(false) + : updateCommand.ExecuteNonQuery(); +#else + int result = updateCommand.ExecuteNonQuery(); +#endif + + return result; + } + } + finally + { + ReleaseConnectionAndTransaction(updateConnection, updateTransaction, ownConnection, ownTransaction); + } + } + + #region Connection & Transaction Management + + private static Tuple GetStore(ObjectContext objectContext) + { + var dbConnection = objectContext.Connection; + var entityConnection = dbConnection as EntityConnection; + + // by-pass entity connection + if (entityConnection == null) + { + return new Tuple(dbConnection, null); + } + + // get internal transaction + var connection = entityConnection.StoreConnection; + dynamic connectionProxy = new DynamicProxy(entityConnection); + dynamic entityTransaction = connectionProxy.CurrentTransaction; + if (entityTransaction == null) + { + return new Tuple(connection, null); + } + + var transaction = entityTransaction.StoreTransaction; + return new Tuple(connection, transaction); + } + + private static void InitializeConnectionAndTransaction(ObjectContext objectContext, ref DbConnection connection, ref DbTransaction transaction, ref bool ownConnection, ref bool ownTransaction) + { + // get store connection and transaction + var store = GetStore(objectContext); + connection = store.Item1; + transaction = store.Item2; + + if (connection.State != ConnectionState.Open) + { + connection.Open(); + ownConnection = true; + } + + // use existing transaction or create new + if (transaction == null) + { + transaction = connection.BeginTransaction(); + ownTransaction = true; + } + } + + private static DbCommand CreateCommand(ObjectContext objectContext, DbConnection connection, DbTransaction transaction) + { + var command = connection.CreateCommand(); + + command.Transaction = transaction; + if (objectContext.CommandTimeout.HasValue) + { + command.CommandTimeout = objectContext.CommandTimeout.Value; + } + + return command; + } + + + private static void ReleaseConnectionAndTransaction(DbConnection connection, DbTransaction transaction, bool ownConnection, bool ownTransaction) + { + if (transaction != null && ownTransaction) + { + transaction.Dispose(); + } + + if (connection != null && ownConnection) + { + connection.Close(); + } + } + + #endregion + + #region Update Helpers + + private static StringBuilder BuildUpdateSql(ObjectContext objectContext, EntityMap entityMap, DbCommand updateCommand, string innerSelect, MemberInitExpression memberInitExpression) where TEntity : class + { + int nameCount = 0; + bool wroteSet = false; + var fieldsToUpdate = new StringBuilder(); + var valuesToUpdate = new StringBuilder(); + foreach (MemberBinding binding in memberInitExpression.Bindings) + { + var memberAssignment = binding as MemberAssignment; + if (memberAssignment == null) + { + throw new ArgumentException("The update expression MemberBinding must only by type MemberAssignment.", "updateExpression"); + } + + if (wroteSet) + { + fieldsToUpdate.Append(", "); + valuesToUpdate.Append(", "); + } + + string propertyName = binding.Member.Name; + string columnName = entityMap.PropertyMaps.Where(p => p.PropertyName == propertyName) + .Select(p => p.ColumnName) + .FirstOrDefault(); + + var memberExpression = memberAssignment.Expression; + ParameterExpression parameterExpression = null; + memberExpression.Visit((ParameterExpression p) => + { + if (p.Type == entityMap.EntityType) + { + parameterExpression = p; + } + + return p; + }); + + if (parameterExpression == null) + { + nameCount = BuildUpdateParameterWithExpression(updateCommand, fieldsToUpdate, valuesToUpdate, nameCount, columnName, memberExpression); + } + else + { + nameCount = BuildUpdateParameterWithoutExpression(objectContext, entityMap, updateCommand, fieldsToUpdate, valuesToUpdate, nameCount, columnName, memberExpression, parameterExpression); + } + + wroteSet = true; + } + + var sqlBuilder = new StringBuilder(); + sqlBuilder.Append("UPDATE "); + sqlBuilder.AppendLine(entityMap.TableName.Replace('[', '\"').Replace(']', '\"')); + sqlBuilder.Append("SET ("); + sqlBuilder.Append(fieldsToUpdate); + sqlBuilder.AppendLine(") = ("); + sqlBuilder.Append("SELECT "); + sqlBuilder.Append(valuesToUpdate); + sqlBuilder.AppendLine(); + sqlBuilder.AppendLine(innerSelect.Substring(innerSelect.IndexOf("FROM"))); + sqlBuilder.AppendLine(")"); + return sqlBuilder; + } + + private static int BuildUpdateParameterWithExpression(DbCommand updateCommand, StringBuilder fieldsToUpdate, StringBuilder valuesToUpdate, int nameCount, string columnName, Expression memberExpression) + { + object value; + + if (memberExpression.NodeType == ExpressionType.Constant) + { + var constantExpression = memberExpression as ConstantExpression; + if (constantExpression == null) + { + throw new ArgumentException("The MemberAssignment expression is not a ConstantExpression.", "updateExpression"); + } + + value = constantExpression.Value; + } + else + { + LambdaExpression lambda = Expression.Lambda(memberExpression, null); + value = lambda.Compile().DynamicInvoke(); + } + + if (value != null) + { + string parameterName = "p__update__" + nameCount++; + var parameter = updateCommand.CreateParameter(); + parameter.ParameterName = parameterName; + parameter.Value = value; + updateCommand.Parameters.Add(parameter); + + fieldsToUpdate.AppendFormat("\"{0}\"", columnName); + valuesToUpdate.AppendFormat(":{0}", parameterName); + } + else + { + fieldsToUpdate.AppendFormat("\"{0}\"", columnName); + valuesToUpdate.Append("NULL"); + } + + return nameCount; + } + + private static int BuildUpdateParameterWithoutExpression(ObjectContext objectContext, EntityMap entityMap, DbCommand updateCommand, StringBuilder fieldsToUpdate, StringBuilder valuesToUpdate, int nameCount, string columnName, Expression memberExpression, ParameterExpression parameterExpression) where TEntity : class + { + // create clean objectset to build query from + var objectSet = objectContext.CreateObjectSet(); + var typeArguments = new[] { entityMap.EntityType, memberExpression.Type }; + var constantExpression = Expression.Constant(objectSet); + var lambdaExpression = Expression.Lambda(memberExpression, parameterExpression); + var selectExpression = Expression.Call(typeof(Queryable), "Select", typeArguments, constantExpression, lambdaExpression); + + // create query from expression + var selectQuery = objectSet.CreateQuery(selectExpression, entityMap.EntityType); + var sql = selectQuery.ToTraceString(); + + // parse select part of sql to use as update + var regex = @"SELECT\s*\r\n\s*(?.+)?\s*AS\s*(?\[\w+\])\r\n\s*FROM\s*(?\[\w+\]\.\[\w+\]|\[\w+\])\s*AS\s*(?\[\w+\])"; + var match = Regex.Match(sql, regex); + if (!match.Success) + { + throw new ArgumentException("The MemberAssignment expression could not be processed.", "updateExpression"); + } + + var alias = match.Groups["TableAlias"].Value; + var value = match.Groups["ColumnValue"].Value.Replace(alias + ".", ""); + + foreach (ObjectParameter objectParameter in selectQuery.Parameters) + { + var parameterName = "p__update__" + nameCount++; + var parameter = updateCommand.CreateParameter(); + + parameter.ParameterName = parameterName; + parameter.Value = objectParameter.Value ?? DBNull.Value; + updateCommand.Parameters.Add(parameter); + + value = value.Replace(objectParameter.Name, parameterName); + } + + fieldsToUpdate.AppendFormat("\"{0}\"", columnName); + valuesToUpdate.AppendFormat("\"{0}\"", value); + + return nameCount; + } + + #endregion + + #region Select Builder + + private static string GetSelectSql(ObjectQuery query, EntityMap entityMap, DbCommand command) + where TEntity : class + { + // changing query to only select keys + var selector = new StringBuilder(50); + selector.Append("new("); + + foreach (var propertyMap in entityMap.KeyMaps) + { + if (selector.Length > 4) + { + selector.Append((", ")); + } + + selector.Append(propertyMap.PropertyName); + } + selector.Append(")"); + + var selectQuery = DynamicQueryable.Select(query, selector.ToString()); + var objectQuery = selectQuery as ObjectQuery; + + if (objectQuery == null) + { + throw new ArgumentException("The query must be of type ObjectQuery.", "query"); + } + + var innerJoinSql = objectQuery.ToTraceString(); + + // create parameters + foreach (var objectParameter in objectQuery.Parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = objectParameter.Name; + parameter.Value = objectParameter.Value ?? DBNull.Value; + + command.Parameters.Add(parameter); + } + + return innerJoinSql; + } + + #endregion + } +} diff --git a/Source/EntityFramework.Extended/EntityFramework.Extended.net40.csproj b/Source/EntityFramework.Extended/EntityFramework.Extended.net40.csproj index f9857cf..54478f9 100644 --- a/Source/EntityFramework.Extended/EntityFramework.Extended.net40.csproj +++ b/Source/EntityFramework.Extended/EntityFramework.Extended.net40.csproj @@ -79,6 +79,7 @@ + @@ -161,4 +162,4 @@ --> - + \ No newline at end of file diff --git a/Source/EntityFramework.Extended/EntityFramework.Extended.net45.csproj b/Source/EntityFramework.Extended/EntityFramework.Extended.net45.csproj index ddf8af6..987f8a0 100644 --- a/Source/EntityFramework.Extended/EntityFramework.Extended.net45.csproj +++ b/Source/EntityFramework.Extended/EntityFramework.Extended.net45.csproj @@ -79,6 +79,7 @@ +