From 8c209d2a2b0673e43abf00bc01a892229dfe8f71 Mon Sep 17 00:00:00 2001 From: Antoni Stachowski Date: Wed, 16 Oct 2024 16:46:15 +0200 Subject: [PATCH] Arrow batches structured types (#1879) --- .../client/core/SFArrowResultSet.java | 6 +- .../AbstractFullVectorConverter.java | 23 + .../fullvectorconverters/ArrowErrorCode.java | 7 + .../ArrowFullVectorConverter.java | 2 +- .../ArrowFullVectorConverterUtil.java | 77 ++- .../FixedSizeListVectorConverter.java | 70 +++ .../ListVectorConverter.java | 77 +++ .../MapVectorConverter.java | 36 ++ .../SFArrowException.java | 23 + .../SimpleArrowFullVectorConverter.java | 5 +- .../StructVectorConverter.java | 83 +++ .../TimeVectorConverter.java | 30 +- .../VarCharVectorConverter.java | 41 ++ .../net/snowflake/client/jdbc/ArrowBatch.java | 3 +- .../client/jdbc/ArrowResultChunk.java | 37 +- .../client/core/arrow/ArrowBatchesTest.java | 55 ++ .../snowflake/client/jdbc/ArrowBatchesIT.java | 528 ++++++++++++++---- 17 files changed, 921 insertions(+), 182 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java create mode 100644 src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java create mode 100644 src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 9c278a873..f14e74e5d 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -816,11 +816,13 @@ public ArrowBatch next() throws SQLException { firstFetched = true; return currentChunkIterator .getChunk() - .getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null); + .getArrowBatch( + SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex); } else { nextChunkIndex++; return fetchNextChunk() - .getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null); + .getArrowBatch( + SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex); } } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java new file mode 100644 index 000000000..f2cf5867e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/AbstractFullVectorConverter.java @@ -0,0 +1,23 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.vector.FieldVector; + +public abstract class AbstractFullVectorConverter implements ArrowFullVectorConverter { + private boolean converted; + + protected abstract FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException; + + @Override + public FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException { + if (converted) { + throw new SFArrowException( + ArrowErrorCode.VECTOR_ALREADY_CONVERTED, "Convert has already been called"); + } else { + converted = true; + return convertVector(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java new file mode 100644 index 000000000..b0af45dfb --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowErrorCode.java @@ -0,0 +1,7 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +public enum ArrowErrorCode { + VECTOR_ALREADY_CONVERTED, + CONVERT_FAILED, + CHUNK_FETCH_FAILED, +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java index 929dcdc1e..29dc8143f 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverter.java @@ -7,5 +7,5 @@ @SnowflakeJdbcInternalApi public interface ArrowFullVectorConverter { - FieldVector convert() throws SFException, SnowflakeSQLException; + FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException; } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java index 6b5c3d9c1..2f2654b21 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ArrowFullVectorConverterUtil.java @@ -4,15 +4,12 @@ import java.util.Map; import java.util.TimeZone; - import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; -import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeSQLException; import net.snowflake.client.jdbc.SnowflakeSQLLoggedException; import net.snowflake.client.jdbc.SnowflakeType; -import net.snowflake.common.core.SqlState; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; @@ -21,7 +18,8 @@ public class ArrowFullVectorConverterUtil { private ArrowFullVectorConverterUtil() {} - static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) throws SnowflakeSQLLoggedException { + static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) + throws SnowflakeSQLLoggedException { Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType()); // each column's metadata Map customMeta = vector.getField().getMetadata(); @@ -39,23 +37,25 @@ static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) thr } break; } + case VECTOR: + return Types.MinorType.FIXED_SIZE_LIST; case TIME: - { - String scaleStr = vector.getField().getMetadata().get("scale"); - int sfScale = Integer.parseInt(scaleStr); - if (sfScale == 0) { - return Types.MinorType.TIMESEC; - } - if (sfScale <= 3) { - return Types.MinorType.TIMEMILLI; - } - if (sfScale <= 6) { - return Types.MinorType.TIMEMICRO; - } - if (sfScale <= 9) { - return Types.MinorType.TIMENANO; + { + String scaleStr = vector.getField().getMetadata().get("scale"); + int sfScale = Integer.parseInt(scaleStr); + if (sfScale == 0) { + return Types.MinorType.TIMESEC; + } + if (sfScale <= 3) { + return Types.MinorType.TIMEMILLI; + } + if (sfScale <= 6) { + return Types.MinorType.TIMEMICRO; + } + if (sfScale <= 9) { + return Types.MinorType.TIMENANO; + } } - } case TIMESTAMP_NTZ: return Types.MinorType.TIMESTAMPNANO; case TIMESTAMP_LTZ: @@ -74,7 +74,7 @@ public static FieldVector convert( TimeZone timeZoneToUse, int idx, Object targetType) - throws SnowflakeSQLException { + throws SFArrowException { try { if (targetType == null) { targetType = deduceType(vector, session); @@ -99,7 +99,7 @@ public static FieldVector convert( return new BinaryVectorConverter(allocator, vector, context, session, idx).convert(); case DATEDAY: return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse) - .convert(); + .convert(); case TIMESEC: return new TimeSecVectorConverter(allocator, vector).convert(); case TIMEMILLI: @@ -108,18 +108,35 @@ public static FieldVector convert( return new TimeMicroVectorConverter(allocator, vector).convert(); case TIMENANO: return new TimeNanoVectorConverter(allocator, vector).convert(); + case STRUCT: + return new StructVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case LIST: + return new ListVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case VARCHAR: + return new VarCharVectorConverter(allocator, vector, context, session, idx).convert(); + case MAP: + return new MapVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); + case FIXED_SIZE_LIST: + return new FixedSizeListVectorConverter( + allocator, vector, context, session, timeZoneToUse, idx, null) + .convert(); default: - throw new SnowflakeSQLLoggedException( - session, - ErrorCode.INTERNAL_ERROR.getMessageCode(), - SqlState.INTERNAL_ERROR, - "Unsupported target type"); + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, + "Unexpected arrow type " + targetType + " at index " + idx); } } - } catch (SFException ex) { - throw new SnowflakeSQLException( - ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams()); + } catch (SnowflakeSQLException | SFException | SFArrowException e) { + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed", e); } - return null; + throw new SFArrowException( + ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed"); } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java new file mode 100644 index 000000000..30d9f9f77 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/FixedSizeListVectorConverter.java @@ -0,0 +1,70 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class FixedSizeListVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Object valueTargetType; + private TimeZone timeZoneToUse; + + FixedSizeListVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.valueTargetType = valueTargetType; + } + + @Override + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + FixedSizeListVector listVector = (FixedSizeListVector) vector; + FieldVector dataVector = listVector.getDataVector(); + FieldVector convertedDataVector = + ArrowFullVectorConverterUtil.convert( + allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType); + FixedSizeListVector convertedListVector = + FixedSizeListVector.empty(listVector.getName(), listVector.getListSize(), allocator); + ArrayList fields = new ArrayList<>(); + fields.add(convertedDataVector.getField()); + convertedListVector.initializeChildrenFromFields(fields); + convertedListVector.allocateNew(); + convertedListVector.setValueCount(listVector.getValueCount()); + ArrowBuf validityBuffer = listVector.getValidityBuffer(); + convertedListVector + .getValidityBuffer() + .setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer(); + return convertedListVector; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java new file mode 100644 index 000000000..bf61245b4 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/ListVectorConverter.java @@ -0,0 +1,77 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class ListVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Object valueTargetType; + private TimeZone timeZoneToUse; + + ListVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.valueTargetType = valueTargetType; + } + + protected ListVector initVector(String name, Field field) { + ListVector convertedListVector = ListVector.empty(name, allocator); + ArrayList fields = new ArrayList<>(); + fields.add(field); + convertedListVector.initializeChildrenFromFields(fields); + return convertedListVector; + } + + @Override + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + ListVector listVector = (ListVector) vector; + FieldVector dataVector = listVector.getDataVector(); + FieldVector convertedDataVector = + ArrowFullVectorConverterUtil.convert( + allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType); + // TODO: change to convertedDataVector and make all necessary changes to make it work + ListVector convertedListVector = initVector(vector.getName(), dataVector.getField()); + convertedListVector.allocateNew(); + convertedListVector.setValueCount(listVector.getValueCount()); + convertedListVector.getOffsetBuffer().setBytes(0, listVector.getOffsetBuffer()); + ArrowBuf validityBuffer = listVector.getValidityBuffer(); + convertedListVector + .getValidityBuffer() + .setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + convertedListVector.setLastSet(listVector.getLastSet()); + convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer(); + return convertedListVector; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java new file mode 100644 index 000000000..0b8ec963e --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/MapVectorConverter.java @@ -0,0 +1,36 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.TimeZone; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.types.pojo.Field; + +@SnowflakeJdbcInternalApi +public class MapVectorConverter extends ListVectorConverter { + + MapVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Object valueTargetType) { + super(allocator, vector, context, session, timeZoneToUse, idx, valueTargetType); + } + + @Override + protected ListVector initVector(String name, Field field) { + MapVector convertedMapVector = MapVector.empty(name, allocator, false); + ArrayList fields = new ArrayList<>(); + fields.add(field); + convertedMapVector.initializeChildrenFromFields(fields); + return convertedMapVector; + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java new file mode 100644 index 000000000..953e4004c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SFArrowException.java @@ -0,0 +1,23 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +public class SFArrowException extends Exception { + private final ArrowErrorCode errorCode; + + public SFArrowException(ArrowErrorCode errorCode, String message) { + this(errorCode, message, null); + } + + public SFArrowException(ArrowErrorCode errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } + + public ArrowErrorCode getErrorCode() { + return errorCode; + } + + @Override + public String toString() { + return super.toString() + (getErrorCode() != null ? ", errorCode = " + getErrorCode() : ""); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java index f2d5c1d27..8f1a8c69a 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/SimpleArrowFullVectorConverter.java @@ -13,7 +13,7 @@ @SnowflakeJdbcInternalApi public abstract class SimpleArrowFullVectorConverter - implements ArrowFullVectorConverter { + extends AbstractFullVectorConverter { protected RootAllocator allocator; protected ValueVector vector; @@ -42,7 +42,8 @@ public SimpleArrowFullVectorConverter( protected void additionalConverterInit(ArrowVectorConverter converter) {} - public FieldVector convert() throws SFException, SnowflakeSQLException { + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { if (matchingType()) { return (FieldVector) vector; } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java new file mode 100644 index 000000000..2dbbe4dec --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/StructVectorConverter.java @@ -0,0 +1,83 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.stream.Collectors; +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.TransferPair; + +@SnowflakeJdbcInternalApi +public class StructVectorConverter extends AbstractFullVectorConverter { + protected RootAllocator allocator; + protected ValueVector vector; + protected DataConversionContext context; + protected SFBaseSession session; + protected int idx; + protected Map targetTypes; + private TimeZone timeZoneToUse; + + StructVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + TimeZone timeZoneToUse, + int idx, + Map targetTypes) { + this.allocator = allocator; + this.vector = vector; + this.context = context; + this.session = session; + this.timeZoneToUse = timeZoneToUse; + this.idx = idx; + this.targetTypes = targetTypes; + } + + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + StructVector structVector = (StructVector) vector; + List childVectors = structVector.getChildrenFromFields(); + List convertedVectors = new ArrayList<>(); + for (FieldVector childVector : childVectors) { + Object targetType = null; + if (targetTypes != null) { + targetType = targetTypes.get(childVector.getName()); + } + convertedVectors.add( + ArrowFullVectorConverterUtil.convert( + allocator, childVector, context, session, timeZoneToUse, idx, targetType)); + } + + List convertedFields = + convertedVectors.stream().map(ValueVector::getField).collect(Collectors.toList()); + StructVector converted = StructVector.empty(vector.getName(), allocator); + converted.allocateNew(); + converted.initializeChildrenFromFields(convertedFields); + for (FieldVector convertedVector : convertedVectors) { + TransferPair transferPair = + convertedVector.makeTransferPair(converted.getChild(convertedVector.getName())); + transferPair.transfer(); + } + ArrowBuf validityBuffer = structVector.getValidityBuffer(); + converted.getValidityBuffer().setBytes(0L, validityBuffer, 0L, validityBuffer.capacity()); + converted.setValueCount(vector.getValueCount()); + + return converted; + } finally { + vector.close(); + } + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java index baba5931a..60d3bae3f 100644 --- a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/TimeVectorConverter.java @@ -12,7 +12,7 @@ @SnowflakeJdbcInternalApi public abstract class TimeVectorConverter - implements ArrowFullVectorConverter { + extends AbstractFullVectorConverter { protected RootAllocator allocator; protected ValueVector vector; @@ -28,18 +28,22 @@ public TimeVectorConverter(RootAllocator allocator, ValueVector vector) { protected abstract int targetScale(); @Override - public FieldVector convert() throws SFException, SnowflakeSQLException { - int size = vector.getValueCount(); - T converted = initVector(); - converted.allocateNew(size); - BaseIntVector srcVector = (BaseIntVector) vector; - int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); - long scalingFactor = ArrowResultUtil.powerOfTen(targetScale() - scale); - for (int i = 0; i < size; i++) { - convertValue(converted, i, srcVector.getValueAsLong(i) * scalingFactor); + protected FieldVector convertVector() + throws SFException, SnowflakeSQLException, SFArrowException { + try { + int size = vector.getValueCount(); + T converted = initVector(); + converted.allocateNew(size); + BaseIntVector srcVector = (BaseIntVector) vector; + int scale = Integer.parseInt(vector.getField().getMetadata().get("scale")); + long scalingFactor = ArrowResultUtil.powerOfTen(targetScale() - scale); + for (int i = 0; i < size; i++) { + convertValue(converted, i, srcVector.getValueAsLong(i) * scalingFactor); + } + converted.setValueCount(size); + return converted; + } finally { + vector.close(); } - converted.setValueCount(size); - vector.close(); - return converted; } } diff --git a/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java new file mode 100644 index 000000000..8898d4498 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/arrow/fullvectorconverters/VarCharVectorConverter.java @@ -0,0 +1,41 @@ +package net.snowflake.client.core.arrow.fullvectorconverters; + +import net.snowflake.client.core.DataConversionContext; +import net.snowflake.client.core.SFBaseSession; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.arrow.ArrowVectorConverter; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; + +@SnowflakeJdbcInternalApi +public class VarCharVectorConverter extends SimpleArrowFullVectorConverter { + public VarCharVectorConverter( + RootAllocator allocator, + ValueVector vector, + DataConversionContext context, + SFBaseSession session, + int idx) { + super(allocator, vector, context, session, idx); + } + + @Override + protected boolean matchingType() { + return (vector instanceof VarCharVector); + } + + @Override + protected VarCharVector initVector() { + VarCharVector resultVector = new VarCharVector(vector.getName(), allocator); + resultVector.allocateNew(vector.getValueCount()); + return resultVector; + } + + @Override + protected void convertValue(ArrowVectorConverter from, VarCharVector to, int idx) + throws SFException { + to.set(idx, new Text(from.toString(idx))); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java index c9dd11c12..d823fd112 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowBatch.java @@ -2,11 +2,12 @@ import java.util.List; import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; public interface ArrowBatch { - List fetch() throws SnowflakeSQLException; + List fetch() throws SnowflakeSQLException, SFArrowException; ArrowVectorConverter getTimestampConverter(FieldVector vector, int colIdx); diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index 1d66cb38e..c080a2f36 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -17,7 +17,9 @@ import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; import net.snowflake.client.core.arrow.ArrowVectorConverter; import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowErrorCode; import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverterUtil; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; @@ -501,9 +503,10 @@ private void sortFirstResultChunk(List converters) } } - public ArrowBatch getArrowBatch(DataConversionContext context, TimeZone timeZoneToUse) { + public ArrowBatch getArrowBatch( + DataConversionContext context, TimeZone timeZoneToUse, long batchIndex) { batchesMode = true; - return new ArrowResultBatch(context, timeZoneToUse); + return new ArrowResultBatch(context, timeZoneToUse, batchIndex); } private boolean sortFirstResultChunkEnabled() { @@ -533,24 +536,32 @@ public final void freeData() { public class ArrowResultBatch implements ArrowBatch { private DataConversionContext context; private TimeZone timeZoneToUse; + private long batchIndex; - ArrowResultBatch(DataConversionContext context, TimeZone timeZoneToUse) { + ArrowResultBatch(DataConversionContext context, TimeZone timeZoneToUse, long batchIndex) { this.context = context; this.timeZoneToUse = timeZoneToUse; + this.batchIndex = batchIndex; } - public List fetch() throws SnowflakeSQLException { - List result = new ArrayList<>(); - for (List record : batchOfVectors) { - List convertedVectors = new ArrayList<>(); - for (int i = 0; i < record.size(); i++) { - ValueVector vector = record.get(i); - convertedVectors.add( - ArrowFullVectorConverterUtil.convert(rootAllocator, vector, context, session, timeZoneToUse, i, null)); + public List fetch() throws SFArrowException { + try { + List result = new ArrayList<>(); + for (List record : batchOfVectors) { + List convertedVectors = new ArrayList<>(); + for (int i = 0; i < record.size(); i++) { + ValueVector vector = record.get(i); + convertedVectors.add( + ArrowFullVectorConverterUtil.convert( + rootAllocator, vector, context, session, timeZoneToUse, i, null)); + } + result.add(new VectorSchemaRoot(convertedVectors)); } - result.add(new VectorSchemaRoot(convertedVectors)); + return result; + } catch (SFArrowException e) { + throw new SFArrowException( + ArrowErrorCode.CHUNK_FETCH_FAILED, "Failed to fetch batch number " + batchIndex, e); } - return result; } @Override diff --git a/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java b/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java new file mode 100644 index 000000000..dfb1fe598 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/arrow/ArrowBatchesTest.java @@ -0,0 +1,55 @@ +package net.snowflake.client.core.arrow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.arrow.fullvectorconverters.ArrowFullVectorConverterUtil; +import net.snowflake.client.core.arrow.fullvectorconverters.IntVectorConverter; +import net.snowflake.client.core.arrow.fullvectorconverters.SFArrowException; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.junit.Test; + +public class ArrowBatchesTest extends BaseConverterTest { + @Test + public void testRepeatedConvert() throws SFException, SnowflakeSQLException, SFArrowException { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + IntVector intVector = new IntVector("test", allocator); + intVector.allocateNew(2); + intVector.set(0, 1); + intVector.set(1, 4); + intVector.setValueCount(2); + + IntVectorConverter converter = new IntVectorConverter(allocator, intVector, this, null, 0); + IntVector convertedIntVector = (IntVector) converter.convert(); + assertEquals(convertedIntVector.getValueCount(), 2); + assertEquals(convertedIntVector.get(0), 1); + assertEquals(convertedIntVector.get(1), 4); + try { + converter.convert().clear(); + } catch (SFArrowException e) { + // should throw + return; + } + fail("Second conversion should throw"); + } + + @Test + public void testUnknownType() { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + // Vector of unsupported type + IntervalDayVector vector = new IntervalDayVector("test", allocator); + try { + ArrowFullVectorConverterUtil.convert(allocator, vector, this, null, null, 0, null); + } catch (SFArrowException e) { + assertTrue(e.getCause() instanceof SFArrowException); + // should throw + return; + } + fail("Should throw on unsupported type"); + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java index 6bfeb7f44..12bcc19e0 100644 --- a/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ArrowBatchesIT.java @@ -10,7 +10,11 @@ import java.time.LocalDate; import java.time.LocalTime; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import net.snowflake.client.category.TestCategoryArrow; import net.snowflake.client.core.SFArrowResultSet; import net.snowflake.client.core.SFException; @@ -26,7 +30,13 @@ import org.apache.arrow.vector.TimeSecVector; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.util.Text; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -39,6 +49,9 @@ public class ArrowBatchesIT extends BaseJDBCWithSharedConnectionIT { public static void setUp() throws Exception { try (Statement statement = connection.createStatement()) { statement.execute("alter session set jdbc_query_result_format = 'arrow'"); + statement.execute("alter session set ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true"); + statement.execute( + "alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true"); } } @@ -46,6 +59,8 @@ public static void setUp() throws Exception { public static void tearDown() throws Exception { try (Statement statement = connection.createStatement()) { statement.execute("alter session unset jdbc_query_result_format"); + statement.execute("alter session unset ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT"); + statement.execute("alter session unset FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT"); } } @@ -58,29 +73,32 @@ private static void assertNoMemoryLeaks(ResultSet rs) throws SQLException { @Test public void testMultipleBatches() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery( - "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 300000))"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - assertEquals(batches.getRowCount(), 300000); int totalRows = 0; ArrayList allRoots = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - allRoots.add(root); - assertTrue(root.getVector(0) instanceof TinyIntVector); - assertTrue(root.getVector(1) instanceof SmallIntVector); - assertTrue(root.getVector(2) instanceof IntVector); - assertTrue(root.getVector(3) instanceof BigIntVector); + // Result set is not in the try-with-resources statement, as we want to check access to memory + // after its closure + // and then check the memory allocation. + ResultSet rs; + try (Statement statement = connection.createStatement()) { + rs = + statement.executeQuery( + "select seq1(), seq2(), seq4(), seq8() from TABLE (generator(rowcount => 300000))"); + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + assertEquals(batches.getRowCount(), 300000); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + allRoots.add(root); + assertTrue(root.getVector(0) instanceof TinyIntVector); + assertTrue(root.getVector(1) instanceof SmallIntVector); + assertTrue(root.getVector(2) instanceof IntVector); + assertTrue(root.getVector(3) instanceof BigIntVector); + } } } - rs.close(); - // The memory should not be freed when closing the result set. for (VectorSchemaRoot root : allRoots) { assertTrue(root.getVector(0).getValueCount() > 0); @@ -92,28 +110,28 @@ public void testMultipleBatches() throws Exception { @Test public void testTinyIntBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof TinyIntVector); - TinyIntVector vector = (TinyIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 1 union select 2 union select 3;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof TinyIntVector); + TinyIntVector vector = (TinyIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); // All expected values are present for (byte i = 1; i < 4; i++) { @@ -125,28 +143,27 @@ public void testTinyIntBatch() throws Exception { @Test public void testSmallIntBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 129 union select 130 union select 131;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof SmallIntVector); - SmallIntVector vector = (SmallIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof SmallIntVector); + SmallIntVector vector = (SmallIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); // All expected values are present for (short i = 129; i < 132; i++) { @@ -158,28 +175,29 @@ public void testSmallIntBatch() throws Exception { @Test public void testIntBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery("select 100000 union select 100001 union select 100002;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof IntVector); - IntVector vector = (IntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery("select 100000 union select 100001 union select 100002;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof IntVector); + IntVector vector = (IntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); // All expected values are present for (int i = 100000; i < 100003; i++) { @@ -191,30 +209,30 @@ public void testIntBatch() throws Exception { @Test public void testBigIntBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery( - "select 10000000000 union select 10000000001 union select 10000000002;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof BigIntVector); - BigIntVector vector = (BigIntVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.get(i)); + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select 10000000000 union select 10000000001 union select 10000000002;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof BigIntVector); + BigIntVector vector = (BigIntVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.get(i)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); // All expected values are present for (long i = 10000000000L; i < 10000000003L; i++) { @@ -226,28 +244,28 @@ public void testBigIntBatch() throws Exception { @Test public void testDecimalBatch() throws Exception { - Statement statement = connection.createStatement(); - ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;"); - ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); - int totalRows = 0; List values = new ArrayList<>(); - while (batches.hasNext()) { - ArrowBatch batch = batches.next(); - List roots = batch.fetch(); - for (VectorSchemaRoot root : roots) { - totalRows += root.getRowCount(); - assertTrue(root.getVector(0) instanceof DecimalVector); - DecimalVector vector = (DecimalVector) root.getVector(0); - for (int i = 0; i < root.getRowCount(); i++) { - values.add(vector.getObject(i)); + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery("select 1.1 union select 1.2 union select 1.3;")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof DecimalVector); + DecimalVector vector = (DecimalVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); + } + root.close(); } - root.close(); } + assertNoMemoryLeaks(rs); } - assertNoMemoryLeaks(rs); - rs.close(); // All expected values are present for (int i = 1; i < 4; i++) { @@ -264,9 +282,9 @@ public void testBitBatch() throws Exception { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery( - "select true union all select false union all select true union all select false" - + " union all select true union all select false union all select true")) { + statement.executeQuery( + "select true union all select false union all select true union all select false" + + " union all select true union all select false union all select true")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -297,9 +315,10 @@ public void testBinaryBatch() throws Exception { int totalRows = 0; List> values = new ArrayList<>(); - try(Statement statement = connection.createStatement(); + try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { + statement.executeQuery( + "select TO_BINARY('546AB0') union select TO_BINARY('018E3271')")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -312,13 +331,13 @@ public void testBinaryBatch() throws Exception { for (int i = 0; i < root.getRowCount(); i++) { byte[] bytes = vector.getObject(i); ArrayList byteArrayList = - new ArrayList() { - { - for (byte aByte : bytes) { - add(aByte); - } - } - }; + new ArrayList() { + { + for (byte aByte : bytes) { + add(aByte); + } + } + }; values.add(byteArrayList); } root.close(); @@ -361,7 +380,7 @@ public void testDateBatch() throws Exception, SFException { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { + statement.executeQuery("select '1119-02-01'::DATE union select '2021-09-11'::DATE")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -399,7 +418,7 @@ public void testTimeSecBatch() throws Exception, SFException { try (Statement statement = connection.createStatement(); ResultSet rs = - statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { + statement.executeQuery("select '11:32:54'::TIME(0) union select '8:11:25'::TIME(0)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -436,8 +455,9 @@ public void testTimeMilliBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.13'::TIME(2) union select '8:11:25.91'::TIME(2)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -474,8 +494,9 @@ public void testTimeMicroBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.139901'::TIME(6) union select '8:11:25.911765'::TIME(6)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -512,8 +533,9 @@ public void testTimeNanoBatch() throws Exception, SFException { List values = new ArrayList<>(); try (Statement statement = connection.createStatement(); - ResultSet rs = - statement.executeQuery("select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { + ResultSet rs = + statement.executeQuery( + "select '11:32:54.1399013'::TIME(7) union select '8:11:25.9117654'::TIME(7)")) { ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); while (batches.hasNext()) { @@ -543,4 +565,270 @@ public void testTimeNanoBatch() throws Exception, SFException { assertEquals(2, totalRows); assertTrue(values.containsAll(expected)); } + + @Test + public void testVarCharBatch() throws Exception { + int totalRows = 0; + List values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select 'Gallia est ' union select 'omnis divisa ' union select 'in partes tres';")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof VarCharVector); + VarCharVector vector = (VarCharVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List expected = + new ArrayList() { + { + add(new Text("Gallia est ")); + add(new Text("omnis divisa ")); + add(new Text("in partes tres")); + } + }; + + assertTrue(values.containsAll(expected)); + + assertEquals(3, totalRows); + } + + private class Pair { + private final A first; + private final B second; + + Pair(A first, B second) { + this.first = first; + this.second = second; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Pair)) { + return false; + } + Pair other = (Pair) obj; + return first.equals(other.first) && second.equals(other.second); + } + } + + @Test + public void testStructBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select {'a': 3.1, 'b': 3.2}::object(a decimal(18, 3), b decimal(18, 3))" + + " union select {'a': 2.2, 'b': 2.3}::object(a decimal(18, 3), b decimal(18, 3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof StructVector); + StructVector vector = (StructVector) root.getVector(0); + DecimalVector aVector = (DecimalVector) vector.getChild("a"); + DecimalVector bVector = (DecimalVector) vector.getChild("b"); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(new Pair<>(aVector.getObject(i), bVector.getObject(i))); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add(new Pair<>(new BigDecimal("3.100"), new BigDecimal("3.200"))); + add(new Pair<>(new BigDecimal("2.200"), new BigDecimal("2.300"))); + } + }; + + assertTrue(values.containsAll(expected)); + + assertEquals(2, totalRows); + } + + @Test + public void testListBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select array_construct(1.2, 2.3)::array(decimal(18, 3)) union all select array_construct(2.1, 1.0)::array(decimal(18, 3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof ListVector); + ListVector vector = (ListVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add((List) vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add(new BigDecimal("1.200")); + add(new BigDecimal("2.300")); + } + }); + add( + new ArrayList() { + { + add(new BigDecimal("2.100")); + add(new BigDecimal("1.000")); + } + }); + } + }; + + assertTrue(expected.containsAll(values)); + + assertEquals(2, totalRows); + } + + @Test + public void testMapBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select {'a': 3.1, 'b': 4.3}::map(varchar, decimal(18,3)) union" + + " select {'c': 2.2, 'd': 1.5}::map(varchar, decimal(18,3))")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof MapVector); + MapVector vector = (MapVector) root.getVector(0); + VarCharVector keyVector = + (VarCharVector) vector.getChildrenFromFields().get(0).getChildrenFromFields().get(0); + DecimalVector valueVector = + (DecimalVector) vector.getChildrenFromFields().get(0).getChildrenFromFields().get(1); + for (int i = 0; i < root.getRowCount(); i++) { + int startIndex = vector.getElementStartIndex(i); + int endIndex = vector.getElementEndIndex(i); + Map map = new HashMap<>(); + for (int j = startIndex; j < endIndex; j++) { + map.put(keyVector.getObject(j), valueVector.getObject(j)); + } + values.add(map); + } + root.close(); + } + } + } + + // All expected values are present + List> expected = + Stream.of( + new HashMap() { + { + put(new Text("a"), new BigDecimal("3.100")); + put(new Text("b"), new BigDecimal("4.300")); + } + }, + new HashMap() { + { + put(new Text("c"), new BigDecimal("2.200")); + put(new Text("d"), new BigDecimal("1.500")); + } + }) + .collect(Collectors.toList()); + + assertTrue(values.containsAll(expected)); + + assertEquals(2, totalRows); + } + + @Test + public void testFixedSizeListBatch() throws Exception { + int totalRows = 0; + List> values = new ArrayList<>(); + + try (Statement statement = connection.createStatement(); + ResultSet rs = + statement.executeQuery( + "select [1, 2]::vector(int, 2) union all select [3, 4]::vector(int, 2)")) { + ArrowBatches batches = rs.unwrap(SnowflakeResultSet.class).getArrowBatches(); + + while (batches.hasNext()) { + ArrowBatch batch = batches.next(); + List roots = batch.fetch(); + for (VectorSchemaRoot root : roots) { + totalRows += root.getRowCount(); + assertTrue(root.getVector(0) instanceof FixedSizeListVector); + FixedSizeListVector vector = (FixedSizeListVector) root.getVector(0); + for (int i = 0; i < root.getRowCount(); i++) { + values.add((List) vector.getObject(i)); + } + root.close(); + } + } + assertNoMemoryLeaks(rs); + } + + List> expected = + new ArrayList>() { + { + add( + new ArrayList() { + { + add(1); + add(2); + } + }); + add( + new ArrayList() { + { + add(3); + add(4); + } + }); + } + }; + + assertTrue(expected.containsAll(values)); + + assertEquals(2, totalRows); + } }