Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzhang committed Sep 9, 2024
1 parent a3f5a3b commit 3d49a91
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,10 @@ public InsertValidationResponse insertRows(
* Flush the data in the row buffer by taking the ownership of the old vectors and pass all the
* required info back to the flush service to build the blob
*
* @param filePath the name of the file the data will be written in
* @return A ChannelData object that contains the info needed by the flush service to build a blob
*/
@Override
public ChannelData<T> flush(final String filePath) {
public ChannelData<T> flush() {
logger.logDebug("Start get data for channel={}", channelFullyQualifiedName);
if (this.bufferedRowCount > 0) {
Optional<T> oldData = Optional.empty();
Expand All @@ -518,7 +517,7 @@ public ChannelData<T> flush(final String filePath) {
try {
if (this.bufferedRowCount > 0) {
// Transfer the ownership of the vectors
oldData = getSnapshot(filePath);
oldData = getSnapshot();
oldRowCount = this.bufferedRowCount;
oldBufferSize = this.bufferSize;
oldRowSequencer = this.channelState.incrementAndGetRowSequencer();
Expand Down Expand Up @@ -617,10 +616,8 @@ void reset() {

/**
* Get buffered data snapshot for later flushing.
*
* @param filePath the name of the file the data will be written in
*/
abstract Optional<T> getSnapshot(final String filePath);
abstract Optional<T> getSnapshot();

@VisibleForTesting
abstract Object getVectorValueAt(String column, int index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ void distributeFlushTasks(Set<String> tablesToFlush) {
.forEach(
channel -> {
if (channel.isValid()) {
ChannelData<T> data = channel.getData(blobPath);
ChannelData<T> data = channel.getData();
if (data != null) {
channelsDataPerTable.add(data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ boolean hasColumns() {
}

@Override
Optional<ParquetChunkData> getSnapshot(final String filePath) {
Optional<ParquetChunkData> getSnapshot() {
List<List<Object>> oldData = new ArrayList<>();
if (!clientBufferParameters.getEnableParquetInternalBuffering()) {
data.forEach(r -> oldData.add(new ArrayList<>(r)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ InsertValidationResponse insertRows(
* Flush the data in the row buffer by taking the ownership of the old vectors and pass all the
* required info back to the flush service to build the blob
*
* @param filePath the name of the file the data will be written in
* @return A ChannelData object that contains the info needed by the flush service to build a blob
*/
ChannelData<T> flush(final String filePath);
ChannelData<T> flush();

/**
* Close the row buffer and release resources. Note that the caller needs to handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ public String getFullyQualifiedTableName() {
/**
* Get all the data needed to build the blob during flush
*
* @param filePath the name of the file the data will be written in
* @return a ChannelData object
*/
ChannelData<T> getData(final String filePath) {
ChannelData<T> data = this.rowBuffer.flush(filePath);
ChannelData<T> getData() {
ChannelData<T> data = this.rowBuffer.flush();
if (data != null) {
data.setChannelContext(channelFlushContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void setParameterOverride(Map<String, Object> parameterOverride) {

ChannelData<T> flushChannel(String name) {
SnowflakeStreamingIngestChannelInternal<T> channel = channels.get(name);
ChannelData<T> channelData = channel.getRowBuffer().flush(name + "_snowpipe_streaming.bdec");
ChannelData<T> channelData = channel.getRowBuffer().flush();
channelData.setChannelContext(channel.getChannelContext());
this.channelData.add(channelData);
return channelData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ private void testFlushHelper(AbstractRowBuffer<?> rowBuffer) {
float bufferSize = rowBuffer.getSize();

final String filename = "2022/7/13/16/56/testFlushHelper_streaming.bdec";
ChannelData<?> data = rowBuffer.flush(filename);
ChannelData<?> data = rowBuffer.flush();
Assert.assertEquals(2, data.getRowCount());
Assert.assertEquals((Long) 1L, data.getRowSequencer());
Assert.assertEquals(startOffsetToken, data.getStartOffsetToken());
Expand Down Expand Up @@ -734,7 +734,7 @@ private void testStatsE2EHelper(AbstractRowBuffer<?> rowBuffer) {
final String filename = "testStatsE2EHelper_streaming.bdec";
InsertValidationResponse response = rowBuffer.insertRows(Arrays.asList(row1, row2), null, null);
Assert.assertFalse(response.hasErrors());
ChannelData<?> result = rowBuffer.flush(filename);
ChannelData<?> result = rowBuffer.flush();
Map<String, RowBufferStats> columnEpStats = result.getColumnEps();

Assert.assertEquals(
Expand Down Expand Up @@ -779,7 +779,7 @@ private void testStatsE2EHelper(AbstractRowBuffer<?> rowBuffer) {
Assert.assertEquals(-1, columnEpStats.get("COLCHAR").getDistinctValues());

// Confirm we reset
ChannelData<?> resetResults = rowBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> resetResults = rowBuffer.flush();
Assert.assertNull(resetResults);
}

Expand Down Expand Up @@ -838,7 +838,7 @@ private void testStatsE2ETimestampHelper(OpenChannelRequest.OnErrorOption onErro
InsertValidationResponse response =
innerBuffer.insertRows(Arrays.asList(row1, row2, row3), null, null);
Assert.assertFalse(response.hasErrors());
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(3, result.getRowCount());

Assert.assertEquals(
Expand Down Expand Up @@ -907,7 +907,7 @@ private void testE2EDateHelper(OpenChannelRequest.OnErrorOption onErrorOption) {
Assert.assertNull(innerBuffer.getVectorValueAt("COLDATE", 2));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(3, result.getRowCount());

Assert.assertEquals(
Expand Down Expand Up @@ -973,7 +973,7 @@ private void testE2ETimeHelper(OpenChannelRequest.OnErrorOption onErrorOption) {
Assert.assertNull(innerBuffer.getVectorValueAt("COLTIMESB8", 2));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(3, result.getRowCount());

Assert.assertEquals(
Expand Down Expand Up @@ -1204,7 +1204,7 @@ private void doTestFailureHalfwayThroughColumnProcessing(
}
}

ChannelData<?> channelData = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> channelData = innerBuffer.flush();
RowBufferStats statsCol1 = channelData.getColumnEps().get("COLVARCHAR1");
RowBufferStats statsCol2 = channelData.getColumnEps().get("COLVARCHAR2");
RowBufferStats statsCol3 = channelData.getColumnEps().get("COLBOOLEAN1");
Expand Down Expand Up @@ -1264,7 +1264,7 @@ private void testE2EBooleanHelper(OpenChannelRequest.OnErrorOption onErrorOption
Assert.assertNull(innerBuffer.getVectorValueAt("COLBOOLEAN", 2));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(3, result.getRowCount());

Assert.assertEquals(
Expand Down Expand Up @@ -1319,7 +1319,7 @@ private void testE2EBinaryHelper(OpenChannelRequest.OnErrorOption onErrorOption)
Assert.assertNull(innerBuffer.getVectorValueAt("COLBINARY", 2));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();

Assert.assertEquals(3, result.getRowCount());
Assert.assertEquals(11L, result.getColumnEps().get("COLBINARY").getCurrentMaxLength());
Expand Down Expand Up @@ -1371,7 +1371,7 @@ private void testE2ERealHelper(OpenChannelRequest.OnErrorOption onErrorOption) {
Assert.assertNull(innerBuffer.getVectorValueAt("COLREAL", 2));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();

Assert.assertEquals(3, result.getRowCount());
Assert.assertEquals(
Expand Down Expand Up @@ -1454,7 +1454,7 @@ public void testOnErrorAbortFailures() {
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMaxIntValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMinIntValue());

ChannelData<?> data = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> data = innerBuffer.flush();
Assert.assertEquals(3, data.getRowCount());
Assert.assertEquals(0, innerBuffer.bufferedRowCount);
}
Expand Down Expand Up @@ -1528,7 +1528,7 @@ public void testOnErrorAbortSkipBatch() {
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMaxIntValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMinIntValue());

ChannelData<?> data = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> data = innerBuffer.flush();
Assert.assertEquals(3, data.getRowCount());
Assert.assertEquals(0, innerBuffer.bufferedRowCount);
}
Expand Down Expand Up @@ -1579,7 +1579,7 @@ private void testE2EVariantHelper(OpenChannelRequest.OnErrorOption onErrorOption
Assert.assertEquals("3", innerBuffer.getVectorValueAt("COLVARIANT", 4));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(5, result.getRowCount());
Assert.assertEquals(2, result.getColumnEps().get("COLVARIANT").getCurrentNullCount());
}
Expand Down Expand Up @@ -1613,7 +1613,7 @@ private void testE2EObjectHelper(OpenChannelRequest.OnErrorOption onErrorOption)
Assert.assertEquals("{\"key\":1}", innerBuffer.getVectorValueAt("COLOBJECT", 0));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(1, result.getRowCount());
}

Expand Down Expand Up @@ -1663,7 +1663,7 @@ private void testE2EArrayHelper(OpenChannelRequest.OnErrorOption onErrorOption)
Assert.assertEquals("[1,2,3]", innerBuffer.getVectorValueAt("COLARRAY", 4));

// Check stats generation
ChannelData<?> result = innerBuffer.flush("my_snowpipe_streaming.bdec");
ChannelData<?> result = innerBuffer.flush();
Assert.assertEquals(5, result.getRowCount());
}

Expand Down Expand Up @@ -1710,24 +1710,21 @@ public void testOnErrorAbortRowsWithError() {
SFException.class, () -> innerBufferOnErrorAbort.insertRows(mixedRows, "1", "3"));

List<List<Object>> snapshotContinueParquet =
((ParquetChunkData) innerBufferOnErrorContinue.getSnapshot("fake/filePath").get()).rows;
((ParquetChunkData) innerBufferOnErrorContinue.getSnapshot().get()).rows;
// validRows and only the good row from mixedRows are in the buffer
Assert.assertEquals(2, snapshotContinueParquet.size());
Assert.assertEquals(Arrays.asList("a"), snapshotContinueParquet.get(0));
Assert.assertEquals(Arrays.asList("b"), snapshotContinueParquet.get(1));

List<List<Object>> snapshotAbortParquet =
((ParquetChunkData) innerBufferOnErrorAbort.getSnapshot("fake/filePath").get()).rows;
((ParquetChunkData) innerBufferOnErrorAbort.getSnapshot().get()).rows;
// only validRows and none of the mixedRows are in the buffer
Assert.assertEquals(1, snapshotAbortParquet.size());
Assert.assertEquals(Arrays.asList("a"), snapshotAbortParquet.get(0));
}

@Test
public void testParquetChunkMetadataCreationIsThreadSafe() throws InterruptedException {
final String testFileA = "testFileA";
final String testFileB = "testFileB";

final ParquetRowBuffer bufferUnderTest =
(ParquetRowBuffer) createTestBuffer(OpenChannelRequest.OnErrorOption.CONTINUE);

Expand All @@ -1749,23 +1746,20 @@ public void testParquetChunkMetadataCreationIsThreadSafe() throws InterruptedExc
final AtomicReference<ChannelData<ParquetChunkData>> firstFlushResult = new AtomicReference<>();
final Thread t =
getThreadThatWaitsForLockReleaseAndFlushes(
bufferUnderTest, testFileA, latch, firstFlushResult);
bufferUnderTest, latch, firstFlushResult);
t.start();

final ChannelData<ParquetChunkData> secondFlushResult = bufferUnderTest.flush(testFileB);
Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult));
final ChannelData<ParquetChunkData> secondFlushResult = bufferUnderTest.flush();
// TODO: need to verify other fields

latch.countDown();
t.join();

Assert.assertNotNull(firstFlushResult.get());
Assert.assertEquals(testFileA, getPrimaryFileId(firstFlushResult.get()));
Assert.assertEquals(testFileB, getPrimaryFileId(secondFlushResult));
}

private static Thread getThreadThatWaitsForLockReleaseAndFlushes(
final ParquetRowBuffer bufferUnderTest,
final String filenameToFlush,
final CountDownLatch latch,
final AtomicReference<ChannelData<ParquetChunkData>> flushResult) {
return new Thread(
Expand All @@ -1775,10 +1769,10 @@ private static Thread getThreadThatWaitsForLockReleaseAndFlushes(
} catch (InterruptedException e) {
fail("Thread was unexpectedly interrupted");
}

final ChannelData<ParquetChunkData> flush =
loadData(bufferUnderTest, Collections.singletonMap("colChar", "b"))
.flush(filenameToFlush);
.flush();
flushResult.set(flush);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ public void testInsertRow() {
row.put("col", 1);

// Get data before insert to verify that there is no row (data should be null)
ChannelData<?> data = channel.getData("my_snowpipe_streaming.bdec");
ChannelData<?> data = channel.getData();
Assert.assertNull(data);

long insertStartTimeInMs = System.currentTimeMillis();
Expand All @@ -605,7 +605,7 @@ public void testInsertRow() {
long insertEndTimeInMs = System.currentTimeMillis();

// Get data again to verify the row is inserted
data = channel.getData("my_snowpipe_streaming.bdec");
data = channel.getData();
Assert.assertEquals(3, data.getRowCount());
Assert.assertEquals((Long) 1L, data.getRowSequencer());
Assert.assertEquals(1, ((ChannelData<ParquetChunkData>) data).getVectors().rows.get(0).size());
Expand Down

0 comments on commit 3d49a91

Please sign in to comment.