diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java index b88090e01..3bd1cd7cf 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java @@ -23,15 +23,21 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import java.util.zip.CRC32; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.Cryptor; +import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; import org.apache.commons.codec.binary.Hex; +import org.apache.parquet.hadoop.BdecParquetReader; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; /** * Build a single blob file that contains file header plus data. The header will be a @@ -89,6 +95,8 @@ static Blob constructBlobAndMetadata( byte[] paddedChunkData = paddedChunk.getFirst(); int paddedChunkLength = paddedChunk.getSecond(); + verifyRowCounts(serializedChunk, channelsDataPerTable, paddedChunkData, paddedChunkLength); + // Encrypt the compressed chunk data, the encryption key is derived using the key from // server with the full blob path. // We need to maintain IV as a block counter for the whole file, even interleaved, @@ -152,6 +160,50 @@ static Blob constructBlobAndMetadata( return new Blob(blobBytes, chunksMetadataList, new BlobStats()); } + /** + * Safety check to verify whether the number of rows in the parquet footer matches the number of + * rows in metadata + */ + private static void verifyRowCounts( + Flusher.SerializationResult serializationResult, + List> channelsDataPerTable, + byte[] paddedChunkData, + int chunkLength) { + final ParquetMetadata metadata = + BdecParquetReader.readParquetFooter(paddedChunkData, chunkLength); + long totalRowCountFromParquetFooter = 0; + for (BlockMetaData blockMetaData : metadata.getBlocks()) { + totalRowCountFromParquetFooter += blockMetaData.getRowCount(); + } + + if (totalRowCountFromParquetFooter != serializationResult.rowCount) { + final String perChannelRowCounts = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + final String perFooterBlockRowCounts = + metadata.getBlocks().stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "The number of rows in Parquet footer does not match the number of rows in metadata. " + + "totalRowCountFromParquetFooter=%d " + + "totalRowCountFromSerializationResult=%d " + + "channelCountFromSerializationResult=%d " + + "perChannelRowCountsFromChannelData=%s " + + "perFooterBlockRowCounts=%s", + totalRowCountFromParquetFooter, + serializationResult.rowCount, + serializationResult.channelsMetadataList.size(), + perChannelRowCounts, + perFooterBlockRowCounts)); + } + } + /** * Pad the compressed data for encryption. Encryption needs padding to the * ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES to align with decryption on the Snowflake query path diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 39ec66dbb..32912e11a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -67,7 +67,10 @@ private SerializationResult serializeFromParquetWriteBuffers( ByteArrayOutputStream mergedChunkData = new ByteArrayOutputStream(); Pair chunkMinMaxInsertTimeInMs = null; - for (ChannelData data : channelsDataPerTable) { + for (int channelPosition = 0; + channelPosition < channelsDataPerTable.size(); + channelPosition++) { + final ChannelData data = channelsDataPerTable.get(channelPosition); // Create channel metadata ChannelMetadata channelMetadata = ChannelMetadata.builder() @@ -112,6 +115,23 @@ private SerializationResult serializeFromParquetWriteBuffers( } rowCount += data.getRowCount(); + + // We check if the number of rows collectively written for channels encountered so far matches + // the number of rows in metadata + if (mergedChannelWriter.getRowsWritten() != rowCount) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[serializeFromParquetWriteBuffers] The actual number of rows and the number of" + + " rows in channel metadata do not match channelName=%s rowCountWritten=%d" + + " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d", + data.getChannelContext().getName(), + mergedChannelWriter.getRowsWritten(), + rowCount, + channelPosition, + channelsDataPerTable.size())); + } + chunkEstimatedUncompressedSize += data.getBufferSize(); logger.logDebug( @@ -147,7 +167,10 @@ private SerializationResult serializeFromJavaObjects( ByteArrayOutputStream mergedData = new ByteArrayOutputStream(); Pair chunkMinMaxInsertTimeInMs = null; - for (ChannelData data : channelsDataPerTable) { + for (int channelPosition = 0; + channelPosition < channelsDataPerTable.size(); + channelPosition++) { + final ChannelData data = channelsDataPerTable.get(channelPosition); // Create channel metadata ChannelMetadata channelMetadata = ChannelMetadata.builder() @@ -189,6 +212,22 @@ private SerializationResult serializeFromJavaObjects( chunkMinMaxInsertTimeInMs, data.getMinMaxInsertTimeInMs()); } + // We check if the number of rows in the current channel matches the number of rows in + // metadata + if (data.getVectors().rows.size() != data.getRowCount()) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[serializeFromJavaObjects] The actual number of rows and the number of rows in" + + " channel metadata do not match channelName=%s actualRowCount=%d" + + " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d", + data.getChannelContext().getName(), + data.getVectors().rows.size(), + data.getRowCount(), + channelPosition, + channelsDataPerTable.size())); + } + rows.addAll(data.getVectors().rows); rowCount += data.getRowCount(); diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java index 1a92a8cd4..0b97b5b0e 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java @@ -15,6 +15,7 @@ import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.DelegatingSeekableInputStream; import org.apache.parquet.io.InputFile; import org.apache.parquet.io.SeekableInputStream; @@ -40,8 +41,17 @@ public class BdecParquetReader implements AutoCloseable { * @throws IOException */ public BdecParquetReader(byte[] data) throws IOException { + this(data, data.length); + } + + /** + * @param data buffer where the data that has to be read resides. + * @param length Length of the data to read + * @throws IOException + */ + public BdecParquetReader(byte[] data, int length) throws IOException { ParquetReadOptions options = ParquetReadOptions.builder().build(); - ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data), options); + ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data, length), options); reader = new InternalParquetRecordReader<>(new BdecReadSupport(), options.getRecordFilter()); reader.initialize(fileReader, options); } @@ -86,21 +96,31 @@ public static void readFileIntoWriter(byte[] data, BdecParquetWriter outputWrite } } + public static ParquetMetadata readParquetFooter(byte[] data, int length) { + try (final ParquetFileReader reader = ParquetFileReader.open(new BdecInputFile(data, length))) { + return reader.getFooter(); + } catch (IOException e) { + throw new SFException(ErrorCode.INTERNAL_ERROR, "Failed to read parquet footer", e); + } + } + private static class BdecInputFile implements InputFile { private final byte[] data; + private final int length; - private BdecInputFile(byte[] data) { + private BdecInputFile(byte[] data, int length) { this.data = data; + this.length = length; } @Override public long getLength() { - return data.length; + return length; } @Override public SeekableInputStream newStream() { - return new BdecSeekableInputStream(new BdecByteArrayInputStream(data)); + return new BdecSeekableInputStream(new BdecByteArrayInputStream(data, length)); } } @@ -124,8 +144,8 @@ public void seek(long newPos) { } private static class BdecByteArrayInputStream extends ByteArrayInputStream { - public BdecByteArrayInputStream(byte[] buf) { - super(buf); + public BdecByteArrayInputStream(byte[] buf, int length) { + super(buf, 0, length); } long getPos() { diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 8b71cfd0e..6e55625b8 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; @@ -35,6 +36,7 @@ public class BdecParquetWriter implements AutoCloseable { private final InternalParquetRecordWriter> writer; private final CodecFactory codecFactory; + private final AtomicLong rowsWritten = new AtomicLong(0); /** * Creates a BDEC specific parquet writer. @@ -103,11 +105,16 @@ public BdecParquetWriter( public void writeRow(List row) { try { writer.write(row); + rowsWritten.incrementAndGet(); } catch (InterruptedException | IOException e) { throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e); } } + public long getRowsWritten() { + return rowsWritten.get(); + } + @Override public void close() throws IOException { try {