diff --git a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java index 69195e8a4..7d8258110 100644 --- a/src/main/java/net/snowflake/client/core/SFArrowResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFArrowResultSet.java @@ -26,6 +26,7 @@ import java.util.stream.Stream; import net.snowflake.client.core.arrow.ArrayConverter; import net.snowflake.client.core.arrow.ArrowVectorConverter; +import net.snowflake.client.core.arrow.StructConverter; import net.snowflake.client.core.arrow.StructObjectWrapper; import net.snowflake.client.core.arrow.VarCharConverter; import net.snowflake.client.core.arrow.VectorTypeConverter; @@ -564,32 +565,80 @@ public Timestamp getTimestamp(int columnIndex, TimeZone tz) throws SFException { @Override public Object getObject(int columnIndex) throws SFException { - int type = resultSetMetaData.getColumnType(columnIndex); - if (type == SnowflakeUtil.EXTRA_TYPES_VECTOR) { + int columnType = resultSetMetaData.getColumnType(columnIndex); + if (columnType == SnowflakeUtil.EXTRA_TYPES_VECTOR) { return getString(columnIndex); } - ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1); + + ArrowVectorConverter converter = getConfiguredConverter(columnIndex); int index = currentChunkIterator.getCurrentRowInRecordBatch(); wasNull = converter.isNull(index); - converter.setTreatNTZAsUTC(treatNTZAsUTC); - converter.setUseSessionTimezone(useSessionTimezone); - converter.setSessionTimeZone(sessionTimeZone); - Object obj = converter.toObject(index); boolean isStructuredType = resultSetMetaData.isStructuredTypeColumn(columnIndex); - if (isVarcharConvertedStruct(type, isStructuredType, converter)) { - if (obj != null) { + Object obj = converter.toObject(index); + if (obj == null) { + return null; + } + if (columnType == Types.STRUCT && isStructuredType) { + if (converter instanceof VarCharConverter) { return new StructObjectWrapper((String) obj, createJsonSqlInput(columnIndex, obj)); + } else if (converter instanceof StructConverter) { + return new StructObjectWrapper( + converter.toString(index), + createArrowSqlInput(columnIndex, (Map) obj), + obj); + } + } + return new StructObjectWrapper(converter.toString(index), null, obj); + } + + @Override + public Object getObject(int columnIndex, Class type) throws SFException { + if (String.class.isAssignableFrom(type)) { + return getObject(columnIndex); + } + + int columnType = resultSetMetaData.getColumnType(columnIndex); + if (columnType == SnowflakeUtil.EXTRA_TYPES_VECTOR) { + return getString(columnIndex); + } + ArrowVectorConverter converter = getConfiguredConverter(columnIndex); + int index = currentChunkIterator.getCurrentRowInRecordBatch(); + wasNull = converter.isNull(index); + Object obj = converter.toObject(index); + boolean isStructuredType = resultSetMetaData.isStructuredTypeColumn(columnIndex); + if (obj == null) { + return null; + } + if (columnType == Types.STRUCT && isStructuredType) { + if (converter instanceof VarCharConverter) { + return new StructObjectWrapper(null, createJsonSqlInput(columnIndex, obj)); + } else if (converter instanceof StructConverter) { + return new StructObjectWrapper( + null, createArrowSqlInput(columnIndex, (Map) obj)); } } - return obj; + return new StructObjectWrapper(null, null, obj); + } + + private ArrowVectorConverter getConfiguredConverter(int columnIndex) throws SFException { + ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1); + converter.setTreatNTZAsUTC(treatNTZAsUTC); + converter.setUseSessionTimezone(useSessionTimezone); + converter.setSessionTimeZone(sessionTimeZone); + + return converter; } - private boolean isVarcharConvertedStruct( - int type, boolean isStructuredType, ArrowVectorConverter converter) { - return type == Types.STRUCT && isStructuredType && converter instanceof VarCharConverter; + private SQLInput createArrowSqlInput(int columnIndex, Map input) + throws SFException { + if (input == null) { + return null; + } + return new ArrowSqlInput( + input, session, converters, resultSetMetaData.getColumnFields(columnIndex)); } - private Object createJsonSqlInput(int columnIndex, Object obj) throws SFException { + private SQLInput createJsonSqlInput(int columnIndex, Object obj) throws SFException { try { if (obj == null) { return null; @@ -620,11 +669,7 @@ public Array getArray(int columnIndex) throws SFException { if (converter instanceof VarCharConverter) { return getJsonArray((String) obj, columnIndex); } else if (converter instanceof ArrayConverter || converter instanceof VectorTypeConverter) { - StructObjectWrapper structObjectWrapper = (StructObjectWrapper) obj; - return getArrowArray( - structObjectWrapper.getJsonString(), - (List) structObjectWrapper.getObject(), - columnIndex); + return getArrowArray(converter.toString(), (List) obj, columnIndex); } else { throw new SFException(queryId, ErrorCode.INVALID_STRUCT_DATA); } diff --git a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java index c0b6256ad..1ea2b09f9 100644 --- a/src/main/java/net/snowflake/client/core/SFBaseResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFBaseResultSet.java @@ -113,6 +113,8 @@ public abstract class SFBaseResultSet { public abstract Object getObject(int columnIndex) throws SFException; + public abstract Object getObject(int columnIndex, Class object) throws SFException; + public Array getArray(int columnIndex) throws SFException { throw new UnsupportedOperationException(); } diff --git a/src/main/java/net/snowflake/client/core/SFJsonResultSet.java b/src/main/java/net/snowflake/client/core/SFJsonResultSet.java index 04e8d3fba..2ce4eb824 100644 --- a/src/main/java/net/snowflake/client/core/SFJsonResultSet.java +++ b/src/main/java/net/snowflake/client/core/SFJsonResultSet.java @@ -103,6 +103,11 @@ public Object getObject(int columnIndex) throws SFException { } } + @Override + public Object getObject(int columnIndex, Class object) throws SFException { + return getObject(columnIndex); + } + /** * Sometimes large BIGINTS overflow the java Long type. In these cases, return a BigDecimal type * instead. @@ -292,7 +297,7 @@ public Converters getConverters() { return converters; } - private Object getSqlInput(String input, int columnIndex) throws SFException { + private SQLInput getSqlInput(String input, int columnIndex) throws SFException { try { JsonNode jsonNode = OBJECT_MAPPER.readTree(input); return new JsonSqlInput( diff --git a/src/main/java/net/snowflake/client/core/arrow/ArrayConverter.java b/src/main/java/net/snowflake/client/core/arrow/ArrayConverter.java index 96b683151..6c6afcd62 100644 --- a/src/main/java/net/snowflake/client/core/arrow/ArrayConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/ArrayConverter.java @@ -25,7 +25,7 @@ public ArrayConverter(ListVector valueVector, int vectorIndex, DataConversionCon @Override public Object toObject(int index) throws SFException { - return isNull(index) ? null : new StructObjectWrapper(toString(index), vector.getObject(index)); + return isNull(index) ? null : vector.getObject(index); } @Override diff --git a/src/main/java/net/snowflake/client/core/arrow/MapConverter.java b/src/main/java/net/snowflake/client/core/arrow/MapConverter.java index 0c6ca072e..d556ef586 100644 --- a/src/main/java/net/snowflake/client/core/arrow/MapConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/MapConverter.java @@ -2,7 +2,6 @@ import java.nio.charset.StandardCharsets; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import net.snowflake.client.core.DataConversionContext; import net.snowflake.client.core.SFException; @@ -36,12 +35,9 @@ public Object toObject(int index) throws SFException { List> entriesList = (List>) vector.getObject(index); - Map map = - entriesList.stream() - .collect( - Collectors.toMap( - entry -> entry.get("key").toString(), entry -> entry.get("value"))); - return new StructObjectWrapper(toString(index), map); + return entriesList.stream() + .collect( + Collectors.toMap(entry -> entry.get("key").toString(), entry -> entry.get("value"))); } @Override diff --git a/src/main/java/net/snowflake/client/core/arrow/StructConverter.java b/src/main/java/net/snowflake/client/core/arrow/StructConverter.java index ab7d20382..c2ef36ef7 100644 --- a/src/main/java/net/snowflake/client/core/arrow/StructConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/StructConverter.java @@ -21,9 +21,7 @@ public StructConverter(StructVector vector, int columnIndex, DataConversionConte @Override public Object toObject(int index) throws SFException { - return isNull(index) - ? null - : new StructObjectWrapper(toString(index), structVector.getObject(index)); + return isNull(index) ? null : structVector.getObject(index); } @Override diff --git a/src/main/java/net/snowflake/client/core/arrow/StructObjectWrapper.java b/src/main/java/net/snowflake/client/core/arrow/StructObjectWrapper.java index 8219c110a..75b2f453c 100644 --- a/src/main/java/net/snowflake/client/core/arrow/StructObjectWrapper.java +++ b/src/main/java/net/snowflake/client/core/arrow/StructObjectWrapper.java @@ -3,15 +3,24 @@ */ package net.snowflake.client.core.arrow; +import java.sql.SQLInput; import net.snowflake.client.core.SnowflakeJdbcInternalApi; @SnowflakeJdbcInternalApi public class StructObjectWrapper { private final String jsonString; + private final SQLInput sqlInput; private final Object object; - public StructObjectWrapper(String jsonString, Object object) { + public StructObjectWrapper(String jsonString, SQLInput sqlInput) { this.jsonString = jsonString; + this.sqlInput = sqlInput; + this.object = null; + } + + public StructObjectWrapper(String jsonString, SQLInput sqlInput, Object object) { + this.jsonString = jsonString; + this.sqlInput = sqlInput; this.object = object; } @@ -19,6 +28,10 @@ public String getJsonString() { return jsonString; } + public SQLInput getSqlInput() { + return sqlInput; + } + public Object getObject() { return object; } diff --git a/src/main/java/net/snowflake/client/core/arrow/VectorTypeConverter.java b/src/main/java/net/snowflake/client/core/arrow/VectorTypeConverter.java index cb9dcad73..2050a3d1f 100644 --- a/src/main/java/net/snowflake/client/core/arrow/VectorTypeConverter.java +++ b/src/main/java/net/snowflake/client/core/arrow/VectorTypeConverter.java @@ -27,8 +27,7 @@ public Object toObject(int index) throws SFException { if (isNull(index)) { return null; } - Object object = vector.getObject(index); - return new StructObjectWrapper(object.toString(), object); + return vector.getObject(index); } @Override diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java index 633083391..e92f4c86f 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeBaseResultSet.java @@ -1399,7 +1399,10 @@ public T getObject(int columnIndex, Class type) throws SQLException { () -> { StructObjectWrapper structObjectWrapper = (StructObjectWrapper) sfBaseResultSet.getObject(columnIndex); - return (SQLInput) createJsonSqlInput(columnIndex, structObjectWrapper); + if (structObjectWrapper == null) { + return null; + } + return structObjectWrapper.getSqlInput(); }); if (sqlInput == null) { return null; @@ -1637,23 +1640,25 @@ public Map getMap(int columnIndex, Class type) throws SQLExcep int scale = valueFieldMetadata.getScale(); TimeZone tz = sfBaseResultSet.getSessionTimeZone(); StructObjectWrapper structObjectWrapper = - (StructObjectWrapper) - SnowflakeUtil.mapSFExceptionToSQLException( - () -> sfBaseResultSet.getObject(columnIndex)); + SnowflakeUtil.mapSFExceptionToSQLException( + () -> (StructObjectWrapper) sfBaseResultSet.getObject(columnIndex, type)); if (structObjectWrapper == null) { return null; } + Map map = - mapSFExceptionToSQLException( - () -> prepareMapWithValues(structObjectWrapper.getObject(), type)); + mapSFExceptionToSQLException(() -> prepareMapWithValues(structObjectWrapper, type)); Map resultMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { if (SQLData.class.isAssignableFrom(type)) { SQLData instance = (SQLData) SQLDataCreationHelper.create(type); + SQLInput parentSqlInput = structObjectWrapper.getSqlInput(); SQLInput sqlInput = sfBaseResultSet.createSqlInputForColumn( entry.getValue(), - structObjectWrapper.getObject().getClass(), + parentSqlInput == null + ? structObjectWrapper.getObject().getClass() + : parentSqlInput.getClass(), columnIndex, session, valueFieldMetadata.getFields()); @@ -1812,11 +1817,12 @@ public boolean isWrapperFor(Class iface) throws SQLException { return iface.isInstance(this); } - private Map prepareMapWithValues(Object object, Class type) - throws SFException { - if (object instanceof JsonSqlInput) { + private Map prepareMapWithValues( + StructObjectWrapper structObjectWrapper, Class type) throws SFException { + SQLInput sqlInput = structObjectWrapper.getSqlInput(); + if (sqlInput instanceof JsonSqlInput) { Map map = new HashMap<>(); - JsonNode jsonNode = ((JsonSqlInput) object).getInput(); + JsonNode jsonNode = ((JsonSqlInput) sqlInput).getInput(); for (Iterator it = jsonNode.fieldNames(); it.hasNext(); ) { String name = it.next(); map.put( @@ -1826,14 +1832,14 @@ private Map prepareMapWithValues(Object object, Class typ : SnowflakeUtil.getJsonNodeStringValue(jsonNode.get(name))); } return map; - } else if (object instanceof Map) { - return (Map) object; + } else if (structObjectWrapper.getObject() != null) { + return (Map) structObjectWrapper.getObject(); } else { throw new SFException(ErrorCode.INVALID_STRUCT_DATA, "Object couldn't be converted to map"); } } - private Object createJsonSqlInput(int columnIndex, StructObjectWrapper obj) throws SFException { + private SQLInput createJsonSqlInput(int columnIndex, StructObjectWrapper obj) throws SFException { try { if (obj == null) { return null; diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java index e9db5ec71..6e4bb2d93 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java @@ -274,18 +274,27 @@ public Object getObject(int columnIndex) throws SQLException { SnowflakeUtil.mapSFExceptionToSQLException(() -> sfBaseResultSet.getObject(columnIndex)); if (object == null) { return null; - } else if (object instanceof JsonSqlInput) { - return ((JsonSqlInput) object).getText(); - } else if (object instanceof StructObjectWrapper) { - return ((StructObjectWrapper) object).getJsonString(); - } else if (object instanceof SfSqlArray) { + } + if (object instanceof SfSqlArray) { return ((SfSqlArray) object).getText(); - } else if (object instanceof ArrowSqlInput) { - throw new SQLException( - "Arrow native struct couldn't be converted to String. To map to SqlData the method getObject(int columnIndex, Class type) should be used"); - } else { - return object; } + if (object instanceof StructObjectWrapper) { + StructObjectWrapper structObjectWrapper = (StructObjectWrapper) object; + if (structObjectWrapper.getSqlInput() instanceof JsonSqlInput) { + return ((JsonSqlInput) structObjectWrapper.getSqlInput()).getText(); + } + if (structObjectWrapper.getSqlInput() instanceof ArrowSqlInput) { + throw new SQLException( + "Arrow native struct couldn't be converted to String. To map to SqlData the method getObject(int columnIndex, Class type) should be used"); + } + if (structObjectWrapper.getObject() != null) { + return structObjectWrapper.getObject(); + } + if (structObjectWrapper.getJsonString() != null) { + return structObjectWrapper.getJsonString(); + } + } + return object; } public Array getArray(int columnIndex) throws SQLException {