Skip to content

Commit

Permalink
Arrow batches structured types (#1879)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astachowski authored Oct 16, 2024
1 parent d11a12b commit 8c209d2
Show file tree
Hide file tree
Showing 17 changed files with 921 additions and 182 deletions.
6 changes: 4 additions & 2 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -816,11 +816,13 @@ public ArrowBatch next() throws SQLException {
firstFetched = true;
return currentChunkIterator
.getChunk()
.getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null);
.getArrowBatch(
SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex);
} else {
nextChunkIndex++;
return fetchNextChunk()
.getArrowBatch(SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null);
.getArrowBatch(
SFArrowResultSet.this, useSessionTimezone ? sessionTimeZone : null, nextChunkIndex);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import net.snowflake.client.core.SFException;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import org.apache.arrow.vector.FieldVector;

public abstract class AbstractFullVectorConverter implements ArrowFullVectorConverter {
private boolean converted;

protected abstract FieldVector convertVector()
throws SFException, SnowflakeSQLException, SFArrowException;

@Override
public FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException {
if (converted) {
throw new SFArrowException(
ArrowErrorCode.VECTOR_ALREADY_CONVERTED, "Convert has already been called");
} else {
converted = true;
return convertVector();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

public enum ArrowErrorCode {
VECTOR_ALREADY_CONVERTED,
CONVERT_FAILED,
CHUNK_FETCH_FAILED,
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

@SnowflakeJdbcInternalApi
public interface ArrowFullVectorConverter {
FieldVector convert() throws SFException, SnowflakeSQLException;
FieldVector convert() throws SFException, SnowflakeSQLException, SFArrowException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@

import java.util.Map;
import java.util.TimeZone;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
import net.snowflake.client.jdbc.SnowflakeType;
import net.snowflake.common.core.SqlState;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
Expand All @@ -21,7 +18,8 @@
public class ArrowFullVectorConverterUtil {
private ArrowFullVectorConverterUtil() {}

static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) throws SnowflakeSQLLoggedException {
static Types.MinorType deduceType(ValueVector vector, SFBaseSession session)
throws SnowflakeSQLLoggedException {
Types.MinorType type = Types.getMinorTypeForArrowType(vector.getField().getType());
// each column's metadata
Map<String, String> customMeta = vector.getField().getMetadata();
Expand All @@ -39,23 +37,25 @@ static Types.MinorType deduceType(ValueVector vector, SFBaseSession session) thr
}
break;
}
case VECTOR:
return Types.MinorType.FIXED_SIZE_LIST;
case TIME:
{
String scaleStr = vector.getField().getMetadata().get("scale");
int sfScale = Integer.parseInt(scaleStr);
if (sfScale == 0) {
return Types.MinorType.TIMESEC;
}
if (sfScale <= 3) {
return Types.MinorType.TIMEMILLI;
}
if (sfScale <= 6) {
return Types.MinorType.TIMEMICRO;
}
if (sfScale <= 9) {
return Types.MinorType.TIMENANO;
{
String scaleStr = vector.getField().getMetadata().get("scale");
int sfScale = Integer.parseInt(scaleStr);
if (sfScale == 0) {
return Types.MinorType.TIMESEC;
}
if (sfScale <= 3) {
return Types.MinorType.TIMEMILLI;
}
if (sfScale <= 6) {
return Types.MinorType.TIMEMICRO;
}
if (sfScale <= 9) {
return Types.MinorType.TIMENANO;
}
}
}
case TIMESTAMP_NTZ:
return Types.MinorType.TIMESTAMPNANO;
case TIMESTAMP_LTZ:
Expand All @@ -74,7 +74,7 @@ public static FieldVector convert(
TimeZone timeZoneToUse,
int idx,
Object targetType)
throws SnowflakeSQLException {
throws SFArrowException {
try {
if (targetType == null) {
targetType = deduceType(vector, session);
Expand All @@ -99,7 +99,7 @@ public static FieldVector convert(
return new BinaryVectorConverter(allocator, vector, context, session, idx).convert();
case DATEDAY:
return new DateVectorConverter(allocator, vector, context, session, idx, timeZoneToUse)
.convert();
.convert();
case TIMESEC:
return new TimeSecVectorConverter(allocator, vector).convert();
case TIMEMILLI:
Expand All @@ -108,18 +108,35 @@ public static FieldVector convert(
return new TimeMicroVectorConverter(allocator, vector).convert();
case TIMENANO:
return new TimeNanoVectorConverter(allocator, vector).convert();
case STRUCT:
return new StructVectorConverter(
allocator, vector, context, session, timeZoneToUse, idx, null)
.convert();
case LIST:
return new ListVectorConverter(
allocator, vector, context, session, timeZoneToUse, idx, null)
.convert();
case VARCHAR:
return new VarCharVectorConverter(allocator, vector, context, session, idx).convert();
case MAP:
return new MapVectorConverter(
allocator, vector, context, session, timeZoneToUse, idx, null)
.convert();
case FIXED_SIZE_LIST:
return new FixedSizeListVectorConverter(
allocator, vector, context, session, timeZoneToUse, idx, null)
.convert();
default:
throw new SnowflakeSQLLoggedException(
session,
ErrorCode.INTERNAL_ERROR.getMessageCode(),
SqlState.INTERNAL_ERROR,
"Unsupported target type");
throw new SFArrowException(
ArrowErrorCode.CONVERT_FAILED,
"Unexpected arrow type " + targetType + " at index " + idx);
}
}
} catch (SFException ex) {
throw new SnowflakeSQLException(
ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams());
} catch (SnowflakeSQLException | SFException | SFArrowException e) {
throw new SFArrowException(
ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed", e);
}
return null;
throw new SFArrowException(
ArrowErrorCode.CONVERT_FAILED, "Converting vector at index " + idx + " failed");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import java.util.ArrayList;
import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.types.pojo.Field;

@SnowflakeJdbcInternalApi
public class FixedSizeListVectorConverter extends AbstractFullVectorConverter {
protected RootAllocator allocator;
protected ValueVector vector;
protected DataConversionContext context;
protected SFBaseSession session;
protected int idx;
protected Object valueTargetType;
private TimeZone timeZoneToUse;

FixedSizeListVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
TimeZone timeZoneToUse,
int idx,
Object valueTargetType) {
this.allocator = allocator;
this.vector = vector;
this.context = context;
this.session = session;
this.timeZoneToUse = timeZoneToUse;
this.idx = idx;
this.valueTargetType = valueTargetType;
}

@Override
protected FieldVector convertVector()
throws SFException, SnowflakeSQLException, SFArrowException {
try {
FixedSizeListVector listVector = (FixedSizeListVector) vector;
FieldVector dataVector = listVector.getDataVector();
FieldVector convertedDataVector =
ArrowFullVectorConverterUtil.convert(
allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType);
FixedSizeListVector convertedListVector =
FixedSizeListVector.empty(listVector.getName(), listVector.getListSize(), allocator);
ArrayList<Field> fields = new ArrayList<>();
fields.add(convertedDataVector.getField());
convertedListVector.initializeChildrenFromFields(fields);
convertedListVector.allocateNew();
convertedListVector.setValueCount(listVector.getValueCount());
ArrowBuf validityBuffer = listVector.getValidityBuffer();
convertedListVector
.getValidityBuffer()
.setBytes(0L, validityBuffer, 0L, validityBuffer.capacity());
convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer();
return convertedListVector;
} finally {
vector.close();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import java.util.ArrayList;
import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.types.pojo.Field;

@SnowflakeJdbcInternalApi
public class ListVectorConverter extends AbstractFullVectorConverter {
protected RootAllocator allocator;
protected ValueVector vector;
protected DataConversionContext context;
protected SFBaseSession session;
protected int idx;
protected Object valueTargetType;
private TimeZone timeZoneToUse;

ListVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
TimeZone timeZoneToUse,
int idx,
Object valueTargetType) {
this.allocator = allocator;
this.vector = vector;
this.context = context;
this.session = session;
this.timeZoneToUse = timeZoneToUse;
this.idx = idx;
this.valueTargetType = valueTargetType;
}

protected ListVector initVector(String name, Field field) {
ListVector convertedListVector = ListVector.empty(name, allocator);
ArrayList<Field> fields = new ArrayList<>();
fields.add(field);
convertedListVector.initializeChildrenFromFields(fields);
return convertedListVector;
}

@Override
protected FieldVector convertVector()
throws SFException, SnowflakeSQLException, SFArrowException {
try {
ListVector listVector = (ListVector) vector;
FieldVector dataVector = listVector.getDataVector();
FieldVector convertedDataVector =
ArrowFullVectorConverterUtil.convert(
allocator, dataVector, context, session, timeZoneToUse, 0, valueTargetType);
// TODO: change to convertedDataVector and make all necessary changes to make it work
ListVector convertedListVector = initVector(vector.getName(), dataVector.getField());
convertedListVector.allocateNew();
convertedListVector.setValueCount(listVector.getValueCount());
convertedListVector.getOffsetBuffer().setBytes(0, listVector.getOffsetBuffer());
ArrowBuf validityBuffer = listVector.getValidityBuffer();
convertedListVector
.getValidityBuffer()
.setBytes(0L, validityBuffer, 0L, validityBuffer.capacity());
convertedListVector.setLastSet(listVector.getLastSet());
convertedDataVector.makeTransferPair(convertedListVector.getDataVector()).transfer();
return convertedListVector;
} finally {
vector.close();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

import java.util.ArrayList;
import java.util.TimeZone;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.pojo.Field;

@SnowflakeJdbcInternalApi
public class MapVectorConverter extends ListVectorConverter {

MapVectorConverter(
RootAllocator allocator,
ValueVector vector,
DataConversionContext context,
SFBaseSession session,
TimeZone timeZoneToUse,
int idx,
Object valueTargetType) {
super(allocator, vector, context, session, timeZoneToUse, idx, valueTargetType);
}

@Override
protected ListVector initVector(String name, Field field) {
MapVector convertedMapVector = MapVector.empty(name, allocator, false);
ArrayList<Field> fields = new ArrayList<>();
fields.add(field);
convertedMapVector.initializeChildrenFromFields(fields);
return convertedMapVector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package net.snowflake.client.core.arrow.fullvectorconverters;

public class SFArrowException extends Exception {
private final ArrowErrorCode errorCode;

public SFArrowException(ArrowErrorCode errorCode, String message) {
this(errorCode, message, null);
}

public SFArrowException(ArrowErrorCode errorCode, String message, Throwable cause) {
super(message, cause);
this.errorCode = errorCode;
}

public ArrowErrorCode getErrorCode() {
return errorCode;
}

@Override
public String toString() {
return super.toString() + (getErrorCode() != null ? ", errorCode = " + getErrorCode() : "");
}
}
Loading

0 comments on commit 8c209d2

Please sign in to comment.