Skip to content

Commit

Permalink
SNOW-1234214 Add support for maps in native arrow structured types
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Mar 27, 2024
1 parent baf98e9 commit 73c0f05
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 81 deletions.
12 changes: 11 additions & 1 deletion src/main/java/net/snowflake/client/core/ArrowSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.sql.Timestamp;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import net.snowflake.client.core.json.Converters;
import net.snowflake.client.core.structs.SQLDataCreationHelper;
Expand All @@ -24,6 +25,7 @@
@SnowflakeJdbcInternalApi
public class ArrowSqlInput extends BaseSqlInput {

private final JsonStringHashMap<String, Object> input;
private final Iterator<Object> structuredTypeFields;
private int currentIndex = 0;

Expand All @@ -34,6 +36,11 @@ public ArrowSqlInput(
List<FieldMetadata> fields) {
super(session, converters, fields);
this.structuredTypeFields = input.values().iterator();
this.input = input;
}

public Map<String, Object> getInput() {
return input;
}

@Override
Expand Down Expand Up @@ -167,14 +174,17 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException {
if (value == null) {
return null;
}
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
int columnSubType = fieldMetadata.getType();
int scale = fieldMetadata.getScale();
return mapSFExceptionToSQLException(
() ->
converters
.getStructuredTypeDateTimeConverter()
.getTimestamp(
(JsonStringHashMap<String, Object>) value,
fieldMetadata.getBase(),
columnType,
columnSubType,
tz,
scale));
});
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public class SFArrowResultSet extends SFBaseResultSet implements DataConversionC
*/
private boolean formatDateWithTimezone;

@SnowflakeJdbcInternalApi protected Converters jsonConverters;
@SnowflakeJdbcInternalApi protected Converters converters;

