Skip to content

Commit

Permalink
Add arrow with json structured types support
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Apr 2, 2024
1 parent ec1325f commit af14680
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 213 deletions.
35 changes: 28 additions & 7 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.TimeZone;
import java.util.stream.Stream;
import net.snowflake.client.core.arrow.ArrayConverter;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.StructConverter;
import net.snowflake.client.core.arrow.VarCharConverter;
Expand Down Expand Up @@ -371,30 +372,44 @@ public Converters getConverters() {

@Override
@SnowflakeJdbcInternalApi
public SQLInput createSqlInputForColumn(Object input, int columnIndex, SFBaseSession session) {
return new ArrowSqlInput(
(Map<String, Object>) input,
session,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields());
public SQLInput createSqlInputForColumn(
Object input, Class<?> parentObjectClass, int columnIndex, SFBaseSession session) {
if (parentObjectClass.equals(JsonSqlInput.class)) {
return createJsonSqlInputForColumn(input, columnIndex, session);
} else {
return new ArrowSqlInput(
(Map<String, Object>) input,
session,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields());
}
}

@Override
@SnowflakeJdbcInternalApi
public Date convertToDate(Object object, TimeZone tz) throws SFException {
if (object instanceof String) {
return convertStringToDate(object, tz);
}
return converters.getStructuredTypeDateTimeConverter().getDate((int) object, tz);
}

@Override
@SnowflakeJdbcInternalApi
public Time convertToTime(Object object, int scale) throws SFException {
if (object instanceof String) {
return convertStringToTime(object, scale);
}
return converters.getStructuredTypeDateTimeConverter().getTime((long) object, scale);
}

@Override
@SnowflakeJdbcInternalApi
public Timestamp convertToTimestamp(
Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
if (object instanceof String) {
return convertStringToTimestamp(object, columnType, columnSubType, tz, scale);
}
return converters
.getStructuredTypeDateTimeConverter()
.getTimestamp(
Expand Down Expand Up @@ -589,7 +604,13 @@ public Array getArray(int columnIndex) throws SFException {
int index = currentChunkIterator.getCurrentRowInRecordBatch();
wasNull = converter.isNull(index);
Object obj = converter.toObject(index);
return getArrayInternal((List<Object>) obj, columnIndex);
if (converter instanceof VarCharConverter) {
return getJsonArrayInternal((String) obj, columnIndex);
} else if (converter instanceof ArrayConverter) {
return getArrayInternal((List<Object>) obj, columnIndex);
} else {
throw new SFException(ErrorCode.INTERNAL_ERROR);
}
}

private SfSqlArray getArrayInternal(List<Object> elements, int columnIndex) throws SFException {
Expand Down
202 changes: 201 additions & 1 deletion src/main/java/net/snowflake/client/core/SFBaseResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,48 @@

package net.snowflake.client.core;

import static net.snowflake.client.jdbc.SnowflakeUtil.getJsonNodeStringValue;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import java.math.BigDecimal;
import java.sql.Array;
import java.sql.Date;
import java.sql.SQLException;
import java.sql.SQLInput;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.TimeZone;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import net.snowflake.client.core.json.Converters;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.FieldMetadata;
import net.snowflake.client.jdbc.SnowflakeColumnMetadata;
import net.snowflake.client.jdbc.SnowflakeResultSetSerializable;
import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import net.snowflake.client.util.Converter;
import net.snowflake.common.core.SFBinaryFormat;
import net.snowflake.common.core.SnowflakeDateTimeFormat;

/** Base class for query result set and metadata result set */
public abstract class SFBaseResultSet {
private static final SFLogger logger = SFLoggerFactory.getLogger(SFBaseResultSet.class);
protected static final ObjectMapper OBJECT_MAPPER = ObjectMapperFactory.getObjectMapper();

boolean wasNull = false;

Expand Down Expand Up @@ -219,7 +236,7 @@ public TimeZone getSessionTimeZone() {

@SnowflakeJdbcInternalApi
public abstract SQLInput createSqlInputForColumn(
Object input, int columnIndex, SFBaseSession session);
Object input, Class<?> parentObjectClass, int columnIndex, SFBaseSession session);

@SnowflakeJdbcInternalApi
public abstract Date convertToDate(Object object, TimeZone tz) throws SFException;
Expand All @@ -230,4 +247,187 @@ public abstract SQLInput createSqlInputForColumn(
@SnowflakeJdbcInternalApi
public abstract Timestamp convertToTimestamp(
Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException;

@SnowflakeJdbcInternalApi
protected SQLInput createJsonSqlInputForColumn(
Object input, int columnIndex, SFBaseSession session) {
JsonNode inputNode;
if (input instanceof JsonNode) {
inputNode = (JsonNode) input;
} else {
inputNode = OBJECT_MAPPER.convertValue(input, JsonNode.class);
}
return new JsonSqlInput(
inputNode,
session,
getConverters(),
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields(),
sessionTimezone);
}

@SnowflakeJdbcInternalApi
protected SfSqlArray getJsonArrayInternal(String obj, int columnIndex) throws SFException {
try {
SnowflakeColumnMetadata arrayMetadata =
resultSetMetaData.getColumnMetadata().get(columnIndex - 1);
FieldMetadata fieldMetadata = arrayMetadata.getFields().get(0);

int columnSubType = fieldMetadata.getType();
int columnType = ColumnTypeHelper.getColumnType(columnSubType, session);
int scale = fieldMetadata.getScale();

ArrayNode arrayNode = (ArrayNode) OBJECT_MAPPER.readTree(obj);
Iterator<JsonNode> nodeElements = arrayNode.elements();

switch (columnSubType) {
case Types.INTEGER:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().integerConverter(columnType))
.toArray(Integer[]::new));
case Types.SMALLINT:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().smallIntConverter(columnType))
.toArray(Short[]::new));
case Types.TINYINT:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().tinyIntConverter(columnType))
.toArray(Byte[]::new));
case Types.BIGINT:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().bigIntConverter(columnType))
.toArray(Long[]::new));
case Types.DECIMAL:
case Types.NUMERIC:
return new SfSqlArray(
columnSubType,
convertToFixedArray(
getStream(nodeElements, getConverters().bigDecimalConverter(columnType))));
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGNVARCHAR:
return new SfSqlArray(
columnSubType,
getStream(
nodeElements,
getConverters().varcharConverter(columnType, columnSubType, scale))
.toArray(String[]::new));
case Types.BINARY:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().bytesConverter(columnType, scale))
.toArray(Byte[][]::new));
case Types.FLOAT:
case Types.REAL:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().floatConverter(columnType))
.toArray(Float[]::new));
case Types.DOUBLE:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().doubleConverter(columnType))
.toArray(Double[]::new));
case Types.DATE:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().dateStringConverter(session))
.toArray(Date[]::new));
case Types.TIME:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().timeFromStringConverter(session))
.toArray(Time[]::new));
case Types.TIMESTAMP:
return new SfSqlArray(
columnSubType,
getStream(
nodeElements,
getConverters()
.timestampFromStringConverter(
columnSubType, columnType, scale, session, null, sessionTimezone))
.toArray(Timestamp[]::new));
case Types.BOOLEAN:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().booleanConverter(columnType))
.toArray(Boolean[]::new));
case Types.STRUCT:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().structConverter(OBJECT_MAPPER))
.toArray(Map[]::new));
case Types.ARRAY:
return new SfSqlArray(
columnSubType,
getStream(nodeElements, getConverters().arrayConverter(OBJECT_MAPPER))
.toArray(Map[][]::new));
default:
throw new SFException(
ErrorCode.FEATURE_UNSUPPORTED,
"Can't construct array for data type: " + columnSubType);
}
} catch (JsonProcessingException e) {
throw new SFException(e, ErrorCode.INVALID_STRUCT_DATA);
}
}

