diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java index b88090e01..8c489c4ff 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java @@ -23,15 +23,21 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import java.util.zip.CRC32; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.Cryptor; +import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; import org.apache.commons.codec.binary.Hex; +import org.apache.parquet.hadoop.BdecParquetReader; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; /** * Build a single blob file that contains file header plus data. The header will be a @@ -89,6 +95,8 @@ static Blob constructBlobAndMetadata( byte[] paddedChunkData = paddedChunk.getFirst(); int paddedChunkLength = paddedChunk.getSecond(); + verifyRowCounts(serializedChunk, channelsDataPerTable, paddedChunkData, paddedChunkLength); + // Encrypt the compressed chunk data, the encryption key is derived using the key from // server with the full blob path. // We need to maintain IV as a block counter for the whole file, even interleaved, @@ -152,6 +160,50 @@ static Blob constructBlobAndMetadata( return new Blob(blobBytes, chunksMetadataList, new BlobStats()); } + /** + * Safety check to verify whether the number of rows in the parquet footer matches the number of + * rows in metadata + */ + static void verifyRowCounts( + Flusher.SerializationResult serializationResult, + List> channelsDataPerTable, + byte[] paddedChunkData, + int chunkLength) { + final ParquetMetadata metadata = + BdecParquetReader.readParquetFooter(paddedChunkData, chunkLength); + long totalRowCountFromParquetFooter = 0; + for (BlockMetaData blockMetaData : metadata.getBlocks()) { + totalRowCountFromParquetFooter += blockMetaData.getRowCount(); + } + + if (totalRowCountFromParquetFooter != serializationResult.rowCount) { + final String perChannelRowCounts = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + final String perFooterBlockRowCounts = + metadata.getBlocks().stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "The number of rows in Parquet footer does not match the number of rows in metadata. " + + "totalRowCountFromParquetFooter=%d " + + "totalRowCountFromSerializationResult=%d " + + "channelCountFromSerializationResult=%d " + + "perChannelRowCountsFromChannelData=%s " + + "perFooterBlockRowCounts=%s", + totalRowCountFromParquetFooter, + serializationResult.rowCount, + serializationResult.channelsMetadataList.size(), + perChannelRowCounts, + perFooterBlockRowCounts)); + } + } + /** * Pad the compressed data for encryption. Encryption needs padding to the * ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES to align with decryption on the Snowflake query path diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 39ec66dbb..32912e11a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -67,7 +67,10 @@ private SerializationResult serializeFromParquetWriteBuffers( ByteArrayOutputStream mergedChunkData = new ByteArrayOutputStream(); Pair chunkMinMaxInsertTimeInMs = null; - for (ChannelData data : channelsDataPerTable) { + for (int channelPosition = 0; + channelPosition < channelsDataPerTable.size(); + channelPosition++) { + final ChannelData data = channelsDataPerTable.get(channelPosition); // Create channel metadata ChannelMetadata channelMetadata = ChannelMetadata.builder() @@ -112,6 +115,23 @@ private SerializationResult serializeFromParquetWriteBuffers( } rowCount += data.getRowCount(); + + // We check if the number of rows collectively written for channels encountered so far matches + // the number of rows in metadata + if (mergedChannelWriter.getRowsWritten() != rowCount) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[serializeFromParquetWriteBuffers] The actual number of rows and the number of" + + " rows in channel metadata do not match channelName=%s rowCountWritten=%d" + + " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d", + data.getChannelContext().getName(), + mergedChannelWriter.getRowsWritten(), + rowCount, + channelPosition, + channelsDataPerTable.size())); + } + chunkEstimatedUncompressedSize += data.getBufferSize(); logger.logDebug( @@ -147,7 +167,10 @@ private SerializationResult serializeFromJavaObjects( ByteArrayOutputStream mergedData = new ByteArrayOutputStream(); Pair chunkMinMaxInsertTimeInMs = null; - for (ChannelData data : channelsDataPerTable) { + for (int channelPosition = 0; + channelPosition < channelsDataPerTable.size(); + channelPosition++) { + final ChannelData data = channelsDataPerTable.get(channelPosition); // Create channel metadata ChannelMetadata channelMetadata = ChannelMetadata.builder() @@ -189,6 +212,22 @@ private SerializationResult serializeFromJavaObjects( chunkMinMaxInsertTimeInMs, data.getMinMaxInsertTimeInMs()); } + // We check if the number of rows in the current channel matches the number of rows in + // metadata + if (data.getVectors().rows.size() != data.getRowCount()) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[serializeFromJavaObjects] The actual number of rows and the number of rows in" + + " channel metadata do not match channelName=%s actualRowCount=%d" + + " rowCountInMetadata=%d currentChannelPosition=%d totalChannelCount=%d", + data.getChannelContext().getName(), + data.getVectors().rows.size(), + data.getRowCount(), + channelPosition, + channelsDataPerTable.size())); + } + rows.addAll(data.getVectors().rows); rowCount += data.getRowCount(); diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java index 1a92a8cd4..0b97b5b0e 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetReader.java @@ -15,6 +15,7 @@ import org.apache.parquet.ParquetReadOptions; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.DelegatingSeekableInputStream; import org.apache.parquet.io.InputFile; import org.apache.parquet.io.SeekableInputStream; @@ -40,8 +41,17 @@ public class BdecParquetReader implements AutoCloseable { * @throws IOException */ public BdecParquetReader(byte[] data) throws IOException { + this(data, data.length); + } + + /** + * @param data buffer where the data that has to be read resides. + * @param length Length of the data to read + * @throws IOException + */ + public BdecParquetReader(byte[] data, int length) throws IOException { ParquetReadOptions options = ParquetReadOptions.builder().build(); - ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data), options); + ParquetFileReader fileReader = ParquetFileReader.open(new BdecInputFile(data, length), options); reader = new InternalParquetRecordReader<>(new BdecReadSupport(), options.getRecordFilter()); reader.initialize(fileReader, options); } @@ -86,21 +96,31 @@ public static void readFileIntoWriter(byte[] data, BdecParquetWriter outputWrite } } + public static ParquetMetadata readParquetFooter(byte[] data, int length) { + try (final ParquetFileReader reader = ParquetFileReader.open(new BdecInputFile(data, length))) { + return reader.getFooter(); + } catch (IOException e) { + throw new SFException(ErrorCode.INTERNAL_ERROR, "Failed to read parquet footer", e); + } + } + private static class BdecInputFile implements InputFile { private final byte[] data; + private final int length; - private BdecInputFile(byte[] data) { + private BdecInputFile(byte[] data, int length) { this.data = data; + this.length = length; } @Override public long getLength() { - return data.length; + return length; } @Override public SeekableInputStream newStream() { - return new BdecSeekableInputStream(new BdecByteArrayInputStream(data)); + return new BdecSeekableInputStream(new BdecByteArrayInputStream(data, length)); } } @@ -124,8 +144,8 @@ public void seek(long newPos) { } private static class BdecByteArrayInputStream extends ByteArrayInputStream { - public BdecByteArrayInputStream(byte[] buf) { - super(buf); + public BdecByteArrayInputStream(byte[] buf, int length) { + super(buf, 0, length); } long getPos() { diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 8b71cfd0e..6e55625b8 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; @@ -35,6 +36,7 @@ public class BdecParquetWriter implements AutoCloseable { private final InternalParquetRecordWriter> writer; private final CodecFactory codecFactory; + private final AtomicLong rowsWritten = new AtomicLong(0); /** * Creates a BDEC specific parquet writer. @@ -103,11 +105,16 @@ public BdecParquetWriter( public void writeRow(List row) { try { writer.write(row); + rowsWritten.incrementAndGet(); } catch (InterruptedException | IOException e) { throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e); } } + public long getRowsWritten() { + return rowsWritten.get(); + } + @Override public void close() throws IOException { try { diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java new file mode 100644 index 000000000..30aac21ab --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -0,0 +1,165 @@ +package net.snowflake.ingest.streaming.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.hadoop.BdecParquetWriter; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class BlobBuilderTest { + + @Test + public void testParquetFooterWrongValue() throws Exception { + List> channelDataPerTable = createChannelDataPerTable(1, false); + ParquetFlusher flusher = + new ParquetFlusher(createSchema("C1"), false, 1000, Constants.BdecParquetCompression.GZIP); + Flusher.SerializationResult serializationResult = + flusher.serialize(channelDataPerTable, "a.bdec"); + + byte[] plainChunk = serializationResult.chunkData.toByteArray(); + byte[] paddedChunk = new byte[plainChunk.length + 10]; + System.arraycopy(plainChunk, 0, paddedChunk, 0, plainChunk.length); + + // Create a new serializationResult with wrong rowCount + Flusher.SerializationResult serializationResultIncorrect = + new Flusher.SerializationResult( + serializationResult.channelsMetadataList, + serializationResult.columnEpStatsMapCombined, + serializationResult.rowCount + 1, + serializationResult.chunkEstimatedUncompressedSize, + serializationResult.chunkData, + serializationResult.chunkMinMaxInsertTimeInMs); + + try { + BlobBuilder.verifyRowCounts( + serializationResultIncorrect, channelDataPerTable, paddedChunk, plainChunk.length); + Assert.fail("Should not pass enableParquetInternalBuffering=false"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("totalRowCountFromParquetFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalRowCountFromSerializationResult=2")); + Assert.assertTrue(e.getMessage().contains("channelCountFromSerializationResult=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsFromChannelData=1")); + Assert.assertTrue(e.getMessage().contains("perFooterBlockRowCounts=1")); + } + } + + @Test + public void testSerializationErrors() throws Exception { + // Construction succeeds if both data and metadata contain 1 row + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, false)), + Constants.BdecVersion.THREE); + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, true)), + Constants.BdecVersion.THREE); + + // Construction fails if metadata contains 0 rows and data 1 row + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, false)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=false"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromJavaObjects")); + Assert.assertTrue( + e.getMessage() + .contains( + "channelName=channel1 actualRowCount=1 rowCountInMetadata=0" + + " currentChannelPosition=0 totalChannelCount=1")); + } + + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, true)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=true"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromParquetWriteBuffers")); + Assert.assertTrue( + e.getMessage() + .contains( + "channelName=channel1 rowCountWritten=1 rowCountInMetadata=0" + + " currentChannelPosition=0 totalChannelCount=1")); + } + } + + /** + * Creates a channel data configurable number of rows in metadata and 1 physical row (using both + * with and without internal buffering optimization) + */ + private List> createChannelDataPerTable( + int metadataRowCount, boolean enableParquetInternalBuffering) throws IOException { + String columnName = "C1"; + ChannelData channelData = Mockito.spy(new ChannelData<>()); + MessageType schema = createSchema(columnName); + Mockito.doReturn( + new ParquetFlusher( + schema, + enableParquetInternalBuffering, + 100L, + Constants.BdecParquetCompression.GZIP)) + .when(channelData) + .createFlusher(); + + channelData.setRowSequencer(1L); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + BdecParquetWriter bdecParquetWriter = + new BdecParquetWriter( + stream, + schema, + new HashMap<>(), + "CHANNEL", + 1000, + Constants.BdecParquetCompression.GZIP); + bdecParquetWriter.writeRow(Collections.singletonList("1")); + channelData.setVectors( + new ParquetChunkData( + Collections.singletonList(Collections.singletonList("A")), + bdecParquetWriter, + stream, + new HashMap<>())); + channelData.setColumnEps(new HashMap<>()); + channelData.setRowCount(metadataRowCount); + channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); + + channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData.setChannelContext( + new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); + return Collections.singletonList(channelData); + } + + private static MessageType createSchema(String columnName) { + ParquetTypeGenerator.ParquetTypeInfo c1 = + ParquetTypeGenerator.generateColumnParquetTypeInfo(createTestTextColumn(columnName), 1); + return new MessageType("bdec", Collections.singletonList(c1.getParquetType())); + } + + private static ColumnMetadata createTestTextColumn(String name) { + ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName(name); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + return colChar; + } +}