Skip to content

Commit

Permalink
Address comments & add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang committed Sep 26, 2024
1 parent acd7dea commit 901a24d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,50 +94,47 @@ private static ParquetBufferValue parseColumnValueToParquet(
switch (primitiveType.getPrimitiveTypeName()) {
case BOOLEAN:
int intValue =
DataValidationUtil.validateAndParseBoolean(
type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseBoolean(path, value, insertRowsCurrIndex);
value = intValue > 0;
stats.addIntValue(BigInteger.valueOf(intValue));
estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN;
break;
case INT32:
int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex);
int intVal = getInt32Value(value, primitiveType, path, insertRowsCurrIndex);
value = intVal;
stats.addIntValue(BigInteger.valueOf(intVal));
estimatedParquetSize += 4;
break;
case INT64:
long longVal =
getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex);
getInt64Value(value, primitiveType, defaultTimezone, path, insertRowsCurrIndex);
value = longVal;
stats.addIntValue(BigInteger.valueOf(longVal));
estimatedParquetSize += 8;
break;
case FLOAT:
float floatVal =
(float)
DataValidationUtil.validateAndParseReal(
type.getName(), value, insertRowsCurrIndex);
(float) DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex);
value = floatVal;
stats.addRealValue((double) floatVal);
estimatedParquetSize += 4;
break;
case DOUBLE:
double doubleVal =
DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex);
value = doubleVal;
stats.addRealValue(doubleVal);
estimatedParquetSize += 8;
break;
case BINARY:
byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex);
byte[] byteVal = getBinaryValue(value, primitiveType, stats, path, insertRowsCurrIndex);
value = byteVal;
estimatedParquetSize +=
ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length;
break;
case FIXED_LEN_BYTE_ARRAY:
byte[] fixedLenByteArrayVal =
getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex);
getFixedLenByteArrayValue(value, primitiveType, stats, path, insertRowsCurrIndex);
value = fixedLenByteArrayVal;
estimatedParquetSize +=
ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN
Expand All @@ -164,7 +161,7 @@ private static ParquetBufferValue parseColumnValueToParquet(
if (value == null) {
if (type.isRepetition(Repetition.REQUIRED)) {
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field");
ErrorCode.INVALID_FORMAT_ROW, path, "Passed null to non nullable field");
}
if (type.isPrimitive()) {
statsMap.get(path).incCurrentNullCount();
Expand All @@ -179,21 +176,21 @@ private static ParquetBufferValue parseColumnValueToParquet(
*
* @param value column value provided by user in a row
* @param type Parquet column type
* @param path column path, used for logging
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @return parsed int32 value
*/
private static int getInt32Value(
Object value, PrimitiveType type, final long insertRowsCurrIndex) {
Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation == null) {
return DataValidationUtil.validateAndParseIcebergInt(
type.getName(), value, insertRowsCurrIndex);
return DataValidationUtil.validateAndParseIcebergInt(path, value, insertRowsCurrIndex);
}
if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) {
return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().intValue();
return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().intValue();
}
if (logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) {
return DataValidationUtil.validateAndParseDate(type.getName(), value, insertRowsCurrIndex);
return DataValidationUtil.validateAndParseDate(path, value, insertRowsCurrIndex);
}
throw new SFException(
ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getPrimitiveTypeName());
Expand All @@ -204,22 +201,26 @@ private static int getInt32Value(
*
* @param value column value provided by user in a row
* @param type Parquet column type
* @param path column path, used for logging
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @return parsed int64 value
*/
private static long getInt64Value(
Object value, PrimitiveType type, ZoneId defaultTimezone, final long insertRowsCurrIndex) {
Object value,
PrimitiveType type,
ZoneId defaultTimezone,
String path,
final long insertRowsCurrIndex) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation == null) {
return DataValidationUtil.validateAndParseIcebergLong(
type.getName(), value, insertRowsCurrIndex);
return DataValidationUtil.validateAndParseIcebergLong(path, value, insertRowsCurrIndex);
}
if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) {
return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().longValue();
return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().longValue();
}
if (logicalTypeAnnotation instanceof TimeLogicalTypeAnnotation) {
return DataValidationUtil.validateAndParseTime(
type.getName(),
path,
value,
timeUnitToScale(((TimeLogicalTypeAnnotation) logicalTypeAnnotation).getUnit()),
insertRowsCurrIndex)
Expand Down Expand Up @@ -250,29 +251,28 @@ private static long getInt64Value(
* @param value value to parse
* @param type Parquet column type
* @param stats column stats to update
* @param path column path, used for logging
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @return string representation
*/
private static byte[] getBinaryValue(
Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) {
Object value,
PrimitiveType type,
RowBufferStats stats,
String path,
final long insertRowsCurrIndex) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation == null) {
byte[] bytes =
DataValidationUtil.validateAndParseBinary(
type.getName(),
value,
Optional.of(Constants.BINARY_COLUMN_MAX_SIZE),
insertRowsCurrIndex);
path, value, Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), insertRowsCurrIndex);
stats.addBinaryValue(bytes);
return bytes;
}
if (logicalTypeAnnotation instanceof StringLogicalTypeAnnotation) {
String string =
DataValidationUtil.validateAndParseString(
type.getName(),
value,
Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE),
insertRowsCurrIndex);
path, value, Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), insertRowsCurrIndex);
stats.addStrValue(string);
return string.getBytes(StandardCharsets.UTF_8);
}
Expand All @@ -286,22 +286,28 @@ private static byte[] getBinaryValue(
* @param value value to parse
* @param type Parquet column type
* @param stats column stats to update
* @param path column path, used for logging
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @return string representation
*/
private static byte[] getFixedLenByteArrayValue(
Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) {
Object value,
PrimitiveType type,
RowBufferStats stats,
String path,
final long insertRowsCurrIndex) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
int length = type.getTypeLength();
byte[] bytes = null;
if (logicalTypeAnnotation == null) {
bytes =
DataValidationUtil.validateAndParseBinary(
type.getName(), value, Optional.of(length), insertRowsCurrIndex);
path, value, Optional.of(length), insertRowsCurrIndex);
stats.addBinaryValue(bytes);
}
if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) {
BigInteger bigIntegerVal = getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue();
BigInteger bigIntegerVal =
getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue();
stats.addIntValue(bigIntegerVal);
bytes = bigIntegerVal.toByteArray();
if (bytes.length < length) {
Expand All @@ -324,15 +330,16 @@ private static byte[] getFixedLenByteArrayValue(
*
* @param value value to parse
* @param type Parquet column type
* @param path column path, used for logging
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @return BigDecimal representation
*/
private static BigDecimal getDecimalValue(
Object value, PrimitiveType type, final long insertRowsCurrIndex) {
Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) {
int scale = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getScale();
int precision = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getPrecision();
BigDecimal bigDecimalValue =
DataValidationUtil.validateAndParseBigDecimal(type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseBigDecimal(path, value, insertRowsCurrIndex);
bigDecimalValue = bigDecimalValue.setScale(scale, RoundingMode.HALF_UP);
DataValidationUtil.checkValueInRange(bigDecimalValue, scale, precision, insertRowsCurrIndex);
return bigDecimalValue;
Expand Down Expand Up @@ -414,8 +421,7 @@ private static ParquetBufferValue getStructValue(
String path,
boolean isDescendantsOfRepeatingGroup) {
Map<String, ?> structVal =
DataValidationUtil.validateAndParseIcebergStruct(
type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseIcebergStruct(path, value, insertRowsCurrIndex);
Set<String> extraFields = structVal.keySet();
List<Object> listVal = new ArrayList<>(type.getFieldCount());
float estimatedParquetSize = 0f;
Expand Down Expand Up @@ -461,7 +467,7 @@ private static ParquetBufferValue get3LevelListValue(
final long insertRowsCurrIndex,
String path) {
Iterable<?> iterableVal =
DataValidationUtil.validateAndParseIcebergList(type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseIcebergList(path, value, insertRowsCurrIndex);
List<Object> listVal = new ArrayList<>();
final AtomicReference<Float> estimatedParquetSize = new AtomicReference<>(0f);
iterableVal.forEach(
Expand Down Expand Up @@ -497,7 +503,7 @@ private static ParquetBufferValue get3LevelMapValue(
String path,
boolean isDescendantsOfRepeatingGroup) {
Map<?, ?> mapVal =
DataValidationUtil.validateAndParseIcebergMap(type.getName(), value, insertRowsCurrIndex);
DataValidationUtil.validateAndParseIcebergMap(path, value, insertRowsCurrIndex);
List<Object> listVal = new ArrayList<>();
final AtomicReference<Float> estimatedParquetSize = new AtomicReference<>(0f);
mapVal.forEach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseBoolean;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseDate;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergInt;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergList;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergLong;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergMap;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergStruct;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObject;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObjectNew;
import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseReal;
Expand Down Expand Up @@ -50,6 +53,7 @@
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;
Expand Down Expand Up @@ -1281,6 +1285,32 @@ public void testValidateAndParseIcebergLong() {
() -> validateAndParseIcebergLong("COL", Double.NEGATIVE_INFINITY, 0));
}

@Test
public void testValidateAndParseIcebergStruct() throws JsonProcessingException {
Map<String, ?> validStruct =
objectMapper.readValue("{\"a\": 1, \"b\":[1, 2, 3], \"c\":{\"d\":3}}", Map.class);
assertEquals(validStruct, validateAndParseIcebergStruct("COL", validStruct, 0));
expectError(
ErrorCode.INVALID_FORMAT_ROW,
() -> validateAndParseIcebergStruct("COL", Collections.singletonMap(1, new Object()), 0));
}

@Test
public void testValidateAndParseIcebergList() throws JsonProcessingException {
List<?> validList = objectMapper.readValue("[1, [2, 3, 4], 5]", List.class);
assertEquals(validList, validateAndParseIcebergList("COL", validList, 0));

expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergList("COL", 1, 0));
}

@Test
public void testValidateAndParseIcebergMap() {
Map<?, ?> validMap = Collections.singletonMap(1, 1);
assertEquals(validMap, validateAndParseIcebergMap("COL", validMap, 0));

expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergMap("COL", 1, 0));
}

/**
* Tests that exception message are constructed correctly when ingesting forbidden Java type, as
* well a value of an allowed type, but in invalid format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ public static Object[] isIceberg() {
@Test
public void testFileColumnPropertiesConstructor() {
// Test simple construction
RowBufferStats stats = new RowBufferStats("COL", null, 1, 0);
RowBufferStats stats = new RowBufferStats("COL", null, 1, 1);
stats.addStrValue("bcd");
stats.addStrValue("abcde");
FileColumnProperties props = new FileColumnProperties(stats, isIceberg);
Assert.assertEquals(1, props.getColumnOrdinal());
Assert.assertEquals(1, props.getFieldId());
Assert.assertEquals("6162636465", props.getMinStrValue());
Assert.assertNull(props.getMinStrNonCollated());
Assert.assertEquals("626364", props.getMaxStrValue());
Expand Down

0 comments on commit 901a24d

Please sign in to comment.