Skip to content

Commit

Permalink
SNOW-1465503 Check row count in Parquet footer before committing
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lsembera committed Jun 25, 2024
1 parent 1357e74 commit 02de7dc
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.zip.CRC32;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.Cryptor;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.Logging;
import net.snowflake.ingest.utils.Pair;
import net.snowflake.ingest.utils.SFException;
import org.apache.commons.codec.binary.Hex;
import org.apache.parquet.hadoop.BdecParquetReader;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;

/**
* Build a single blob file that contains file header plus data. The header will be a
Expand Down Expand Up @@ -89,6 +95,8 @@ static <T> Blob constructBlobAndMetadata(
byte[] paddedChunkData = paddedChunk.getFirst();
int paddedChunkLength = paddedChunk.getSecond();

verifyRowCounts(serializedChunk, channelsDataPerTable, paddedChunkData, paddedChunkLength);

// Encrypt the compressed chunk data, the encryption key is derived using the key from
// server with the full blob path.
// We need to maintain IV as a block counter for the whole file, even interleaved,
Expand Down Expand Up @@ -152,6 +160,50 @@ static <T> Blob constructBlobAndMetadata(
return new Blob(blobBytes, chunksMetadataList, new BlobStats());
}

/**
* Safety check to verify whether the number of rows in the parquet footer matches the number of
* rows in metadata
*/
private static <T> void verifyRowCounts(
Flusher.SerializationResult serializationResult,
List<ChannelData<T>> channelsDataPerTable,
byte[] paddedChunkData,
int chunkLength) {
final ParquetMetadata metadata =
BdecParquetReader.readParquetFooter(paddedChunkData, chunkLength);
long totalRowCountFromParquetFooter = 0;
for (BlockMetaData blockMetaData : metadata.getBlocks()) {
totalRowCountFromParquetFooter += blockMetaData.getRowCount();
}

if (totalRowCountFromParquetFooter != serializationResult.rowCount) {
final String perChannelRowCounts =
channelsDataPerTable.stream()
.map(x -> String.valueOf(x.getRowCount()))
.collect(Collectors.joining(","));

final String perFooterBlockRowCounts =
metadata.getBlocks().stream()
.map(x -> String.valueOf(x.getRowCount()))
.collect(Collectors.joining(","));

throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format(
"The number of rows in Parquet footer does not match the number of rows in metadata. "
+ "totalRowCountFromParquetFooter=%d "
+ "totalRowCountFromSerializationResult=%d "
+ "channelCountFromSerializationResult=%d "
+ "perChannelRowCountsFromChannelData=%s "
+ "perFooterBlockRowCounts=%s",
totalRowCountFromParquetFooter,
serializationResult.rowCount,
serializationResult.channelsMetadataList.size(),
perChannelRowCounts,
perFooterBlockRowCounts));
}
}

