Skip to content

Commit

Permalink
Add native arrow structured types support
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Mar 7, 2024
1 parent 63ba15c commit 2cfa1e9
Show file tree
Hide file tree
Showing 14 changed files with 510 additions and 71 deletions.
293 changes: 293 additions & 0 deletions src/main/java/net/snowflake/client/core/ArrowSqlInput.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
package net.snowflake.client.core;

import net.snowflake.client.core.json.Converters;
import net.snowflake.client.core.structs.SQLDataCreationHelper;
import net.snowflake.client.jdbc.FieldMetadata;
import net.snowflake.client.jdbc.SnowflakeLoggedFeatureNotSupportedException;
import net.snowflake.client.util.ThrowingBiFunction;
import net.snowflake.common.core.SFTimestamp;
import net.snowflake.common.core.SnowflakeDateTimeFormat;
import org.apache.arrow.vector.util.JsonStringHashMap;

import java.io.InputStream;
import java.io.Reader;
import java.math.BigDecimal;
import java.net.URL;
import java.sql.*;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TimeZone;

import static net.snowflake.client.jdbc.SnowflakeUtil.mapExceptions;

@SnowflakeJdbcInternalApi
public class ArrowSqlInput implements SFSqlInput {

private final JsonStringHashMap<String, Object> input;
private final SFBaseSession session;
private final Iterator<Object> elements;
private final Converters converters;
private final List<FieldMetadata> fields;

private int currentIndex = 0;

public ArrowSqlInput(JsonStringHashMap<String, Object> input, SFBaseSession session, Converters converters, List<FieldMetadata> fields) {
this.input = input;
this.elements = input.values().iterator();
this.session = session;
this.converters = converters;
this.fields = fields;
}

@Override
public String readString() throws SQLException {
return withNextValue(
((value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
int columnSubType = fieldMetadata.getType();
int scale = fieldMetadata.getScale();
return mapExceptions(
() ->
converters
.getStringConverter()
.getString(value, columnType, columnSubType, scale));
}));
}

@Override
public boolean readBoolean() throws SQLException {
return withNextValue((value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(
() -> converters.getBooleanConverter().getBoolean(value, columnType));
});
}

@Override
public byte readByte() throws SQLException {
return withNextValue(
(value, fieldMetadata) ->
mapExceptions(() -> converters.getNumberConverter().getByte(value)));
}

@Override
public short readShort() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(() -> converters.getNumberConverter().getShort(value, columnType));
});
}

@Override
public int readInt() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(() -> converters.getNumberConverter().getInt(value, columnType));
});
}

@Override
public long readLong() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(() -> converters.getNumberConverter().getLong(value, columnType));
});
}

@Override
public float readFloat() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(() -> converters.getNumberConverter().getFloat(value, columnType));
});
}

@Override
public double readDouble() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(() -> converters.getNumberConverter().getDouble(value, columnType));
});
}

@Override
public BigDecimal readBigDecimal() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
return mapExceptions(
() -> converters.getNumberConverter().getBigDecimal(value, columnType));
});
}

@Override
public byte[] readBytes() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
int columnSubType = fieldMetadata.getType();
int scale = fieldMetadata.getScale();
return mapExceptions(
() ->
converters.getBytesConverter().getBytes(value, columnType, columnSubType, scale));
});
}

@Override
public Date readDate() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
SnowflakeDateTimeFormat formatter = SqlInputTimestampUtil.extractDateTimeFormat(session, "DATE_OUTPUT_FORMAT");
SFTimestamp timestamp = formatter.parse((String) value);
return Date.valueOf(
Instant.ofEpochMilli(timestamp.getTime()).atZone(ZoneOffset.UTC).toLocalDate());
});
}

@Override
public Time readTime() throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
SnowflakeDateTimeFormat formatter = SqlInputTimestampUtil.extractDateTimeFormat(session, "TIME_OUTPUT_FORMAT");
SFTimestamp timestamp = formatter.parse((String) value);
return Time.valueOf(
Instant.ofEpochMilli(timestamp.getTime()).atZone(ZoneOffset.UTC).toLocalTime());
});
}

@Override
public Timestamp readTimestamp() throws SQLException {
return readTimestamp(null);
}

@Override
public Timestamp readTimestamp(TimeZone tz) throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
if (value == null) {
return null;
}
int columnType = ColumnTypeHelper.getColumnType(fieldMetadata.getType(), session);
int columnSubType = fieldMetadata.getType();
int scale = fieldMetadata.getScale();
// TODO structuredType what if not a string value?
Timestamp result = SqlInputTimestampUtil.getTimestampFromType(columnSubType, (String) value, session);
if (result != null) {
return result;
}
return mapExceptions(
() ->
converters
.getDateTimeConverter()
.getTimestamp(value, columnType, columnSubType, tz, scale));
});
}


@Override
public Reader readCharacterStream() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readCharacterStream");
}

@Override
public InputStream readAsciiStream() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readAsciiStream");
}

@Override
public InputStream readBinaryStream() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readBinaryStream");
}

@Override
public Object readObject() throws SQLException {
return withNextValue((value, fieldMetadata) -> {
if (!(value instanceof JsonStringHashMap)) {
throw new SQLException("Invalid value passed to 'readObject()', expected Map; got: " + value.getClass());
}
return value;
});
}

@Override
public <T> T readObject(Class<T> type) throws SQLException {
return withNextValue(
(value, fieldMetadata) -> {
SQLData instance = (SQLData) SQLDataCreationHelper.create(type);
instance.readSQL(
new ArrowSqlInput(
(JsonStringHashMap<String, Object>) value,
session,
converters,
Arrays.asList(fieldMetadata.getFields())
),
null
);
return (T) instance;
});
}

