Skip to content

Commit

Permalink
SNOW-1787322 Fix InsertError for structured data type (#888)
Browse files Browse the repository at this point in the history
Currently the `InsertError` doesn't populate extra columns, missing columns, null value for non null columns in `InsertValidationResponse` when ingesting structured data type to Iceberg tables. The PR is fixing this.

We use parquet dot path with escaping dot and back slash character in field name to represent sub-columns. For example, column `x` in `map_col(string, object("a.b" array(object("c\d" object(x int)))))` has a path `MAP_COL.key_value.value.a\.b.list.element.c\\d.x`
  • Loading branch information
sfc-gh-alhuang authored Nov 8, 2024
1 parent c696cd6 commit 70943f7
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 152 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021 Snowflake Computing Inc. All rights reserved.
* Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming;
Expand Down Expand Up @@ -115,6 +115,18 @@ public void setExtraColNames(List<String> extraColNames) {
this.extraColNames = extraColNames;
}

/**
* Add an extra column name in the input row compared with the table schema
*
* @param extraColName the extra column name
*/
public void addExtraColName(String extraColName) {
if (extraColNames == null) {
extraColNames = new ArrayList<>();
}
extraColNames.add(extraColName);
}

/** Get the list of extra column names in the input row compared with the table schema */
public List<String> getExtraColNames() {
return extraColNames;
Expand All @@ -125,6 +137,18 @@ public void setMissingNotNullColNames(List<String> missingNotNullColNames) {
this.missingNotNullColNames = missingNotNullColNames;
}

/**
* Add a missing non-nullable column name in the input row compared with the table schema
*
* @param missingNotNullColName the missing non-nullable column name
*/
public void addMissingNotNullColName(String missingNotNullColName) {
if (missingNotNullColNames == null) {
missingNotNullColNames = new ArrayList<>();
}
missingNotNullColNames.add(missingNotNullColName);
}

/**
* Get the list of missing non-nullable column names in the input row compared with the table
* schema
Expand All @@ -141,6 +165,19 @@ public void setNullValueForNotNullColNames(List<String> nullValueForNotNullColNa
this.nullValueForNotNullColNames = nullValueForNotNullColNames;
}

/**
* Add a name of non-nullable column which have null value in the input row compared with the
* table schema
*
* @param nullValueForNotNullColName the name of non-nullable column which have null value
*/
public void addNullValueForNotNullColName(String nullValueForNotNullColName) {
if (nullValueForNotNullColNames == null) {
nullValueForNotNullColNames = new ArrayList<>();
}
nullValueForNotNullColNames.add(nullValueForNotNullColName);
}

/**
* Get the list of names of non-nullable column which have null value in the input row compared
* with the table schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ public InsertValidationResponse insertRows(
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
rowsSizeInBytes +=
addRow(
row, rowBuffer.bufferedRowCount, rowBuffer.statsMap, inputColumnNames, rowIndex);
row,
rowBuffer.bufferedRowCount,
rowBuffer.statsMap,
inputColumnNames,
rowIndex,
error);
rowBuffer.bufferedRowCount++;
} catch (SFException e) {
error.setException(e);
Expand Down Expand Up @@ -200,7 +205,13 @@ public InsertValidationResponse insertRows(
for (Map<String, Object> row : rows) {
Set<String> inputColumnNames = verifyInputColumns(row, null, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, rowBuffer.tempStatsMap, inputColumnNames, tempRowCount);
addTempRow(
row,
tempRowCount,
rowBuffer.tempStatsMap,
inputColumnNames,
tempRowCount,
new InsertValidationResponse.InsertError(row, 0) /* dummy error */);
tempRowCount++;
if ((long) rowBuffer.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
Expand Down Expand Up @@ -249,7 +260,8 @@ public InsertValidationResponse insertRows(
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, rowBuffer.tempStatsMap, inputColumnNames, rowIndex);
addTempRow(
row, tempRowCount, rowBuffer.tempStatsMap, inputColumnNames, rowIndex, error);
tempRowCount++;
} catch (SFException e) {
error.setException(e);
Expand Down Expand Up @@ -575,14 +587,17 @@ public ChannelData<T> flush() {
* @param formattedInputColumnNames list of input column names after formatting
* @param insertRowIndex Index of the rows given in insertRows API. Not the same as
* bufferedRowIndex
* @param error Insert error object, used to populate error details when doing structured data
* type parsing
* @return row size
*/
abstract float addRow(
Map<String, Object> row,
int bufferedRowIndex,
Map<String, RowBufferStats> statsMap,
Set<String> formattedInputColumnNames,
final long insertRowIndex);
final long insertRowIndex,
InsertValidationResponse.InsertError error);

/**
* Add an input row to the temporary row buffer.
Expand All @@ -595,14 +610,17 @@ abstract float addRow(
* @param statsMap column stats map
* @param formattedInputColumnNames list of input column names after formatting
* @param insertRowIndex index of the row being inserteed from User Input List
* @param error Insert error object, used to populate error details when doing structured data
* type parsing
* @return row size
*/
abstract float addTempRow(
Map<String, Object> row,
int curRowIndex,
Map<String, RowBufferStats> statsMap,
Set<String> formattedInputColumnNames,
long insertRowIndex);
long insertRowIndex,
InsertValidationResponse.InsertError error);

/** Move rows from the temporary buffer to the current row buffer. */
abstract void moveTempRowsToActualBuffer(int tempRowCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package net.snowflake.ingest.streaming.internal;

import static net.snowflake.ingest.streaming.internal.DataValidationUtil.checkFixedLengthByteArray;
import static net.snowflake.ingest.utils.Utils.concatDotPath;

import java.math.BigDecimal;
import java.math.BigInteger;
Expand All @@ -19,7 +20,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import net.snowflake.ingest.streaming.InsertValidationResponse;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.SFException;
Expand Down Expand Up @@ -49,6 +50,7 @@ class IcebergParquetValueParser {
* @param defaultTimezone default timezone to use for timestamp parsing
* @param insertRowsCurrIndex Row index corresponding the row to parse (w.r.t input rows in
* insertRows API, and not buffered row)
* @param error InsertError object to add errors
* @return parsed value and byte size of Parquet internal representation
*/
static ParquetBufferValue parseColumnValueToParquet(
Expand All @@ -57,10 +59,18 @@ static ParquetBufferValue parseColumnValueToParquet(
Map<String, RowBufferStats> statsMap,
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
long insertRowsCurrIndex) {
long insertRowsCurrIndex,
InsertValidationResponse.InsertError error) {
Utils.assertNotNull("Parquet column stats map", statsMap);
return parseColumnValueToParquet(
value, type, statsMap, subColumnFinder, defaultTimezone, insertRowsCurrIndex, false);
value,
type,
statsMap,
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
false /* isDescendantsOfRepeatingGroup */,
error);
}

private static ParquetBufferValue parseColumnValueToParquet(
Expand All @@ -70,7 +80,8 @@ private static ParquetBufferValue parseColumnValueToParquet(
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
long insertRowsCurrIndex,
boolean isDescendantsOfRepeatingGroup) {
boolean isDescendantsOfRepeatingGroup,
InsertValidationResponse.InsertError error) {
float estimatedParquetSize = 0F;

if (type.getId() == null) {
Expand All @@ -86,8 +97,8 @@ private static ParquetBufferValue parseColumnValueToParquet(
}
}

String path = subColumnFinder.getDotPath(id);
if (value != null) {
String path = subColumnFinder.getDotPath(id);
if (type.isPrimitive()) {
RowBufferStats stats = statsMap.get(id.toString());
estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN;
Expand Down Expand Up @@ -160,18 +171,14 @@ private static ParquetBufferValue parseColumnValueToParquet(
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
isDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup,
error);
}
}

if (value == null) {
if (type.isRepetition(Repetition.REQUIRED)) {
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
subColumnFinder.getDotPath(id),
String.format(
"Passed null to non nullable field, rowIndex:%d, column:%s",
insertRowsCurrIndex, subColumnFinder.getDotPath(id)));
error.addNullValueForNotNullColName(path);
}
subColumnFinder
.getSubColumns(id)
Expand Down Expand Up @@ -381,6 +388,7 @@ private static int timeUnitToScale(LogicalTypeAnnotation.TimeUnit timeUnit) {
* @param defaultTimezone default timezone to use for timestamp parsing
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @param isDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group,
* @param error InsertError object to add errors
* @return list of parsed values
*/
private static ParquetBufferValue getGroupValue(
Expand All @@ -390,7 +398,8 @@ private static ParquetBufferValue getGroupValue(
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
boolean isDescendantsOfRepeatingGroup) {
boolean isDescendantsOfRepeatingGroup,
InsertValidationResponse.InsertError error) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation == null) {
return getStructValue(
Expand All @@ -400,15 +409,16 @@ private static ParquetBufferValue getGroupValue(
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
isDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup,
error);
}
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
return get3LevelListValue(
value, type, statsMap, subColumnFinder, defaultTimezone, insertRowsCurrIndex);
value, type, statsMap, subColumnFinder, defaultTimezone, insertRowsCurrIndex, error);
}
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.MapLogicalTypeAnnotation) {
return get3LevelMapValue(
value, type, statsMap, subColumnFinder, defaultTimezone, insertRowsCurrIndex);
value, type, statsMap, subColumnFinder, defaultTimezone, insertRowsCurrIndex, error);
}
throw new SFException(
ErrorCode.UNKNOWN_DATA_TYPE,
Expand All @@ -429,35 +439,48 @@ private static ParquetBufferValue getStructValue(
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
boolean isDescendantsOfRepeatingGroup) {
boolean isDescendantsOfRepeatingGroup,
InsertValidationResponse.InsertError error) {
Map<String, ?> structVal =
DataValidationUtil.validateAndParseIcebergStruct(
subColumnFinder.getDotPath(type.getId()), value, insertRowsCurrIndex);
Set<String> extraFields = new HashSet<>(structVal.keySet());
List<String> missingFields = new ArrayList<>();
List<Object> listVal = new ArrayList<>(type.getFieldCount());
float estimatedParquetSize = 0f;
for (int i = 0; i < type.getFieldCount(); i++) {
ParquetBufferValue parsedValue =
parseColumnValueToParquet(
structVal.getOrDefault(type.getFieldName(i), null),
type.getType(i),
statsMap,
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
isDescendantsOfRepeatingGroup);
extraFields.remove(type.getFieldName(i));
listVal.add(parsedValue.getValue());
estimatedParquetSize += parsedValue.getSize();
if (structVal.containsKey(type.getFieldName(i))) {
ParquetBufferValue parsedValue =
parseColumnValueToParquet(
structVal.get(type.getFieldName(i)),
type.getType(i),
statsMap,
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
isDescendantsOfRepeatingGroup,
error);
listVal.add(parsedValue.getValue());
estimatedParquetSize += parsedValue.getSize();
} else {
if (type.getType(i).isRepetition(Repetition.REQUIRED)) {
missingFields.add(type.getFieldName(i));
} else {
listVal.add(null);
}
}
}
if (!extraFields.isEmpty()) {
String extraFieldsStr = extraFields.stream().collect(Collectors.joining(", ", "[", "]"));
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
"Extra fields: " + extraFieldsStr,
String.format(
"Fields not present in the struct %s shouldn't be specified, rowIndex:%d",
subColumnFinder.getDotPath(type.getId()), insertRowsCurrIndex));

for (String missingField : missingFields) {
List<String> missingFieldPath = new ArrayList<>(subColumnFinder.getPath(type.getId()));
missingFieldPath.add(missingField);
error.addMissingNotNullColName(concatDotPath(missingFieldPath.toArray(new String[0])));
}
for (String extraField : extraFields) {
List<String> extraFieldPath = new ArrayList<>(subColumnFinder.getPath(type.getId()));
extraFieldPath.add(extraField);
error.addExtraColName(concatDotPath(extraFieldPath.toArray(new String[0])));
}
return new ParquetBufferValue(listVal, estimatedParquetSize);
}
Expand All @@ -475,7 +498,8 @@ private static ParquetBufferValue get3LevelListValue(
Map<String, RowBufferStats> statsMap,
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
final long insertRowsCurrIndex) {
final long insertRowsCurrIndex,
InsertValidationResponse.InsertError error) {
Iterable<?> iterableVal =
DataValidationUtil.validateAndParseIcebergList(
subColumnFinder.getDotPath(type.getId()), value, insertRowsCurrIndex);
Expand All @@ -490,7 +514,8 @@ private static ParquetBufferValue get3LevelListValue(
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
true);
true /* isDecedentOfRepeatingGroup */,
error);
listVal.add(Collections.singletonList(parsedValue.getValue()));
estimatedParquetSize += parsedValue.getSize();
}
Expand All @@ -515,7 +540,8 @@ private static ParquetBufferValue get3LevelMapValue(
Map<String, RowBufferStats> statsMap,
SubColumnFinder subColumnFinder,
ZoneId defaultTimezone,
final long insertRowsCurrIndex) {
final long insertRowsCurrIndex,
InsertValidationResponse.InsertError error) {
Map<?, ?> mapVal =
DataValidationUtil.validateAndParseIcebergMap(
subColumnFinder.getDotPath(type.getId()), value, insertRowsCurrIndex);
Expand All @@ -530,7 +556,8 @@ private static ParquetBufferValue get3LevelMapValue(
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
true);
true /* isDecedentOfRepeatingGroup */,
error);
ParquetBufferValue parsedValue =
parseColumnValueToParquet(
entry.getValue(),
Expand All @@ -539,7 +566,8 @@ private static ParquetBufferValue get3LevelMapValue(
subColumnFinder,
defaultTimezone,
insertRowsCurrIndex,
true);
true /* isDecedentOfRepeatingGroup */,
error);
listVal.add(Arrays.asList(parsedKey.getValue(), parsedValue.getValue()));
estimatedParquetSize += parsedKey.getSize() + parsedValue.getSize();
}
Expand Down
Loading

0 comments on commit 70943f7

Please sign in to comment.