diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index cbb92bcde..dcea1d575 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -30,6 +30,8 @@ import net.snowflake.client.core.arrow.VarCharConverter; import net.snowflake.client.core.arrow.VectorTypeConverter; import net.snowflake.client.core.json.Converters; +import net.snowflake.client.jdbc.ArrowBatch; +import net.snowflake.client.jdbc.ArrowBatches; import net.snowflake.client.jdbc.ArrowResultChunk; import net.snowflake.client.jdbc.ArrowResultChunk.ArrowChunkIterator; import net.snowflake.client.jdbc.ErrorCode; @@ -112,6 +114,15 @@ public class SFArrowResultSet extends SFBaseResultSet implements DataConversionC */ private boolean formatDateWithTimezone; + /** The result set should be read either only as rows or only as batches */ + private enum ReadingMode { + UNSPECIFIED, + ROW_MODE, + BATCH_MODE + } + + private ReadingMode readingMode = ReadingMode.UNSPECIFIED; + @SnowflakeJdbcInternalApi protected Converters converters; /** @@ -239,6 +250,11 @@ public SFArrowResultSet( } } + @SnowflakeJdbcInternalApi + public long getAllocatedMemory() { + return rootAllocator.getAllocatedMemory(); + } + private boolean fetchNextRow() throws SnowflakeSQLException { if (sortResult) { return fetchNextRowSorted(); @@ -247,6 +263,31 @@ private boolean fetchNextRow() throws SnowflakeSQLException { } } + private ArrowResultChunk fetchNextChunk() throws SnowflakeSQLException { + try { + logger.debug("Fetching chunk number " + nextChunkIndex); + eventHandler.triggerStateTransition( + BasicEvent.QueryState.CONSUMING_RESULT, + String.format( + BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); + ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); + + if (nextChunk == null) { + throw new SnowflakeSQLLoggedException( + queryId, + session, + ErrorCode.INTERNAL_ERROR.getMessageCode(), + SqlState.INTERNAL_ERROR, + "Expect chunk but got null for chunk index " + nextChunkIndex); + } + logger.debug("Chunk number " + nextChunkIndex + " fetched successfully."); + return nextChunk; + } catch (InterruptedException ex) { + throw new SnowflakeSQLLoggedException( + queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); + } + } + /** * Goto next row. If end of current chunk, update currentChunkIterator to the beginning of next * chunk, if any chunk not being consumed yet. @@ -260,40 +301,19 @@ private boolean fetchNextRowUnsorted() throws SnowflakeSQLException { return true; } else { if (nextChunkIndex < chunkCount) { - try { - eventHandler.triggerStateTransition( - BasicEvent.QueryState.CONSUMING_RESULT, - String.format( - BasicEvent.QueryState.CONSUMING_RESULT.getArgString(), queryId, nextChunkIndex)); - - ArrowResultChunk nextChunk = (ArrowResultChunk) chunkDownloader.getNextChunkToConsume(); - - if (nextChunk == null) { - throw new SnowflakeSQLLoggedException( - queryId, - session, - ErrorCode.INTERNAL_ERROR.getMessageCode(), - SqlState.INTERNAL_ERROR, - "Expect chunk but got null for chunk index " + nextChunkIndex); - } + ArrowResultChunk nextChunk = fetchNextChunk(); - currentChunkIterator.getChunk().freeData(); - currentChunkIterator = nextChunk.getIterator(this); - if (currentChunkIterator.next()) { - - logger.debug( - "Moving to chunk index: {}, row count: {}", - nextChunkIndex, - nextChunk.getRowCount()); - - nextChunkIndex++; - return true; - } else { - return false; - } - } catch (InterruptedException ex) { - throw new SnowflakeSQLLoggedException( - queryId, session, ErrorCode.INTERRUPTED.getMessageCode(), SqlState.QUERY_CANCELED); + currentChunkIterator.getChunk().freeData(); + currentChunkIterator = nextChunk.getIterator(this); + if (currentChunkIterator.next()) { + + logger.debug( + "Moving to chunk index: {}, row count: {}", nextChunkIndex, nextChunk.getRowCount()); + + nextChunkIndex++; + return true; + } else { + return false; } } else { // always free current chunk @@ -432,6 +452,11 @@ public boolean next() throws SFException, SnowflakeSQLException { if (isClosed()) { return false; } + if (readingMode == ReadingMode.BATCH_MODE) { + logger.warn("Cannot read rows after getArrowBatches() was called."); + return false; + } + readingMode = ReadingMode.ROW_MODE; // otherwise try to fetch again if (fetchNextRow()) { @@ -764,6 +789,45 @@ public BigDecimal getBigDecimal(int columnIndex, int scale) throws SFException { return bigDec == null ? null : bigDec.setScale(scale, RoundingMode.HALF_UP); } + public ArrowBatches getArrowBatches() { + if (readingMode == ReadingMode.ROW_MODE) { + logger.warn("Cannot read arrow batches after next() was called."); + return null; + } + readingMode = ReadingMode.BATCH_MODE; + return new SFArrowBatchesIterator(); + } + + private class SFArrowBatchesIterator implements ArrowBatches { + private boolean firstFetched = false; + + @Override + public long getRowCount() throws SQLException { + return resultSetSerializable.getRowCount(); + } + + @Override + public boolean hasNext() { + return nextChunkIndex < chunkCount || !firstFetched; + } + + @Override + public ArrowBatch next() throws SQLException { + if (!firstFetched) { + firstFetched = true; + return currentChunkIterator + .getChunk() + .getArrowBatch( + SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex); + } else { + nextChunkIndex++; + return fetchNextChunk() + .getArrowBatch( + SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex); + } + } + } + @Override public boolean isLast() { return nextChunkIndex == chunkCount && currentChunkIterator.isLast(); diff --git a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java index 71e56a515..f9fcf5c01 100644 --- a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java @@ -30,8 +30,10 @@ import java.util.stream.Stream; import java.util.stream.StreamSupport; import net.snowflake.client.core.json.Converters; +import net.snowflake.client.jdbc.ArrowBatches; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.FieldMetadata; +import net.snowflake.client.jdbc.SnowflakeLoggedFeatureNotSupportedException; import net.snowflake.client.jdbc.SnowflakeResultSetSerializable; import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1; import net.snowflake.client.jdbc.SnowflakeSQLException; @@ -137,6 +139,10 @@ public SFBaseSession getSession() { return this.session; } + public ArrowBatches getArrowBatches() throws SnowflakeLoggedFeatureNotSupportedException { + throw new SnowflakeLoggedFeatureNotSupportedException(session); + } + // default implementation public boolean next() throws SFException, SnowflakeSQLException { logger.trace("boolean next()", false); diff --git a/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java index 68ccd2a14..a6799b223 100644 --- a/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java +++ b/src/main/java/net/snowflake/client/core/arrow/ArrowVectorConverterUtil.java @@ -22,6 +22,20 @@ public final class ArrowVectorConverterUtil { private ArrowVectorConverterUtil() {} + public static int getScale(ValueVector vector, SFBaseSession session) + throws SnowflakeSQLLoggedException { + try { + String scaleStr = vector.getField().getMetadata().get("scale"); + return Integer.parseInt(scaleStr); + } catch (NullPointerException | NumberFormatException e) { + throw new SnowflakeSQLLoggedException( + session, + ErrorCode.INTERNAL_ERROR.getMessageCode(), + SqlState.INTERNAL_ERROR, + "Invalid scale metadata"); + } + } + public static SnowflakeType getSnowflakeTypeFromFieldMetadata(Field field) { Map customMeta = field.getMetadata(); if (customMeta != null && customMeta.containsKey("logicalType")) { @@ -111,8 +125,7 @@ public static ArrowVectorConverter initConverter( return new DateConverter(vector, idx, context, getFormatDateWithTimeZone); case FIXED: - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); + int sfScale = getScale(vector, session); switch (type) { case TINYINT: if (sfScale == 0) { diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java new file mode 100644 index 000000000..f2cf5867e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java @@ -0,0 +1,23 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.vector.FieldVector; + +public abstract class AbstractFullVectorConverter implements ArrowFullVectorConverter { + private boolean converted; + + protected abstract FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException; + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException { + if (converted) { + throw new SFArrowException( + ArrowErrorCode.VECTOR_ALREADY_CONVERTED, "Convert has already been called"); + } else { + converted = true; + return convertVector(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java new file mode 100644 index 000000000..b0af45dfb --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java @@ -0,0 +1,7 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +public enum ArrowErrorCode { + VECTOR_ALREADY_CONVERTED, + CONVERT_FAILED, + CHUNK_FETCH_FAILED, +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java new file mode 100644 index 000000000..29dc8143f --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -0,0 +1,11 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.vector.FieldVector; + +@SnowflakeJdbcInternalApi +public interface ArrowFullVectorConverter { + FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException; +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java new file mode 100644 index 000000000..6898f42e9 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java @@ -0,0 +1,148 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import static net.snowflake.client.core.arrow.ArrowVectorConverterUtil.getScale; + +import java.util.Map; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; +import net.snowflake.client.jdbc.SnowflakeType; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.Types; + +public class ArrowFullVectorConverterUtil { + private ArrowFullVectorConverterUtil() {} + + static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) + throws SnowflakeSQLLoggedException { + Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); + // each column's metadata + Map customMeta = vector.getField().getMetadata(); + if (type == Types.MinorType.DECIMAL) { + // Note: Decimal vector is different from others + return Types.MinorType.DECIMAL; + } else if (!customMeta.isEmpty()) { + SnowflakeType st = SnowflakeType.valueOf(customMeta.get("logicalType")); + switch (st) { + case FIXED: + { + int sfScale = getScale(vector, session); + if (sfScale != 0) { + return Types.MinorType.DECIMAL; + } + break; + } + case VECTOR: + return Types.MinorType.FIXED_SIZE_LIST; + case TIME: + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale == 0) { + return Types.MinorType.TIMESEC; + } + if (sfScale <= 3) { + return Types.MinorType.TIMEMILLI; + } + if (sfScale <= 6) { + return Types.MinorType.TIMEMICRO; + } + if (sfScale <= 9) { + return Types.MinorType.TIMENANO; + } + } + case TIMESTAMP_NTZ: + return Types.MinorType.TIMESTAMPNANO; + case TIMESTAMP_LTZ: + case TIMESTAMP_TZ: + return Types.MinorType.TIMESTAMPNANOTZ; + } + } + return type; + } + + public static FieldVector convert( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object targetType) + throws SFArrowException { + try { + if (targetType == null) { + targetType = deduceType(vector, session); + } + if (targetType instanceof Types.MinorType) { + switch ((Types.MinorType) targetType) { + case TINYINT: + return new TinyIntVectorConverter(allocator, vector, context, session, idx).convert(); + case SMALLINT: + return new SmallIntVectorConverter(allocator, vector, context, session, idx).convert(); + case INT: + return new IntVectorConverter(allocator, vector, context, session, idx).convert(); + case BIGINT: + return new BigIntVectorConverter(allocator, vector, context, session, idx).convert(); + case DECIMAL: + return new DecimalVectorConverter(allocator, vector, context, session, idx).convert(); + case FLOAT8: + return new FloatVectorConverter(allocator, vector, context, session, idx).convert(); + case BIT: + return new BitVectorConverter(allocator, vector, context, session, idx).convert(); + case VARBINARY: + return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); + case DATEDAY: + return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) + .convert(); + case TIMESEC: + return new TimeSecVectorConverter(allocator, vector).convert(); + case TIMEMILLI: + return new TimeMilliVectorConverter(allocator, vector).convert(); + case TIMEMICRO: + return new TimeMicroVectorConverter(allocator, vector).convert(); + case TIMENANO: + return new TimeNanoVectorConverter(allocator, vector).convert(); + case TIMESTAMPNANOTZ: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, false) + .convert(); + case TIMESTAMPNANO: + return new TimestampVectorConverter(allocator, vector, context, timeZoneToUse, true) + .convert(); + case STRUCT: + return new StructVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case LIST: + return new ListVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case VARCHAR: + return new VarCharVectorConverter(allocator, vector, context, session, idx).convert(); + case MAP: + return new MapVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case FIXED_SIZE_LIST: + return new FixedSizeListVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + default: + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, + "Unexpected arrow type " + targetType + " at index " + idx); + } + } + } catch (SnowflakeSQLException | SFException | SFArrowException e) { + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed", e); + } + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed"); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java new file mode 100644 index 000000000..04e90e1a4 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BigIntVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class BigIntVectorConverter extends SimpleArrowFullVectorConverter { + + public BigIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof BigIntVector); + } + + @Override + protected BigIntVector initVector() { + BigIntVector resultVector = new BigIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, BigIntVector to, int idx) + throws SFException { + to.set(idx, from.toLong(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java new file mode 100644 index 000000000..8cee6d3f5 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BinaryVectorConverter.java @@ -0,0 +1,40 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; + +@SnowflakeJdbcInternalApi +public class BinaryVectorConverter extends SimpleArrowFullVectorConverter { + public BinaryVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof VarBinaryVector; + } + + @Override + protected VarBinaryVector initVector() { + VarBinaryVector resultVector = new VarBinaryVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, VarBinaryVector to, int idx) + throws SFException { + to.set(idx, from.toBytes(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java new file mode 100644 index 000000000..76701800f --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/BitVectorConverter.java @@ -0,0 +1,40 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class BitVectorConverter extends SimpleArrowFullVectorConverter { + + public BitVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof BitVector; + } + + @Override + protected BitVector initVector() { + BitVector resultVector = new BitVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, BitVector to, int idx) throws SFException { + to.set(idx, from.toBoolean(idx) ? 1 : 0); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java new file mode 100644 index 000000000..c509af685 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DateVectorConverter.java @@ -0,0 +1,53 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class DateVectorConverter extends SimpleArrowFullVectorConverter { + private TimeZone timeZone; + + public DateVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx, + TimeZone timeZone) { + super(allocator, vector, context, session, idx); + this.timeZone = timeZone; + } + + @Override + protected boolean matchingType() { + return vector instanceof DateDayVector; + } + + @Override + protected DateDayVector initVector() { + DateDayVector resultVector = new DateDayVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void additionalConverterInit(ArrowVectorConverter converter) { + if (timeZone != null) { + converter.setSessionTimeZone(timeZone); + converter.setUseSessionTimezone(true); + } + } + + @Override + protected void convertValue(ArrowVectorConverter from, DateDayVector to, int idx) + throws SFException { + to.set(idx, (int) (from.toDate(idx, null, false).getTime() / (1000 * 3600 * 24))); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java new file mode 100644 index 000000000..d7421f858 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/DecimalVectorConverter.java @@ -0,0 +1,45 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class DecimalVectorConverter extends SimpleArrowFullVectorConverter { + + public DecimalVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof DecimalVector); + } + + @Override + protected DecimalVector initVector() { + String scaleString = vector.getField().getMetadata().get("scale"); + String precisionString = vector.getField().getMetadata().get("precision"); + int scale = Integer.parseInt(scaleString); + int precision = Integer.parseInt(precisionString); + DecimalVector resultVector = new DecimalVector(vector.getName(), allocator, precision, scale); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, DecimalVector to, int idx) + throws SFException { + to.set(idx, from.toBigDecimal(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java new file mode 100644 index 000000000..30d9f9f77 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java @@ -0,0 +1,70 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class FixedSizeListVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Object valueTargetType; + private TimeZone timeZoneToUse; + + FixedSizeListVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.valueTargetType = valueTargetType; + } + + @Override + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + FixedSizeListVector listVector = (FixedSizeListVector) vector; + FieldVector dataVector = listVector.getDataVector(); + FieldVector convertedDataVector = + ArrowFullVectorConverterUtil.convert( + allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType); + FixedSizeListVector convertedListVector = + FixedSizeListVector.empty(listVector.getName(), listVector.getListSize(), allocator); + ArrayList fields = new ArrayList<>(); + fields.add(convertedDataVector.getField()); + convertedListVector.initializeChildrenFromFields(fields); + convertedListVector.allocateNew(); + convertedListVector.setValueCount(listVector.getValueCount()); + ArrowBuf validityBuffer = listVector.getValidityBuffer(); + convertedListVector + .getValidityBuffer() + .setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer(); + return convertedListVector; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java new file mode 100644 index 000000000..e47079293 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FloatVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class FloatVectorConverter extends SimpleArrowFullVectorConverter { + + public FloatVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return vector instanceof Float8Vector; + } + + @Override + protected Float8Vector initVector() { + Float8Vector resultVector = new Float8Vector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, Float8Vector to, int idx) + throws SFException { + to.set(idx, from.toDouble(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java new file mode 100644 index 000000000..db199e703 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/IntVectorConverter.java @@ -0,0 +1,40 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class IntVectorConverter extends SimpleArrowFullVectorConverter { + + public IntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof IntVector); + } + + @Override + protected IntVector initVector() { + IntVector resultVector = new IntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, IntVector to, int idx) throws SFException { + to.set(idx, from.toInt(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java new file mode 100644 index 000000000..bf61245b4 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java @@ -0,0 +1,77 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class ListVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Object valueTargetType; + private TimeZone timeZoneToUse; + + ListVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.valueTargetType = valueTargetType; + } + + protected ListVector initVector(String name, Field field) { + ListVector convertedListVector = ListVector.empty(name, allocator); + ArrayList fields = new ArrayList<>(); + fields.add(field); + convertedListVector.initializeChildrenFromFields(fields); + return convertedListVector; + } + + @Override + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + ListVector listVector = (ListVector) vector; + FieldVector dataVector = listVector.getDataVector(); + FieldVector convertedDataVector = + ArrowFullVectorConverterUtil.convert( + allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType); + // TODO: change to convertedDataVector and make all necessary changes to make it work + ListVector convertedListVector = initVector(vector.getName(), dataVector.getField()); + convertedListVector.allocateNew(); + convertedListVector.setValueCount(listVector.getValueCount()); + convertedListVector.getOffsetBuffer().setBytes(0, listVector.getOffsetBuffer()); + ArrowBuf validityBuffer = listVector.getValidityBuffer(); + convertedListVector + .getValidityBuffer() + .setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + convertedListVector.setLastSet(listVector.getLastSet()); + convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer(); + return convertedListVector; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java new file mode 100644 index 000000000..0b8ec963e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java @@ -0,0 +1,36 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class MapVectorConverter extends ListVectorConverter { + + MapVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + super(allocator, vector, context, session, timeZoneToUse, idx, valueTargetType); + } + + @Override + protected ListVector initVector(String name, Field field) { + MapVector convertedMapVector = MapVector.empty(name, allocator, false); + ArrayList fields = new ArrayList<>(); + fields.add(field); + convertedMapVector.initializeChildrenFromFields(fields); + return convertedMapVector; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java new file mode 100644 index 000000000..953e4004c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java @@ -0,0 +1,23 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +public class SFArrowException extends Exception { + private final ArrowErrorCode errorCode; + + public SFArrowException(ArrowErrorCode errorCode, String message) { + this(errorCode, message, null); + } + + public SFArrowException(ArrowErrorCode errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } + + public ArrowErrorCode getErrorCode() { + return errorCode; + } + + @Override + public String toString() { + return super.toString() + (getErrorCode() != null ? ", errorCode = " + getErrorCode() : ""); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java new file mode 100644 index 000000000..8f1a8c69a --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -0,0 +1,64 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.ArrowVectorConverterUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public abstract class SimpleArrowFullVectorConverter + extends AbstractFullVectorConverter { + + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + + public SimpleArrowFullVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.idx = idx; + } + + protected abstract boolean matchingType(); + + protected abstract T initVector(); + + protected abstract void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException; + + protected void additionalConverterInit(ArrowVectorConverter converter) {} + + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + if (matchingType()) { + return (FieldVector) vector; + } + int size = vector.getValueCount(); + T converted = initVector(); + ArrowVectorConverter converter = + ArrowVectorConverterUtil.initConverter(vector, context, session, idx); + additionalConverterInit(converter); + for (int i = 0; i < size; i++) { + if (!vector.isNull(i)) { + convertValue(converter, converted, i); + } + } + converted.setValueCount(size); + vector.close(); + return converted; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java new file mode 100644 index 000000000..f15a027ef --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SmallIntVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class SmallIntVectorConverter extends SimpleArrowFullVectorConverter { + + public SmallIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof SmallIntVector); + } + + @Override + protected SmallIntVector initVector() { + SmallIntVector resultVector = new SmallIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, SmallIntVector to, int idx) + throws SFException { + to.set(idx, from.toShort(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java new file mode 100644 index 000000000..2dbbe4dec --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java @@ -0,0 +1,83 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.stream.Collectors; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.TransferPair; + +@SnowflakeJdbcInternalApi +public class StructVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Map targetTypes; + private TimeZone timeZoneToUse; + + StructVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Map targetTypes) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.targetTypes = targetTypes; + } + + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + StructVector structVector = (StructVector) vector; + List childVectors = structVector.getChildrenFromFields(); + List convertedVectors = new ArrayList<>(); + for (FieldVector childVector : childVectors) { + Object targetType = null; + if (targetTypes != null) { + targetType = targetTypes.get(childVector.getName()); + } + convertedVectors.add( + ArrowFullVectorConverterUtil.convert( + allocator, childVector, context, session, timeZoneToUse, idx, targetType)); + } + + List convertedFields = + convertedVectors.stream().map(ValueVector::getField).collect(Collectors.toList()); + StructVector converted = StructVector.empty(vector.getName(), allocator); + converted.allocateNew(); + converted.initializeChildrenFromFields(convertedFields); + for (FieldVector convertedVector : convertedVectors) { + TransferPair transferPair = + convertedVector.makeTransferPair(converted.getChild(convertedVector.getName())); + transferPair.transfer(); + } + ArrowBuf validityBuffer = structVector.getValidityBuffer(); + converted.getValidityBuffer().setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + converted.setValueCount(vector.getValueCount()); + + return converted; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java new file mode 100644 index 000000000..93bc6318e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMicroVectorConverter.java @@ -0,0 +1,29 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeMicroVectorConverter extends TimeVectorConverter { + + public TimeMicroVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeMicroVector initVector() { + return new TimeMicroVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeMicroVector dstVector, int idx, long value) { + dstVector.set(idx, value); + } + + @Override + protected int targetScale() { + return 6; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java new file mode 100644 index 000000000..63a56c73c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeMilliVectorConverter.java @@ -0,0 +1,28 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeMilliVectorConverter extends TimeVectorConverter { + public TimeMilliVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeMilliVector initVector() { + return new TimeMilliVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeMilliVector dstVector, int idx, long value) { + dstVector.set(idx, (int) value); + } + + @Override + protected int targetScale() { + return 3; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java new file mode 100644 index 000000000..ad91e7a67 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeNanoVectorConverter.java @@ -0,0 +1,29 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeNanoVectorConverter extends TimeVectorConverter { + + public TimeNanoVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeNanoVector initVector() { + return new TimeNanoVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeNanoVector dstVector, int idx, long value) { + dstVector.set(idx, value); + } + + @Override + protected int targetScale() { + return 9; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java new file mode 100644 index 000000000..64498c715 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeSecVectorConverter.java @@ -0,0 +1,28 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TimeSecVectorConverter extends TimeVectorConverter { + public TimeSecVectorConverter(RootAllocator allocator, ValueVector vector) { + super(allocator, vector); + } + + @Override + protected TimeSecVector initVector() { + return new TimeSecVector(vector.getName(), allocator); + } + + @Override + protected void convertValue(TimeSecVector dstVector, int idx, long value) { + dstVector.set(idx, (int) value); + } + + @Override + protected int targetScale() { + return 0; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java new file mode 100644 index 000000000..60d3bae3f --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java @@ -0,0 +1,49 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public abstract class TimeVectorConverter + extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + + public TimeVectorConverter(RootAllocator allocator, ValueVector vector) { + this.allocator = allocator; + this.vector = vector; + } + + protected abstract T initVector(); + + protected abstract void convertValue(T dstVector, int idx, long value); + + protected abstract int targetScale(); + + @Override + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + int size = vector.getValueCount(); + T converted = initVector(); + converted.allocateNew(size); + BaseIntVector srcVector = (BaseIntVector) vector; + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + long scalingFactor = ArrowResultUtil.powerOfTen(targetScale() - scale); + for (int i = 0; i < size; i++) { + convertValue(converted, i, srcVector.getValueAsLong(i) * scalingFactor); + } + converted.setValueCount(size); + return converted; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java new file mode 100644 index 000000000..2d9f5b121 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimestampVectorConverter.java @@ -0,0 +1,186 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.List; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowResultUtil; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.util.SFPair; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class TimestampVectorConverter implements ArrowFullVectorConverter { + private RootAllocator allocator; + private ValueVector vector; + private DataConversionContext context; + private TimeZone timeZoneToUse; + + // This parameter is used to distinguish between NTZ and LTZ (TZ is distinct in having the offset + // vector) + private boolean isNTZ; + + /** Field names of the struct vectors used by timestamp */ + private static final String FIELD_NAME_EPOCH = "epoch"; // seconds since epoch + + private static final String FIELD_NAME_TIME_ZONE_INDEX = "timezone"; // time zone index + private static final String FIELD_NAME_FRACTION = "fraction"; // fraction in nanoseconds + private static final int UTC_OFFSET = 1440; + private static final long NANOS_PER_MILLI = 1000000L; + private static final int MILLIS_PER_SECOND = 1000; + private static final int SECONDS_PER_MINUTE = 60; + + public TimestampVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + TimeZone timeZoneToUse, + boolean isNTZ) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.isNTZ = isNTZ; + } + + private IntVector makeVectorOfZeroes(int length) { + IntVector vector = new IntVector(FIELD_NAME_FRACTION, allocator); + vector.allocateNew(length); + vector.zeroVector(); + vector.setValueCount(length); + return vector; + } + + private IntVector makeVectorOfUTCOffsets(int length) { + IntVector vector = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + vector.allocateNew(length); + vector.setValueCount(length); + for (int i = 0; i < length; i++) { + vector.set(i, UTC_OFFSET); + } + return vector; + } + + private SFPair normalizeTimeSinceEpoch(BigIntVector vector) { + int length = vector.getValueCount(); + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + if (scale == 0) { + IntVector fractions = makeVectorOfZeroes(length); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + fractions + .getValidityBuffer() + .setBytes(0L, vector.getValidityBuffer(), 0L, vector.getValidityBuffer().capacity()); + vector.makeTransferPair(epoch).transfer(); + return SFPair.of(epoch, fractions); + } + long scaleFactor = ArrowResultUtil.powerOfTen(scale); + long fractionScaleFactor = ArrowResultUtil.powerOfTen(9 - scale); + BigIntVector epoch = new BigIntVector(FIELD_NAME_EPOCH, allocator); + epoch.allocateNew(length); + epoch.setValueCount(length); + IntVector fractions = new IntVector(FIELD_NAME_FRACTION, allocator); + fractions.allocateNew(length); + fractions.setValueCount(length); + for (int i = 0; i < length; i++) { + epoch.set(i, vector.get(i) / scaleFactor); + fractions.set(i, (int) ((vector.get(i) % scaleFactor) * fractionScaleFactor)); + } + return SFPair.of(epoch, fractions); + } + + private IntVector makeTimeZoneOffsets( + BigIntVector seconds, IntVector fractions, TimeZone timeZone) { + IntVector offsets = new IntVector(FIELD_NAME_TIME_ZONE_INDEX, allocator); + offsets.allocateNew(seconds.getValueCount()); + offsets.setValueCount(seconds.getValueCount()); + for (int i = 0; i < seconds.getValueCount(); i++) { + offsets.set( + i, + UTC_OFFSET + + timeZone.getOffset( + seconds.get(i) * MILLIS_PER_SECOND + fractions.get(i) / NANOS_PER_MILLI) + / (MILLIS_PER_SECOND * SECONDS_PER_MINUTE)); + } + return offsets; + } + + private StructVector pack(BigIntVector seconds, IntVector fractions, IntVector offsets) { + StructVector result = StructVector.empty(vector.getName(), allocator); + List fields = + new ArrayList() { + { + add(seconds.getField()); + add(fractions.getField()); + add(offsets.getField()); + } + }; + result.setInitialCapacity(seconds.getValueCount()); + result.initializeChildrenFromFields(fields); + seconds.makeTransferPair(result.getChild(FIELD_NAME_EPOCH)).transfer(); + fractions.makeTransferPair(result.getChild(FIELD_NAME_FRACTION)).transfer(); + offsets.makeTransferPair(result.getChild(FIELD_NAME_TIME_ZONE_INDEX)).transfer(); + result.setValueCount(seconds.getValueCount()); + result + .getValidityBuffer() + .setBytes(0L, seconds.getValidityBuffer(), 0L, seconds.getValidityBuffer().capacity()); + return result; + } + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException { + boolean returnedOriginal = false; + try { + BigIntVector seconds; + IntVector fractions; + IntVector timeZoneIndices = null; + if (vector instanceof BigIntVector) { + SFPair normalized = normalizeTimeSinceEpoch((BigIntVector) vector); + seconds = normalized.left; + fractions = normalized.right; + } else { + StructVector structVector = (StructVector) vector; + if (structVector.getChildrenFromFields().size() == 3) { + returnedOriginal = true; + return structVector; + } + if (structVector.getChild(FIELD_NAME_FRACTION) == null) { + SFPair normalized = + normalizeTimeSinceEpoch(structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class)); + seconds = normalized.left; + fractions = normalized.right; + } else { + seconds = structVector.getChild(FIELD_NAME_EPOCH, BigIntVector.class); + fractions = structVector.getChild(FIELD_NAME_FRACTION, IntVector.class); + } + timeZoneIndices = structVector.getChild(FIELD_NAME_TIME_ZONE_INDEX, IntVector.class); + } + if (timeZoneIndices == null) { + if (isNTZ && context.getHonorClientTZForTimestampNTZ()) { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, TimeZone.getDefault()); + for (int i = 0; i < seconds.getValueCount(); i++) { + seconds.set( + i, + seconds.get(i) - (long) (timeZoneIndices.get(i) - UTC_OFFSET) * SECONDS_PER_MINUTE); + } + } else if (isNTZ || timeZoneToUse == null) { + timeZoneIndices = makeVectorOfUTCOffsets(seconds.getValueCount()); + } else { + timeZoneIndices = makeTimeZoneOffsets(seconds, fractions, timeZoneToUse); + } + } + return pack(seconds, fractions, timeZoneIndices); + } finally { + if (!returnedOriginal) { + vector.close(); + } + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java new file mode 100644 index 000000000..a4c7bdb22 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TinyIntVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; + +@SnowflakeJdbcInternalApi +public class TinyIntVectorConverter extends SimpleArrowFullVectorConverter { + + public TinyIntVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof TinyIntVector); + } + + @Override + protected TinyIntVector initVector() { + TinyIntVector resultVector = new TinyIntVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, TinyIntVector to, int idx) + throws SFException { + to.set(idx, from.toByte(idx)); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java new file mode 100644 index 000000000..8898d4498 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; + +@SnowflakeJdbcInternalApi +public class VarCharVectorConverter extends SimpleArrowFullVectorConverter { + public VarCharVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof VarCharVector); + } + + @Override + protected VarCharVector initVector() { + VarCharVector resultVector = new VarCharVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, VarCharVector to, int idx) + throws SFException { + to.set(idx, new Text(from.toString(idx))); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java new file mode 100644 index 000000000..d823fd112 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc; + +import java.util.List; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; + +public interface ArrowBatch { + List fetch() throws SnowflakeSQLException, SFArrowException; + + ArrowVectorConverter getTimestampConverter(FieldVector vector, int colIdx); + + long getRowCount(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java new file mode 100644 index 000000000..fba1d8d3e --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatches.java @@ -0,0 +1,11 @@ +package net.snowflake.client.jdbc; + +import java.sql.SQLException; + +public interface ArrowBatches { + boolean hasNext(); + + ArrowBatch next() throws SQLException; + + long getRowCount() throws SQLException; +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index cf641dc10..896437def 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -10,11 +10,16 @@ import java.nio.channels.ClosedByInterruptException; import java.util.ArrayList; import java.util.List; +import java.util.TimeZone; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowErrorCode; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverterUtil; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; @@ -55,6 +60,7 @@ public class ArrowResultChunk extends SnowflakeResultChunk { private IntVector firstResultChunkSortedIndices; private VectorSchemaRoot root; private SFBaseSession session; + private boolean batchesMode = false; public ArrowResultChunk( String url, @@ -126,6 +132,9 @@ public long computeNeededChunkMemory() { @Override public void freeData() { + if (batchesMode) { + return; + } batchOfVectors.forEach(list -> list.forEach(ValueVector::close)); this.batchOfVectors.clear(); if (firstResultChunkSortedIndices != null) { @@ -505,6 +514,12 @@ private void sortFirstResultChunk(List converters) } } + public ArrowBatch getArrowBatch( + DataConversionContext context, TimeZone timeZoneToUse, long batchIndex) { + batchesMode = true; + return new ArrowResultBatch(context, timeZoneToUse, batchIndex); + } + private boolean sortFirstResultChunkEnabled() { return enableSortFirstResultChunk; } @@ -528,4 +543,46 @@ public final void freeData() { // do nothing } } + + public class ArrowResultBatch implements ArrowBatch { + private DataConversionContext context; + private TimeZone timeZoneToUse; + private long batchIndex; + + ArrowResultBatch(DataConversionContext context, TimeZone timeZoneToUse, long batchIndex) { + this.context = context; + this.timeZoneToUse = timeZoneToUse; + this.batchIndex = batchIndex; + } + + public List fetch() throws SFArrowException { + try { + List result = new ArrayList<>(); + for (List record : batchOfVectors) { + List convertedVectors = new ArrayList<>(); + for (int i = 0; i < record.size(); i++) { + ValueVector vector = record.get(i); + convertedVectors.add( + ArrowFullVectorConverterUtil.convert( + rootAllocator, vector, context, session, timeZoneToUse, i, null)); + } + result.add(new VectorSchemaRoot(convertedVectors)); + } + return result; + } catch (SFArrowException e) { + throw new SFArrowException( + ArrowErrorCode.CHUNK_FETCH_FAILED, "Failed to fetch batch number " + batchIndex, e); + } + } + + @Override + public ArrowVectorConverter getTimestampConverter(FieldVector vector, int colIdx) { + return new ThreeFieldStructToTimestampTZConverter(vector, colIdx, context); + } + + @Override + public long getRowCount() { + return rowCount; + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java b/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java index 0bafbf12d..51cc14a04 100644 --- a/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SFAsyncResultSet.java @@ -402,4 +402,9 @@ public List getResultSetSerializables(long maxSi .unwrap(SnowflakeResultSet.class) .getResultSetSerializables(maxSizeInBytes); } + + @Override + public ArrowBatches getArrowBatches() throws SQLException { + throw new SnowflakeLoggedFeatureNotSupportedException(session); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java index fe1880083..5163f3299 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeChunkDownloader.java @@ -845,7 +845,7 @@ public DownloaderMetrics terminate() throws InterruptedException { logger.info( "Completed processing {} {} chunks for query {} in {} ms. Download took {} ms (average: {} ms)," + " parsing took {} ms (average: {} ms). Chunks uncompressed size: {} MB (average: {} MB)," - + " rows in chunks: {} (total: {}, average in chunk: {}), total memory used: {} MB", + + " rows in chunks: {} (total: {}, average in chunk: {}), total memory used: {} MB, free memory {} MB", chunksSize, queryResultFormat == QueryResultFormat.ARROW ? "ARROW" : "JSON", queryId, @@ -859,7 +859,8 @@ public DownloaderMetrics terminate() throws InterruptedException { rowsInChunks, firstChunkRowCount + rowsInChunks, rowsInChunks / chunksSize, - Runtime.getRuntime().totalMemory() / MB); + Runtime.getRuntime().totalMemory() / MB, + Runtime.getRuntime().freeMemory() / MB); return new DownloaderMetrics( numberMillisWaitingForChunks, diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java index 2df8975b5..5b6304bdf 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSet.java @@ -63,4 +63,6 @@ public interface SnowflakeResultSet { */ List getResultSetSerializables(long maxSizeInBytes) throws SQLException; + + ArrowBatches getArrowBatches() throws SQLException; } diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java index 4f73b4c18..e290db3f2 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java @@ -397,6 +397,11 @@ public List getResultSetSerializables(long maxSi return sfBaseResultSet.getResultSetSerializables(maxSizeInBytes); } + @Override + public ArrowBatches getArrowBatches() throws SQLException { + return sfBaseResultSet.getArrowBatches(); + } + /** Empty result set */ static class EmptyResultSet implements ResultSet { private boolean isClosed; diff --git a/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java new file mode 100644 index 000000000..dfb1fe598 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java @@ -0,0 +1,55 @@ +package net.snowflake.client.core.arrow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverterUtil; +import net.snowflake.client.core.arrow.fullvectorconverters.IntVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.junit.Test; + +public class ArrowBatchesTest extends BaseConverterTest { + @Test + public void testRepeatedConvert() throws SFException, SnowflakeSQLException, SFArrowException { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + IntVector intVector = new IntVector("test", allocator); + intVector.allocateNew(2); + intVector.set(0, 1); + intVector.set(1, 4); + intVector.setValueCount(2); + + IntVectorConverter converter = new IntVectorConverter(allocator, intVector, this, null, 0); + IntVector convertedIntVector = (IntVector) converter.convert(); + assertEquals(convertedIntVector.getValueCount(), 2); + assertEquals(convertedIntVector.get(0), 1); + assertEquals(convertedIntVector.get(1), 4); + try { + converter.convert().clear(); + } catch (SFArrowException e) { + // should throw + return; + } + fail("Second conversion should throw"); + } + + @Test + public void testUnknownType() { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + // Vector of unsupported type + IntervalDayVector vector = new IntervalDayVector("test", allocator); + try { + ArrowFullVectorConverterUtil.convert(allocator, vector, this, null, null, 0, null); + } catch (SFArrowException e) { + assertTrue(e.getCause() instanceof SFArrowException); + // should throw + return; + } + fail("Should throw on unsupported type"); + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java new file mode 100644 index 000000000..52161e185 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java @@ -0,0 +1,898 @@ +package net.snowflake.client.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import net.snowflake.client.category.TestTags; +import net.snowflake.client.core.SFArrowResultSet; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +@Tag(TestTags.ARROW) +public class ArrowBatchesIT extends BaseJDBCWithSharedConnectionIT { + + @BeforeAll + public static void setUp() throws Exception { + try (Statement statement = connection.createStatement()) { + statement.execute("alter session set jdbc_query_result_format = 'arrow'"); + statement.execute("alter session set ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true"); + statement.execute( + "alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true"); + } + } + + @AfterAll + public static void tearDown() throws Exception { + try (Statement statement = connection.createStatement()) { + statement.execute("alter session unset jdbc_query_result_format"); + statement.execute("alter session unset ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT"); + statement.execute("alter session unset FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT"); + } + } + + private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { + assertEquals( + 0, + ((SFArrowResultSet) rs.unwrap(SnowflakeResultSetV1.class).sfBaseResultSet) + .getAllocatedMemory()); + } + + @Test + public void testMultipleBatches() throws Exception { + int totalRows = 0; + ArrayList allRoots = new ArrayList<>(); + // Result set is not in the try-with-resources statement, as we want to check access to memory + // after its closure + // and then check the memory allocation. + ResultSet rs; + try (Statement statement = connection.createStatement()) { + rs = + statement.executeQuery( + "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 300000))"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + assertEquals(batches.getRowCount(), 300000); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + allRoots.add(root); + assertTrue(root.getVector(0) instanceof TinyIntVector); + assertTrue(root.getVector(1) instanceof SmallIntVector); + assertTrue(root.getVector(2) instanceof IntVector); + assertTrue(root.getVector(3) instanceof BigIntVector); + } + } + } + + // The memory should not be freed when closing the result set. + for (VectorSchemaRoot root : allRoots) { + assertTrue(root.getVector(0).getValueCount() > 0); + root.close(); + } + assertNoMemoryLeaks(rs); + assertEquals(300000, totalRows); + } + + @Test + public void testTinyIntBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TinyIntVector); + TinyIntVector vector = (TinyIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + // All expected values are present + for (byte i = 1; i < 4; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testSmallIntBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertInstanceOf(SmallIntVector.class, root.getVector(0)); + SmallIntVector vector = (SmallIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + // All expected values are present + for (short i = 129; i < 132; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testIntBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select 100000 union select 100001 union select 100002;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof IntVector); + IntVector vector = (IntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + // All expected values are present + for (int i = 100000; i < 100003; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testBigIntBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select 10000000000 union select 10000000001 union select 10000000002;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof BigIntVector); + BigIntVector vector = (BigIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + // All expected values are present + for (long i = 10000000000L; i < 10000000003L; i++) { + assertTrue(values.contains(i)); + } + + assertEquals(3, totalRows); + } + + @Test + public void testDecimalBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DecimalVector); + DecimalVector vector = (DecimalVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + // All expected values are present + for (int i = 1; i < 4; i++) { + assertTrue(values.contains(new BigDecimal("1." + i))); + } + + assertEquals(3, totalRows); + } + + @Test + public void testBitBatch() throws Exception { + int trueCount = 0; + int falseCount = 0; + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select true union all select false union all select true union all select false" + + " union all select true union all select false union all select true")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof BitVector); + BitVector vector = (BitVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + if (vector.getObject(i)) { + trueCount++; + } else { + falseCount++; + } + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + assertEquals(4, trueCount); + assertEquals(3, falseCount); + } + + @Test + public void testBinaryBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + assertTrue(root.getVector(0) instanceof VarBinaryVector); + VarBinaryVector vector = (VarBinaryVector) root.getVector(0); + totalRows += root.getRowCount(); + for (int i = 0; i < root.getRowCount(); i++) { + byte[] bytes = vector.getObject(i); + ArrayList byteArrayList = + new ArrayList() { + { + for (byte aByte : bytes) { + add(aByte); + } + } + }; + values.add(byteArrayList); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add((byte) 0x54); + add((byte) 0x6A); + add((byte) 0xB0); + } + }); + add( + new ArrayList() { + { + add((byte) 0x01); + add((byte) 0x8E); + add((byte) 0x32); + add((byte) 0x71); + } + }); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testDateBatch() throws Exception, SFException { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DateDayVector); + DateDayVector vector = (DateDayVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalDate.ofEpochDay(vector.get(i))); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(LocalDate.of(1119, 2, 1)); + add(LocalDate.of(2021, 9, 11)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeSecBatch() throws Exception, SFException { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeSecVector); + TimeSecVector vector = (TimeSecVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofSecondOfDay(vector.get(i))); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54)); + add(LocalTime.of(8, 11, 25)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeMilliBatch() throws Exception, SFException { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMilliVector); + TimeMilliVector vector = (TimeMilliVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i).toLocalTime()); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 130 * 1000 * 1000)); + add(LocalTime.of(8, 11, 25, 910 * 1000 * 1000)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeMicroBatch() throws Exception, SFException { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeMicroVector); + TimeMicroVector vector = (TimeMicroVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i) * 1000)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 139901 * 1000)); + add(LocalTime.of(8, 11, 25, 911765 * 1000)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testTimeNanoBatch() throws Exception, SFException { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TimeNanoVector); + TimeNanoVector vector = (TimeNanoVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(LocalTime.ofNanoOfDay(vector.get(i))); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(LocalTime.of(11, 32, 54, 139901300)); + add(LocalTime.of(8, 11, 25, 911765400)); + } + }; + + assertEquals(2, totalRows); + assertTrue(values.containsAll(expected)); + } + + @Test + public void testVarCharBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select 'Gallia est ' union select 'omnis divisa ' union select 'in partes tres';")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof VarCharVector); + VarCharVector vector = (VarCharVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(new Text("Gallia est ")); + add(new Text("omnis divisa ")); + add(new Text("in partes tres")); + } + }; + + assertTrue(values.containsAll(expected)); + + assertEquals(3, totalRows); + } + + private class Pair { + private final A first; + private final B second; + + Pair(A first, B second) { + this.first = first; + this.second = second; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Pair)) { + return false; + } + Pair other = (Pair) obj; + return first.equals(other.first) && second.equals(other.second); + } + } + + @Test + public void testStructBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select {'a': 3.1, 'b': 3.2}::object(a decimal(18, 3), b decimal(18, 3))" + + " union select {'a': 2.2, 'b': 2.3}::object(a decimal(18, 3), b decimal(18, 3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof StructVector); + StructVector vector = (StructVector) root.getVector(0); + DecimalVector aVector = (DecimalVector) vector.getChild("a"); + DecimalVector bVector = (DecimalVector) vector.getChild("b"); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(new Pair<>(aVector.getObject(i), bVector.getObject(i))); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add(new Pair<>(new BigDecimal("3.100"), new BigDecimal("3.200"))); + add(new Pair<>(new BigDecimal("2.200"), new BigDecimal("2.300"))); + } + }; + + assertTrue(values.containsAll(expected)); + + assertEquals(2, totalRows); + } + + @Test + public void testListBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select array_construct(1.2, 2.3)::array(decimal(18, 3)) union all select array_construct(2.1, 1.0)::array(decimal(18, 3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof ListVector); + ListVector vector = (ListVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add((List) vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add(new BigDecimal("1.200")); + add(new BigDecimal("2.300")); + } + }); + add( + new ArrayList() { + { + add(new BigDecimal("2.100")); + add(new BigDecimal("1.000")); + } + }); + } + }; + + assertTrue(expected.containsAll(values)); + + assertEquals(2, totalRows); + } + + @Test + public void testMapBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select {'a': 3.1, 'b': 4.3}::map(varchar, decimal(18,3)) union" + + " select {'c': 2.2, 'd': 1.5}::map(varchar, decimal(18,3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof MapVector); + MapVector vector = (MapVector) root.getVector(0); + VarCharVector keyVector = + (VarCharVector) vector.getChildrenFromFields().get(0).getChildrenFromFields().get(0); + DecimalVector valueVector = + (DecimalVector) vector.getChildrenFromFields().get(0).getChildrenFromFields().get(1); + for (int i = 0; i < root.getRowCount(); i++) { + int startIndex = vector.getElementStartIndex(i); + int endIndex = vector.getElementEndIndex(i); + Map map = new HashMap<>(); + for (int j = startIndex; j < endIndex; j++) { + map.put(keyVector.getObject(j), valueVector.getObject(j)); + } + values.add(map); + } + root.close(); + } + } + } + + // All expected values are present + List> expected = + Stream.of( + new HashMap() { + { + put(new Text("a"), new BigDecimal("3.100")); + put(new Text("b"), new BigDecimal("4.300")); + } + }, + new HashMap() { + { + put(new Text("c"), new BigDecimal("2.200")); + put(new Text("d"), new BigDecimal("1.500")); + } + }) + .collect(Collectors.toList()); + + assertTrue(values.containsAll(expected)); + + assertEquals(2, totalRows); + } + + @Test + public void testFixedSizeListBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select [1, 2]::vector(int, 2) union all select [3, 4]::vector(int, 2)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof FixedSizeListVector); + FixedSizeListVector vector = (FixedSizeListVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add((List) vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add(1); + add(2); + } + }); + add( + new ArrayList() { + { + add(3); + add(4); + } + }); + } + }; + + assertTrue(expected.containsAll(values)); + + assertEquals(2, totalRows); + } + + private void testTimestampCase(String query) throws Exception, SFException { + Timestamp tsFromBatch; + Timestamp tsFromRow; + + try (Statement statement = connection.createStatement()) { + try (ResultSet rs = statement.executeQuery(query)) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + ArrowBatch batch = batches.next(); + VectorSchemaRoot root = batch.fetch().get(0); + assertTrue(root.getVector(0) instanceof StructVector); + ArrowVectorConverter converter = batch.getTimestampConverter(root.getVector(0), 1); + tsFromBatch = converter.toTimestamp(0, null); + root.close(); + assertNoMemoryLeaks(rs); + } + try (ResultSet rs = statement.executeQuery(query)) { + rs.next(); + tsFromRow = rs.getTimestamp(1); + } + } + assertTrue(tsFromBatch.equals(tsFromRow)); + } + + private void testTimestampBase(String query) throws Exception, SFException { + testTimestampCase(query); + testTimestampCase(query + "(0)"); + testTimestampCase(query + "(1)"); + } + + @Test + public void testTimestampTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12+0700'::TIMESTAMP_TZ"); + } + + @Test + public void testTimestampLTZUseSessionTimezoneBatch() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set JDBC_USE_SESSION_TIMEZONE=true"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset JDBC_USE_SESSION_TIMEZONE"); + } + + @Test + public void testTimestampLTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + } + + @Test + public void testTimestampNTZBatch() throws Exception, SFException { + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_NTZ"); + } + + @Test + public void testTimestampNTZDontHonorClientTimezone() throws Exception, SFException { + Statement statement = connection.createStatement(); + statement.execute("alter session set CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ=false"); + testTimestampBase("select '2020-04-05 12:22:12'::TIMESTAMP_LTZ"); + statement.execute("alter session unset CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ"); + } +}