@Override
public Ref readRef() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readRef");
}

@Override
public Blob readBlob() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readBlob");
}

@Override
public Clob readClob() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readClob");
}

@Override
public Array readArray() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readArray");
}

@Override
public boolean wasNull() throws SQLException {
return false; // nulls are not allowed in structure types
}

@Override
public URL readURL() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readCharacterStream");
}

@Override
public NClob readNClob() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readNClob");
}

@Override
public String readNString() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readNString");
}

@Override
public SQLXML readSQLXML() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readSQLXML");
}

@Override
public RowId readRowId() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readRowId");
}

private <T> T withNextValue(
ThrowingBiFunction<Object, FieldMetadata, T, SQLException> action)
throws SQLException {
return action.apply(elements.next(), fields.get(currentIndex++));
}
}
46 changes: 5 additions & 41 deletions src/main/java/net/snowflake/client/core/JsonSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.sql.SQLXML;
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.Arrays;
Expand All @@ -31,12 +30,12 @@
import net.snowflake.client.core.structs.SQLDataCreationHelper;
import net.snowflake.client.jdbc.FieldMetadata;
import net.snowflake.client.jdbc.SnowflakeLoggedFeatureNotSupportedException;
import net.snowflake.client.jdbc.SnowflakeUtil;
import net.snowflake.client.util.ThrowingCallable;
import net.snowflake.client.util.ThrowingTriFunction;
import net.snowflake.common.core.SFTimestamp;
import net.snowflake.common.core.SnowflakeDateTimeFormat;

import static net.snowflake.client.jdbc.SnowflakeUtil.mapExceptions;

@SnowflakeJdbcInternalApi
public class JsonSqlInput implements SFSqlInput {
private final JsonNode input;
Expand Down Expand Up @@ -163,7 +162,7 @@ public byte[] readBytes() throws SQLException {
public Date readDate() throws SQLException {
return withNextValue(
(value, jsonNode, fieldMetadata) -> {
SnowflakeDateTimeFormat formatter = getFormat(session, "DATE_OUTPUT_FORMAT");
SnowflakeDateTimeFormat formatter = SqlInputTimestampUtil.extractDateTimeFormat(session, "DATE_OUTPUT_FORMAT");
SFTimestamp timestamp = formatter.parse((String) value);
return Date.valueOf(
Instant.ofEpochMilli(timestamp.getTime()).atZone(ZoneOffset.UTC).toLocalDate());
Expand All @@ -174,7 +173,7 @@ public Date readDate() throws SQLException {
public Time readTime() throws SQLException {
return withNextValue(
(value, jsonNode, fieldMetadata) -> {
SnowflakeDateTimeFormat formatter = getFormat(session, "TIME_OUTPUT_FORMAT");
SnowflakeDateTimeFormat formatter = SqlInputTimestampUtil.extractDateTimeFormat(session, "TIME_OUTPUT_FORMAT");
SFTimestamp timestamp = formatter.parse((String) value);
return Time.valueOf(
Instant.ofEpochMilli(timestamp.getTime()).atZone(ZoneOffset.UTC).toLocalTime());
Expand All @@ -197,7 +196,7 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException {
int columnSubType = fieldMetadata.getType();
int scale = fieldMetadata.getScale();
// TODO structuredType what if not a string value?
Timestamp result = getTimestampFromType(columnSubType, (String) value);
Timestamp result = SqlInputTimestampUtil.getTimestampFromType(columnSubType, (String) value, session);
if (result != null) {
return result;
}
Expand All @@ -209,28 +208,6 @@ public Timestamp readTimestamp(TimeZone tz) throws SQLException {
});
}

private Timestamp getTimestampFromType(int columnSubType, String value) {
if (columnSubType == SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_LTZ) {
return getTimestampFromFormat("TIMESTAMP_LTZ_OUTPUT_FORMAT", value);
} else if (columnSubType == SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_NTZ
|| columnSubType == Types.TIMESTAMP) {
return getTimestampFromFormat("TIMESTAMP_NTZ_OUTPUT_FORMAT", value);
} else if (columnSubType == SnowflakeUtil.EXTRA_TYPES_TIMESTAMP_TZ) {
return getTimestampFromFormat("TIMESTAMP_TZ_OUTPUT_FORMAT", value);
} else {
return null;
}
}

private Timestamp getTimestampFromFormat(String format, String value) {
String rawFormat = (String) session.getCommonParameters().get(format);
if (rawFormat == null || rawFormat.isEmpty()) {
rawFormat = (String) session.getCommonParameters().get("TIMESTAMP_OUTPUT_FORMAT");
}
SnowflakeDateTimeFormat formatter = SnowflakeDateTimeFormat.fromSqlFormat(rawFormat);
return formatter.parse(value).getTimestamp();
}

@Override
public Reader readCharacterStream() throws SQLException {
throw new SnowflakeLoggedFeatureNotSupportedException(session, "readCharacterStream");
Expand Down Expand Up @@ -333,17 +310,4 @@ private Object getValue(JsonNode jsonNode) {
}
return null;
}

private <T> T mapExceptions(ThrowingCallable<T, SFException> action) throws SQLException {
try {
return action.call();
} catch (SFException e) {
throw new SQLException(e);
}
}

private static SnowflakeDateTimeFormat getFormat(SFBaseSession session, String format) {
return SnowflakeDateTimeFormat.fromSqlFormat(
(String) session.getCommonParameters().get(format));
}
}
Loading

0 comments on commit 2cfa1e9

Please sign in to comment.