Skip to content

Commit

Permalink
Arrow batches all simple types (#1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astachowski authored Sep 20, 2024
1 parent 7d93e03 commit d11a12b
Show file tree
Hide file tree
Showing 15 changed files with 703 additions and 54 deletions.
7 changes: 5 additions & 2 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,13 @@ public boolean hasNext() {
public ArrowBatch next() throws SQLException {
if (!firstFetched) {
firstFetched = true;
return currentChunkIterator.getChunk().getArrowBatch(SFArrowResultSet.this);
return currentChunkIterator
.getChunk()
.getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null);
} else {
nextChunkIndex++;
return fetchNextChunk().getArrowBatch(SFArrowResultSet.this);
return fetchNextChunk()
.getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import static net.snowflake.client.core.arrow.ArrowVectorConverterUtil.getScale;

import java.util.Map;
import java.util.TimeZone;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
Expand All @@ -19,8 +21,7 @@
public class ArrowFullVectorConverterUtil {
private ArrowFullVectorConverterUtil() {}

public static Types.MinorType deduceType(ValueVector vector, SFBaseSession session)
throws SnowflakeSQLLoggedException {
static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) throws SnowflakeSQLLoggedException {
Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());
// each column's metadata
Map<String, String> customMeta = vector.getField().getMetadata();
Expand All @@ -39,52 +40,27 @@ public static Types.MinorType deduceType(ValueVector vector, SFBaseSession sessi
break;
}
case TIME:
return Types.MinorType.TIMEMILLI;
case TIMESTAMP_LTZ:
{
int sfScale = getScale(vector, session);
switch (sfScale) {
case 0:
return Types.MinorType.TIMESTAMPSECTZ;
case 3:
return Types.MinorType.TIMESTAMPMILLITZ;
case 6:
return Types.MinorType.TIMESTAMPMICROTZ;
case 9:
return Types.MinorType.TIMESTAMPNANOTZ;
}
break;
{
String scaleStr = vector.getField().getMetadata().get("scale");
int sfScale = Integer.parseInt(scaleStr);
if (sfScale == 0) {
return Types.MinorType.TIMESEC;
}
case TIMESTAMP_TZ:
{
int sfScale = getScale(vector, session);
switch (sfScale) {
case 0:
return Types.MinorType.TIMESTAMPSECTZ;
case 3:
return Types.MinorType.TIMESTAMPMILLITZ;
case 6:
return Types.MinorType.TIMESTAMPMICROTZ;
case 9:
return Types.MinorType.TIMESTAMPNANOTZ;
}
break;
if (sfScale <= 3) {
return Types.MinorType.TIMEMILLI;
}
case TIMESTAMP_NTZ:
{
int sfScale = getScale(vector, session);
switch (sfScale) {
case 0:
return Types.MinorType.TIMESTAMPSEC;
case 3:
return Types.MinorType.TIMESTAMPMILLI;
case 6:
return Types.MinorType.TIMESTAMPMICRO;
case 9:
return Types.MinorType.TIMESTAMPNANO;
}
break;
if (sfScale <= 6) {
return Types.MinorType.TIMEMICRO;
}
if (sfScale <= 9) {
return Types.MinorType.TIMENANO;
}
}
case TIMESTAMP_NTZ:
return Types.MinorType.TIMESTAMPNANO;
case TIMESTAMP_LTZ:
case TIMESTAMP_TZ:
return Types.MinorType.TIMESTAMPNANOTZ;
}
}
return type;
Expand All @@ -95,6 +71,7 @@ public static FieldVector convert(
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
TimeZone timeZoneToUse,
int idx,
Object targetType)
throws SnowflakeSQLException {
Expand All @@ -114,6 +91,23 @@ public static FieldVector convert(
return new BigIntVectorConverter(allocator, vector, context, session, idx).convert();
case DECIMAL:
return new DecimalVectorConverter(allocator, vector, context, session, idx).convert();
case FLOAT8:
return new FloatVectorConverter(allocator, vector, context, session, idx).convert();
case BIT:
return new BitVectorConverter(allocator, vector, context, session, idx).convert();
case VARBINARY:
return new BinaryVectorConverter(allocator, vector, context, session, idx).convert();
case DATEDAY:
return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse)
.convert();
case TIMESEC:
return new TimeSecVectorConverter(allocator, vector).convert();
case TIMEMILLI:
return new TimeMilliVectorConverter(allocator, vector).convert();
case TIMEMICRO:
return new TimeMicroVectorConverter(allocator, vector).convert();
case TIMENANO:
return new TimeNanoVectorConverter(allocator, vector).convert();
default:
throw new SnowflakeSQLLoggedException(
session,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarBinaryVector;

@SnowflakeJdbcInternalApi
public class BinaryVectorConverter extends SimpleArrowFullVectorConverter<VarBinaryVector> {
public BinaryVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
int idx) {
super(allocator, vector, context, session, idx);
}

@Override
protected boolean matchingType() {
return vector instanceof VarBinaryVector;
}

@Override
protected VarBinaryVector initVector() {
VarBinaryVector resultVector = new VarBinaryVector(vector.getName(), allocator);
resultVector.allocateNew(vector.getValueCount());
return resultVector;
}

@Override
protected void convertValue(ArrowVectorConverter from, VarBinaryVector to, int idx)
throws SFException {
to.set(idx, from.toBytes(idx));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.ValueVector;

@SnowflakeJdbcInternalApi
public class BitVectorConverter extends SimpleArrowFullVectorConverter<BitVector> {

public BitVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
int idx) {
super(allocator, vector, context, session, idx);
}

@Override
protected boolean matchingType() {
return vector instanceof BitVector;
}

@Override
protected BitVector initVector() {
BitVector resultVector = new BitVector(vector.getName(), allocator);
resultVector.allocateNew(vector.getValueCount());
return resultVector;
}

@Override
protected void convertValue(ArrowVectorConverter from, BitVector to, int idx) throws SFException {
to.set(idx, from.toBoolean(idx) ? 1 : 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.ValueVector;

@SnowflakeJdbcInternalApi
public class DateVectorConverter extends SimpleArrowFullVectorConverter<DateDayVector> {
private TimeZone timeZone;

public DateVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
int idx,
TimeZone timeZone) {
super(allocator, vector, context, session, idx);
this.timeZone = timeZone;
}

@Override
protected boolean matchingType() {
return vector instanceof DateDayVector;
}

@Override
protected DateDayVector initVector() {
DateDayVector resultVector = new DateDayVector(vector.getName(), allocator);
resultVector.allocateNew(vector.getValueCount());
return resultVector;
}

@Override
protected void additionalConverterInit(ArrowVectorConverter converter) {
if (timeZone != null) {
converter.setSessionTimeZone(timeZone);
converter.setUseSessionTimezone(true);
}
}

@Override
protected void convertValue(ArrowVectorConverter from, DateDayVector to, int idx)
throws SFException {
to.set(idx, (int) (from.toDate(idx, null, false).getTime() / (1000 * 3600 * 24)));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.ValueVector;

@SnowflakeJdbcInternalApi
public class FloatVectorConverter extends SimpleArrowFullVectorConverter<Float8Vector> {

public FloatVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
int idx) {
super(allocator, vector, context, session, idx);
}

@Override
protected boolean matchingType() {
return vector instanceof Float8Vector;
}

@Override
protected Float8Vector initVector() {
Float8Vector resultVector = new Float8Vector(vector.getName(), allocator);
resultVector.allocateNew(vector.getValueCount());
return resultVector;
}

@Override
protected void convertValue(ArrowVectorConverter from, Float8Vector to, int idx)
throws SFException {
to.set(idx, from.toDouble(idx));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public SimpleArrowFullVectorConverter(

protected abstract void convertValue(ArrowVectorConverter from, T to, int idx) throws SFException;

protected void additionalConverterInit(ArrowVectorConverter converter) {}

public FieldVector convert() throws SFException, SnowflakeSQLException {
if (matchingType()) {
return (FieldVector) vector;
Expand All @@ -48,10 +50,14 @@ public FieldVector convert() throws SFException, SnowflakeSQLException {
T converted = initVector();
ArrowVectorConverter converter =
ArrowVectorConverterUtil.initConverter(vector, context, session, idx);
additionalConverterInit(converter);
for (int i = 0; i < size; i++) {
convertValue(converter, converted, i);
if (!vector.isNull(i)) {
convertValue(converter, converted, i);
}
}
converted.setValueCount(size);
vector.close();
return converted;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.ValueVector;

@SnowflakeJdbcInternalApi
public class TimeMicroVectorConverter extends TimeVectorConverter<TimeMicroVector> {

public TimeMicroVectorConverter(RootAllocator allocator, ValueVector vector) {
super(allocator, vector);
}

@Override
protected TimeMicroVector initVector() {
return new TimeMicroVector(vector.getName(), allocator);
}

@Override
protected void convertValue(TimeMicroVector dstVector, int idx, long value) {
dstVector.set(idx, value);
}

@Override
protected int targetScale() {
return 6;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.ValueVector;

@SnowflakeJdbcInternalApi
public class TimeMilliVectorConverter extends TimeVectorConverter<TimeMilliVector> {
public TimeMilliVectorConverter(RootAllocator allocator, ValueVector vector) {
super(allocator, vector);
}

@Override
protected TimeMilliVector initVector() {
return new TimeMilliVector(vector.getName(), allocator);
}

@Override
protected void convertValue(TimeMilliVector dstVector, int idx, long value) {
dstVector.set(idx, (int) value);
}

@Override
protected int targetScale() {
return 3;
}
}
Loading

0 comments on commit d11a12b

Please sign in to comment.