diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala index befdc5d90..00bde16a2 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala @@ -16,16 +16,12 @@ package io.delta.sharing.spark -import java.lang.ref.WeakReference - import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.spark.delta.sharing.CachedTableManager import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, GenericInternalRow, Literal, SubqueryExpression} -import org.apache.spark.sql.execution.datasources.{FileFormat, FileIndex, HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.catalyst.expressions.{And, Cast, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory} import org.apache.spark.sql.types.{DataType, StructType} import io.delta.sharing.spark.filters.{BaseOp, OpConverter} @@ -70,7 +66,7 @@ private[sharing] abstract class RemoteDeltaFileIndexBase( actions.groupBy(_.getPartitionValuesInDF()).map { case (partitionValues, files) => val rowValues: Array[Any] = partitionSchema.map { p => - Cast(Literal(partitionValues(p.name)), p.dataType, Option(timeZone)).eval() + new Cast(Literal(partitionValues(p.name)), p.dataType, Option(timeZone)).eval() }.toArray val fileStats = files.map { f => @@ -101,7 +97,8 @@ private[sharing] abstract class RemoteDeltaFileIndexBase( val rewrittenFilters = DeltaTableUtils.rewritePartitionFilters( params.snapshotAtAnalysis.partitionSchema, params.spark.sessionState.conf.resolver, - partitionFilters) + partitionFilters, + params.spark.sessionState.conf.sessionLocalTimeZone) new Column(rewrittenFilters.reduceLeftOption(And).getOrElse(Literal(true))) } diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 7f01e7906..2c99aea9e 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -17,20 +17,17 @@ package io.delta.sharing.spark import java.lang.ref.WeakReference -import java.util.concurrent.TimeUnit import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.delta.sharing.CachedTableManager import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{Column, Encoder, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, GenericInternalRow, Literal, SubqueryExpression} -import org.apache.spark.sql.execution.datasources.{FileFormat, FileIndex, HadoopFsRelation, PartitionDirectory} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, Literal, SubqueryExpression} +import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -38,8 +35,6 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} import io.delta.sharing.spark.model.{ AddFile, CDFColumnInfo, - DeltaTableFiles, - FileAction, Metadata, Protocol, Table => DeltaSharingTable @@ -259,7 +254,8 @@ class RemoteSnapshot( val rewrittenFilters = DeltaTableUtils.rewritePartitionFilters( partitionSchema, spark.sessionState.conf.resolver, - partitionFilters) + partitionFilters, + spark.sessionState.conf.sessionLocalTimeZone) val predicates = rewrittenFilters.map(_.sql) if (predicates.nonEmpty) { @@ -420,7 +416,8 @@ private[sharing] object DeltaTableUtils { def rewritePartitionFilters( partitionSchema: StructType, resolver: Resolver, - partitionFilters: Seq[Expression]): Seq[Expression] = { + partitionFilters: Seq[Expression], + timeZone: String): Seq[Expression] = { partitionFilters.map(_.transformUp { case a: Attribute => // If we have a special column name, e.g. `a.a`, then an UnresolvedAttribute returns @@ -429,9 +426,7 @@ private[sharing] object DeltaTableUtils { val partitionCol = partitionSchema.find { field => resolver(field.name, unquoted) } partitionCol match { case Some(StructField(name, dataType, _, _)) => - Cast( - UnresolvedAttribute(Seq("partitionValues", name)), - dataType) + new Cast(UnresolvedAttribute(Seq("partitionValues", name)), dataType, Option(timeZone)) case None => // This should not be able to happen, but the case was present in the original code so // we kept it to be safe.