From a1ab1074f657d8ed515a0b7fc90847f2f66c3955 Mon Sep 17 00:00:00 2001 From: Alec Huang Date: Wed, 25 Sep 2024 18:18:14 -0700 Subject: [PATCH] Address comments & add tests --- .../internal/FileColumnProperties.java | 8 +- .../internal/IcebergParquetValueParser.java | 84 ++++++++++--------- .../streaming/internal/ParquetRowBuffer.java | 7 +- .../streaming/internal/RowBufferStats.java | 10 +-- .../streaming/internal/BlobBuilderTest.java | 4 +- .../internal/DataValidationUtilTest.java | 30 +++++++ .../internal/FileColumnPropertiesTest.java | 5 +- 7 files changed, 95 insertions(+), 53 deletions(-) diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java index eef55cf94..3a8dbc2b6 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java @@ -14,7 +14,7 @@ /** Audit register endpoint/FileColumnPropertyDTO property list. */ class FileColumnProperties { private int columnOrdinal; - private int fieldId; + private Integer fieldId; private String minStrValue; private String maxStrValue; @@ -101,12 +101,12 @@ public void setColumnOrdinal(int columnOrdinal) { } @JsonProperty("fieldId") - @JsonInclude(JsonInclude.Include.NON_DEFAULT) - public int getFieldId() { + @JsonInclude(JsonInclude.Include.NON_NULL) + public Integer getFieldId() { return fieldId; } - public void setFieldId(int fieldId) { + public void setFieldId(Integer fieldId) { this.fieldId = fieldId; } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java index 721731a84..93b62d7c4 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java @@ -94,50 +94,47 @@ private static ParquetBufferValue parseColumnValueToParquet( switch (primitiveType.getPrimitiveTypeName()) { case BOOLEAN: int intValue = - DataValidationUtil.validateAndParseBoolean( - type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseBoolean(path, value, insertRowsCurrIndex); value = intValue > 0; stats.addIntValue(BigInteger.valueOf(intValue)); estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; break; case INT32: - int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex); + int intVal = getInt32Value(value, primitiveType, path, insertRowsCurrIndex); value = intVal; stats.addIntValue(BigInteger.valueOf(intVal)); estimatedParquetSize += 4; break; case INT64: long longVal = - getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex); + getInt64Value(value, primitiveType, defaultTimezone, path, insertRowsCurrIndex); value = longVal; stats.addIntValue(BigInteger.valueOf(longVal)); estimatedParquetSize += 8; break; case FLOAT: float floatVal = - (float) - DataValidationUtil.validateAndParseReal( - type.getName(), value, insertRowsCurrIndex); + (float) DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); value = floatVal; stats.addRealValue((double) floatVal); estimatedParquetSize += 4; break; case DOUBLE: double doubleVal = - DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); value = doubleVal; stats.addRealValue(doubleVal); estimatedParquetSize += 8; break; case BINARY: - byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex); + byte[] byteVal = getBinaryValue(value, primitiveType, stats, path, insertRowsCurrIndex); value = byteVal; estimatedParquetSize += ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; break; case FIXED_LEN_BYTE_ARRAY: byte[] fixedLenByteArrayVal = - getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex); + getFixedLenByteArrayValue(value, primitiveType, stats, path, insertRowsCurrIndex); value = fixedLenByteArrayVal; estimatedParquetSize += ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN @@ -164,7 +161,7 @@ private static ParquetBufferValue parseColumnValueToParquet( if (value == null) { if (type.isRepetition(Repetition.REQUIRED)) { throw new SFException( - ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field"); + ErrorCode.INVALID_FORMAT_ROW, path, "Passed null to non nullable field"); } if (type.isPrimitive()) { statsMap.get(path).incCurrentNullCount(); @@ -179,21 +176,21 @@ private static ParquetBufferValue parseColumnValueToParquet( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int32 value */ private static int getInt32Value( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergInt( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergInt(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().intValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().intValue(); } if (logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) { - return DataValidationUtil.validateAndParseDate(type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseDate(path, value, insertRowsCurrIndex); } throw new SFException( ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getPrimitiveTypeName()); @@ -204,22 +201,26 @@ private static int getInt32Value( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int64 value */ private static long getInt64Value( - Object value, PrimitiveType type, ZoneId defaultTimezone, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + ZoneId defaultTimezone, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergLong( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergLong(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().longValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().longValue(); } if (logicalTypeAnnotation instanceof TimeLogicalTypeAnnotation) { return DataValidationUtil.validateAndParseTime( - type.getName(), + path, value, timeUnitToScale(((TimeLogicalTypeAnnotation) logicalTypeAnnotation).getUnit()), insertRowsCurrIndex) @@ -250,29 +251,28 @@ private static long getInt64Value( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getBinaryValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { byte[] bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), - value, - Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addBinaryValue(bytes); return bytes; } if (logicalTypeAnnotation instanceof StringLogicalTypeAnnotation) { String string = DataValidationUtil.validateAndParseString( - type.getName(), - value, - Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addStrValue(string); return string.getBytes(StandardCharsets.UTF_8); } @@ -286,22 +286,28 @@ private static byte[] getBinaryValue( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getFixedLenByteArrayValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); int length = type.getTypeLength(); byte[] bytes = null; if (logicalTypeAnnotation == null) { bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), value, Optional.of(length), insertRowsCurrIndex); + path, value, Optional.of(length), insertRowsCurrIndex); stats.addBinaryValue(bytes); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - BigInteger bigIntegerVal = getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue(); + BigInteger bigIntegerVal = + getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue(); stats.addIntValue(bigIntegerVal); bytes = bigIntegerVal.toByteArray(); if (bytes.length < length) { @@ -324,15 +330,16 @@ private static byte[] getFixedLenByteArrayValue( * * @param value value to parse * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return BigDecimal representation */ private static BigDecimal getDecimalValue( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { int scale = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getScale(); int precision = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getPrecision(); BigDecimal bigDecimalValue = - DataValidationUtil.validateAndParseBigDecimal(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseBigDecimal(path, value, insertRowsCurrIndex); bigDecimalValue = bigDecimalValue.setScale(scale, RoundingMode.HALF_UP); DataValidationUtil.checkValueInRange(bigDecimalValue, scale, precision, insertRowsCurrIndex); return bigDecimalValue; @@ -414,8 +421,7 @@ private static ParquetBufferValue getStructValue( String path, boolean isDescendantsOfRepeatingGroup) { Map structVal = - DataValidationUtil.validateAndParseIcebergStruct( - type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseIcebergStruct(path, value, insertRowsCurrIndex); Set extraFields = structVal.keySet(); List listVal = new ArrayList<>(type.getFieldCount()); float estimatedParquetSize = 0f; @@ -461,7 +467,7 @@ private static ParquetBufferValue get3LevelListValue( final long insertRowsCurrIndex, String path) { Iterable iterableVal = - DataValidationUtil.validateAndParseIcebergList(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseIcebergList(path, value, insertRowsCurrIndex); List listVal = new ArrayList<>(); final AtomicReference estimatedParquetSize = new AtomicReference<>(0f); iterableVal.forEach( @@ -497,7 +503,7 @@ private static ParquetBufferValue get3LevelMapValue( String path, boolean isDescendantsOfRepeatingGroup) { Map mapVal = - DataValidationUtil.validateAndParseIcebergMap(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseIcebergMap(path, value, insertRowsCurrIndex); List listVal = new ArrayList<>(); final AtomicReference estimatedParquetSize = new AtomicReference<>(0f); mapVal.forEach( diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index 9e7988211..5dd6b4640 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -99,14 +99,17 @@ public void setupSchema(List columns) { this.statsMap.put( column.getInternalName(), new RowBufferStats( - column.getName(), column.getCollation(), column.getOrdinal(), 0 /* fieldId */)); + column.getName(), column.getCollation(), column.getOrdinal(), null /* fieldId */)); if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { this.tempStatsMap.put( column.getInternalName(), new RowBufferStats( - column.getName(), column.getCollation(), column.getOrdinal(), 0 /* fieldId */)); + column.getName(), + column.getCollation(), + column.getOrdinal(), + null /* fieldId */)); } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java index 48f7b47d9..2fae695f0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java @@ -20,10 +20,10 @@ class RowBufferStats { /* * Field id of a column. - * For FDN columns, it's always 0. + * For FDN columns, it's always null. * For Iceberg columns, set to nonzero Iceberg field id if it's a sub-column, otherwise zero. */ - private final int fieldId; + private final Integer fieldId; private byte[] currentMinStrValue; private byte[] currentMaxStrValue; @@ -40,7 +40,7 @@ class RowBufferStats { /** Creates empty stats */ RowBufferStats( - String columnDisplayName, String collationDefinitionString, int ordinal, int fieldId) { + String columnDisplayName, String collationDefinitionString, int ordinal, Integer fieldId) { this.columnDisplayName = columnDisplayName; this.collationDefinitionString = collationDefinitionString; this.ordinal = ordinal; @@ -49,7 +49,7 @@ class RowBufferStats { } RowBufferStats(String columnDisplayName) { - this(columnDisplayName, null, -1, 0); + this(columnDisplayName, null, -1, null); } void reset() { @@ -234,7 +234,7 @@ public int getOrdinal() { return ordinal; } - int getFieldId() { + Integer getFieldId() { return fieldId; } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java index 05db18923..185fa5ded 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -96,7 +96,9 @@ private List> createChannelDataPerTable(int metada channelData.setRowCount(metadataRowCount); channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); - channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1, 0)); + channelData + .getColumnEps() + .putIfAbsent(columnName, new RowBufferStats(columnName, null, 1, isIceberg ? 0 : null)); channelData.setChannelContext( new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); return Collections.singletonList(channelData); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java index 0e738a4b3..4d0c51596 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java @@ -16,7 +16,10 @@ import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseBoolean; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseDate; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergInt; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergList; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergLong; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergMap; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergStruct; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObject; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObjectNew; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseReal; @@ -50,6 +53,7 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.TimeZone; @@ -1281,6 +1285,32 @@ public void testValidateAndParseIcebergLong() { () -> validateAndParseIcebergLong("COL", Double.NEGATIVE_INFINITY, 0)); } + @Test + public void testValidateAndParseIcebergStruct() throws JsonProcessingException { + Map validStruct = + objectMapper.readValue("{\"a\": 1, \"b\":[1, 2, 3], \"c\":{\"d\":3}}", Map.class); + assertEquals(validStruct, validateAndParseIcebergStruct("COL", validStruct, 0)); + expectError( + ErrorCode.INVALID_FORMAT_ROW, + () -> validateAndParseIcebergStruct("COL", Collections.singletonMap(1, new Object()), 0)); + } + + @Test + public void testValidateAndParseIcebergList() throws JsonProcessingException { + List validList = objectMapper.readValue("[1, [2, 3, 4], 5]", List.class); + assertEquals(validList, validateAndParseIcebergList("COL", validList, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergList("COL", 1, 0)); + } + + @Test + public void testValidateAndParseIcebergMap() { + Map validMap = Collections.singletonMap(1, 1); + assertEquals(validMap, validateAndParseIcebergMap("COL", validMap, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergMap("COL", 1, 0)); + } + /** * Tests that exception message are constructed correctly when ingesting forbidden Java type, as * well a value of an allowed type, but in invalid format diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java index 444c7cd8f..7b131b310 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java @@ -19,18 +19,19 @@ public static Object[] isIceberg() { @Test public void testFileColumnPropertiesConstructor() { // Test simple construction - RowBufferStats stats = new RowBufferStats("COL", null, 1, 0); + RowBufferStats stats = new RowBufferStats("COL", null, 1, isIceberg ? 1 : null); stats.addStrValue("bcd"); stats.addStrValue("abcde"); FileColumnProperties props = new FileColumnProperties(stats, isIceberg); Assert.assertEquals(1, props.getColumnOrdinal()); + Assert.assertEquals(isIceberg ? 1 : null, props.getFieldId()); Assert.assertEquals("6162636465", props.getMinStrValue()); Assert.assertNull(props.getMinStrNonCollated()); Assert.assertEquals("626364", props.getMaxStrValue()); Assert.assertNull(props.getMaxStrNonCollated()); // Test that truncation is performed - stats = new RowBufferStats("COL", null, 1, 0); + stats = new RowBufferStats("COL", null, 1, isIceberg ? 0 : null); stats.addStrValue("aßßßßßßßßßßßßßßßß"); Assert.assertEquals(33, stats.getCurrentMinStrValue().length); props = new FileColumnProperties(stats, isIceberg);