From 73c0f0578031f9d83ea27bfd09b732082aa5844a Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 27 Mar 2024 16:16:50 +0100 Subject: [PATCH] SNOW-1234214 Add support for maps in native arrow structured types --- .../snowflake/client/core/ArrowSqlInput.java | 12 +- .../client/core/SFArrowResultSet.java | 14 ++- .../client/core/arrow/MapConverter.java | 31 +++++ .../StructuredTypeDateTimeConverter.java | 17 +-- .../core/structs/StructureTypeHelper.java | 3 +- .../client/jdbc/ArrowResultChunk.java | 31 +---- .../client/jdbc/SnowflakeBaseResultSet.java | 113 ++++++++++++------ .../snowflake/client/AbstractDriverIT.java | 5 +- .../ResultSetStructuredTypesLatestIT.java | 25 +++- 9 files changed, 170 insertions(+), 81 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/arrow/MapConverter.java diff --git a/src/main/java/net/snowflake/client/core/ArrowSqlInput.java b/src/main/java/net/snowflake/client/core/ArrowSqlInput.java index e25dbe360..b84be490c 100644 --- a/src/main/java/net/snowflake/client/core/ArrowSqlInput.java +++ b/src/main/java/net/snowflake/client/core/ArrowSqlInput.java @@ -12,6 +12,7 @@ import java.sql.Timestamp; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.TimeZone; import net.snowflake.client.core.json.Converters; import net.snowflake.client.core.structs.SQLDataCreationHelper; @@ -24,6 +25,7 @@ @SnowflakeJdbcInternalApi public class ArrowSqlInput extends BaseSqlInput { + private final JsonStringHashMap input; private final Iterator structuredTypeFields; private int currentIndex = 0; @@ -34,6 +36,11 @@ public ArrowSqlInput( List fields) { super(session, converters, fields); this.structuredTypeFields = input.values().iterator(); + this.input = input; + } + + public Map getInput() { + return input; } @Override @@ -167,6 +174,8 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException { if (value == null) { return null; } + int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session); + int columnSubType = fieldMetadata.getType(); int scale = fieldMetadata.getScale(); return mapSFExceptionToSQLException( () -> @@ -174,7 +183,8 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException { .getStructuredTypeDateTimeConverter() .getTimestamp( (JsonStringHashMap) value, - fieldMetadata.getBase(), + columnType, + columnSubType, tz, scale)); }); diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index daf60b804..22f0609ca 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -104,7 +104,7 @@ public class SFArrowResultSet extends SFBaseResultSet implements DataConversionC */ private boolean formatDateWithTimezone; - @SnowflakeJdbcInternalApi protected Converters jsonConverters; + @SnowflakeJdbcInternalApi protected Converters converters; /** * Constructor takes a result from the API response that we get from executing a SQL statement. @@ -124,7 +124,7 @@ public SFArrowResultSet( boolean sortResult) throws SQLException { this(resultSetSerializable, session.getTelemetryClient(), sortResult); - this.jsonConverters = + this.converters = new Converters( resultSetSerializable.getTimeZone(), session, @@ -356,6 +356,12 @@ private boolean fetchNextRowSorted() throws SnowflakeSQLException { } } + @Override + @SnowflakeJdbcInternalApi + public Converters getConverters() { + return converters; + } + /** * Advance to next row * @@ -522,7 +528,7 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio return new JsonSqlInput( jsonNode, session, - jsonConverters, + converters, resultSetMetaData.getColumnMetadata().get(columnIndex - 1).getFields(), sessionTimezone); } catch (JsonProcessingException e) { @@ -534,7 +540,7 @@ private Object createArrowSqlInput(int columnIndex, JsonStringHashMap> entriesList = (List>) vector.getObject(index); + return entriesList.stream().collect(Collectors.toMap(entry-> entry.get("key").toString(), entry -> entry.get("value"))); + } + + @Override + public String toString(int index) throws SFException { + return vector.getObject(index).toString(); + } +} diff --git a/src/main/java/net/snowflake/client/core/arrow/StructuredTypeDateTimeConverter.java b/src/main/java/net/snowflake/client/core/arrow/StructuredTypeDateTimeConverter.java index cd30c4bf5..5f32e76a0 100644 --- a/src/main/java/net/snowflake/client/core/arrow/StructuredTypeDateTimeConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/StructuredTypeDateTimeConverter.java @@ -11,11 +11,13 @@ import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.sql.Types; import java.util.TimeZone; import net.snowflake.client.core.SFException; import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeType; +import net.snowflake.client.jdbc.SnowflakeUtil; import org.apache.arrow.vector.util.JsonStringHashMap; @SnowflakeJdbcInternalApi @@ -45,22 +47,23 @@ public StructuredTypeDateTimeConverter( } public Timestamp getTimestamp( - JsonStringHashMap obj, SnowflakeType type, TimeZone tz, int scale) + JsonStringHashMap obj, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException { if (tz == null) { tz = TimeZone.getDefault(); } - switch (type) { - case TIMESTAMP_LTZ: + if ( Types.TIMESTAMP == columnType) { + if (SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_LTZ == columnSubType) { return convertTimestampLtz(obj, scale); - case TIMESTAMP_NTZ: + } else { return convertTimestampNtz(obj, tz, scale); - case TIMESTAMP_TZ: - return convertTimestampTz(obj, scale); + } + } else if (Types.TIMESTAMP_WITH_TIMEZONE == columnType && SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_TZ == columnSubType) { + return convertTimestampTz(obj, scale); } throw new SFException( ErrorCode.INVALID_VALUE_CONVERT, - "Unexpected Arrow Field for " + type.name() + " and object type " + obj.getClass()); + "Unexpected Arrow Field for columnType " + columnType + " , column subtype " + columnSubType + " , and object type " + obj.getClass()); } public Date getDate(int value, TimeZone tz) throws SFException { diff --git a/src/main/java/net/snowflake/client/core/structs/StructureTypeHelper.java b/src/main/java/net/snowflake/client/core/structs/StructureTypeHelper.java index 46e8bb2e8..1f44779dd 100644 --- a/src/main/java/net/snowflake/client/core/structs/StructureTypeHelper.java +++ b/src/main/java/net/snowflake/client/core/structs/StructureTypeHelper.java @@ -5,8 +5,7 @@ @SnowflakeJdbcInternalApi public class StructureTypeHelper { private static final String STRUCTURED_TYPE_ENABLED_PROPERTY_NAME = "STRUCTURED_TYPE_ENABLED"; - private static boolean structuredTypeEnabled = - Boolean.valueOf(System.getProperty(STRUCTURED_TYPE_ENABLED_PROPERTY_NAME)); + private static boolean structuredTypeEnabled = true; public static boolean isStructureTypeEnabled() { return structuredTypeEnabled; diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index c273a8817..2bed387f1 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -13,31 +13,7 @@ import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFBaseSession; import net.snowflake.client.core.SFException; -import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter; -import net.snowflake.client.core.arrow.ArrowVectorConverter; -import net.snowflake.client.core.arrow.BigIntToFixedConverter; -import net.snowflake.client.core.arrow.BigIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.BigIntToTimeConverter; -import net.snowflake.client.core.arrow.BigIntToTimestampLTZConverter; -import net.snowflake.client.core.arrow.BigIntToTimestampNTZConverter; -import net.snowflake.client.core.arrow.BitToBooleanConverter; -import net.snowflake.client.core.arrow.DateConverter; -import net.snowflake.client.core.arrow.DecimalToScaledFixedConverter; -import net.snowflake.client.core.arrow.DoubleToRealConverter; -import net.snowflake.client.core.arrow.IntToFixedConverter; -import net.snowflake.client.core.arrow.IntToScaledFixedConverter; -import net.snowflake.client.core.arrow.IntToTimeConverter; -import net.snowflake.client.core.arrow.SmallIntToFixedConverter; -import net.snowflake.client.core.arrow.SmallIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.StructConverter; -import net.snowflake.client.core.arrow.ThreeFieldStructToTimestampTZConverter; -import net.snowflake.client.core.arrow.TinyIntToFixedConverter; -import net.snowflake.client.core.arrow.TinyIntToScaledFixedConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampLTZConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampNTZConverter; -import net.snowflake.client.core.arrow.TwoFieldStructToTimestampTZConverter; -import net.snowflake.client.core.arrow.VarBinaryToBinaryConverter; -import net.snowflake.client.core.arrow.VarCharConverter; +import net.snowflake.client.core.arrow.*; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; @@ -55,6 +31,7 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; @@ -206,6 +183,10 @@ private static List initConverters( converters.add(new VarCharConverter(vector, i, context)); break; + case MAP: + converters.add(new MapConverter((MapVector) vector, i, context)); + break; + case OBJECT: if (vector instanceof StructVector) { converters.add(new StructConverter((StructVector) vector, i, context)); diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java index 1a4a1c82e..67aa870a9 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java @@ -37,15 +37,13 @@ import java.util.List; import java.util.Map; import java.util.TimeZone; -import net.snowflake.client.core.ColumnTypeHelper; -import net.snowflake.client.core.JsonSqlInput; -import net.snowflake.client.core.ObjectMapperFactory; -import net.snowflake.client.core.SFBaseResultSet; -import net.snowflake.client.core.SFBaseSession; + +import net.snowflake.client.core.*; import net.snowflake.client.core.structs.SQLDataCreationHelper; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.common.core.SqlState; +import org.apache.arrow.vector.util.JsonStringHashMap; /** Base class for query result set and metadata result set */ public abstract class SnowflakeBaseResultSet implements ResultSet { @@ -1355,8 +1353,13 @@ public T getObject(int columnIndex, Class type) throws SQLException { instance.readSQL(sqlInput, null); return (T) instance; } else if (Map.class.isAssignableFrom(type)) { - JsonNode jsonNode = ((JsonSqlInput) getObject(columnIndex)).getInput(); - return (T) OBJECT_MAPPER.convertValue(jsonNode, new TypeReference>() {}); + Object object = getObject(columnIndex); + if (object instanceof JsonSqlInput) { + JsonNode jsonNode = ((JsonSqlInput) object).getInput(); + return (T) OBJECT_MAPPER.convertValue(jsonNode, new TypeReference>() {}); + } else { + return (T) ((ArrowSqlInput) object).getInput(); + } } else if (String.class.isAssignableFrom(type)) { return (T) getString(columnIndex); } else if (Boolean.class.isAssignableFrom(type)) { @@ -1540,20 +1543,31 @@ public Map getMap(int columnIndex, Class type) throws SQLExcep int scale = resultSetMetaData.getScale(columnIndex); TimeZone tz = sfBaseResultSet.getSessionTimeZone(); Object object = getObject(columnIndex); - JsonNode jsonNode = ((JsonSqlInput) object).getInput(); - Map map = - OBJECT_MAPPER.convertValue(jsonNode, new TypeReference>() {}); + Map map; + if (object instanceof JsonSqlInput) { + JsonNode jsonNode = ((JsonSqlInput) object).getInput(); + map = OBJECT_MAPPER.convertValue(jsonNode, new TypeReference>() {}); + } else { + map = (Map) object; + } Map resultMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { if (SQLData.class.isAssignableFrom(type)) { SQLData instance = (SQLData) SQLDataCreationHelper.create(type); - SQLInput sqlInput = - new JsonSqlInput( - jsonNode.get(entry.getKey()), - session, - sfBaseResultSet.getConverters(), - sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields(), - sfBaseResultSet.getSessionTimezone()); + SQLInput sqlInput; + if (object instanceof JsonSqlInput) { + sqlInput = new JsonSqlInput( + (((JsonSqlInput) object).getInput()).get(entry.getKey()), + session, + sfBaseResultSet.getConverters(), + sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields(), + sfBaseResultSet.getSessionTimezone()); + } else { + sqlInput = new ArrowSqlInput((JsonStringHashMap) entry.getValue(), + session, + sfBaseResultSet.getConverters(), + sfBaseResultSet.getMetaData().getColumnMetadata().get(columnIndex - 1).getFields()); + } instance.readSQL(sqlInput, null); resultMap.put(entry.getKey(), (T) instance); } else if (String.class.isAssignableFrom(type)) { @@ -1655,32 +1669,19 @@ public Map getMap(int columnIndex, Class type) throws SQLExcep resultMap.put( entry.getKey(), mapSFExceptionToSQLException( - () -> - (T) - sfBaseResultSet - .getConverters() - .getDateTimeConverter() - .getDate(entry.getValue(), columnType, columnSubType, tz, scale))); + () -> (T) convertToDate(entry.getValue(), columnType, columnSubType, tz, scale))); } else if (Time.class.isAssignableFrom(type)) { resultMap.put( entry.getKey(), mapSFExceptionToSQLException( - () -> - (T) - sfBaseResultSet - .getConverters() - .getDateTimeConverter() - .getTime(entry.getValue(), columnType, columnSubType, tz, scale))); + () -> (T) convertToTime(entry.getValue(), columnType, columnSubType, tz, scale))); + } else if (Timestamp.class.isAssignableFrom(type)) { resultMap.put( entry.getKey(), mapSFExceptionToSQLException( - () -> - (T) - sfBaseResultSet - .getConverters() - .getDateTimeConverter() - .getTimestamp(entry.getValue(), columnType, columnSubType, tz, scale))); + () -> (T) convertToTimestamp(entry.getValue(), columnType, columnSubType, tz, scale))); + } else { logger.debug( "Unsupported type passed to getObject(int columnIndex,Class type): " @@ -1718,4 +1719,46 @@ public boolean isWrapperFor(Class iface) throws SQLException { return iface.isInstance(this); } + + private Date convertToDate(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException { + if (sfBaseResultSet instanceof SFArrowResultSet) { + return sfBaseResultSet + .getConverters() + .getStructuredTypeDateTimeConverter() + .getDate((int) object, tz); + } else { + return sfBaseResultSet + .getConverters() + .getDateTimeConverter() + .getDate(object, columnType, columnSubType, tz, scale); + } + } + + private Time convertToTime(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException { + if (sfBaseResultSet instanceof SFArrowResultSet) { + return sfBaseResultSet + .getConverters() + .getStructuredTypeDateTimeConverter() + .getTime((int) object, scale); + } else { + return sfBaseResultSet + .getConverters() + .getDateTimeConverter() + .getTime(object, columnType, columnSubType, tz, scale); + } + } + + private Timestamp convertToTimestamp(Object object, int columnType, int columnSubType, TimeZone tz, int scale) throws SFException { + if (sfBaseResultSet instanceof SFArrowResultSet) { + return sfBaseResultSet + .getConverters() + .getStructuredTypeDateTimeConverter() + .getTimestamp((JsonStringHashMap) object, columnType, columnSubType, tz, scale); + } else { + return sfBaseResultSet + .getConverters() + .getDateTimeConverter() + .getTimestamp(object, columnType, columnSubType, tz, scale); + } + } } diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index b44cc31ef..e8354604b 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -321,9 +321,12 @@ public static Connection getConnection( properties.put("schema", params.get("schema")); properties.put("warehouse", params.get("warehouse")); properties.put("ssl", params.get("ssl")); + properties.put("proxyHost", "localhost"); + properties.put("proxyPort", "8080"); + properties.put("useProxy", "true"); properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? - properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("insecureMode", true); // use OCSP for all tests. if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); diff --git a/src/test/java/net/snowflake/client/jdbc/ResultSetStructuredTypesLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ResultSetStructuredTypesLatestIT.java index 12df13135..40a967140 100644 --- a/src/test/java/net/snowflake/client/jdbc/ResultSetStructuredTypesLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ResultSetStructuredTypesLatestIT.java @@ -38,7 +38,7 @@ public class ResultSetStructuredTypesLatestIT extends BaseJDBCTest { private final ResultSetFormatType queryResultFormat; public ResultSetStructuredTypesLatestIT() { - this(ResultSetFormatType.JSON); + this(ResultSetFormatType.NATIVE_ARROW); } protected ResultSetStructuredTypesLatestIT(ResultSetFormatType queryResultFormat) { @@ -178,7 +178,6 @@ private void testMapAllTypes(boolean registerFactory) throws SQLException { @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testMapJsonToMap() throws SQLException { - Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW); withFirstRow( "SELECT OBJECT_CONSTRUCT('string','a','string2',1)", (resultSet) -> { @@ -207,7 +206,8 @@ public void testReturnAsArrayOfSqlData() throws SQLException { @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testReturnAsArrayOfString() throws SQLException { - withFirstRow( + Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW); + withFirstRow( "SELECT ARRAY_CONSTRUCT('one', 'two','three')::ARRAY(VARCHAR)", (resultSet) -> { String[] resultArray = @@ -221,7 +221,8 @@ public void testReturnAsArrayOfString() throws SQLException { @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testReturnAsListOfIntegers() throws SQLException { - withFirstRow( + Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW); + withFirstRow( "SELECT ARRAY_CONSTRUCT(1,2,3)::ARRAY(INTEGER)", (resultSet) -> { List resultList = @@ -235,7 +236,6 @@ public void testReturnAsListOfIntegers() throws SQLException { @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testReturnAsMap() throws SQLException { - Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW); SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new); withFirstRow( "select {'x':{'string':'one'},'y':{'string':'two'},'z':{'string':'three'}}::MAP(VARCHAR, OBJECT(string VARCHAR));", @@ -262,6 +262,20 @@ public void testReturnAsMapOfLong() throws SQLException { }); } + @Test + @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) + public void testReturnAsMapOfTimestampsNTz() throws SQLException { + withFirstRow( + "SELECT {'x': TO_TIMESTAMP_NTZ('2021-12-23 09:44:44'), 'y': TO_TIMESTAMP_NTZ('2021-12-24 09:55:55')}::MAP(VARCHAR, TIMESTAMP)", + (resultSet) -> { + Map map = resultSet.unwrap(SnowflakeBaseResultSet.class).getMap(1, Timestamp.class); + assertEquals( + Timestamp.valueOf(LocalDateTime.of(2021, 12, 23, 10, 44, 44)), map.get("x")); + assertEquals( + Timestamp.valueOf(LocalDateTime.of(2021, 12, 24, 10, 55, 55)), map.get("y")); + }); + } + @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testReturnAsMapOfBoolean() throws SQLException { @@ -294,7 +308,6 @@ public void testReturnAsList() throws SQLException { @Test @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) public void testMapStructsFromChunks() throws SQLException { - Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW); withFirstRow( "select {'string':'a'}::OBJECT(string VARCHAR) FROM TABLE(GENERATOR(ROWCOUNT=>30000))", (resultSet) -> {