diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 39ec66dbb..771e5f71a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; @@ -124,6 +125,8 @@ private SerializationResult serializeFromParquetWriteBuffers( if (mergedChannelWriter != null) { mergedChannelWriter.close(); + this.verifyRowCounts( + "serializeFromParquetWriteBuffers", mergedChannelWriter, channelsDataPerTable, -1); } return new SerializationResult( channelsMetadataList, @@ -216,6 +219,9 @@ private SerializationResult serializeFromJavaObjects( rows.forEach(parquetWriter::writeRow); parquetWriter.close(); + this.verifyRowCounts( + "serializeFromJavaObjects", parquetWriter, channelsDataPerTable, rows.size()); + return new SerializationResult( channelsMetadataList, columnEpStatsMapCombined, @@ -224,4 +230,64 @@ private SerializationResult serializeFromJavaObjects( mergedData, chunkMinMaxInsertTimeInMs); } + + /** + * Validates that rows count in metadata matches the row count in Parquet footer and the row count + * written by the parquet writer + * + * @param serializationType Serialization type, used for logging purposes only + * @param writer Parquet writer writing the data + * @param channelsDataPerTable Channel data + * @param javaSerializationTotalRowCount Total row count when java object serialization is used. + * Used only for logging purposes if there is a mismatch. + */ + private void verifyRowCounts( + String serializationType, + BdecParquetWriter writer, + List> channelsDataPerTable, + long javaSerializationTotalRowCount) { + long parquetTotalRowsWritten = writer.getRowsWritten(); + + List parquetFooterRowsPerBlock = writer.getRowCountFromFooter(); + long parquetTotalRowsInFooter = 0; + for (long perBlockCount : parquetFooterRowsPerBlock) parquetTotalRowsInFooter += perBlockCount; + + long totalRowsInMetadata = 0; + for (ChannelData channelData : channelsDataPerTable) + totalRowsInMetadata += channelData.getRowCount(); + + if (parquetTotalRowsInFooter != totalRowsInMetadata + || parquetTotalRowsWritten != totalRowsInMetadata) { + + final String perChannelRowCountsInMetadata = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + final String perBlockRowCountsInFooter = + parquetFooterRowsPerBlock.stream().map(String::valueOf).collect(Collectors.joining(",")); + + final long channelsCountInMetadata = channelsDataPerTable.size(); + + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[%s]The number of rows in Parquet does not match the number of rows in metadata. " + + "parquetTotalRowsInFooter=%d " + + "totalRowsInMetadata=%d " + + "parquetTotalRowsWritten=%d " + + "perChannelRowCountsInMetadata=%s " + + "perBlockRowCountsInFooter=%s " + + "channelsCountInMetadata=%d " + + "countOfSerializedJavaObjects=%d", + serializationType, + parquetTotalRowsInFooter, + totalRowsInMetadata, + parquetTotalRowsWritten, + perChannelRowCountsInMetadata, + perBlockRowCountsInFooter, + channelsCountInMetadata, + javaSerializationTotalRowCount)); + } + } } diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 8b71cfd0e..6bb8837bd 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -6,8 +6,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; @@ -17,6 +19,7 @@ import org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory; import org.apache.parquet.crypto.FileEncryptionProperties; import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.io.DelegatingPositionOutputStream; import org.apache.parquet.io.OutputFile; import org.apache.parquet.io.ParquetEncodingException; @@ -35,6 +38,7 @@ public class BdecParquetWriter implements AutoCloseable { private final InternalParquetRecordWriter> writer; private final CodecFactory codecFactory; + private long rowsWritten = 0; /** * Creates a BDEC specific parquet writer. @@ -100,14 +104,28 @@ public BdecParquetWriter( encodingProps); } + /** @return List of row counts per block stored in the parquet footer */ + public List getRowCountFromFooter() { + final List blockRowCounts = new ArrayList<>(); + for (BlockMetaData metadata : writer.getFooter().getBlocks()) { + blockRowCounts.add(metadata.getRowCount()); + } + return blockRowCounts; + } + public void writeRow(List row) { try { writer.write(row); + rowsWritten++; } catch (InterruptedException | IOException e) { throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e); } } + public long getRowsWritten() { + return rowsWritten; + } + @Override public void close() throws IOException { try { diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java new file mode 100644 index 000000000..6f9fd2dc2 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -0,0 +1,133 @@ +package net.snowflake.ingest.streaming.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.hadoop.BdecParquetWriter; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class BlobBuilderTest { + + @Test + public void testSerializationErrors() throws Exception { + // Construction succeeds if both data and metadata contain 1 row + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, false)), + Constants.BdecVersion.THREE); + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, true)), + Constants.BdecVersion.THREE); + + // Construction fails if metadata contains 0 rows and data 1 row + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, false)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=false"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromJavaObjects")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalRowsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=1")); + } + + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, true)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=true"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromParquetWriteBuffers")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalRowsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=-1")); + } + } + + /** + * Creates a channel data configurable number of rows in metadata and 1 physical row (using both + * with and without internal buffering optimization) + */ + private List> createChannelDataPerTable( + int metadataRowCount, boolean enableParquetInternalBuffering) throws IOException { + String columnName = "C1"; + ChannelData channelData = Mockito.spy(new ChannelData<>()); + MessageType schema = createSchema(columnName); + Mockito.doReturn( + new ParquetFlusher( + schema, + enableParquetInternalBuffering, + 100L, + Constants.BdecParquetCompression.GZIP)) + .when(channelData) + .createFlusher(); + + channelData.setRowSequencer(1L); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + BdecParquetWriter bdecParquetWriter = + new BdecParquetWriter( + stream, + schema, + new HashMap<>(), + "CHANNEL", + 1000, + Constants.BdecParquetCompression.GZIP); + bdecParquetWriter.writeRow(Collections.singletonList("1")); + channelData.setVectors( + new ParquetChunkData( + Collections.singletonList(Collections.singletonList("A")), + bdecParquetWriter, + stream, + new HashMap<>())); + channelData.setColumnEps(new HashMap<>()); + channelData.setRowCount(metadataRowCount); + channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); + + channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData.setChannelContext( + new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); + return Collections.singletonList(channelData); + } + + private static MessageType createSchema(String columnName) { + ParquetTypeGenerator.ParquetTypeInfo c1 = + ParquetTypeGenerator.generateColumnParquetTypeInfo(createTestTextColumn(columnName), 1); + return new MessageType("bdec", Collections.singletonList(c1.getParquetType())); + } + + private static ColumnMetadata createTestTextColumn(String name) { + ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName(name); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + return colChar; + } +}