diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index f7784178182..044f1d46322 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -726,8 +726,6 @@ def test_cast_int_to_string_not_UTC(): {"spark.sql.session.timeZone": "+08"}) not_utc_fallback_test_params = [(timestamp_gen, 'STRING'), - # python does not like year 0, and with time zones the default start date can become year 0 :( - (DateGen(start=date(1, 1, 1)), 'TIMESTAMP'), (SetValuesGen(StringType(), ['2023-03-20 10:38:50', '2023-03-20 10:39:02']), 'TIMESTAMP')] @allow_non_gpu('ProjectExec') diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 558ce728d9b..5143c2b0bda 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -684,3 +684,15 @@ def test_timestamp_millis_long_overflow(): def test_timestamp_micros(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) + + +@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") +@pytest.mark.parametrize('parser_policy', ['LEGACY', 'CORRECTED', 'EXCEPTION'], ids=idfn) +def test_date_to_timestamp(parser_policy): + assert_gpu_and_cpu_are_equal_sql( + lambda spark : unary_op_df(spark, date_gen), + "tab", + "SELECT cast(a as timestamp) from tab", + conf = { + "spark.sql.legacy.timeParserPolicy": parser_policy, + "spark.rapids.sql.incompatibleDateFormats.enabled": True}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 8ae3450c0af..f101b8a33eb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -90,6 +90,7 @@ abstract class CastExprMetaBase[INPUT <: UnaryExpression with TimeZoneAwareExpre override def isTimeZoneSupported: Boolean = { (fromType, toType) match { case (TimestampType, DateType) => true // this is for to_date(...) + case (DateType, TimestampType) => true case _ => false } } @@ -631,6 +632,11 @@ object GpuCast { zoneId.normalized())) { shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } + case (DateType, TimestampType) if options.timeZoneId.isDefined => + val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get) + withResource(input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))) { cv => + GpuTimeZoneDB.fromTimestampToUtcTimestamp(cv, zoneId.normalized()) + } case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) }