/**
* Pad the compressed data for encryption. Encryption needs padding to the
* ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES to align with decryption on the Snowflake query path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ private SerializationResult serializeFromParquetWriteBuffers(
ByteArrayOutputStream mergedChunkData = new ByteArrayOutputStream();
Pair<Long, Long> chunkMinMaxInsertTimeInMs = null;

for (ChannelData<ParquetChunkData> data : channelsDataPerTable) {
for (int channelPosition = 0;
channelPosition < channelsDataPerTable.size();
channelPosition++) {
final ChannelData<ParquetChunkData> data = channelsDataPerTable.get(channelPosition);
// Create channel metadata
ChannelMetadata channelMetadata =
ChannelMetadata.builder()
Expand Down Expand Up @@ -112,6 +115,23 @@ private SerializationResult serializeFromParquetWriteBuffers(
}

rowCount += data.getRowCount();

// We check if the number of rows collectively written for channels encountered so far matches
// the number of rows in metadata
if (mergedChannelWriter.getRowsWritten() != rowCount) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format(
"[serializeFromParquetWriteBuffers] The actual number of rows and the number of"
+ " rows in channel metadata do not match channelName=%s rowCountWritten=%d"
+ " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d",
data.getChannelContext().getName(),
mergedChannelWriter.getRowsWritten(),
rowCount,
channelPosition,
channelsDataPerTable.size()));
}

chunkEstimatedUncompressedSize += data.getBufferSize();

logger.logDebug(
Expand Down Expand Up @@ -147,7 +167,10 @@ private SerializationResult serializeFromJavaObjects(
ByteArrayOutputStream mergedData = new ByteArrayOutputStream();
Pair<Long, Long> chunkMinMaxInsertTimeInMs = null;

for (ChannelData<ParquetChunkData> data : channelsDataPerTable) {
for (int channelPosition = 0;
channelPosition < channelsDataPerTable.size();
channelPosition++) {
final ChannelData<ParquetChunkData> data = channelsDataPerTable.get(channelPosition);
// Create channel metadata
ChannelMetadata channelMetadata =
ChannelMetadata.builder()
Expand Down Expand Up @@ -189,6 +212,22 @@ private SerializationResult serializeFromJavaObjects(
chunkMinMaxInsertTimeInMs, data.getMinMaxInsertTimeInMs());
}

// We check if the number of rows in the current channel matches the number of rows in
// metadata
if (data.getVectors().rows.size() != data.getRowCount()) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format(
"[serializeFromJavaObjects] The actual number of rows and the number of rows in"
+ " channel metadata do not match channelName=%s actualRowCount=%d"
+ " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d",
data.getChannelContext().getName(),
data.getVectors().rows.size(),
data.getRowCount(),
channelPosition,
channelsDataPerTable.size()));
}

rows.addAll(data.getVectors().rows);

rowCount += data.getRowCount();
Expand Down
32 changes: 26 additions & 6 deletions src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.parquet.ParquetReadOptions;
import org.apache.parquet.hadoop.api.InitContext;
import org.apache.parquet.hadoop.api.ReadSupport;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.DelegatingSeekableInputStream;
import org.apache.parquet.io.InputFile;
import org.apache.parquet.io.SeekableInputStream;
Expand All @@ -40,8 +41,17 @@ public class BdecParquetReader implements AutoCloseable {
* @throws IOException
*/
public BdecParquetReader(byte[] data) throws IOException {
this(data, data.length);
}

/**
* @param data buffer where the data that has to be read resides.
* @param length Length of the data to read
* @throws IOException
*/
public BdecParquetReader(byte[] data, int length) throws IOException {
ParquetReadOptions options = ParquetReadOptions.builder().build();
ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data), options);
ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data, length), options);
reader = new InternalParquetRecordReader<>(new BdecReadSupport(), options.getRecordFilter());
reader.initialize(fileReader, options);
}
Expand Down Expand Up @@ -86,21 +96,31 @@ public static void readFileIntoWriter(byte[] data, BdecParquetWriter outputWrite
}
}

public static ParquetMetadata readParquetFooter(byte[] data, int length) {
try (final ParquetFileReader reader = ParquetFileReader.open(new BdecInputFile(data, length))) {
return reader.getFooter();
} catch (IOException e) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Failed to read parquet footer", e);
}
}

private static class BdecInputFile implements InputFile {
private final byte[] data;
private final int length;

private BdecInputFile(byte[] data) {
private BdecInputFile(byte[] data, int length) {
this.data = data;
this.length = length;
}

@Override
public long getLength() {
return data.length;
return length;
}

@Override
public SeekableInputStream newStream() {
return new BdecSeekableInputStream(new BdecByteArrayInputStream(data));
return new BdecSeekableInputStream(new BdecByteArrayInputStream(data, length));
}
}

Expand All @@ -124,8 +144,8 @@ public void seek(long newPos) {
}

private static class BdecByteArrayInputStream extends ByteArrayInputStream {
public BdecByteArrayInputStream(byte[] buf) {
super(buf);
public BdecByteArrayInputStream(byte[] buf, int length) {
super(buf, 0, length);
}

long getPos() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.SFException;
Expand Down Expand Up @@ -35,6 +36,7 @@
public class BdecParquetWriter implements AutoCloseable {
private final InternalParquetRecordWriter<List<Object>> writer;
private final CodecFactory codecFactory;
private final AtomicLong rowsWritten = new AtomicLong(0);

/**
* Creates a BDEC specific parquet writer.
Expand Down Expand Up @@ -103,11 +105,16 @@ public BdecParquetWriter(
public void writeRow(List<Object> row) {
try {
writer.write(row);
rowsWritten.incrementAndGet();
} catch (InterruptedException | IOException e) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e);
}
}

public long getRowsWritten() {
return rowsWritten.get();
}

@Override
public void close() throws IOException {
try {
Expand Down

0 comments on commit 02de7dc

Please sign in to comment.