@SnowflakeJdbcInternalApi
protected Date convertStringToDate(Object object, TimeZone tz) throws SFException {
return (Date) getConverters().dateStringConverter(session).convert(object);
}

@SnowflakeJdbcInternalApi
protected Time convertStringToTime(Object object, int scale) throws SFException {
return (Time) getConverters().timeFromStringConverter(session).convert(object);
}

@SnowflakeJdbcInternalApi
protected Timestamp convertStringToTimestamp(
Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
return (Timestamp)
getConverters()
.timestampFromStringConverter(columnSubType, columnType, scale, session, null, tz)
.convert(object);
}

private Stream getStream(Iterator nodeElements, Converter converter) {
return StreamSupport.stream(
Spliterators.spliteratorUnknownSize(nodeElements, Spliterator.ORDERED), false)
.map(
elem -> {
try {
return convert(converter, (JsonNode) elem);
} catch (SFException e) {
throw new RuntimeException(e);
}
});
}

private static Object convert(Converter converter, JsonNode node) throws SFException {
String nodeValue = getJsonNodeStringValue(node);
return converter.convert(nodeValue);
}

private Object[] convertToFixedArray(Stream inputStream) {
AtomicInteger bigDecimalCount = new AtomicInteger();
Object[] elements =
inputStream
.peek(
elem -> {
if (elem instanceof BigDecimal) {
bigDecimalCount.incrementAndGet();
}
})
.toArray(
size -> {
boolean shouldReturnAsBigDecimal = bigDecimalCount.get() > 0;
Class<?> returnedClass = shouldReturnAsBigDecimal ? BigDecimal.class : Long.class;
return java.lang.reflect.Array.newInstance(returnedClass, size);
});
return elements;
}
}
Loading

0 comments on commit af14680

Please sign in to comment.