diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java index 2f2654b21..6898f42e9 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java @@ -108,6 +108,12 @@ public static FieldVector convert( return new TimeMicroVectorConverter(allocator, vector).convert(); case TIMENANO: return new TimeNanoVectorConverter(allocator, vector).convert(); + case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); case STRUCT: return new StructVectorConverter( allocator, vector, context, session, timeZoneToUse, idx, null) diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java new file mode 100644 index 000000000..2d9f5b121 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -0,0 +1,186 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.List; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.util.SFPair; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class TimestampVectorConverter implements ArrowFullVectorConverter { + private RootAllocator allocator; + private ValueVector vector; + private DataConversionContext context; + private TimeZone timeZoneToUse; + + // This parameter is used to distinguish between NTZ and LTZ (TZ is distinct in having the offset + // vector) + private boolean isNTZ; + + /** Field names of the struct vectors used by timestamp */ + private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch + + private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index + private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + private static final int UTC_OFFSET = 1440; + private static final long NANOS_PER_MILLI = 1000000L; + private static final int MILLIS_PER_SECOND = 1000; + private static final int SECONDS_PER_MINUTE = 60; + + public TimestampVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + TimeZone timeZoneToUse, + boolean isNTZ) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.isNTZ = isNTZ; + } + + private IntVector makeVectorOfZeroes(int length) { + IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); + vector.allocateNew(length); + vector.zeroVector(); + vector.setValueCount(length); + return vector; + } + + private IntVector makeVectorOfUTCOffsets(int length) { + IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + vector.allocateNew(length); + vector.setValueCount(length); + for (int i = 0; i < length; i++) { + vector.set(i, UTC_OFFSET); + } + return vector; + } + + private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { + int length = vector.getValueCount(); + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + if (scale == 0) { + IntVector fractions = makeVectorOfZeroes(length); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + fractions + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + vector.makeTransferPair(epoch).transfer(); + return SFPair.of(epoch, fractions); + } + long scaleFactor = ArrowResultUtil.powerOfTen(scale); + long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + epoch.allocateNew(length); + epoch.setValueCount(length); + IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); + fractions.allocateNew(length); + fractions.setValueCount(length); + for (int i = 0; i < length; i++) { + epoch.set(i, vector.get(i) / scaleFactor); + fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); + } + return SFPair.of(epoch, fractions); + } + + private IntVector makeTimeZoneOffsets( + BigIntVector seconds, IntVector fractions, TimeZone timeZone) { + IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + offsets.allocateNew(seconds.getValueCount()); + offsets.setValueCount(seconds.getValueCount()); + for (int i = 0; i < seconds.getValueCount(); i++) { + offsets.set( + i, + UTC_OFFSET + + timeZone.getOffset( + seconds.get(i) * MILLIS_PER_SECOND + fractions.get(i) / NANOS_PER_MILLI) + / (MILLIS_PER_SECOND * SECONDS_PER_MINUTE)); + } + return offsets; + } + + private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { + StructVector result = StructVector.empty(vector.getName(), allocator); + List fields = + new ArrayList() { + { + add(seconds.getField()); + add(fractions.getField()); + add(offsets.getField()); + } + }; + result.setInitialCapacity(seconds.getValueCount()); + result.initializeChildrenFromFields(fields); + seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); + fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); + offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); + result.setValueCount(seconds.getValueCount()); + result + .getValidityBuffer() + .setBytes(0L, seconds.getValidityBuffer(), 0L, seconds.getValidityBuffer().capacity()); + return result; + } + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + boolean returnedOriginal = false; + try { + BigIntVector seconds; + IntVector fractions; + IntVector timeZoneIndices = null; + if (vector instanceof BigIntVector) { + SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); + seconds = normalized.left; + fractions = normalized.right; + } else { + StructVector structVector = (StructVector) vector; + if (structVector.getChildrenFromFields().size() == 3) { + returnedOriginal = true; + return structVector; + } + if (structVector.getChild(FIELD_NAME_FRACTION) == null) { + SFPair normalized = + normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); + seconds = normalized.left; + fractions = normalized.right; + } else { + seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); + fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); + } + timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); + } + if (timeZoneIndices == null) { + if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); + for (int i = 0; i < seconds.getValueCount(); i++) { + seconds.set( + i, + seconds.get(i) - (long) (timeZoneIndices.get(i) - UTC_OFFSET) * SECONDS_PER_MINUTE); + } + } else if (isNTZ || timeZoneToUse == null) { + timeZoneIndices = makeVectorOfUTCOffsets(seconds.getValueCount()); + } else { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); + } + } + return pack(seconds, fractions, timeZoneIndices); + } finally { + if (!returnedOriginal) { + vector.close(); + } + } + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java index 12bcc19e0..a8f7b7793 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java @@ -7,6 +7,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalTime; import java.util.ArrayList; @@ -18,6 +19,7 @@ import net.snowflake.client.category.TestCategoryArrow; import net.snowflake.client.core.SFArrowResultSet; import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -66,9 +68,9 @@ public static void tearDown() throws Exception { private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { assertEquals( + 0, ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) - .getAllocatedMemory(), - 0); + .getAllocatedMemory()); } @Test @@ -831,4 +833,65 @@ public void testFixedSizeListBatch() throws Exception { assertEquals(2, totalRows); } + + private void testTimestampCase(String query) throws Exception, SFException { + Timestamp tsFromBatch; + Timestamp tsFromRow; + + try (Statement statement = connection.createStatement()) { + try (ResultSet rs = statement.executeQuery(query)) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + tsFromBatch = converter.toTimestamp(0, null); + root.close(); + assertNoMemoryLeaks(rs); + } + try (ResultSet rs = statement.executeQuery(query)) { + rs.next(); + tsFromRow = rs.getTimestamp(1); + } + } + assertTrue(tsFromBatch.equals(tsFromRow)); + } + + private void testTimestampBase(String query) throws Exception, SFException { + testTimestampCase(query); + testTimestampCase(query + "(0)"); + testTimestampCase(query + "(1)"); + } + + @Test + public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_TZ"); + } + + @Test + public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); + } + + @Test + public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + } + + @Test + public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); + } + + @Test + public void testTimestampNTZDontHonorClientTimezone() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ=false"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ"); + } }