diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java index 177710714..18b28a037 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java @@ -161,7 +161,8 @@ static Blob constructBlobAndMetadata( chunkMetadataBuilder .setMajorVersion(Constants.PARQUET_MAJOR_VERSION) .setMinorVersion(Constants.PARQUET_MINOR_VERSION) - .setCreatedOn(0L) + // set createdOn in seconds + .setCreatedOn(System.currentTimeMillis() / 1000) .setExtendedMetadataSize(-1L); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java index dfadd029a..0a9711ee8 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java @@ -10,6 +10,8 @@ /** Channel's buffer relevant parameters that are set at the owning client level. */ public class ClientBufferParameters { + private static final String BDEC_PARQUET_MESSAGE_TYPE_NAME = "bdec"; + private static final String PARQUET_MESSAGE_TYPE_NAME = "schema"; private long maxChunkSizeInBytes; @@ -118,4 +120,8 @@ public boolean getIsIcebergMode() { public Optional getMaxRowGroups() { return maxRowGroups; } + + public String getParquetMessageTypeName() { + return isIcebergMode ? PARQUET_MESSAGE_TYPE_NAME : BDEC_PARQUET_MESSAGE_TYPE_NAME; + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 6e3281997..8d8bff3f5 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -1055,6 +1055,83 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR insertRowIndex); } + /** + * Validate and cast Iceberg struct column to Map. Allowed Java type: + * + *
    + *
  • Map + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergStruct( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, + input.getClass(), + "STRUCT", + new String[] {"Map"}, + insertRowIndex); + } + for (Object key : ((Map) input).keySet()) { + if (!(key instanceof String)) { + throw new SFException( + ErrorCode.INVALID_FORMAT_ROW, + String.format( + "Field name of struct %s must be of type String, rowIndex:%d", + columnName, insertRowIndex)); + } + } + + return (Map) input; + } + + /** + * Validate and parse Iceberg list column to an Iterable. Allowed Java type: + * + *
    + *
  • Iterable + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Iterable + */ + static Iterable validateAndParseIcebergList( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Iterable)) { + throw typeNotAllowedException( + columnName, input.getClass(), "LIST", new String[] {"Iterable"}, insertRowIndex); + } + return (Iterable) input; + } + + /** + * Validate and parse Iceberg map column to a map. Allowed Java type: + * + *
    + *
  • Map + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergMap( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, input.getClass(), "MAP", new String[] {"Map"}, insertRowIndex); + } + return (Map) input; + } + static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { BigDecimal comparand = diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java index 322b53acf..d199531b7 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java @@ -33,6 +33,9 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { @JsonProperty("client_sequencer") Long clientSequencer; + @JsonProperty("is_iceberg") + boolean isIceberg; + DropChannelRequestInternal( String requestId, String role, @@ -40,6 +43,7 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { String schema, String table, String channel, + boolean isIceberg, Long clientSequencer) { this.requestId = requestId; this.role = role; @@ -47,6 +51,7 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { this.schema = schema; this.table = table; this.channel = channel; + this.isIceberg = isIceberg; this.clientSequencer = clientSequencer; } @@ -74,6 +79,10 @@ String getSchema() { return schema; } + boolean isIceberg() { + return isIceberg; + } + Long getClientSequencer() { return clientSequencer; } @@ -86,7 +95,7 @@ String getFullyQualifiedTableName() { public String getStringForLogging() { return String.format( "DropChannelRequest(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," - + " clientSequencer=%s)", - requestId, role, database, schema, table, channel, clientSequencer); + + " isIceberg=%s, clientSequencer=%s)", + requestId, role, database, schema, table, channel, isIceberg, clientSequencer); } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java index 0ec671bf7..3a8dbc2b6 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java @@ -1,7 +1,12 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.streaming.internal.BinaryStringUtils.truncateBytesAsHex; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import java.math.BigInteger; import java.util.Objects; @@ -9,6 +14,7 @@ /** Audit register endpoint/FileColumnPropertyDTO property list. */ class FileColumnProperties { private int columnOrdinal; + private Integer fieldId; private String minStrValue; private String maxStrValue; @@ -46,6 +52,7 @@ class FileColumnProperties { FileColumnProperties(RowBufferStats stats, boolean setDefaultValues) { this.setColumnOrdinal(stats.getOrdinal()); + this.setFieldId(stats.getFieldId()); this.setCollation(stats.getCollationDefinitionString()); this.setMaxIntValue( stats.getCurrentMaxIntValue() == null @@ -93,6 +100,16 @@ public void setColumnOrdinal(int columnOrdinal) { this.columnOrdinal = columnOrdinal; } + @JsonProperty("fieldId") + @JsonInclude(JsonInclude.Include.NON_NULL) + public Integer getFieldId() { + return fieldId; + } + + public void setFieldId(Integer fieldId) { + this.fieldId = fieldId; + } + // Annotation required in order to have package private fields serialized @JsonProperty("minStrValue") String getMinStrValue() { @@ -206,6 +223,7 @@ void setMaxStrNonCollated(String maxStrNonCollated) { public String toString() { final StringBuilder sb = new StringBuilder("{"); sb.append("\"columnOrdinal\": ").append(columnOrdinal); + sb.append(", \"fieldId\": ").append(fieldId); if (minIntValue != null) { sb.append(", \"minIntValue\": ").append(minIntValue); sb.append(", \"maxIntValue\": ").append(maxIntValue); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java index 0cf8220bb..eb1d580c5 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java index 18a66f4d5..963dbf188 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java @@ -5,18 +5,28 @@ package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.checkFixedLengthByteArray; +import static net.snowflake.ingest.utils.Utils.concatDotPath; +import static net.snowflake.ingest.utils.Utils.isNullOrEmpty; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; @@ -30,12 +40,15 @@ /** Parses a user Iceberg column value into Parquet internal representation for buffering. */ class IcebergParquetValueParser { + static final String THREE_LEVEL_MAP_GROUP_NAME = "key_value"; + static final String THREE_LEVEL_LIST_GROUP_NAME = "list"; + /** * Parses a user column value into Parquet internal representation for buffering. * * @param value column value provided by user in a row * @param type Parquet column type - * @param stats column stats to update + * @param statsMap column stats map to update * @param defaultTimezone default timezone to use for timestamp parsing * @param insertRowsCurrIndex Row index corresponding the row to parse (w.r.t input rows in * insertRows API, and not buffered row) @@ -44,78 +57,116 @@ class IcebergParquetValueParser { static ParquetBufferValue parseColumnValueToParquet( Object value, Type type, - RowBufferStats stats, + Map statsMap, ZoneId defaultTimezone, long insertRowsCurrIndex) { - Utils.assertNotNull("Parquet column stats", stats); + Utils.assertNotNull("Parquet column stats map", statsMap); + return parseColumnValueToParquet( + value, type, statsMap, defaultTimezone, insertRowsCurrIndex, null, false); + } + + private static ParquetBufferValue parseColumnValueToParquet( + Object value, + Type type, + Map statsMap, + ZoneId defaultTimezone, + long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + path = isNullOrEmpty(path) ? type.getName() : concatDotPath(path, type.getName()); float estimatedParquetSize = 0F; + + if (type.isPrimitive()) { + if (!statsMap.containsKey(path)) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, String.format("Stats not found for column: %s", path)); + } + } + if (value != null) { - estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; - PrimitiveType primitiveType = type.asPrimitiveType(); - switch (primitiveType.getPrimitiveTypeName()) { - case BOOLEAN: - int intValue = - DataValidationUtil.validateAndParseBoolean( - type.getName(), value, insertRowsCurrIndex); - value = intValue > 0; - stats.addIntValue(BigInteger.valueOf(intValue)); - estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; - break; - case INT32: - int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex); - value = intVal; - stats.addIntValue(BigInteger.valueOf(intVal)); - estimatedParquetSize += 4; - break; - case INT64: - long longVal = getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex); - value = longVal; - stats.addIntValue(BigInteger.valueOf(longVal)); - estimatedParquetSize += 8; - break; - case FLOAT: - float floatVal = - (float) - DataValidationUtil.validateAndParseReal( - type.getName(), value, insertRowsCurrIndex); - value = floatVal; - stats.addRealValue((double) floatVal); - estimatedParquetSize += 4; - break; - case DOUBLE: - double doubleVal = - DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex); - value = doubleVal; - stats.addRealValue(doubleVal); - estimatedParquetSize += 8; - break; - case BINARY: - byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex); - value = byteVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; - break; - case FIXED_LEN_BYTE_ARRAY: - byte[] fixedLenByteArrayVal = - getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex); - value = fixedLenByteArrayVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + fixedLenByteArrayVal.length; - break; - default: - throw new SFException( - ErrorCode.UNKNOWN_DATA_TYPE, - type.getLogicalTypeAnnotation(), - primitiveType.getPrimitiveTypeName()); + if (type.isPrimitive()) { + RowBufferStats stats = statsMap.get(path); + estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; + estimatedParquetSize += + isDescendantsOfRepeatingGroup + ? ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN + : 0; + PrimitiveType primitiveType = type.asPrimitiveType(); + switch (primitiveType.getPrimitiveTypeName()) { + case BOOLEAN: + int intValue = + DataValidationUtil.validateAndParseBoolean(path, value, insertRowsCurrIndex); + value = intValue > 0; + stats.addIntValue(BigInteger.valueOf(intValue)); + estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; + break; + case INT32: + int intVal = getInt32Value(value, primitiveType, path, insertRowsCurrIndex); + value = intVal; + stats.addIntValue(BigInteger.valueOf(intVal)); + estimatedParquetSize += 4; + break; + case INT64: + long longVal = + getInt64Value(value, primitiveType, defaultTimezone, path, insertRowsCurrIndex); + value = longVal; + stats.addIntValue(BigInteger.valueOf(longVal)); + estimatedParquetSize += 8; + break; + case FLOAT: + float floatVal = + (float) DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); + value = floatVal; + stats.addRealValue((double) floatVal); + estimatedParquetSize += 4; + break; + case DOUBLE: + double doubleVal = + DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); + value = doubleVal; + stats.addRealValue(doubleVal); + estimatedParquetSize += 8; + break; + case BINARY: + byte[] byteVal = getBinaryValue(value, primitiveType, stats, path, insertRowsCurrIndex); + value = byteVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; + break; + case FIXED_LEN_BYTE_ARRAY: + byte[] fixedLenByteArrayVal = + getFixedLenByteArrayValue(value, primitiveType, stats, path, insertRowsCurrIndex); + value = fixedLenByteArrayVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + + fixedLenByteArrayVal.length; + break; + default: + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, + type.getLogicalTypeAnnotation(), + primitiveType.getPrimitiveTypeName()); + } + } else { + return getGroupValue( + value, + type.asGroupType(), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); } } if (value == null) { if (type.isRepetition(Repetition.REQUIRED)) { throw new SFException( - ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field"); + ErrorCode.INVALID_FORMAT_ROW, path, "Passed null to non nullable field"); + } + if (type.isPrimitive()) { + statsMap.get(path).incCurrentNullCount(); } - stats.incCurrentNullCount(); } return new ParquetBufferValue(value, estimatedParquetSize); @@ -126,21 +177,21 @@ static ParquetBufferValue parseColumnValueToParquet( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int32 value */ private static int getInt32Value( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergInt( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergInt(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().intValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().intValue(); } if (logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) { - return DataValidationUtil.validateAndParseDate(type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseDate(path, value, insertRowsCurrIndex); } throw new SFException( ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getPrimitiveTypeName()); @@ -151,22 +202,26 @@ private static int getInt32Value( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int64 value */ private static long getInt64Value( - Object value, PrimitiveType type, ZoneId defaultTimezone, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + ZoneId defaultTimezone, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergLong( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergLong(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().longValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().longValue(); } if (logicalTypeAnnotation instanceof TimeLogicalTypeAnnotation) { return DataValidationUtil.validateAndParseTime( - type.getName(), + path, value, timeUnitToScale(((TimeLogicalTypeAnnotation) logicalTypeAnnotation).getUnit()), insertRowsCurrIndex) @@ -197,29 +252,28 @@ private static long getInt64Value( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getBinaryValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { byte[] bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), - value, - Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addBinaryValue(bytes); return bytes; } if (logicalTypeAnnotation instanceof StringLogicalTypeAnnotation) { String string = DataValidationUtil.validateAndParseString( - type.getName(), - value, - Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addStrValue(string); return string.getBytes(StandardCharsets.UTF_8); } @@ -233,22 +287,28 @@ private static byte[] getBinaryValue( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getFixedLenByteArrayValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); int length = type.getTypeLength(); byte[] bytes = null; if (logicalTypeAnnotation == null) { bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), value, Optional.of(length), insertRowsCurrIndex); + path, value, Optional.of(length), insertRowsCurrIndex); stats.addBinaryValue(bytes); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - BigInteger bigIntegerVal = getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue(); + BigInteger bigIntegerVal = + getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue(); stats.addIntValue(bigIntegerVal); bytes = bigIntegerVal.toByteArray(); if (bytes.length < length) { @@ -271,15 +331,16 @@ private static byte[] getFixedLenByteArrayValue( * * @param value value to parse * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return BigDecimal representation */ private static BigDecimal getDecimalValue( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { int scale = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getScale(); int precision = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getPrecision(); BigDecimal bigDecimalValue = - DataValidationUtil.validateAndParseBigDecimal(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseBigDecimal(path, value, insertRowsCurrIndex); bigDecimalValue = bigDecimalValue.setScale(scale, RoundingMode.HALF_UP); DataValidationUtil.checkValueInRange(bigDecimalValue, scale, precision, insertRowsCurrIndex); return bigDecimalValue; @@ -298,4 +359,169 @@ private static int timeUnitToScale(LogicalTypeAnnotation.TimeUnit timeUnit) { ErrorCode.INTERNAL_ERROR, String.format("Unknown time unit: %s", timeUnit)); } } + + /** + * Parses a group value based on Parquet group logical type. + * + * @param value value to parse + * @param type Parquet column type + * @param statsMap column stats map to update + * @param defaultTimezone default timezone to use for timestamp parsing + * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API + * @param path dot path of the column + * @param isDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group, + * @return list of parsed values + */ + private static ParquetBufferValue getGroupValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); + if (logicalTypeAnnotation == null) { + return getStructValue( + value, + type, + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) { + return get3LevelListValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.MapLogicalTypeAnnotation) { + return get3LevelMapValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path); + } + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getClass().getSimpleName()); + } + + /** + * Parses a struct value based on Parquet group logical type. The parsed value is a list of + * values, where each element represents a field in the group. For example, an input {@code + * {"field1": 1, "field2": 2}} will be parsed as {@code [1, 2]}. + */ + private static ParquetBufferValue getStructValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + Map structVal = + DataValidationUtil.validateAndParseIcebergStruct(path, value, insertRowsCurrIndex); + Set extraFields = new HashSet<>(structVal.keySet()); + List listVal = new ArrayList<>(type.getFieldCount()); + float estimatedParquetSize = 0f; + for (int i = 0; i < type.getFieldCount(); i++) { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + structVal.getOrDefault(type.getFieldName(i), null), + type.getType(i), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); + extraFields.remove(type.getFieldName(i)); + listVal.add(parsedValue.getValue()); + estimatedParquetSize += parsedValue.getSize(); + } + if (!extraFields.isEmpty()) { + String extraFieldsStr = + extraFields.stream().map(f -> concatDotPath(path, f)).collect(Collectors.joining(", ")); + throw new SFException( + ErrorCode.INVALID_FORMAT_ROW, + "Extra fields: " + extraFieldsStr, + String.format( + "Fields not present in the struct shouldn't be specified, rowIndex:%d", + insertRowsCurrIndex)); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } + + /** + * Parses an iterable value based on Parquet 3-level list logical type. Please check Parquet + * Logical Types#Lists for more details. The parsed value is a list of lists, where each inner + * list represents a list of elements in the group. For example, an input {@code [1, 2, 3, 4]} + * will be parsed as {@code [[1], [2], [3], [4]]}. + */ + private static ParquetBufferValue get3LevelListValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path) { + Iterable iterableVal = + DataValidationUtil.validateAndParseIcebergList(path, value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + float estimatedParquetSize = 0; + String listGroupPath = concatDotPath(path, THREE_LEVEL_LIST_GROUP_NAME); + for (Object val : iterableVal) { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + val, + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + listGroupPath, + true); + listVal.add(Collections.singletonList(parsedValue.getValue())); + estimatedParquetSize += parsedValue.getSize(); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } + + /** + * Parses a map value based on Parquet 3-level map logical type. Please check Parquet + * Logical Types#Maps for more details. The parsed value is a list of lists, where each inner + * list represents a key-value pair in the group. For example, an input {@code {"a": 1, "b": 2}} + * will be parsed as {@code [["a", 1], ["b", 2]]}. + */ + private static ParquetBufferValue get3LevelMapValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path) { + Map mapVal = + DataValidationUtil.validateAndParseIcebergMap(path, value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + float estimatedParquetSize = 0; + String mapGroupPath = concatDotPath(path, THREE_LEVEL_MAP_GROUP_NAME); + for (Map.Entry entry : mapVal.entrySet()) { + ParquetBufferValue parsedKey = + parseColumnValueToParquet( + entry.getKey(), + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + mapGroupPath, + true); + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + entry.getValue(), + type.getType(0).asGroupType().getType(1), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + mapGroupPath, + true); + listVal.add(Arrays.asList(parsedKey.getValue(), parsedValue.getValue())); + estimatedParquetSize += parsedKey.getSize() + parsedValue.getSize(); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java index f89de0aa7..48987bd74 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java @@ -10,11 +10,12 @@ class ParquetBufferValue { public static final float BIT_ENCODING_BYTE_LEN = 1.0f / 8; /** - * On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition level. + * On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition and + * repetition level. * *
    - * There are two cases how definition level (0 for null values, 1 for non-null values) is - * encoded: + * There are two cases how definition and repetition level (0 for null values, 1 for non-null + * values) is encoded: *
  • If there are at least 8 repeated values in a row, they are run-length encoded (length + * value itself). E.g. 11111111 -> 8 1 *
  • If there are less than 8 repeated values, they are written in group as part of a @@ -31,6 +32,8 @@ class ParquetBufferValue { */ public static final float DEFINITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8; + public static final float REPETITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8; + // Parquet stores length in 4 bytes before the actual data bytes public static final int BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN = 4; private final Object value; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 8865a88c3..f8f60cecb 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index b8054ad9f..38f8a3bce 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -4,6 +4,8 @@ package net.snowflake.ingest.streaming.internal; +import static net.snowflake.ingest.utils.Utils.concatDotPath; + import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; @@ -25,6 +27,7 @@ import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; @@ -33,7 +36,6 @@ * converted to Parquet format for faster processing */ public class ParquetRowBuffer extends AbstractRowBuffer { - private static final String PARQUET_MESSAGE_TYPE_NAME = "bdec"; private final Map fieldIndex; @@ -71,6 +73,11 @@ public class ParquetRowBuffer extends AbstractRowBuffer { this.tempData = new ArrayList<>(); } + /** + * Set up the parquet schema. + * + * @param columns top level columns list of column metadata + */ @Override public void setupSchema(List columns) { fieldIndex.clear(); @@ -79,7 +86,9 @@ public void setupSchema(List columns) { metadata.put(Constants.SDK_VERSION_KEY, RequestBuilder.DEFAULT_VERSION); List parquetTypes = new ArrayList<>(); int id = 1; + for (ColumnMetadata column : columns) { + /* Set up fields using top level column information */ validateColumnCollation(column); ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(column, id); parquetTypes.add(typeInfo.getParquetType()); @@ -91,20 +100,105 @@ public void setupSchema(List columns) { if (!column.getNullable()) { addNonNullableFieldName(column.getInternalName()); } - this.statsMap.put( - column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); - - if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT - || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { - this.tempStatsMap.put( + if (!clientBufferParameters.getIsIcebergMode()) { + /* Streaming to FDN table doesn't support sub-columns, set up the stats here. */ + this.statsMap.put( column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); + new RowBufferStats( + column.getName(), column.getCollation(), column.getOrdinal(), null /* fieldId */)); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + /* + * tempStatsMap is used to store stats for the current batch, + * create a separate stats in case current batch has invalid rows which ruins the original stats. + */ + this.tempStatsMap.put( + column.getInternalName(), + new RowBufferStats( + column.getName(), + column.getCollation(), + column.getOrdinal(), + null /* fieldId */)); + } } id++; } - schema = new MessageType(PARQUET_MESSAGE_TYPE_NAME, parquetTypes); + schema = new MessageType(clientBufferParameters.getParquetMessageTypeName(), parquetTypes); + + /* + * Iceberg mode requires stats for all primitive columns and sub-columns, set them up here. + * + * There are two values that are used to identify a column in the stats map: + * 1. ordinal - The ordinal is the index of the top level column in the schema. + * 2. fieldId - The fieldId is the id of all sub-columns in the schema. + * It's indexed by the level and order of the column in the schema. + * Note that the fieldId is set to 0 for non-structured columns. + * + * For example, consider the following schema: + * F1 INT, + * F2 STRUCT(F21 STRUCT(F211 INT), F22 INT), + * F3 INT, + * F4 MAP(INT, MAP(INT, INT)), + * F5 INT, + * F6 ARRAY(INT), + * F7 INT + * + * The ordinal and fieldId will look like this: + * F1: ordinal=1, fieldId=1 + * F2: ordinal=2, fieldId=2 + * F2.F21: ordinal=2, fieldId=8 + * F2.F21.F211: ordinal=2, fieldId=13 + * F2.F22: ordinal=2, fieldId=9 + * F3: ordinal=3, fieldId=3 + * F4: ordinal=4, fieldId=4 + * F4.key: ordinal=4, fieldId=10 + * F4.value: ordinal=4, fieldId=11 + * F4.value.key: ordinal=4, fieldId=14 + * F4.value.value: ordinal=4, fieldId=15 + * F5: ordinal=5, fieldId=5 + * F6: ordinal=6, fieldId=6 + * F6.element: ordinal=6, fieldId=12 + * F7: ordinal=7, fieldId=7 + * + * The stats map will contain the following entries: + * F1: ordinal=1, fieldId=0 + * F2: ordinal=2, fieldId=0 + * F2.F21.F211: ordinal=2, fieldId=13 + * F2.F22: ordinal=2, fieldId=9 + * F3: ordinal=3, fieldId=0 + * F4.key: ordinal=4, fieldId=10 + * F4.value.key: ordinal=4, fieldId=14 + * F4.value.value: ordinal=4, fieldId=15 + * F5: ordinal=5, fieldId=0 + * F6.element: ordinal=6, fieldId=12 + * F7: ordinal=7, fieldId=0 + */ + if (clientBufferParameters.getIsIcebergMode()) { + for (ColumnDescriptor columnDescriptor : schema.getColumns()) { + String columnPath = concatDotPath(columnDescriptor.getPath()); + + /* set fieldId to 0 for non-structured columns */ + int fieldId = + columnDescriptor.getPath().length == 1 + ? 0 + : columnDescriptor.getPrimitiveType().getId().intValue(); + int ordinal = schema.getType(columnDescriptor.getPath()[0]).getId().intValue(); + + this.statsMap.put( + columnPath, + new RowBufferStats(columnPath, null /* collationDefinitionString */, ordinal, fieldId)); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + this.tempStatsMap.put( + columnPath, + new RowBufferStats( + columnPath, null /* collationDefinitionString */, ordinal, fieldId)); + } + } + } tempData.clear(); data.clear(); } @@ -161,6 +255,7 @@ private float addRow( // Create new empty stats just for the current row. Map forkedStatsMap = new HashMap<>(); + statsMap.forEach((columnName, stats) -> forkedStatsMap.put(columnName, stats.forkEmpty())); for (Map.Entry entry : row.entrySet()) { String key = entry.getKey(); @@ -168,18 +263,16 @@ private float addRow( String columnName = LiteralQuoteUtils.unquoteColumnName(key); ParquetColumn parquetColumn = fieldIndex.get(columnName); int colIndex = parquetColumn.index; - RowBufferStats forkedStats = statsMap.get(columnName).forkEmpty(); - forkedStatsMap.put(columnName, forkedStats); ColumnMetadata column = parquetColumn.columnMetadata; ParquetBufferValue valueWithSize = (clientBufferParameters.getIsIcebergMode() ? IcebergParquetValueParser.parseColumnValueToParquet( - value, parquetColumn.type, forkedStats, defaultTimezone, insertRowsCurrIndex) + value, parquetColumn.type, forkedStatsMap, defaultTimezone, insertRowsCurrIndex) : SnowflakeParquetValueParser.parseColumnValueToParquet( value, column, parquetColumn.type.asPrimitiveType().getPrimitiveTypeName(), - forkedStats, + forkedStatsMap.get(columnName), defaultTimezone, insertRowsCurrIndex, clientBufferParameters.isEnableNewJsonParsingLogic())); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java index 395123f1f..2fae695f0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -15,7 +15,16 @@ /** Keeps track of the active EP stats, used to generate a file EP info */ class RowBufferStats { + /* Ordinal of a column, one-based. */ private final int ordinal; + + /* + * Field id of a column. + * For FDN columns, it's always null. + * For Iceberg columns, set to nonzero Iceberg field id if it's a sub-column, otherwise zero. + */ + private final Integer fieldId; + private byte[] currentMinStrValue; private byte[] currentMaxStrValue; private BigInteger currentMinIntValue; @@ -30,15 +39,17 @@ class RowBufferStats { private final String columnDisplayName; /** Creates empty stats */ - RowBufferStats(String columnDisplayName, String collationDefinitionString, int ordinal) { + RowBufferStats( + String columnDisplayName, String collationDefinitionString, int ordinal, Integer fieldId) { this.columnDisplayName = columnDisplayName; this.collationDefinitionString = collationDefinitionString; this.ordinal = ordinal; + this.fieldId = fieldId; reset(); } RowBufferStats(String columnDisplayName) { - this(columnDisplayName, null, -1); + this(columnDisplayName, null, -1, null); } void reset() { @@ -55,7 +66,10 @@ void reset() { /** Create new statistics for the same column, with all calculated values set to empty */ RowBufferStats forkEmpty() { return new RowBufferStats( - this.getColumnDisplayName(), this.getCollationDefinitionString(), this.getOrdinal()); + this.getColumnDisplayName(), + this.getCollationDefinitionString(), + this.getOrdinal(), + this.getFieldId()); } // TODO performance test this vs in place update @@ -70,7 +84,10 @@ static RowBufferStats getCombinedStats(RowBufferStats left, RowBufferStats right } RowBufferStats combined = new RowBufferStats( - left.columnDisplayName, left.getCollationDefinitionString(), left.getOrdinal()); + left.columnDisplayName, + left.getCollationDefinitionString(), + left.getOrdinal(), + left.getFieldId()); if (left.currentMinIntValue != null) { combined.addIntValue(left.currentMinIntValue); @@ -217,6 +234,10 @@ public int getOrdinal() { return ordinal; } + Integer getFieldId() { + return fieldId; + } + /** * Compares two byte arrays lexicographically. If the two arrays share a common prefix then the * lexicographic comparison is the result of comparing two elements, as if by Byte.compare(byte, diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 793f8bc9a..843b8975a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -438,6 +438,7 @@ public void dropChannel(DropChannelRequest request) { request.getSchemaName(), request.getTableName(), request.getChannelName(), + this.isIcebergMode, request instanceof DropChannelVersionRequest ? ((DropChannelVersionRequest) request).getClientSequencer() : null); diff --git a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java index 95e484b70..abb03cdef 100644 --- a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java +++ b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java @@ -65,11 +65,24 @@ public static org.apache.parquet.schema.Type parseIcebergDataTypeStringToParquet int id, String name) { Type icebergType = deserializeIcebergType(icebergDataType); - if (!icebergType.isPrimitiveType()) { - throw new IllegalArgumentException( - String.format("Snowflake supports only primitive Iceberg types, got '%s'", icebergType)); + if (icebergType.isPrimitiveType()) { + return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); + } else { + switch (icebergType.typeId()) { + case LIST: + return typeToMessageType.list(icebergType.asListType(), repetition, id, name); + case MAP: + return typeToMessageType.map(icebergType.asMapType(), repetition, id, name); + case STRUCT: + return typeToMessageType.struct(icebergType.asStructType(), repetition, id, name); + default: + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "Cannot convert Iceberg column to parquet type, name=%s, dataType=%s", + name, icebergDataType)); + } } - return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); } /** diff --git a/src/main/java/net/snowflake/ingest/utils/Utils.java b/src/main/java/net/snowflake/ingest/utils/Utils.java index 5220625da..95d941036 100644 --- a/src/main/java/net/snowflake/ingest/utils/Utils.java +++ b/src/main/java/net/snowflake/ingest/utils/Utils.java @@ -411,4 +411,23 @@ public static String getFullyQualifiedChannelName( String dbName, String schemaName, String tableName, String channelName) { return String.format("%s.%s.%s.%s", dbName, schemaName, tableName, channelName); } + + /* + * Get concat dot path, check if any path is empty or null + * + * @param path the path + */ + public static String concatDotPath(String... path) { + StringBuilder sb = new StringBuilder(); + for (String p : path) { + if (isNullOrEmpty(p)) { + throw new IllegalArgumentException("Path cannot be null or empty"); + } + if (sb.length() > 0) { + sb.append("."); + } + sb.append(p); + } + return sb.toString(); + } } diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 45c65a4ea..c73269748 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -14,7 +14,6 @@ import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.apache.hadoop.conf.Configuration; -import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory; import org.apache.parquet.crypto.FileEncryptionProperties; @@ -283,7 +282,8 @@ public void prepareForWrite(RecordConsumer recordConsumer) { @Override public void write(List values) { - List cols = schema.getColumns(); + List cols = + schema.getFields(); /* getFields() returns top level columns in the schema */ if (values.size() != cols.size()) { throw new ParquetEncodingException( "Invalid input data in channel '" @@ -302,7 +302,7 @@ public void write(List values) { recordConsumer.endMessage(); } - private void writeValues(List values, GroupType type) { + private void writeValues(List values, GroupType type) { List cols = type.getFields(); for (int i = 0; i < cols.size(); ++i) { Object val = values.get(i); @@ -344,7 +344,31 @@ private void writeValues(List values, GroupType type) { "Unsupported column type: " + cols.get(i).asPrimitiveType()); } } else { - throw new ParquetEncodingException("Unsupported column type: " + cols.get(i)); + if (cols.get(i).isRepetition(Type.Repetition.REPEATED)) { + /* List and Map */ + for (Object o : values) { + recordConsumer.startGroup(); + if (o != null) { + if (o instanceof List) { + writeValues((List) o, cols.get(i).asGroupType()); + } else { + throw new ParquetEncodingException( + String.format("Field %s should be a 3 level list or map", fieldName)); + } + } + recordConsumer.endGroup(); + } + } else { + /* Struct */ + recordConsumer.startGroup(); + if (val instanceof List) { + writeValues((List) val, cols.get(i).asGroupType()); + } else { + throw new ParquetEncodingException( + String.format("Field %s should be a 2 level struct", fieldName)); + } + recordConsumer.endGroup(); + } } recordConsumer.endField(fieldName, i); } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java index 708fd81bd..8a6190a71 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -105,7 +105,9 @@ private List> createChannelDataPerTable(int metada channelData.setRowCount(metadataRowCount); channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); - channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData + .getColumnEps() + .putIfAbsent(columnName, new RowBufferStats(columnName, null, 1, isIceberg ? 0 : null)); channelData.setChannelContext( new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); return Collections.singletonList(channelData); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java index 0e738a4b3..4d0c51596 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java @@ -16,7 +16,10 @@ import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseBoolean; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseDate; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergInt; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergList; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergLong; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergMap; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergStruct; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObject; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObjectNew; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseReal; @@ -50,6 +53,7 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.TimeZone; @@ -1281,6 +1285,32 @@ public void testValidateAndParseIcebergLong() { () -> validateAndParseIcebergLong("COL", Double.NEGATIVE_INFINITY, 0)); } + @Test + public void testValidateAndParseIcebergStruct() throws JsonProcessingException { + Map validStruct = + objectMapper.readValue("{\"a\": 1, \"b\":[1, 2, 3], \"c\":{\"d\":3}}", Map.class); + assertEquals(validStruct, validateAndParseIcebergStruct("COL", validStruct, 0)); + expectError( + ErrorCode.INVALID_FORMAT_ROW, + () -> validateAndParseIcebergStruct("COL", Collections.singletonMap(1, new Object()), 0)); + } + + @Test + public void testValidateAndParseIcebergList() throws JsonProcessingException { + List validList = objectMapper.readValue("[1, [2, 3, 4], 5]", List.class); + assertEquals(validList, validateAndParseIcebergList("COL", validList, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergList("COL", 1, 0)); + } + + @Test + public void testValidateAndParseIcebergMap() { + Map validMap = Collections.singletonMap(1, 1); + assertEquals(validMap, validateAndParseIcebergMap("COL", validMap, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergMap("COL", 1, 0)); + } + /** * Tests that exception message are constructed correctly when ingesting forbidden Java type, as * well a value of an allowed type, but in invalid format diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java index 63ba51abb..7b131b310 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import org.junit.Assert; @@ -15,18 +19,19 @@ public static Object[] isIceberg() { @Test public void testFileColumnPropertiesConstructor() { // Test simple construction - RowBufferStats stats = new RowBufferStats("COL", null, 1); + RowBufferStats stats = new RowBufferStats("COL", null, 1, isIceberg ? 1 : null); stats.addStrValue("bcd"); stats.addStrValue("abcde"); FileColumnProperties props = new FileColumnProperties(stats, isIceberg); Assert.assertEquals(1, props.getColumnOrdinal()); + Assert.assertEquals(isIceberg ? 1 : null, props.getFieldId()); Assert.assertEquals("6162636465", props.getMinStrValue()); Assert.assertNull(props.getMinStrNonCollated()); Assert.assertEquals("626364", props.getMaxStrValue()); Assert.assertNull(props.getMaxStrNonCollated()); // Test that truncation is performed - stats = new RowBufferStats("COL", null, 1); + stats = new RowBufferStats("COL", null, 1, isIceberg ? 0 : null); stats.addStrValue("aßßßßßßßßßßßßßßßß"); Assert.assertEquals(33, stats.getCurrentMinStrValue().length); props = new FileColumnProperties(stats, isIceberg); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java index a0b4caa1c..007dc3e23 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java @@ -7,26 +7,46 @@ import static java.time.ZoneOffset.UTC; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; +import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import org.junit.Assert; import org.junit.Test; public class IcebergParquetValueParserTest { + static ObjectMapper objectMapper = new ObjectMapper(); + @Test public void parseValueBoolean() { Type type = Types.primitive(PrimitiveTypeName.BOOLEAN, Repetition.OPTIONAL).named("BOOLEAN_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BOOLEAN_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BOOLEAN_COL", rowBufferStats); + } + }; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -42,9 +62,15 @@ public void parseValueInt() { Type type = Types.primitive(PrimitiveTypeName.INT32, Repetition.OPTIONAL).named("INT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("INT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("INT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Integer.MAX_VALUE, type, rowBufferStats, UTC, 0); + Integer.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -63,9 +89,15 @@ public void parseValueDecimalToInt() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("12345.6789"), type, rowBufferStats, UTC, 0); + new BigDecimal("12345.6789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -84,9 +116,15 @@ public void parseValueDateToInt() { .named("DATE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DATE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DATE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01", type, rowBufferStats, UTC, 0); + "2024-01-01", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -102,9 +140,15 @@ public void parseValueLong() { Type type = Types.primitive(PrimitiveTypeName.INT64, Repetition.OPTIONAL).named("LONG_COL"); RowBufferStats rowBufferStats = new RowBufferStats("LONG_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LONG_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Long.MAX_VALUE, type, rowBufferStats, UTC, 0); + Long.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -123,9 +167,15 @@ public void parseValueDecimalToLong() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("123456789.123456789"), type, rowBufferStats, UTC, 0); + new BigDecimal("123456789.123456789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -144,9 +194,15 @@ public void parseValueTimeToLong() { .named("TIME_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIME_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIME_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "12:34:56.789", type, rowBufferStats, UTC, 0); + "12:34:56.789", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -165,9 +221,15 @@ public void parseValueTimestampToLong() { .named("TIMESTAMP_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -186,9 +248,15 @@ public void parseValueTimestampTZToLong() { .named("TIMESTAMP_TZ_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_TZ_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_TZ_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -204,9 +272,15 @@ public void parseValueFloat() { Type type = Types.primitive(PrimitiveTypeName.FLOAT, Repetition.OPTIONAL).named("FLOAT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FLOAT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FLOAT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Float.MAX_VALUE, type, rowBufferStats, UTC, 0); + Float.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -222,9 +296,15 @@ public void parseValueDouble() { Type type = Types.primitive(PrimitiveTypeName.DOUBLE, Repetition.OPTIONAL).named("DOUBLE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DOUBLE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DOUBLE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Double.MAX_VALUE, type, rowBufferStats, UTC, 0); + Double.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -240,9 +320,15 @@ public void parseValueBinary() { Type type = Types.primitive(PrimitiveTypeName.BINARY, Repetition.OPTIONAL).named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; byte[] value = "snowflake_to_the_moon".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -262,9 +348,15 @@ public void parseValueStringToBinary() { .named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; String value = "snowflake_to_the_moon"; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -286,9 +378,15 @@ public void parseValueFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; byte[] value = "snow".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -309,9 +407,15 @@ public void parseValueDecimalToFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; BigDecimal value = new BigDecimal("1234567890.0123456789"); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -322,4 +426,353 @@ public void parseValueDecimalToFixed() { .expectedMinMax(value.unscaledValue()) .assertMatches(); } + + @Test + public void parseList() throws JsonProcessingException { + Type list = + Types.optionalList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + RowBufferStats rowBufferStats = new RowBufferStats("LIST_COL.list.element"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LIST_COL.list.element", rowBufferStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, list, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + Arrays.asList(1, 2, 3, 4, 5), list, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList( + objectMapper.readValue("[[1], [2], [3], [4], [5]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 5) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list */ + Type requiredList = + Types.requiredList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredList, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new ArrayList<>(), requiredList, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list with required elements */ + Type requiredElements = + Types.requiredList() + .element(Types.required(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + Collections.singletonList(null), requiredElements, rowBufferStatsMap, UTC, 0)); + } + + @Test + public void parseMap() throws JsonProcessingException { + Type map = + Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + RowBufferStats rowBufferKeyStats = new RowBufferStats("MAP_COL.key_value.key"); + RowBufferStats rowBufferValueStats = new RowBufferStats("MAP_COL.key_value.value"); + Map rowBufferStatsMap = + new HashMap() { + { + put("MAP_COL.key_value.key", rowBufferKeyStats); + put("MAP_COL.key_value.value", rowBufferValueStats); + } + }; + IcebergParquetValueParser.parseColumnValueToParquet(null, map, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, 1); + put(2, 2); + } + }, + map, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[[1, 1], [2, 2]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 2 + + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + * 2) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map */ + Type requiredMap = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredMap, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredMap, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map with required values */ + Type requiredValues = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.required(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, null); + } + }, + requiredValues, + rowBufferStatsMap, + UTC, + 0)); + } + + @Test + public void parseStruct() throws JsonProcessingException { + Type struct = + Types.optionalGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.required(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + + RowBufferStats rowBufferAStats = new RowBufferStats("STRUCT_COL.a"); + RowBufferStats rowBufferBStats = new RowBufferStats("STRUCT_COL.b"); + Map rowBufferStatsMap = + new HashMap() { + { + put("STRUCT_COL.a", rowBufferAStats); + put("STRUCT_COL.b", rowBufferBStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, struct, rowBufferStatsMap, UTC, 0); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put("a", 1); + } + }, + struct, + rowBufferStatsMap, + UTC, + 0)); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put("c", 1); + } + }, + struct, + rowBufferStatsMap, + UTC, + 0)); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + Collections.unmodifiableMap( + new java.util.HashMap() { + { + // a is null + put("b", "2"); + } + }), + struct, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, \"2\"]", ArrayList.class))) + .expectedSize(1 + BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + + /* Test required struct */ + Type requiredStruct = + Types.requiredGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredStruct, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredStruct, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, null]", ArrayList.class))) + .expectedSize(0) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + } + + @Test + public void parseNestedTypes() { + for (int depth = 1; depth <= 100; depth *= 10) { + Map rowBufferStatsMap = new HashMap<>(); + Type type = generateNestedTypeAndStats(depth, "a", rowBufferStatsMap, "a"); + Pair res = generateNestedValueAndReference(depth); + Object value = res.getFirst(); + List reference = (List) res.getSecond(); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + value, type, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(reference)) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + * (depth / 3 + 1)) + .assertMatches(); + } + } + + private static Type generateNestedTypeAndStats( + int depth, String name, Map rowBufferStatsMap, String path) { + if (depth == 0) { + rowBufferStatsMap.put(path, new RowBufferStats(path)); + return Types.optional(PrimitiveTypeName.INT32).named(name); + } + switch (depth % 3) { + case 1: + return Types.optionalList() + .element( + generateNestedTypeAndStats( + depth - 1, "element", rowBufferStatsMap, path + ".list.element")) + .named(name); + case 2: + return Types.optionalGroup() + .addField(generateNestedTypeAndStats(depth - 1, "a", rowBufferStatsMap, path + ".a")) + .named(name); + case 0: + rowBufferStatsMap.put(path + ".key_value.key", new RowBufferStats(path + ".key_value.key")); + return Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value( + generateNestedTypeAndStats( + depth - 1, "value", rowBufferStatsMap, path + ".key_value.value")) + .named(name); + } + return null; + } + + private static Pair generateNestedValueAndReference(int depth) { + if (depth == 0) { + return new Pair<>(1, 1); + } + Pair res = generateNestedValueAndReference(depth - 1); + Assert.assertNotNull(res); + switch (depth % 3) { + case 1: + return new Pair<>( + Collections.singletonList(res.getFirst()), + Collections.singletonList(Collections.singletonList(res.getSecond()))); + case 2: + return new Pair<>( + new java.util.HashMap() { + { + put("a", res.getFirst()); + } + }, + Collections.singletonList(res.getSecond())); + case 0: + return new Pair<>( + new java.util.HashMap() { + { + put(1, res.getFirst()); + } + }, + Collections.singletonList(Arrays.asList(1, res.getSecond()))); + } + return null; + } + + private static ArrayList convertToArrayList(List list) { + ArrayList arrayList = new ArrayList<>(); + for (Object element : list) { + if (element instanceof List) { + // Recursively convert nested lists + arrayList.add(convertToArrayList((List) element)); + } else if (element instanceof String) { + // Convert string to byte array + arrayList.add(((String) element).getBytes()); + } else { + arrayList.add(element); + } + } + return arrayList; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java index c83d339c1..99411be5f 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java @@ -693,6 +693,313 @@ public void buildFieldIcebergBinary() { .assertMatches(); } + @Test + public void buildFieldIcebergStruct() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergList() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"list\"," + + " \"element\": \"int\"," + + " \"element-required\": true," + + " \"element-id\": 1" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergMap() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": \"int\"," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergNestedStructuredDataType() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": {" + + " \"type\": \"list\"," + + " \"element\": {" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + " }," + + " \"element-required\": true," + + " \"element-id\": 1" + + " }," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(valueTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + /** Builder that helps to assert parquet type info */ private static class ParquetTypeInfoAssertionBuilder { private String fieldName; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java index 8480311fa..1aef68663 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java @@ -7,6 +7,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.Assert; /** Builder that helps to assert parsing of values to parquet types */ @@ -16,7 +17,8 @@ class ParquetValueParserAssertionBuilder { private Class valueClass; private Object value; private float size; - private Object minMaxStat; + private Object minStat; + private Object maxStat; private long currentNullCount; static ParquetValueParserAssertionBuilder newBuilder() { @@ -50,7 +52,18 @@ ParquetValueParserAssertionBuilder expectedSize(float size) { } public ParquetValueParserAssertionBuilder expectedMinMax(Object minMaxStat) { - this.minMaxStat = minMaxStat; + this.minStat = minMaxStat; + this.maxStat = minMaxStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMin(Object minStat) { + this.minStat = minStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMax(Object maxStat) { + this.maxStat = maxStat; return this; } @@ -64,41 +77,64 @@ void assertMatches() { if (valueClass.equals(byte[].class)) { Assert.assertArrayEquals((byte[]) value, (byte[]) parquetBufferValue.getValue()); } else { - Assert.assertEquals(value, parquetBufferValue.getValue()); + assertValueEquals(value, parquetBufferValue.getValue()); } Assert.assertEquals(size, parquetBufferValue.getSize(), 0); - if (minMaxStat instanceof BigInteger) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinIntValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxIntValue()); - return; - } else if (minMaxStat instanceof byte[]) { - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMinStrValue()); - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMaxStrValue()); - return; - } else if (valueClass.equals(String.class)) { - // String can have null min/max stats for variant data types - Object min = - rowBufferStats.getCurrentMinStrValue() != null - ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMinStrValue(); - Object max = - rowBufferStats.getCurrentMaxStrValue() != null - ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMaxStrValue(); - Assert.assertEquals(minMaxStat, min); - Assert.assertEquals(minMaxStat, max); - return; - } else if (minMaxStat instanceof Double || minMaxStat instanceof BigDecimal) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinRealValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxRealValue()); - return; + if (rowBufferStats != null) { + if (minStat instanceof BigInteger) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinIntValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxIntValue()); + return; + } else if (minStat instanceof byte[]) { + Assert.assertArrayEquals((byte[]) minStat, rowBufferStats.getCurrentMinStrValue()); + Assert.assertArrayEquals((byte[]) maxStat, rowBufferStats.getCurrentMaxStrValue()); + return; + } else if (valueClass.equals(String.class)) { + // String can have null min/max stats for variant data types + Object min = + rowBufferStats.getCurrentMinStrValue() != null + ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMinStrValue(); + Object max = + rowBufferStats.getCurrentMaxStrValue() != null + ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMaxStrValue(); + Assert.assertEquals(minStat, min); + Assert.assertEquals(maxStat, max); + return; + } else if (minStat instanceof Double || minStat instanceof BigDecimal) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinRealValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxRealValue()); + return; + } + throw new IllegalArgumentException( + String.format("Unknown data type for min stat: %s", minStat.getClass())); } - throw new IllegalArgumentException( - String.format("Unknown data type for min stat: %s", minMaxStat.getClass())); } void assertNull() { Assert.assertNull(parquetBufferValue.getValue()); Assert.assertEquals(currentNullCount, rowBufferStats.getCurrentNullCount()); } + + void assertValueEquals(Object expectedValue, Object actualValue) { + if (expectedValue == null) { + Assert.assertNull(actualValue); + return; + } + if (expectedValue instanceof List) { + Assert.assertTrue(actualValue instanceof List); + List expectedList = (List) expectedValue; + List actualList = (List) actualValue; + Assert.assertEquals(expectedList.size(), actualList.size()); + for (int i = 0; i < expectedList.size(); i++) { + assertValueEquals(expectedList.get(i), actualList.get(i)); + } + } else if (expectedValue.getClass().equals(byte[].class)) { + Assert.assertEquals(byte[].class, actualValue.getClass()); + Assert.assertArrayEquals((byte[]) expectedValue, (byte[]) actualValue); + } else { + Assert.assertEquals(expectedValue, actualValue); + } + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java index 253dee01c..f77a9bb60 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java @@ -9,8 +9,18 @@ import net.snowflake.ingest.utils.Constants; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class SnowflakeServiceClientTest { + @Parameterized.Parameters(name = "isIceberg: {0}") + public static Object[] isIceberg() { + return new Object[] {false, true}; + } + + @Parameterized.Parameter public boolean isIceberg; + private SnowflakeServiceClient snowflakeServiceClient; @Before @@ -50,7 +60,7 @@ public void testOpenChannel() throws IngestResponseException, IOException { "test_table", "test_channel", Constants.WriteMode.CLOUD_STORAGE, - false, + isIceberg, "test_offset_token"); OpenChannelResponse openChannelResponse = snowflakeServiceClient.openChannel(openChannelRequest); @@ -72,7 +82,14 @@ public void testOpenChannel() throws IngestResponseException, IOException { public void testDropChannel() throws IngestResponseException, IOException { DropChannelRequestInternal dropChannelRequest = new DropChannelRequestInternal( - "request_id", "test_role", "test_db", "test_schema", "test_table", "test_channel", 0L); + "request_id", + "test_role", + "test_db", + "test_schema", + "test_table", + "test_channel", + isIceberg, + 0L); DropChannelResponse dropChannelResponse = snowflakeServiceClient.dropChannel(dropChannelRequest); assert dropChannelResponse.getStatusCode() == 0L;