diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java index f6c26f9bd..8becb190b 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 6e3281997..0712c68a0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -1055,6 +1055,80 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR insertRowIndex); } + /** + * Validate and cast Iceberg struct column to Map. Allowed Java type: + * + * + * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergStruct( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, + input.getClass(), + "STRUCT", + new String[] {"Map"}, + insertRowIndex); + } + if (!((Map) input).keySet().stream().allMatch(key -> key instanceof String)) { + throw new SFException( + ErrorCode.INVALID_FORMAT_ROW, + String.format( + "Flied name of a struct must be of type String, rowIndex:%d", insertRowIndex)); + } + + return (Map) input; + } + + /** + * Validate and parse Iceberg list column to an Iterable. Allowed Java type: + * + *
    + *
  • Iterable + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Iterable + */ + static Iterable validateAndParseIcebergList( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Iterable)) { + throw typeNotAllowedException( + columnName, input.getClass(), "LIST", new String[] {"Iterable"}, insertRowIndex); + } + return (Iterable) input; + } + + /** + * Validate and parse Iceberg map column to a map. Allowed Java type: + * + *
    + *
  • Map + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergMap( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, input.getClass(), "MAP", new String[] {"Map"}, insertRowIndex); + } + return (Map) input; + } + static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { BigDecimal comparand = diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java index 0ec671bf7..ca2b8398d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java @@ -1,14 +1,28 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.streaming.internal.BinaryStringUtils.truncateBytesAsHex; +import static net.snowflake.ingest.utils.Constants.EP_NDV_UNKNOWN; import com.fasterxml.jackson.annotation.JsonProperty; import java.math.BigInteger; import java.util.Objects; +import org.apache.parquet.column.statistics.BinaryStatistics; +import org.apache.parquet.column.statistics.BooleanStatistics; +import org.apache.parquet.column.statistics.DoubleStatistics; +import org.apache.parquet.column.statistics.FloatStatistics; +import org.apache.parquet.column.statistics.IntStatistics; +import org.apache.parquet.column.statistics.LongStatistics; +import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.schema.LogicalTypeAnnotation; /** Audit register endpoint/FileColumnPropertyDTO property list. */ class FileColumnProperties { private int columnOrdinal; + private int fieldId; private String minStrValue; private String maxStrValue; @@ -84,6 +98,38 @@ class FileColumnProperties { this.setDistinctValues(stats.getDistinctValues()); } + FileColumnProperties(int fieldId, Statistics statistics) { + this.setColumnOrdinal(fieldId); + this.setFieldId(fieldId); + this.setNullCount(statistics.getNumNulls()); + this.setDistinctValues(EP_NDV_UNKNOWN); + this.setCollation(null); + this.setMaxStrNonCollated(null); + this.setMinStrNonCollated(null); + + if (statistics instanceof BooleanStatistics) { + this.setMinIntValue( + ((BooleanStatistics) statistics).genericGetMin() ? BigInteger.ONE : BigInteger.ZERO); + this.setMaxIntValue( + ((BooleanStatistics) statistics).genericGetMax() ? BigInteger.ONE : BigInteger.ZERO); + } else if (statistics instanceof IntStatistics || statistics instanceof LongStatistics) { + this.setMinIntValue(BigInteger.valueOf(((Number) statistics.genericGetMin()).longValue())); + this.setMaxIntValue(BigInteger.valueOf(((Number) statistics.genericGetMax()).longValue())); + } else if (statistics instanceof FloatStatistics || statistics instanceof DoubleStatistics) { + this.setMinRealValue((Double) statistics.genericGetMin()); + this.setMaxRealValue((Double) statistics.genericGetMax()); + } else if (statistics instanceof BinaryStatistics) { + if (statistics.type().getLogicalTypeAnnotation() + instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) { + this.setMinIntValue(new BigInteger(statistics.getMinBytes())); + this.setMaxIntValue(new BigInteger(statistics.getMaxBytes())); + } else { + this.setMinStrValue(truncateBytesAsHex(statistics.getMinBytes(), false)); + this.setMaxStrValue(truncateBytesAsHex(statistics.getMaxBytes(), true)); + } + } + } + @JsonProperty("columnId") public int getColumnOrdinal() { return columnOrdinal; @@ -93,6 +139,15 @@ public void setColumnOrdinal(int columnOrdinal) { this.columnOrdinal = columnOrdinal; } + @JsonProperty("fieldId") + public int getFieldId() { + return fieldId; + } + + public void setFieldId(int fieldId) { + this.fieldId = fieldId; + } + // Annotation required in order to have package private fields serialized @JsonProperty("minStrValue") String getMinStrValue() { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java index 0cf8220bb..3fdfa2c1f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -9,6 +9,7 @@ import java.util.List; import java.util.Map; import net.snowflake.ingest.utils.Pair; +import org.apache.parquet.hadoop.metadata.BlockMetaData; /** * Interface to convert {@link ChannelData} buffered in {@link RowBuffer} to the underlying format @@ -39,6 +40,7 @@ class SerializationResult { final float chunkEstimatedUncompressedSize; final ByteArrayOutputStream chunkData; final Pair chunkMinMaxInsertTimeInMs; + final List blocksMetadata; public SerializationResult( List channelsMetadataList, @@ -46,13 +48,15 @@ public SerializationResult( long rowCount, float chunkEstimatedUncompressedSize, ByteArrayOutputStream chunkData, - Pair chunkMinMaxInsertTimeInMs) { + Pair chunkMinMaxInsertTimeInMs, + List blocksMetadata) { this.channelsMetadataList = channelsMetadataList; this.columnEpStatsMapCombined = columnEpStatsMapCombined; this.rowCount = rowCount; this.chunkEstimatedUncompressedSize = chunkEstimatedUncompressedSize; this.chunkData = chunkData; this.chunkMinMaxInsertTimeInMs = chunkMinMaxInsertTimeInMs; + this.blocksMetadata = blocksMetadata; } } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java index 18a66f4d5..d722d26ad 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java @@ -11,12 +11,18 @@ import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; @@ -35,7 +41,7 @@ class IcebergParquetValueParser { * * @param value column value provided by user in a row * @param type Parquet column type - * @param stats column stats to update + * @param statsMap column stats map to update * @param defaultTimezone default timezone to use for timestamp parsing * @param insertRowsCurrIndex Row index corresponding the row to parse (w.r.t input rows in * insertRows API, and not buffered row) @@ -44,69 +50,105 @@ class IcebergParquetValueParser { static ParquetBufferValue parseColumnValueToParquet( Object value, Type type, - RowBufferStats stats, + Map statsMap, ZoneId defaultTimezone, long insertRowsCurrIndex) { - Utils.assertNotNull("Parquet column stats", stats); + Utils.assertNotNull("Parquet column stats map", statsMap); + return parseColumnValueToParquet( + value, type, statsMap, defaultTimezone, insertRowsCurrIndex, null, false); + } + + private static ParquetBufferValue parseColumnValueToParquet( + Object value, + Type type, + Map statsMap, + ZoneId defaultTimezone, + long insertRowsCurrIndex, + String path, + boolean isdDescendantsOfRepeatingGroup) { + path = (path == null || path.isEmpty()) ? type.getName() : path + "." + type.getName(); float estimatedParquetSize = 0F; if (value != null) { - estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; - PrimitiveType primitiveType = type.asPrimitiveType(); - switch (primitiveType.getPrimitiveTypeName()) { - case BOOLEAN: - int intValue = - DataValidationUtil.validateAndParseBoolean( - type.getName(), value, insertRowsCurrIndex); - value = intValue > 0; - stats.addIntValue(BigInteger.valueOf(intValue)); - estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; - break; - case INT32: - int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex); - value = intVal; - stats.addIntValue(BigInteger.valueOf(intVal)); - estimatedParquetSize += 4; - break; - case INT64: - long longVal = getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex); - value = longVal; - stats.addIntValue(BigInteger.valueOf(longVal)); - estimatedParquetSize += 8; - break; - case FLOAT: - float floatVal = - (float) - DataValidationUtil.validateAndParseReal( - type.getName(), value, insertRowsCurrIndex); - value = floatVal; - stats.addRealValue((double) floatVal); - estimatedParquetSize += 4; - break; - case DOUBLE: - double doubleVal = - DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex); - value = doubleVal; - stats.addRealValue(doubleVal); - estimatedParquetSize += 8; - break; - case BINARY: - byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex); - value = byteVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; - break; - case FIXED_LEN_BYTE_ARRAY: - byte[] fixedLenByteArrayVal = - getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex); - value = fixedLenByteArrayVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + fixedLenByteArrayVal.length; - break; - default: + if (type.isPrimitive()) { + if (!statsMap.containsKey(path)) { throw new SFException( - ErrorCode.UNKNOWN_DATA_TYPE, - type.getLogicalTypeAnnotation(), - primitiveType.getPrimitiveTypeName()); + ErrorCode.INTERNAL_ERROR, + String.format("Stats not found for column: %s", type.getName())); + } + RowBufferStats stats = statsMap.get(path); + estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; + estimatedParquetSize += + isdDescendantsOfRepeatingGroup + ? ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN + : 0; + PrimitiveType primitiveType = type.asPrimitiveType(); + switch (primitiveType.getPrimitiveTypeName()) { + case BOOLEAN: + int intValue = + DataValidationUtil.validateAndParseBoolean( + type.getName(), value, insertRowsCurrIndex); + value = intValue > 0; + stats.addIntValue(BigInteger.valueOf(intValue)); + estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; + break; + case INT32: + int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex); + value = intVal; + stats.addIntValue(BigInteger.valueOf(intVal)); + estimatedParquetSize += 4; + break; + case INT64: + long longVal = + getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex); + value = longVal; + stats.addIntValue(BigInteger.valueOf(longVal)); + estimatedParquetSize += 8; + break; + case FLOAT: + float floatVal = + (float) + DataValidationUtil.validateAndParseReal( + type.getName(), value, insertRowsCurrIndex); + value = floatVal; + stats.addRealValue((double) floatVal); + estimatedParquetSize += 4; + break; + case DOUBLE: + double doubleVal = + DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex); + value = doubleVal; + stats.addRealValue(doubleVal); + estimatedParquetSize += 8; + break; + case BINARY: + byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex); + value = byteVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; + break; + case FIXED_LEN_BYTE_ARRAY: + byte[] fixedLenByteArrayVal = + getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex); + value = fixedLenByteArrayVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + + fixedLenByteArrayVal.length; + break; + default: + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, + type.getLogicalTypeAnnotation(), + primitiveType.getPrimitiveTypeName()); + } + } else { + return getGroupValue( + value, + type.asGroupType(), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isdDescendantsOfRepeatingGroup); } } @@ -115,7 +157,14 @@ static ParquetBufferValue parseColumnValueToParquet( throw new SFException( ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field"); } - stats.incCurrentNullCount(); + if (type.isPrimitive()) { + if (!statsMap.containsKey(path)) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format("Stats not found for column: %s", type.getName())); + } + statsMap.get(path).incCurrentNullCount(); + } } return new ParquetBufferValue(value, estimatedParquetSize); @@ -298,4 +347,166 @@ private static int timeUnitToScale(LogicalTypeAnnotation.TimeUnit timeUnit) { ErrorCode.INTERNAL_ERROR, String.format("Unknown time unit: %s", timeUnit)); } } + + /** + * Parses a group value based on Parquet group logical type. + * + * @param value value to parse + * @param type Parquet column type + * @param statsMap column stats map to update + * @param defaultTimezone default timezone to use for timestamp parsing + * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API + * @param path dot path of the column + * @param isdDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group, + * @return list of parsed values + */ + private static ParquetBufferValue getGroupValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isdDescendantsOfRepeatingGroup) { + LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); + if (logicalTypeAnnotation == null) { + return getStructValue( + value, + type, + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isdDescendantsOfRepeatingGroup); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) { + return get3LevelListValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.MapLogicalTypeAnnotation) { + return get3LevelMapValue( + value, + type, + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isdDescendantsOfRepeatingGroup); + } + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getClass().getSimpleName()); + } + + /** + * Parses a struct value based on Parquet group logical type. The parsed value is a list of + * values, where each element represents a field in the group. For example, an input {@code + * {"field1": 1, "field2": 2}} will be parsed as {@code [1, 2]}. + */ + private static ParquetBufferValue getStructValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isdDescendantsOfRepeatingGroup) { + Map structVal = + DataValidationUtil.validateAndParseIcebergStruct( + type.getName(), value, insertRowsCurrIndex); + List listVal = new ArrayList<>(type.getFieldCount()); + float estimatedParquetSize = 0f; + for (int i = 0; i < type.getFieldCount(); i++) { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + structVal.getOrDefault(type.getFieldName(i), null), + type.getType(i), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isdDescendantsOfRepeatingGroup); + listVal.add(parsedValue.getValue()); + estimatedParquetSize += parsedValue.getSize(); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } + + /** + * Parses an iterable value based on Parquet 3-level list logical type. Please check Parquet + * Logical Types#Lists for more details. The parsed value is a list of lists, where each inner + * list represents a list of elements in the group. For example, an input {@code [1, 2, 3, 4]} + * will be parsed as {@code [[1], [2], [3], [4]]}. + */ + private static ParquetBufferValue get3LevelListValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path) { + Iterable iterableVal = + DataValidationUtil.validateAndParseIcebergList(type.getName(), value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + final AtomicReference estimatedParquetSize = new AtomicReference<>(0f); + iterableVal.forEach( + element -> { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + element, + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + true); + listVal.add(Collections.singletonList(parsedValue.getValue())); + estimatedParquetSize.updateAndGet(sz -> sz + parsedValue.getSize()); + }); + return new ParquetBufferValue(listVal, estimatedParquetSize.get()); + } + + /** + * Parses a map value based on Parquet 3-level map logical type. Please check Parquet + * Logical Types#Maps for more details. The parsed value is a list of lists, where each inner + * list represents a key-value pair in the group. For example, an input {@code {"a": 1, "b": 2}} + * will be parsed as {@code [["a", 1], ["b", 2]]}. + */ + private static ParquetBufferValue get3LevelMapValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isdDescendantsOfRepeatingGroup) { + Map mapVal = + DataValidationUtil.validateAndParseIcebergMap(type.getName(), value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + final AtomicReference estimatedParquetSize = new AtomicReference<>(0f); + mapVal.forEach( + (k, v) -> { + ParquetBufferValue parsedKey = + parseColumnValueToParquet( + k, + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + true); + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + v, + type.getType(0).asGroupType().getType(1), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isdDescendantsOfRepeatingGroup); + listVal.add(Arrays.asList(parsedKey.getValue(), parsedValue.getValue())); + estimatedParquetSize.updateAndGet(sz -> sz + parsedKey.getSize() + parsedValue.getSize()); + }); + return new ParquetBufferValue(listVal, estimatedParquetSize.get()); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java index f89de0aa7..4602b0c9a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java @@ -9,6 +9,8 @@ class ParquetBufferValue { // Parquet uses BitPacking to encode boolean, hence 1 bit per value public static final float BIT_ENCODING_BYTE_LEN = 1.0f / 8; + public static final float REPETITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8; + /** * On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition level. * diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 8865a88c3..cb2349e55 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -143,7 +143,8 @@ private SerializationResult serializeFromJavaObjects( rowCount, chunkEstimatedUncompressedSize, mergedData, - chunkMinMaxInsertTimeInMs); + chunkMinMaxInsertTimeInMs, + parquetWriter.getBlocksMetadata()); } private static void addFileIdToMetadata( diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index b8054ad9f..8132a3798 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -25,6 +25,7 @@ import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; @@ -91,20 +92,46 @@ public void setupSchema(List columns) { if (!column.getNullable()) { addNonNullableFieldName(column.getInternalName()); } - this.statsMap.put( - column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); - - if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT - || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { - this.tempStatsMap.put( + if (!clientBufferParameters.getIsIcebergMode()) { + this.statsMap.put( column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); + new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal(), 0)); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + this.tempStatsMap.put( + column.getInternalName(), + new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal(), 0)); + } } id++; } schema = new MessageType(PARQUET_MESSAGE_TYPE_NAME, parquetTypes); + if (clientBufferParameters.getIsIcebergMode()) { + int ordinal = 0; + String prevParentColumnName = ""; + for (ColumnDescriptor columnDescriptor : schema.getColumns()) { + String parentColumnName = columnDescriptor.getPath()[0]; + if (!parentColumnName.equals(prevParentColumnName)) { + ordinal++; + prevParentColumnName = parentColumnName; + } + String subColumnName = String.join(".", columnDescriptor.getPath()); + int fieldId = + parentColumnName.equals(subColumnName) + ? 0 + : columnDescriptor.getPrimitiveType().getId().intValue(); + RowBufferStats stats = new RowBufferStats(subColumnName, null, ordinal, fieldId); + this.statsMap.put(subColumnName, stats); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + this.tempStatsMap.put( + subColumnName, new RowBufferStats(subColumnName, null, ordinal, fieldId)); + } + } + } tempData.clear(); data.clear(); } @@ -161,6 +188,7 @@ private float addRow( // Create new empty stats just for the current row. Map forkedStatsMap = new HashMap<>(); + statsMap.forEach((columnName, stats) -> forkedStatsMap.put(columnName, stats.forkEmpty())); for (Map.Entry entry : row.entrySet()) { String key = entry.getKey(); @@ -168,18 +196,16 @@ private float addRow( String columnName = LiteralQuoteUtils.unquoteColumnName(key); ParquetColumn parquetColumn = fieldIndex.get(columnName); int colIndex = parquetColumn.index; - RowBufferStats forkedStats = statsMap.get(columnName).forkEmpty(); - forkedStatsMap.put(columnName, forkedStats); ColumnMetadata column = parquetColumn.columnMetadata; ParquetBufferValue valueWithSize = (clientBufferParameters.getIsIcebergMode() ? IcebergParquetValueParser.parseColumnValueToParquet( - value, parquetColumn.type, forkedStats, defaultTimezone, insertRowsCurrIndex) + value, parquetColumn.type, forkedStatsMap, defaultTimezone, insertRowsCurrIndex) : SnowflakeParquetValueParser.parseColumnValueToParquet( value, column, parquetColumn.type.asPrimitiveType().getPrimitiveTypeName(), - forkedStats, + forkedStatsMap.get(columnName), defaultTimezone, insertRowsCurrIndex, clientBufferParameters.isEnableNewJsonParsingLogic())); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java index 395123f1f..5a1aa33b4 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -16,6 +16,7 @@ class RowBufferStats { private final int ordinal; + private final int fieldId; private byte[] currentMinStrValue; private byte[] currentMaxStrValue; private BigInteger currentMinIntValue; @@ -30,15 +31,17 @@ class RowBufferStats { private final String columnDisplayName; /** Creates empty stats */ - RowBufferStats(String columnDisplayName, String collationDefinitionString, int ordinal) { + RowBufferStats( + String columnDisplayName, String collationDefinitionString, int ordinal, int fieldId) { this.columnDisplayName = columnDisplayName; this.collationDefinitionString = collationDefinitionString; this.ordinal = ordinal; + this.fieldId = fieldId; reset(); } RowBufferStats(String columnDisplayName) { - this(columnDisplayName, null, -1); + this(columnDisplayName, null, -1, 0); } void reset() { @@ -55,7 +58,10 @@ void reset() { /** Create new statistics for the same column, with all calculated values set to empty */ RowBufferStats forkEmpty() { return new RowBufferStats( - this.getColumnDisplayName(), this.getCollationDefinitionString(), this.getOrdinal()); + this.getColumnDisplayName(), + this.getCollationDefinitionString(), + this.getOrdinal(), + this.getFieldId()); } // TODO performance test this vs in place update @@ -70,7 +76,10 @@ static RowBufferStats getCombinedStats(RowBufferStats left, RowBufferStats right } RowBufferStats combined = new RowBufferStats( - left.columnDisplayName, left.getCollationDefinitionString(), left.getOrdinal()); + left.columnDisplayName, + left.getCollationDefinitionString(), + left.getOrdinal(), + left.getFieldId()); if (left.currentMinIntValue != null) { combined.addIntValue(left.currentMinIntValue); @@ -217,6 +226,10 @@ public int getOrdinal() { return ordinal; } + int getFieldId() { + return fieldId; + } + /** * Compares two byte arrays lexicographically. If the two arrays share a common prefix then the * lexicographic comparison is the result of comparing two elements, as if by Byte.compare(byte, diff --git a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java index 95e484b70..abb03cdef 100644 --- a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java +++ b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java @@ -65,11 +65,24 @@ public static org.apache.parquet.schema.Type parseIcebergDataTypeStringToParquet int id, String name) { Type icebergType = deserializeIcebergType(icebergDataType); - if (!icebergType.isPrimitiveType()) { - throw new IllegalArgumentException( - String.format("Snowflake supports only primitive Iceberg types, got '%s'", icebergType)); + if (icebergType.isPrimitiveType()) { + return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); + } else { + switch (icebergType.typeId()) { + case LIST: + return typeToMessageType.list(icebergType.asListType(), repetition, id, name); + case MAP: + return typeToMessageType.map(icebergType.asMapType(), repetition, id, name); + case STRUCT: + return typeToMessageType.struct(icebergType.asStructType(), repetition, id, name); + default: + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "Cannot convert Iceberg column to parquet type, name=%s, dataType=%s", + name, icebergDataType)); + } } - return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); } /** diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 45c65a4ea..3d4c38ed1 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -131,6 +131,10 @@ public List getRowCountsFromFooter() { return blockRowCounts; } + public List getBlocksMetadata() { + return writer.getFooter().getBlocks(); + } + public void writeRow(List row) { try { writer.write(row); @@ -344,7 +348,19 @@ private void writeValues(List values, GroupType type) { "Unsupported column type: " + cols.get(i).asPrimitiveType()); } } else { - throw new ParquetEncodingException("Unsupported column type: " + cols.get(i)); + if (cols.get(i).isRepetition(Type.Repetition.REPEATED)) { + for (Object o : values) { + recordConsumer.startGroup(); + if (o != null) { + writeValues((List) o, cols.get(i).asGroupType()); + } + recordConsumer.endGroup(); + } + } else { + recordConsumer.startGroup(); + writeValues((List) val, cols.get(i).asGroupType()); + recordConsumer.endGroup(); + } } recordConsumer.endField(fieldName, i); } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java index fff1fe53e..05db18923 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -96,7 +96,7 @@ private List> createChannelDataPerTable(int metada channelData.setRowCount(metadataRowCount); channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); - channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1, 0)); channelData.setChannelContext( new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); return Collections.singletonList(channelData); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java index 63ba51abb..444c7cd8f 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import org.junit.Assert; @@ -15,7 +19,7 @@ public static Object[] isIceberg() { @Test public void testFileColumnPropertiesConstructor() { // Test simple construction - RowBufferStats stats = new RowBufferStats("COL", null, 1); + RowBufferStats stats = new RowBufferStats("COL", null, 1, 0); stats.addStrValue("bcd"); stats.addStrValue("abcde"); FileColumnProperties props = new FileColumnProperties(stats, isIceberg); @@ -26,7 +30,7 @@ public void testFileColumnPropertiesConstructor() { Assert.assertNull(props.getMaxStrNonCollated()); // Test that truncation is performed - stats = new RowBufferStats("COL", null, 1); + stats = new RowBufferStats("COL", null, 1, 0); stats.addStrValue("aßßßßßßßßßßßßßßßß"); Assert.assertEquals(33, stats.getCurrentMinStrValue().length); props = new FileColumnProperties(stats, isIceberg); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java index a0b4caa1c..6ecf63960 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java @@ -7,26 +7,46 @@ import static java.time.ZoneOffset.UTC; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; +import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import org.junit.Assert; import org.junit.Test; public class IcebergParquetValueParserTest { + static ObjectMapper objectMapper = new ObjectMapper(); + @Test public void parseValueBoolean() { Type type = Types.primitive(PrimitiveTypeName.BOOLEAN, Repetition.OPTIONAL).named("BOOLEAN_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BOOLEAN_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BOOLEAN_COL", rowBufferStats); + } + }; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -42,9 +62,15 @@ public void parseValueInt() { Type type = Types.primitive(PrimitiveTypeName.INT32, Repetition.OPTIONAL).named("INT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("INT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("INT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Integer.MAX_VALUE, type, rowBufferStats, UTC, 0); + Integer.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -63,9 +89,15 @@ public void parseValueDecimalToInt() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("12345.6789"), type, rowBufferStats, UTC, 0); + new BigDecimal("12345.6789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -84,9 +116,15 @@ public void parseValueDateToInt() { .named("DATE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DATE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DATE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01", type, rowBufferStats, UTC, 0); + "2024-01-01", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -102,9 +140,15 @@ public void parseValueLong() { Type type = Types.primitive(PrimitiveTypeName.INT64, Repetition.OPTIONAL).named("LONG_COL"); RowBufferStats rowBufferStats = new RowBufferStats("LONG_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LONG_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Long.MAX_VALUE, type, rowBufferStats, UTC, 0); + Long.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -123,9 +167,15 @@ public void parseValueDecimalToLong() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("123456789.123456789"), type, rowBufferStats, UTC, 0); + new BigDecimal("123456789.123456789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -144,9 +194,15 @@ public void parseValueTimeToLong() { .named("TIME_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIME_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIME_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "12:34:56.789", type, rowBufferStats, UTC, 0); + "12:34:56.789", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -165,9 +221,15 @@ public void parseValueTimestampToLong() { .named("TIMESTAMP_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -186,9 +248,15 @@ public void parseValueTimestampTZToLong() { .named("TIMESTAMP_TZ_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_TZ_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_TZ_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -204,9 +272,15 @@ public void parseValueFloat() { Type type = Types.primitive(PrimitiveTypeName.FLOAT, Repetition.OPTIONAL).named("FLOAT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FLOAT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FLOAT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Float.MAX_VALUE, type, rowBufferStats, UTC, 0); + Float.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -222,9 +296,15 @@ public void parseValueDouble() { Type type = Types.primitive(PrimitiveTypeName.DOUBLE, Repetition.OPTIONAL).named("DOUBLE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DOUBLE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DOUBLE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Double.MAX_VALUE, type, rowBufferStats, UTC, 0); + Double.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -240,9 +320,15 @@ public void parseValueBinary() { Type type = Types.primitive(PrimitiveTypeName.BINARY, Repetition.OPTIONAL).named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; byte[] value = "snowflake_to_the_moon".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -262,9 +348,15 @@ public void parseValueStringToBinary() { .named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; String value = "snowflake_to_the_moon"; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -286,9 +378,15 @@ public void parseValueFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; byte[] value = "snow".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -309,9 +407,15 @@ public void parseValueDecimalToFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; BigDecimal value = new BigDecimal("1234567890.0123456789"); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -322,4 +426,338 @@ public void parseValueDecimalToFixed() { .expectedMinMax(value.unscaledValue()) .assertMatches(); } + + @Test + public void parseList() throws JsonProcessingException { + Type list = + Types.optionalList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + RowBufferStats rowBufferStats = new RowBufferStats("LIST_COL.element"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LIST_COL.element", rowBufferStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, list, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + Arrays.asList(1, 2, 3, 4, 5), list, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList( + objectMapper.readValue("[[1], [2], [3], [4], [5]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 5) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list */ + Type requiredList = + Types.requiredList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredList, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new ArrayList<>(), requiredList, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list with required elements */ + Type requiredElements = + Types.requiredList() + .element(Types.required(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + Collections.singletonList(null), requiredElements, rowBufferStatsMap, UTC, 0)); + } + + @Test + public void parseMap() throws JsonProcessingException { + Type map = + Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + RowBufferStats rowBufferKeyStats = new RowBufferStats("MAP_COL.key"); + RowBufferStats rowBufferValueStats = new RowBufferStats("MAP_COL.value"); + Map rowBufferStatsMap = + new HashMap() { + { + put("MAP_COL.key", rowBufferKeyStats); + put("MAP_COL.value", rowBufferValueStats); + } + }; + IcebergParquetValueParser.parseColumnValueToParquet(null, map, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, 1); + put(2, 2); + } + }, + map, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[[1, 1], [2, 2]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 2 + + (4.0f + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 2) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map */ + Type requiredMap = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredMap, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredMap, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map with required values */ + Type requiredValues = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.required(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, null); + } + }, + requiredValues, + rowBufferStatsMap, + UTC, + 0)); + } + + @Test + public void parseStruct() throws JsonProcessingException { + Type struct = + Types.optionalGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.required(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + + RowBufferStats rowBufferAStats = new RowBufferStats("STRUCT_COL.a"); + RowBufferStats rowBufferBStats = new RowBufferStats("STRUCT_COL.b"); + Map rowBufferStatsMap = + new HashMap() { + { + put("STRUCT_COL.a", rowBufferAStats); + put("STRUCT_COL.b", rowBufferBStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, struct, rowBufferStatsMap, UTC, 0); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put("a", 1); + } + }, + struct, + rowBufferStatsMap, + UTC, + 0)); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + // a is null + put("b", "2"); + put("c", 1); // Ignored + } + }, + struct, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, \"2\"]", ArrayList.class))) + .expectedSize(1 + BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + + /* Test required struct */ + Type requiredStruct = + Types.requiredGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredStruct, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredStruct, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, null]", ArrayList.class))) + .expectedSize(0) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + } + + @Test + public void parseNestedTypes() { + for (int depth = 1; depth <= 100; depth *= 10) { + Map rowBufferStatsMap = new HashMap<>(); + Type type = generateNestedTypeAndStats(depth, "a", rowBufferStatsMap, "a"); + Pair res = generateNestedValueAndReference(depth); + Object value = res.getFirst(); + List reference = (List) res.getSecond(); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + value, type, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(reference)) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + * (depth / 3 + 1)) + .assertMatches(); + } + } + + private static Type generateNestedTypeAndStats( + int depth, String name, Map rowBufferStatsMap, String path) { + if (depth == 0) { + rowBufferStatsMap.put(path, new RowBufferStats(path)); + return Types.optional(PrimitiveTypeName.INT32).named(name); + } + switch (depth % 3) { + case 1: + return Types.optionalList() + .element( + generateNestedTypeAndStats( + depth - 1, "element", rowBufferStatsMap, path + ".element")) + .named(name); + case 2: + return Types.optionalGroup() + .addField(generateNestedTypeAndStats(depth - 1, "a", rowBufferStatsMap, path + ".a")) + .named(name); + case 0: + rowBufferStatsMap.put(path + ".key", new RowBufferStats(path + ".key")); + return Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value( + generateNestedTypeAndStats(depth - 1, "value", rowBufferStatsMap, path + ".value")) + .named(name); + } + return null; + } + + private static Pair generateNestedValueAndReference(int depth) { + if (depth == 0) { + return new Pair<>(1, 1); + } + Pair res = generateNestedValueAndReference(depth - 1); + Assert.assertNotNull(res); + switch (depth % 3) { + case 1: + return new Pair<>( + Collections.singletonList(res.getFirst()), + Collections.singletonList(Collections.singletonList(res.getSecond()))); + case 2: + return new Pair<>( + new java.util.HashMap() { + { + put("a", res.getFirst()); + } + }, + Collections.singletonList(res.getSecond())); + case 0: + return new Pair<>( + new java.util.HashMap() { + { + put(1, res.getFirst()); + } + }, + Collections.singletonList(Arrays.asList(1, res.getSecond()))); + } + return null; + } + + private static ArrayList convertToArrayList(List list) { + ArrayList arrayList = new ArrayList<>(); + for (Object element : list) { + if (element instanceof List) { + // Recursively convert nested lists + arrayList.add(convertToArrayList((List) element)); + } else if (element instanceof String) { + // Convert string to byte array + arrayList.add(((String) element).getBytes()); + } else { + arrayList.add(element); + } + } + return arrayList; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java index c83d339c1..99411be5f 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java @@ -693,6 +693,313 @@ public void buildFieldIcebergBinary() { .assertMatches(); } + @Test + public void buildFieldIcebergStruct() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergList() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"list\"," + + " \"element\": \"int\"," + + " \"element-required\": true," + + " \"element-id\": 1" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergMap() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": \"int\"," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergNestedStructuredDataType() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": {" + + " \"type\": \"list\"," + + " \"element\": {" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + " }," + + " \"element-required\": true," + + " \"element-id\": 1" + + " }," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(valueTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + /** Builder that helps to assert parquet type info */ private static class ParquetTypeInfoAssertionBuilder { private String fieldName; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java index 8480311fa..1aef68663 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java @@ -7,6 +7,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.Assert; /** Builder that helps to assert parsing of values to parquet types */ @@ -16,7 +17,8 @@ class ParquetValueParserAssertionBuilder { private Class valueClass; private Object value; private float size; - private Object minMaxStat; + private Object minStat; + private Object maxStat; private long currentNullCount; static ParquetValueParserAssertionBuilder newBuilder() { @@ -50,7 +52,18 @@ ParquetValueParserAssertionBuilder expectedSize(float size) { } public ParquetValueParserAssertionBuilder expectedMinMax(Object minMaxStat) { - this.minMaxStat = minMaxStat; + this.minStat = minMaxStat; + this.maxStat = minMaxStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMin(Object minStat) { + this.minStat = minStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMax(Object maxStat) { + this.maxStat = maxStat; return this; } @@ -64,41 +77,64 @@ void assertMatches() { if (valueClass.equals(byte[].class)) { Assert.assertArrayEquals((byte[]) value, (byte[]) parquetBufferValue.getValue()); } else { - Assert.assertEquals(value, parquetBufferValue.getValue()); + assertValueEquals(value, parquetBufferValue.getValue()); } Assert.assertEquals(size, parquetBufferValue.getSize(), 0); - if (minMaxStat instanceof BigInteger) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinIntValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxIntValue()); - return; - } else if (minMaxStat instanceof byte[]) { - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMinStrValue()); - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMaxStrValue()); - return; - } else if (valueClass.equals(String.class)) { - // String can have null min/max stats for variant data types - Object min = - rowBufferStats.getCurrentMinStrValue() != null - ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMinStrValue(); - Object max = - rowBufferStats.getCurrentMaxStrValue() != null - ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMaxStrValue(); - Assert.assertEquals(minMaxStat, min); - Assert.assertEquals(minMaxStat, max); - return; - } else if (minMaxStat instanceof Double || minMaxStat instanceof BigDecimal) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinRealValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxRealValue()); - return; + if (rowBufferStats != null) { + if (minStat instanceof BigInteger) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinIntValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxIntValue()); + return; + } else if (minStat instanceof byte[]) { + Assert.assertArrayEquals((byte[]) minStat, rowBufferStats.getCurrentMinStrValue()); + Assert.assertArrayEquals((byte[]) maxStat, rowBufferStats.getCurrentMaxStrValue()); + return; + } else if (valueClass.equals(String.class)) { + // String can have null min/max stats for variant data types + Object min = + rowBufferStats.getCurrentMinStrValue() != null + ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMinStrValue(); + Object max = + rowBufferStats.getCurrentMaxStrValue() != null + ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMaxStrValue(); + Assert.assertEquals(minStat, min); + Assert.assertEquals(maxStat, max); + return; + } else if (minStat instanceof Double || minStat instanceof BigDecimal) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinRealValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxRealValue()); + return; + } + throw new IllegalArgumentException( + String.format("Unknown data type for min stat: %s", minStat.getClass())); } - throw new IllegalArgumentException( - String.format("Unknown data type for min stat: %s", minMaxStat.getClass())); } void assertNull() { Assert.assertNull(parquetBufferValue.getValue()); Assert.assertEquals(currentNullCount, rowBufferStats.getCurrentNullCount()); } + + void assertValueEquals(Object expectedValue, Object actualValue) { + if (expectedValue == null) { + Assert.assertNull(actualValue); + return; + } + if (expectedValue instanceof List) { + Assert.assertTrue(actualValue instanceof List); + List expectedList = (List) expectedValue; + List actualList = (List) actualValue; + Assert.assertEquals(expectedList.size(), actualList.size()); + for (int i = 0; i < expectedList.size(); i++) { + assertValueEquals(expectedList.get(i), actualList.get(i)); + } + } else if (expectedValue.getClass().equals(byte[].class)) { + Assert.assertEquals(byte[].class, actualValue.getClass()); + Assert.assertArrayEquals((byte[]) expectedValue, (byte[]) actualValue); + } else { + Assert.assertEquals(expectedValue, actualValue); + } + } }