Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang committed Sep 24, 2024
1 parent 6423f7c commit d30e7fe
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR
* @param insertRowIndex Row index for error reporting
* @return Object cast to Map
*/
static Map<String, Object> validateAndParseIcebergStruct(
static Map<String, ?> validateAndParseIcebergStruct(
String columnName, Object input, long insertRowIndex) {
if (!(input instanceof Map)) {
throw typeNotAllowedException(
Expand All @@ -1077,14 +1077,16 @@ static Map<String, Object> validateAndParseIcebergStruct(
new String[] {"Map<String, Object>"},
insertRowIndex);
}
if (!((Map<?, ?>) input).keySet().stream().allMatch(key -> key instanceof String)) {
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
String.format(
"Flied name of a struct must be of type String, rowIndex:%d", insertRowIndex));
for (Object key : ((Map<?, ?>) input).keySet()) {
if (!(key instanceof String)) {
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
String.format(
"Field name of a struct must be of type String, rowIndex:%d", insertRowIndex));
}
}

return (Map<String, Object>) input;
return (Map<String, ?>) input;
}

/**
Expand All @@ -1099,13 +1101,13 @@ static Map<String, Object> validateAndParseIcebergStruct(
* @param insertRowIndex Row index for error reporting
* @return Object cast to Iterable
*/
static Iterable<Object> validateAndParseIcebergList(
static Iterable<?> validateAndParseIcebergList(
String columnName, Object input, long insertRowIndex) {
if (!(input instanceof Iterable)) {
throw typeNotAllowedException(
columnName, input.getClass(), "LIST", new String[] {"Iterable"}, insertRowIndex);
}
return (Iterable<Object>) input;
return (Iterable<?>) input;
}

/**
Expand All @@ -1120,13 +1122,13 @@ static Iterable<Object> validateAndParseIcebergList(
* @param insertRowIndex Row index for error reporting
* @return Object cast to Map
*/
static Map<Object, Object> validateAndParseIcebergMap(
static Map<?, ?> validateAndParseIcebergMap(
String columnName, Object input, long insertRowIndex) {
if (!(input instanceof Map)) {
throw typeNotAllowedException(
columnName, input.getClass(), "MAP", new String[] {"Map"}, insertRowIndex);
}
return (Map<Object, Object>) input;
return (Map<?, ?>) input;
}

static void checkValueInRange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@
package net.snowflake.ingest.streaming.internal;

import static net.snowflake.ingest.streaming.internal.BinaryStringUtils.truncateBytesAsHex;
import static net.snowflake.ingest.utils.Constants.EP_NDV_UNKNOWN;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.math.BigInteger;
import java.util.Objects;
import org.apache.parquet.column.statistics.BinaryStatistics;
import org.apache.parquet.column.statistics.BooleanStatistics;
import org.apache.parquet.column.statistics.DoubleStatistics;
import org.apache.parquet.column.statistics.FloatStatistics;
import org.apache.parquet.column.statistics.IntStatistics;
import org.apache.parquet.column.statistics.LongStatistics;
import org.apache.parquet.column.statistics.Statistics;
import org.apache.parquet.schema.LogicalTypeAnnotation;

/** Audit register endpoint/FileColumnPropertyDTO property list. */
class FileColumnProperties {
Expand Down Expand Up @@ -60,6 +52,7 @@ class FileColumnProperties {

FileColumnProperties(RowBufferStats stats, boolean setDefaultValues) {
this.setColumnOrdinal(stats.getOrdinal());
this.setFieldId(stats.getFieldId());
this.setCollation(stats.getCollationDefinitionString());
this.setMaxIntValue(
stats.getCurrentMaxIntValue() == null
Expand Down Expand Up @@ -98,38 +91,6 @@ class FileColumnProperties {
this.setDistinctValues(stats.getDistinctValues());
}

FileColumnProperties(int fieldId, Statistics<?> statistics) {
this.setColumnOrdinal(fieldId);
this.setFieldId(fieldId);
this.setNullCount(statistics.getNumNulls());
this.setDistinctValues(EP_NDV_UNKNOWN);
this.setCollation(null);
this.setMaxStrNonCollated(null);
this.setMinStrNonCollated(null);

if (statistics instanceof BooleanStatistics) {
this.setMinIntValue(
((BooleanStatistics) statistics).genericGetMin() ? BigInteger.ONE : BigInteger.ZERO);
this.setMaxIntValue(
((BooleanStatistics) statistics).genericGetMax() ? BigInteger.ONE : BigInteger.ZERO);
} else if (statistics instanceof IntStatistics || statistics instanceof LongStatistics) {
this.setMinIntValue(BigInteger.valueOf(((Number) statistics.genericGetMin()).longValue()));
this.setMaxIntValue(BigInteger.valueOf(((Number) statistics.genericGetMax()).longValue()));
} else if (statistics instanceof FloatStatistics || statistics instanceof DoubleStatistics) {
this.setMinRealValue((Double) statistics.genericGetMin());
this.setMaxRealValue((Double) statistics.genericGetMax());
} else if (statistics instanceof BinaryStatistics) {
if (statistics.type().getLogicalTypeAnnotation()
instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) {
this.setMinIntValue(new BigInteger(statistics.getMinBytes()));
this.setMaxIntValue(new BigInteger(statistics.getMaxBytes()));
} else {
this.setMinStrValue(truncateBytesAsHex(statistics.getMinBytes(), false));
this.setMaxStrValue(truncateBytesAsHex(statistics.getMaxBytes(), true));
}
}
}

@JsonProperty("columnId")
public int getColumnOrdinal() {
return columnOrdinal;
Expand All @@ -140,6 +101,7 @@ public void setColumnOrdinal(int columnOrdinal) {
}

@JsonProperty("fieldId")
@JsonInclude(JsonInclude.Include.NON_DEFAULT)
public int getFieldId() {
return fieldId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package net.snowflake.ingest.streaming.internal;

import static net.snowflake.ingest.streaming.internal.DataValidationUtil.checkFixedLengthByteArray;
import static net.snowflake.ingest.utils.Utils.concatDotPath;

import java.math.BigDecimal;
import java.math.BigInteger;
Expand All @@ -17,7 +18,9 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.SFException;
Expand All @@ -36,6 +39,9 @@
/** Parses a user Iceberg column value into Parquet internal representation for buffering. */
class IcebergParquetValueParser {

static final String THREE_LEVEL_MAP_GROUP_NAME = "key_value";
static final String THREE_LEVEL_LIST_GROUP_NAME = "list";

/**
* Parses a user column value into Parquet internal representation for buffering.
*
Expand Down Expand Up @@ -65,20 +71,23 @@ private static ParquetBufferValue parseColumnValueToParquet(
ZoneId defaultTimezone,
long insertRowsCurrIndex,
String path,
boolean isdDescendantsOfRepeatingGroup) {
path = (path == null || path.isEmpty()) ? type.getName() : path + "." + type.getName();
boolean isDescendantsOfRepeatingGroup) {
path = concatDotPath(path, type.getName());
float estimatedParquetSize = 0F;

if (type.isPrimitive()) {
if (!statsMap.containsKey(path)) {
throw new SFException(
ErrorCode.INTERNAL_ERROR, String.format("Stats not found for column: %s", path));
}
}

if (value != null) {
if (type.isPrimitive()) {
if (!statsMap.containsKey(path)) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format("Stats not found for column: %s", type.getName()));
}
RowBufferStats stats = statsMap.get(path);
estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN;
estimatedParquetSize +=
isdDescendantsOfRepeatingGroup
isDescendantsOfRepeatingGroup
? ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN
: 0;
PrimitiveType primitiveType = type.asPrimitiveType();
Expand Down Expand Up @@ -148,7 +157,7 @@ private static ParquetBufferValue parseColumnValueToParquet(
defaultTimezone,
insertRowsCurrIndex,
path,
isdDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup);
}
}

Expand All @@ -158,11 +167,6 @@ private static ParquetBufferValue parseColumnValueToParquet(
ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field");
}
if (type.isPrimitive()) {
if (!statsMap.containsKey(path)) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format("Stats not found for column: %s", type.getName()));
}
statsMap.get(path).incCurrentNullCount();
}
}
Expand Down Expand Up @@ -357,7 +361,7 @@ private static int timeUnitToScale(LogicalTypeAnnotation.TimeUnit timeUnit) {
* @param defaultTimezone default timezone to use for timestamp parsing
* @param insertRowsCurrIndex Used for logging the row of index given in insertRows API
* @param path dot path of the column
* @param isdDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group,
* @param isDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group,
* @return list of parsed values
*/
private static ParquetBufferValue getGroupValue(
Expand All @@ -367,7 +371,7 @@ private static ParquetBufferValue getGroupValue(
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
String path,
boolean isdDescendantsOfRepeatingGroup) {
boolean isDescendantsOfRepeatingGroup) {
LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation();
if (logicalTypeAnnotation == null) {
return getStructValue(
Expand All @@ -377,7 +381,7 @@ private static ParquetBufferValue getGroupValue(
defaultTimezone,
insertRowsCurrIndex,
path,
isdDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup);
}
if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) {
return get3LevelListValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path);
Expand All @@ -390,7 +394,7 @@ private static ParquetBufferValue getGroupValue(
defaultTimezone,
insertRowsCurrIndex,
path,
isdDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup);
}
throw new SFException(
ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getClass().getSimpleName());
Expand All @@ -408,10 +412,11 @@ private static ParquetBufferValue getStructValue(
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
String path,
boolean isdDescendantsOfRepeatingGroup) {
Map<String, Object> structVal =
boolean isDescendantsOfRepeatingGroup) {
Map<String, ?> structVal =
DataValidationUtil.validateAndParseIcebergStruct(
type.getName(), value, insertRowsCurrIndex);
Set<String> extraFields = structVal.keySet();
List<Object> listVal = new ArrayList<>(type.getFieldCount());
float estimatedParquetSize = 0f;
for (int i = 0; i < type.getFieldCount(); i++) {
Expand All @@ -423,10 +428,21 @@ private static ParquetBufferValue getStructValue(
defaultTimezone,
insertRowsCurrIndex,
path,
isdDescendantsOfRepeatingGroup);
isDescendantsOfRepeatingGroup);
extraFields.remove(type.getFieldName(i));
listVal.add(parsedValue.getValue());
estimatedParquetSize += parsedValue.getSize();
}
if (!extraFields.isEmpty()) {
extraFields =
extraFields.stream().map(f -> concatDotPath(path, f)).collect(Collectors.toSet());
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
"Extra fields: " + extraFields,
String.format(
"Fields not present in the struct shouldn't be specified, rowIndex:%d",
insertRowsCurrIndex));
}
return new ParquetBufferValue(listVal, estimatedParquetSize);
}

Expand All @@ -444,7 +460,7 @@ private static ParquetBufferValue get3LevelListValue(
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
String path) {
Iterable<Object> iterableVal =
Iterable<?> iterableVal =
DataValidationUtil.validateAndParseIcebergList(type.getName(), value, insertRowsCurrIndex);
List<Object> listVal = new ArrayList<>();
final AtomicReference<Float> estimatedParquetSize = new AtomicReference<>(0f);
Expand All @@ -457,7 +473,7 @@ private static ParquetBufferValue get3LevelListValue(
statsMap,
defaultTimezone,
insertRowsCurrIndex,
path,
concatDotPath(path, THREE_LEVEL_LIST_GROUP_NAME),
true);
listVal.add(Collections.singletonList(parsedValue.getValue()));
estimatedParquetSize.updateAndGet(sz -> sz + parsedValue.getSize());
Expand All @@ -479,8 +495,8 @@ private static ParquetBufferValue get3LevelMapValue(
ZoneId defaultTimezone,
final long insertRowsCurrIndex,
String path,
boolean isdDescendantsOfRepeatingGroup) {
Map<Object, Object> mapVal =
boolean isDescendantsOfRepeatingGroup) {
Map<?, ?> mapVal =
DataValidationUtil.validateAndParseIcebergMap(type.getName(), value, insertRowsCurrIndex);
List<Object> listVal = new ArrayList<>();
final AtomicReference<Float> estimatedParquetSize = new AtomicReference<>(0f);
Expand All @@ -493,7 +509,7 @@ private static ParquetBufferValue get3LevelMapValue(
statsMap,
defaultTimezone,
insertRowsCurrIndex,
path,
concatDotPath(path, THREE_LEVEL_MAP_GROUP_NAME),
true);
ParquetBufferValue parsedValue =
parseColumnValueToParquet(
Expand All @@ -502,8 +518,8 @@ private static ParquetBufferValue get3LevelMapValue(
statsMap,
defaultTimezone,
insertRowsCurrIndex,
path,
isdDescendantsOfRepeatingGroup);
concatDotPath(path, THREE_LEVEL_MAP_GROUP_NAME),
isDescendantsOfRepeatingGroup);
listVal.add(Arrays.asList(parsedKey.getValue(), parsedValue.getValue()));
estimatedParquetSize.updateAndGet(sz -> sz + parsedKey.getSize() + parsedValue.getSize());
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ class ParquetBufferValue {
// Parquet uses BitPacking to encode boolean, hence 1 bit per value
public static final float BIT_ENCODING_BYTE_LEN = 1.0f / 8;

public static final float REPETITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8;

/**
* On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition level.
* On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition and
* repetition level.
*
* <ul>
* There are two cases how definition level (0 for null values, 1 for non-null values) is
* encoded:
* There are two cases how definition and repetition level (0 for null values, 1 for non-null
* values) is encoded:
* <li>If there are at least 8 repeated values in a row, they are run-length encoded (length +
* value itself). E.g. 11111111 -> 8 1
* <li>If there are less than 8 repeated values, they are written in group as part of a
Expand All @@ -33,6 +32,8 @@ class ParquetBufferValue {
*/
public static final float DEFINITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8;

public static final float REPETITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8;

// Parquet stores length in 4 bytes before the actual data bytes
public static final int BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN = 4;
private final Object value;
Expand Down
Loading

0 comments on commit d30e7fe

Please sign in to comment.