From 59169072d08e1f20a389e6581ef576117ef5fb8d Mon Sep 17 00:00:00 2001 From: Alec Huang Date: Thu, 17 Oct 2024 16:08:46 -0700 Subject: [PATCH] done --- .../streaming/internal/ParquetRowBuffer.java | 4 +- .../ingest/utils/IcebergDataTypeParser.java | 207 +++++++++++++++++- .../datatypes/IcebergNumericTypesIT.java | 54 +++++ 3 files changed, 255 insertions(+), 10 deletions(-) diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index 5e3fa1191..782c84647 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -86,7 +86,9 @@ public class ParquetRowBuffer extends AbstractRowBuffer { public void setupSchema(List columns) { fieldIndex.clear(); metadata.clear(); - metadata.put("sfVer", "1,1"); + if (!clientBufferParameters.getIsIcebergMode()) { + metadata.put("sfVer", "1,1"); + } List parquetTypes = new ArrayList<>(); int id = 1; diff --git a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java index abb03cdef..3f1f5fa01 100644 --- a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java +++ b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java @@ -4,6 +4,14 @@ package net.snowflake.ingest.utils; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -12,10 +20,11 @@ import java.util.Iterator; import java.util.List; import javax.annotation.Nonnull; -import org.apache.iceberg.parquet.TypeToMessageType; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.JsonUtil; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; /** * This class is used to Iceberg data type (include primitive types and nested types) serialization @@ -23,7 +32,9 @@ * *

This code is modified from * GlobalServices/modules/data-lake/datalake-api/src/main/java/com/snowflake/metadata/iceberg - * /IcebergDataTypeParser.java + * /IcebergDataTypeParser.java and + * TypeToMessageType.java */ public class IcebergDataTypeParser { private static final String TYPE = "type"; @@ -44,12 +55,26 @@ public class IcebergDataTypeParser { private static final String ELEMENT_REQUIRED = "element-required"; private static final String VALUE_REQUIRED = "value-required"; + private static final LogicalTypeAnnotation STRING = LogicalTypeAnnotation.stringType(); + private static final LogicalTypeAnnotation DATE = LogicalTypeAnnotation.dateType(); + private static final LogicalTypeAnnotation TIME_MICROS = + LogicalTypeAnnotation.timeType( + false /* not adjusted to UTC */, LogicalTypeAnnotation.TimeUnit.MICROS); + private static final LogicalTypeAnnotation TIMESTAMP_MICROS = + LogicalTypeAnnotation.timestampType( + false /* not adjusted to UTC */, LogicalTypeAnnotation.TimeUnit.MICROS); + private static final LogicalTypeAnnotation TIMESTAMPTZ_MICROS = + LogicalTypeAnnotation.timestampType( + true /* adjusted to UTC */, LogicalTypeAnnotation.TimeUnit.MICROS); + + private static final int DECIMAL_INT32_MAX_DIGITS = 9; + private static final int DECIMAL_INT64_MAX_DIGITS = 18; + private static final int DECIMAL_MAX_DIGITS = 38; + private static final int DECIMAL_MAX_BYTES = 16; + /** Object mapper for this class */ private static final ObjectMapper MAPPER = new ObjectMapper(); - /** Util class that contains the mapping between Iceberg data type and Parquet data type */ - private static final TypeToMessageType typeToMessageType = new TypeToMessageType(); - /** * Get Iceberg data type information by deserialization. * @@ -66,15 +91,15 @@ public static org.apache.parquet.schema.Type parseIcebergDataTypeStringToParquet String name) { Type icebergType = deserializeIcebergType(icebergDataType); if (icebergType.isPrimitiveType()) { - return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); + return primitive(icebergType.asPrimitiveType(), repetition, id, name); } else { switch (icebergType.typeId()) { case LIST: - return typeToMessageType.list(icebergType.asListType(), repetition, id, name); + return list(icebergType.asListType(), repetition, id, name); case MAP: - return typeToMessageType.map(icebergType.asMapType(), repetition, id, name); + return map(icebergType.asMapType(), repetition, id, name); case STRUCT: - return typeToMessageType.struct(icebergType.asStructType(), repetition, id, name); + return struct(icebergType.asStructType(), repetition, id, name); default: throw new SFException( ErrorCode.INTERNAL_ERROR, @@ -208,4 +233,168 @@ public static Types.MapType mapFromJson(JsonNode json) { return Types.MapType.ofOptional(keyId, valueId, keyType, valueType); } } + + private static GroupType struct( + Types.StructType struct, + org.apache.parquet.schema.Type.Repetition repetition, + int id, + String name) { + org.apache.parquet.schema.Types.GroupBuilder builder = + org.apache.parquet.schema.Types.buildGroup(repetition); + + for (Types.NestedField field : struct.fields()) { + builder.addField(field(field)); + } + + return builder.id(id).named(name); + } + + private static org.apache.parquet.schema.Type field(Types.NestedField field) { + org.apache.parquet.schema.Type.Repetition repetition = + field.isOptional() + ? org.apache.parquet.schema.Type.Repetition.OPTIONAL + : org.apache.parquet.schema.Type.Repetition.REQUIRED; + int id = field.fieldId(); + String name = field.name(); + + if (field.type().isPrimitiveType()) { + return primitive(field.type().asPrimitiveType(), repetition, id, name); + + } else { + Type.NestedType nested = field.type().asNestedType(); + if (nested.isStructType()) { + return struct(nested.asStructType(), repetition, id, name); + } else if (nested.isMapType()) { + return map(nested.asMapType(), repetition, id, name); + } else if (nested.isListType()) { + return list(nested.asListType(), repetition, id, name); + } + throw new UnsupportedOperationException("Can't convert unknown type: " + nested); + } + } + + private static GroupType list( + Types.ListType list, + org.apache.parquet.schema.Type.Repetition repetition, + int id, + String name) { + Types.NestedField elementField = list.fields().get(0); + return org.apache.parquet.schema.Types.list(repetition) + .element(field(elementField)) + .id(id) + .named(name); + } + + private static GroupType map( + Types.MapType map, + org.apache.parquet.schema.Type.Repetition repetition, + int id, + String name) { + Types.NestedField keyField = map.fields().get(0); + Types.NestedField valueField = map.fields().get(1); + return org.apache.parquet.schema.Types.map(repetition) + .key(field(keyField)) + .value(field(valueField)) + .id(id) + .named(name); + } + + public static org.apache.parquet.schema.Type primitive( + Type.PrimitiveType primitive, + org.apache.parquet.schema.Type.Repetition repetition, + int id, + String name) { + switch (primitive.typeId()) { + case BOOLEAN: + return org.apache.parquet.schema.Types.primitive(BOOLEAN, repetition).id(id).named(name); + case INTEGER: + return org.apache.parquet.schema.Types.primitive(INT32, repetition).id(id).named(name); + case LONG: + return org.apache.parquet.schema.Types.primitive(INT64, repetition).id(id).named(name); + case FLOAT: + return org.apache.parquet.schema.Types.primitive(FLOAT, repetition).id(id).named(name); + case DOUBLE: + return org.apache.parquet.schema.Types.primitive(DOUBLE, repetition).id(id).named(name); + case DATE: + return org.apache.parquet.schema.Types.primitive(INT32, repetition) + .as(DATE) + .id(id) + .named(name); + case TIME: + return org.apache.parquet.schema.Types.primitive(INT64, repetition) + .as(TIME_MICROS) + .id(id) + .named(name); + case TIMESTAMP: + if (((Types.TimestampType) primitive).shouldAdjustToUTC()) { + return org.apache.parquet.schema.Types.primitive(INT64, repetition) + .as(TIMESTAMPTZ_MICROS) + .id(id) + .named(name); + } else { + return org.apache.parquet.schema.Types.primitive(INT64, repetition) + .as(TIMESTAMP_MICROS) + .id(id) + .named(name); + } + case STRING: + return org.apache.parquet.schema.Types.primitive(BINARY, repetition) + .as(STRING) + .id(id) + .named(name); + case BINARY: + return org.apache.parquet.schema.Types.primitive(BINARY, repetition).id(id).named(name); + case FIXED: + Types.FixedType fixed = (Types.FixedType) primitive; + + return org.apache.parquet.schema.Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .length(fixed.length()) + .id(id) + .named(name); + + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) primitive; + + if (decimal.precision() <= DECIMAL_INT32_MAX_DIGITS) { + /* Store as an int. */ + return org.apache.parquet.schema.Types.primitive(INT32, repetition) + .as(decimalAnnotation(decimal.precision(), decimal.scale())) + .id(id) + .named(name); + + } else if (decimal.precision() <= DECIMAL_INT64_MAX_DIGITS) { + /* Store as a long. */ + return org.apache.parquet.schema.Types.primitive(INT64, repetition) + .as(decimalAnnotation(decimal.precision(), decimal.scale())) + .id(id) + .named(name); + + } else { + /* Does not follow Iceberg spec which should be minimum number of bytes. Use fix(16) (SB16) instead. */ + if (decimal.precision() > DECIMAL_MAX_DIGITS) { + throw new IllegalArgumentException( + "Unsupported decimal precision: " + decimal.precision()); + } + return org.apache.parquet.schema.Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .length(DECIMAL_MAX_BYTES) + .as(decimalAnnotation(decimal.precision(), decimal.scale())) + .id(id) + .named(name); + } + + case UUID: + return org.apache.parquet.schema.Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .length(16) + .as(LogicalTypeAnnotation.uuidType()) + .id(id) + .named(name); + + default: + throw new UnsupportedOperationException("Unsupported type for Parquet: " + primitive); + } + } + + private static LogicalTypeAnnotation decimalAnnotation(int precision, int scale) { + return LogicalTypeAnnotation.decimalType(scale, precision); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/datatypes/IcebergNumericTypesIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/datatypes/IcebergNumericTypesIT.java index e4b0783d4..81f4558e0 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/datatypes/IcebergNumericTypesIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/datatypes/IcebergNumericTypesIT.java @@ -1,10 +1,18 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal.datatypes; import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Random; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Ignore; @@ -26,6 +34,8 @@ public static Object[][] parameters() { @Parameterized.Parameter(1) public static Constants.IcebergSerializationPolicy icebergSerializationPolicy; + static final Random generator = new Random(0x5EED); + @Before public void before() throws Exception { super.beforeIceberg(compressionAlgorithm, icebergSerializationPolicy); @@ -306,6 +316,7 @@ public void testDecimal() throws Exception { testIcebergIngestion("decimal(3, 1)", 12.5f, new FloatProvider()); testIcebergIngestion("decimal(3, 1)", -99, new IntProvider()); testIcebergIngestion("decimal(38, 0)", Long.MAX_VALUE, new LongProvider()); + testIcebergIngestion("decimal(21, 0)", .0, new DoubleProvider()); testIcebergIngestion("decimal(38, 10)", null, new BigDecimalProvider()); testIcebergIngestion( @@ -368,5 +379,48 @@ public void testDecimalAndQueries() throws Exception { Arrays.asList(new BigDecimal("-12.3"), new BigDecimal("-12.3"), null), "select COUNT({columnName}) from {tableName} where {columnName} = -12.3", Arrays.asList(2L)); + + List bigDecimals_9_4 = randomBigDecimal(200, 9, 4); + testIcebergIngestAndQuery( + "decimal(9, 4)", bigDecimals_9_4, "select {columnName} from {tableName}", bigDecimals_9_4); + + List bigDecimals_18_9 = randomBigDecimal(200, 18, 9); + testIcebergIngestAndQuery( + "decimal(18, 9)", + bigDecimals_18_9, + "select {columnName} from {tableName}", + bigDecimals_18_9); + + List bigDecimals_21_0 = randomBigDecimal(200, 21, 0); + testIcebergIngestAndQuery( + "decimal(21, 0)", + bigDecimals_21_0, + "select {columnName} from {tableName}", + bigDecimals_21_0); + + List bigDecimals_38_10 = randomBigDecimal(200, 38, 10); + testIcebergIngestAndQuery( + "decimal(38, 10)", + bigDecimals_38_10, + "select {columnName} from {tableName}", + bigDecimals_38_10); + } + + private static List randomBigDecimal(int count, int precision, int scale) { + List list = new ArrayList<>(); + for (int i = 0; i < count; i++) { + int intPart = generator.nextInt(precision - scale + 1); + int floatPart = generator.nextInt(scale + 1); + if (intPart == 0 && floatPart == 0) { + list.add(null); + continue; + } + list.add( + new BigDecimal( + RandomStringUtils.randomNumeric(intPart) + + "." + + RandomStringUtils.randomNumeric(floatPart))); + } + return list; } }