diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java index 16b1ededa..9950c44aa 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetChunkData.java @@ -5,6 +5,7 @@ package net.snowflake.ingest.streaming.internal; import java.io.ByteArrayOutputStream; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.parquet.hadoop.BdecParquetWriter; @@ -34,6 +35,16 @@ public ParquetChunkData( this.rows = rows; this.parquetWriter = parquetWriter; this.output = output; - this.metadata = metadata; + // create a defensive copy of the parameter map because the argument map passed here + // may currently be shared across multiple threads. + this.metadata = createDefensiveCopy(metadata); + } + + private Map createDefensiveCopy(final Map metadata) { + final Map copy = new HashMap<>(metadata); + for (String k : metadata.keySet()) { + copy.put(k, metadata.get(k)); + } + return copy; } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java index 60f711700..eb84e363d 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java @@ -4,6 +4,7 @@ import static net.snowflake.ingest.utils.ParameterProvider.ENABLE_NEW_JSON_PARSING_LOGIC_DEFAULT; import static net.snowflake.ingest.utils.ParameterProvider.MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT; import static net.snowflake.ingest.utils.ParameterProvider.MAX_CHUNK_SIZE_IN_BYTES_DEFAULT; +import static org.junit.Assert.fail; import java.math.BigDecimal; import java.math.BigInteger; @@ -15,6 +16,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.utils.Constants; @@ -146,7 +149,7 @@ public void testCollatedColumnsAreRejected() { collatedColumn.setCollation("en-ci"); try { this.rowBufferOnErrorAbort.setupSchema(Collections.singletonList(collatedColumn)); - Assert.fail("Collated columns are not supported"); + fail("Collated columns are not supported"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNSUPPORTED_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -166,7 +169,7 @@ public void buildFieldErrorStates() { testCol.setPrecision(4); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -178,7 +181,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("FIXED"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -190,7 +193,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_NTZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -202,7 +205,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIMESTAMP_TZ"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -214,7 +217,7 @@ public void buildFieldErrorStates() { testCol.setLogicalType("TIME"); try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(testCol)); - Assert.fail("Expected error"); + fail("Expected error"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); } @@ -246,7 +249,7 @@ public void testInvalidLogicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidLogical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode(), e.getVendorCode()); // Do nothing @@ -266,7 +269,7 @@ public void testInvalidPhysicalType() { try { this.rowBufferOnErrorContinue.setupSchema(Collections.singletonList(colInvalidPhysical)); - Assert.fail("Setup should fail if invalid column metadata is provided"); + fail("Setup should fail if invalid column metadata is provided"); } catch (SFException e) { Assert.assertEquals(e.getVendorCode(), ErrorCode.UNKNOWN_DATA_TYPE.getMessageCode()); } @@ -627,7 +630,7 @@ public void testInvalidEPInfo() { try { AbstractRowBuffer.buildEpInfoFromStats(1, colStats); - Assert.fail("should fail when row count is smaller than null count."); + fail("should fail when row count is smaller than null count."); } catch (SFException e) { Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); } @@ -1719,4 +1722,78 @@ public void testOnErrorAbortRowsWithError() { Assert.assertEquals(1, snapshotAbortParquet.size()); Assert.assertEquals(Arrays.asList("a"), snapshotAbortParquet.get(0)); } + + @Test + public void testParquetChunkMetadataCreationIsThreadSafe() throws InterruptedException { + final String testFileA = "testFileA"; + final String testFileB = "testFileB"; + + final ParquetRowBuffer bufferUnderTest = + (ParquetRowBuffer) createTestBuffer(OpenChannelRequest.OnErrorOption.CONTINUE); + + final ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName("COLCHAR"); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + + bufferUnderTest.setupSchema(Collections.singletonList(colChar)); + + loadData(bufferUnderTest, Collections.singletonMap("colChar", "a")); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference> firstFlushResult = new AtomicReference<>(); + final Thread t = + getThreadThatWaitsForLockReleaseAndFlushes( + bufferUnderTest, testFileA, latch, firstFlushResult); + t.start(); + + final ChannelData secondFlushResult = bufferUnderTest.flush(testFileB); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + + latch.countDown(); + t.join(); + + Assert.assertNotNull(firstFlushResult.get()); + Assert.assertEquals(testFileA, getPrimaryFileId(firstFlushResult.get())); + Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult)); + } + + private static Thread getThreadThatWaitsForLockReleaseAndFlushes( + final ParquetRowBuffer bufferUnderTest, + final String filenameToFlush, + final CountDownLatch latch, + final AtomicReference> flushResult) { + return new Thread( + () -> { + try { + latch.await(); + } catch (InterruptedException e) { + fail("Thread was unexpectedly interrupted"); + } + + final ChannelData flush = + loadData(bufferUnderTest, Collections.singletonMap("colChar", "b")) + .flush(filenameToFlush); + flushResult.set(flush); + }); + } + + private static ParquetRowBuffer loadData( + final ParquetRowBuffer bufferToLoad, final Map data) { + final List> validRows = new ArrayList<>(); + validRows.add(data); + + final InsertValidationResponse nResponse = bufferToLoad.insertRows(validRows, "1", "1"); + Assert.assertFalse(nResponse.hasErrors()); + return bufferToLoad; + } + + private static String getPrimaryFileId(final ChannelData chunkData) { + return chunkData.getVectors().metadata.get(Constants.PRIMARY_FILE_ID_KEY); + } }