/**
* Constructor takes a result from the API response that we get from executing a SQL statement.
Expand All @@ -124,7 +124,7 @@ public SFArrowResultSet(
boolean sortResult)
throws SQLException {
this(resultSetSerializable, session.getTelemetryClient(), sortResult);
this.jsonConverters =
this.converters =
new Converters(
resultSetSerializable.getTimeZone(),
session,
Expand Down Expand Up @@ -356,6 +356,12 @@ private boolean fetchNextRowSorted() throws SnowflakeSQLException {
}
}

@Override
@SnowflakeJdbcInternalApi
public Converters getConverters() {
return converters;
}

/**
* Advance to next row
*
Expand Down Expand Up @@ -522,7 +528,7 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio
return new JsonSqlInput(
jsonNode,
session,
jsonConverters,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields(),
sessionTimezone);
} catch (JsonProcessingException e) {
Expand All @@ -534,7 +540,7 @@ private Object createArrowSqlInput(int columnIndex, JsonStringHashMap<String, Ob
return new ArrowSqlInput(
input,
session,
jsonConverters,
converters,
resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields());
}

Expand Down
31 changes: 31 additions & 0 deletions src/main/java/net/snowflake/client/core/arrow/MapConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package net.snowflake.client.core.arrow;

import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFException;
import net.snowflake.client.jdbc.SnowflakeType;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.util.JsonStringHashMap;

import java.util.List;
import java.util.stream.Collectors;

public class MapConverter extends AbstractArrowVectorConverter{

private final MapVector vector;

public MapConverter(MapVector valueVector, int columnIndex, DataConversionContext context) {
super(SnowflakeType.MAP.name(), valueVector, columnIndex, context);
this.vector = valueVector;
}

@Override
public Object toObject(int index) throws SFException {
List<JsonStringHashMap<String, Object>> entriesList = (List<JsonStringHashMap<String, Object>>) vector.getObject(index);
return entriesList.stream().collect(Collectors.toMap(entry-> entry.get("key").toString(), entry -> entry.get("value")));
}

@Override
public String toString(int index) throws SFException {
return vector.getObject(index).toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.util.TimeZone;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeType;
import net.snowflake.client.jdbc.SnowflakeUtil;
import org.apache.arrow.vector.util.JsonStringHashMap;

@SnowflakeJdbcInternalApi
Expand Down Expand Up @@ -45,22 +47,23 @@ public StructuredTypeDateTimeConverter(
}

public Timestamp getTimestamp(
JsonStringHashMap<String, Object> obj, SnowflakeType type, TimeZone tz, int scale)
JsonStringHashMap<String, Object> obj, int columnType, int columnSubType, TimeZone tz, int scale)
throws SFException {
if (tz == null) {
tz = TimeZone.getDefault();
}
switch (type) {
case TIMESTAMP_LTZ:
if ( Types.TIMESTAMP == columnType) {
if (SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_LTZ == columnSubType) {
return convertTimestampLtz(obj, scale);
case TIMESTAMP_NTZ:
} else {
return convertTimestampNtz(obj, tz, scale);
case TIMESTAMP_TZ:
return convertTimestampTz(obj, scale);
}
} else if (Types.TIMESTAMP_WITH_TIMEZONE == columnType && SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_TZ == columnSubType) {
return convertTimestampTz(obj, scale);
}
throw new SFException(
ErrorCode.INVALID_VALUE_CONVERT,
"Unexpected Arrow Field for " + type.name() + " and object type " + obj.getClass());
"Unexpected Arrow Field for columnType " + columnType + " , column subtype " + columnSubType + " , and object type " + obj.getClass());
}

public Date getDate(int value, TimeZone tz) throws SFException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
@SnowflakeJdbcInternalApi
public class StructureTypeHelper {
private static final String STRUCTURED_TYPE_ENABLED_PROPERTY_NAME = "STRUCTURED_TYPE_ENABLED";
private static boolean structuredTypeEnabled =
Boolean.valueOf(System.getProperty(STRUCTURED_TYPE_ENABLED_PROPERTY_NAME));
private static boolean structuredTypeEnabled = true;

public static boolean isStructureTypeEnabled() {
return structuredTypeEnabled;
Expand Down
31 changes: 6 additions & 25 deletions src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,7 @@
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.BigIntToFixedConverter;
import net.snowflake.client.core.arrow.BigIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.BigIntToTimeConverter;
import net.snowflake.client.core.arrow.BigIntToTimestampLTZConverter;
import net.snowflake.client.core.arrow.BigIntToTimestampNTZConverter;
import net.snowflake.client.core.arrow.BitToBooleanConverter;
import net.snowflake.client.core.arrow.DateConverter;
import net.snowflake.client.core.arrow.DecimalToScaledFixedConverter;
import net.snowflake.client.core.arrow.DoubleToRealConverter;
import net.snowflake.client.core.arrow.IntToFixedConverter;
import net.snowflake.client.core.arrow.IntToScaledFixedConverter;
import net.snowflake.client.core.arrow.IntToTimeConverter;
import net.snowflake.client.core.arrow.SmallIntToFixedConverter;
import net.snowflake.client.core.arrow.SmallIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.StructConverter;
import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter;
import net.snowflake.client.core.arrow.TinyIntToFixedConverter;
import net.snowflake.client.core.arrow.TinyIntToScaledFixedConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampLTZConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampNTZConverter;
import net.snowflake.client.core.arrow.TwoFieldStructToTimestampTZConverter;
import net.snowflake.client.core.arrow.VarBinaryToBinaryConverter;
import net.snowflake.client.core.arrow.VarCharConverter;
import net.snowflake.client.core.arrow.*;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import net.snowflake.common.core.SqlState;
Expand All @@ -55,6 +31,7 @@
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
Expand Down Expand Up @@ -206,6 +183,10 @@ private static List<ArrowVectorConverter> initConverters(
converters.add(new VarCharConverter(vector, i, context));
break;

case MAP:
converters.add(new MapConverter((MapVector) vector, i, context));
break;

case OBJECT:
if (vector instanceof StructVector) {
converters.add(new StructConverter((StructVector) vector, i, context));
Expand Down
113 changes: 78 additions & 35 deletions src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import net.snowflake.client.core.ColumnTypeHelper;
import net.snowflake.client.core.JsonSqlInput;
import net.snowflake.client.core.ObjectMapperFactory;
import net.snowflake.client.core.SFBaseResultSet;
import net.snowflake.client.core.SFBaseSession;

import net.snowflake.client.core.*;
import net.snowflake.client.core.structs.SQLDataCreationHelper;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import net.snowflake.common.core.SqlState;
import org.apache.arrow.vector.util.JsonStringHashMap;

/** Base class for query result set and metadata result set */
public abstract class SnowflakeBaseResultSet implements ResultSet {
Expand Down Expand Up @@ -1355,8 +1353,13 @@ public <T> T getObject(int columnIndex, Class<T> type) throws SQLException {
instance.readSQL(sqlInput, null);
return (T) instance;
} else if (Map.class.isAssignableFrom(type)) {
JsonNode jsonNode = ((JsonSqlInput) getObject(columnIndex)).getInput();
return (T) OBJECT_MAPPER.convertValue(jsonNode, new TypeReference<Map<String, Object>>() {});
Object object = getObject(columnIndex);
if (object instanceof JsonSqlInput) {
JsonNode jsonNode = ((JsonSqlInput) object).getInput();
return (T) OBJECT_MAPPER.convertValue(jsonNode, new TypeReference<Map<String, Object>>() {});
} else {
return (T) ((ArrowSqlInput) object).getInput();
}
} else if (String.class.isAssignableFrom(type)) {
return (T) getString(columnIndex);
} else if (Boolean.class.isAssignableFrom(type)) {
Expand Down Expand Up @@ -1540,20 +1543,31 @@ public <T> Map<String, T> getMap(int columnIndex, Class<T> type) throws SQLExcep
int scale = resultSetMetaData.getScale(columnIndex);
TimeZone tz = sfBaseResultSet.getSessionTimeZone();
Object object = getObject(columnIndex);
JsonNode jsonNode = ((JsonSqlInput) object).getInput();
Map<String, Object> map =
OBJECT_MAPPER.convertValue(jsonNode, new TypeReference<Map<String, Object>>() {});
Map<String, Object> map;
if (object instanceof JsonSqlInput) {
JsonNode jsonNode = ((JsonSqlInput) object).getInput();
map = OBJECT_MAPPER.convertValue(jsonNode, new TypeReference<Map<String, Object>>() {});
} else {
map = (Map<String, Object>) object;
}
Map<String, T> resultMap = new HashMap<>();
for (Map.Entry<String, Object> entry : map.entrySet()) {
if (SQLData.class.isAssignableFrom(type)) {
SQLData instance = (SQLData) SQLDataCreationHelper.create(type);
SQLInput sqlInput =
new JsonSqlInput(
jsonNode.get(entry.getKey()),
session,
sfBaseResultSet.getConverters(),
sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields(),
sfBaseResultSet.getSessionTimezone());
SQLInput sqlInput;
if (object instanceof JsonSqlInput) {
sqlInput = new JsonSqlInput(
(((JsonSqlInput) object).getInput()).get(entry.getKey()),
session,
sfBaseResultSet.getConverters(),
sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields(),
sfBaseResultSet.getSessionTimezone());
} else {
sqlInput = new ArrowSqlInput((JsonStringHashMap<String, Object>) entry.getValue(),
session,
sfBaseResultSet.getConverters(),
sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields());
}
instance.readSQL(sqlInput, null);
resultMap.put(entry.getKey(), (T) instance);
} else if (String.class.isAssignableFrom(type)) {
Expand Down Expand Up @@ -1655,32 +1669,19 @@ public <T> Map<String, T> getMap(int columnIndex, Class<T> type) throws SQLExcep
resultMap.put(
entry.getKey(),
mapSFExceptionToSQLException(
() ->
(T)
sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getDate(entry.getValue(), columnType, columnSubType, tz, scale)));
() -> (T) convertToDate(entry.getValue(), columnType, columnSubType, tz, scale)));
} else if (Time.class.isAssignableFrom(type)) {
resultMap.put(
entry.getKey(),
mapSFExceptionToSQLException(
() ->
(T)
sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getTime(entry.getValue(), columnType, columnSubType, tz, scale)));
() -> (T) convertToTime(entry.getValue(), columnType, columnSubType, tz, scale)));

} else if (Timestamp.class.isAssignableFrom(type)) {
resultMap.put(
entry.getKey(),
mapSFExceptionToSQLException(
() ->
(T)
sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getTimestamp(entry.getValue(), columnType, columnSubType, tz, scale)));
() -> (T) convertToTimestamp(entry.getValue(), columnType, columnSubType, tz, scale)));

} else {
logger.debug(
"Unsupported type passed to getObject(int columnIndex,Class<T> type): "
Expand Down Expand Up @@ -1718,4 +1719,46 @@ public boolean isWrapperFor(Class<?> iface) throws SQLException {

return iface.isInstance(this);
}

private Date convertToDate(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
if (sfBaseResultSet instanceof SFArrowResultSet) {
return sfBaseResultSet
.getConverters()
.getStructuredTypeDateTimeConverter()
.getDate((int) object, tz);
} else {
return sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getDate(object, columnType, columnSubType, tz, scale);
}
}

private Time convertToTime(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
if (sfBaseResultSet instanceof SFArrowResultSet) {
return sfBaseResultSet
.getConverters()
.getStructuredTypeDateTimeConverter()
.getTime((int) object, scale);
} else {
return sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getTime(object, columnType, columnSubType, tz, scale);
}
}

private Timestamp convertToTimestamp(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException {
if (sfBaseResultSet instanceof SFArrowResultSet) {
return sfBaseResultSet
.getConverters()
.getStructuredTypeDateTimeConverter()
.getTimestamp((JsonStringHashMap<String, Object>) object, columnType, columnSubType, tz, scale);
} else {
return sfBaseResultSet
.getConverters()
.getDateTimeConverter()
.getTimestamp(object, columnType, columnSubType, tz, scale);
}
}
}
Loading

0 comments on commit 73c0f05

Please sign in to comment.