From 8c3e7399197573d78cefab169af7016e32780f26 Mon Sep 17 00:00:00 2001 From: Konstantinos Kloudas Date: Thu, 8 Aug 2024 14:23:16 +0000 Subject: [PATCH] Fix PRIMARY_FILE_ID_KEY --- .../streaming/internal/ParquetChunkData.java | 12 ++- .../streaming/internal/RowBufferTest.java | 100 +++++++++++++++--- 2 files changed, 95 insertions(+), 17 deletions(-) diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java index 16b1ededa..4165079de 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java @@ -5,6 +5,7 @@ package net.snowflake.ingest.streaming.internal; import java.io.ByteArrayOutputStream; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.parquet.hadoop.BdecParquetWriter; @@ -34,6 +35,15 @@ public ParquetChunkData( this.rows = rows; this.parquetWriter = parquetWriter; this.output = output; - this.metadata = metadata; + // create a defensive copy of the parameter map. + this.metadata = createDefensiveCopy(metadata); + } + + private Map createDefensiveCopy(final Map metadata) { + final Map copy = new HashMap<>(metadata); + for (String k: metadata.keySet()) { + copy.put(k, metadata.get(k)); + } + return copy; } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java index 2a3bb7edd..c59638c47 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java @@ -3,17 +3,15 @@ import static java.time.ZoneOffset.UTC; import static net.snowflake.ingest.utils.ParameterProvider.MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT; import static net.snowflake.ingest.utils.ParameterProvider.MAX_CHUNK_SIZE_IN_BYTES_DEFAULT; +import static org.junit.Assert.fail; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.utils.Constants; @@ -144,7 +142,7 @@ public void testCollatedColumnsAreRejected() { collatedColumn.setCollation("en-ci"); try { this.rowBufferOnErrorAbort.setupSchema(Collections.singletonList(collatedColumn)); - Assert.fail("Collated columns are not supported"); + fail("Collated columns are not supported"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNSUPPORTED_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -164,7 +162,7 @@ public void buildFieldErrorStates() { testCol.setPrecision(4); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -176,7 +174,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("FIXED"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -188,7 +186,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_NTZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -200,7 +198,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_TZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -212,7 +210,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIME"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -244,7 +242,7 @@ public void testInvalidLogicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidLogical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); // Do nothing @@ -264,7 +262,7 @@ public void testInvalidPhysicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidPhysical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(e.getVendorCode(), ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode()); } @@ -630,7 +628,7 @@ public void testInvalidEPInfo() { try { AbstractRowBuffer.buildEpInfoFromStats(1, colStats); - Assert.fail("should fail when row count is smaller than null count."); + fail("should fail when row count is smaller than null count."); } catch (SFException e) { Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); } @@ -1725,4 +1723,74 @@ public void testOnErrorAbortRowsWithError() { Assert.assertEquals(1, snapshotAbortParquet.size()); Assert.assertEquals(Arrays.asList("a"), snapshotAbortParquet.get(0)); } + + @Test + public void testParquetChunkMetadataCreationIsThreadSafe() throws InterruptedException { + final String testFileA = "testFileA"; + final String testFileB = "testFileB"; + + final ParquetRowBuffer bufferUnderTest = + (ParquetRowBuffer) createTestBuffer(OpenChannelRequest.OnErrorOption.CONTINUE); + + final ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName("COLCHAR"); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + + bufferUnderTest.setupSchema(Collections.singletonList(colChar)); + + loadData(bufferUnderTest, Collections.singletonMap("colChar", "a")); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference> firstFlushResult = new AtomicReference<>(); + final Thread t = getThreadThatWaitsForLockReleaseAndFlushes(bufferUnderTest, testFileA, latch, firstFlushResult); + t.start(); + + final ChannelData secondFlushResult = bufferUnderTest.flush(testFileB); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + + latch.countDown(); + t.join(); + + Assert.assertNotNull(firstFlushResult.get()); + Assert.assertEquals(testFileA, getPrimaryFileId(firstFlushResult.get())); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + } + + private static Thread getThreadThatWaitsForLockReleaseAndFlushes( + final ParquetRowBuffer bufferUnderTest, + final String filenameToFlush, + final CountDownLatch latch, + final AtomicReference> flushResult) { + return new Thread(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + fail("Thread was unexpectedly interrupted"); + } + + final ChannelData flush = + loadData(bufferUnderTest, Collections.singletonMap("colChar", "b")) + .flush(filenameToFlush); + flushResult.set(flush); + }); + } + + private static ParquetRowBuffer loadData(final ParquetRowBuffer bufferToLoad, final Map data) { + final List> validRows = new ArrayList<>(); + validRows.add(data); + + final InsertValidationResponse nResponse = bufferToLoad.insertRows(validRows, "1", "1"); + Assert.assertFalse(nResponse.hasErrors()); + return bufferToLoad; + } + + private static String getPrimaryFileId(final ChannelData chunkData) { + return chunkData.getVectors().metadata.get(Constants.PRIMARY_FILE_ID_KEY); + } }