Skip to content

Commit

Permalink
address comment + add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzhang committed Oct 4, 2023
1 parent 6e20f47 commit 825dd23
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ public class OpenChannelRequest {
public enum OnErrorOption {
CONTINUE, // CONTINUE loading the rows, and return all the errors in the response
ABORT, // ABORT the entire batch, and throw an exception when we hit the first error
SKIP_BATCH, // Skip the batch after the first error, and return all the errors in the response
SKIP_BATCH, // If an error in the batch is detected return a response containing all error row
// indexes. No data is ingested
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,113 +297,140 @@ Set<String> verifyInputColumns(
@Override
public InsertValidationResponse insertRows(
Iterable<Map<String, Object>> rows, String offsetToken) {
float rowsSizeInBytes = 0F;
if (!hasColumns()) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Empty column fields");
}
InsertValidationResponse response = new InsertValidationResponse();
InsertValidationResponse response = null;
this.flushLock.lock();
try {
this.channelState.updateInsertStats(System.currentTimeMillis(), this.bufferedRowCount);
if (onErrorOption == OpenChannelRequest.OnErrorOption.CONTINUE) {
// Used to map incoming row(nth row) to InsertError(for nth row) in response
int rowIndex = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, rowIndex);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
rowsSizeInBytes +=
addRow(row, this.bufferedRowCount, this.statsMap, inputColumnNames, rowIndex);
this.bufferedRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(rowsSizeInBytes);
rowIndex++;
if (this.bufferedRowCount == Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
}
checkBatchSizeRecommendedMaximum(rowsSizeInBytes);
this.channelState.setOffsetToken(offsetToken);
response = insertRowsHelperForContinue(rows, offsetToken);
} else if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT) {
// If the on_error option is ABORT, simply throw the first exception
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
Set<String> inputColumnNames = verifyInputColumns(row, null, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
tempRowCount++;
}
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);

moveTempRowsToActualBuffer(tempRowCount);

rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName,
RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
response = insertRowsHelperForAbort(rows, offsetToken);
} else if (onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) {
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, tempRowCount);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
tempRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
}

if (!response.hasErrors()) {
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);
moveTempRowsToActualBuffer(tempRowCount);
rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName,
RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
}
response = insertRowsHelperForSkipBatch(rows, offsetToken);
}
this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
} finally {
this.tempStatsMap.values().forEach(RowBufferStats::reset);
clearTempRows();
this.flushLock.unlock();
}
return response;
}

/** Insert rows function helper for ON_ERROR=CONTINUE */
private InsertValidationResponse insertRowsHelperForContinue(
Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
int rowIndex = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, rowIndex);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
rowsSizeInBytes +=
addRow(row, this.bufferedRowCount, this.statsMap, inputColumnNames, rowIndex);
this.bufferedRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(rowsSizeInBytes);
rowIndex++;
if (this.bufferedRowCount == Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
}
checkBatchSizeRecommendedMaximum(rowsSizeInBytes);
this.channelState.setOffsetToken(offsetToken);
this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

/** Insert rows function helper for ON_ERROR=ABORT */
private InsertValidationResponse insertRowsHelperForAbort(
Iterable<Map<String, Object>> rows, String offsetToken) {
// If the on_error option is ABORT, simply throw the first exception
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
Set<String> inputColumnNames = verifyInputColumns(row, null, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
tempRowCount++;
}
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);

moveTempRowsToActualBuffer(tempRowCount);

rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName, RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

/** Insert rows function helper for ON_ERROR=SKIP_BATCH */
private InsertValidationResponse insertRowsHelperForSkipBatch(
Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, tempRowCount);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
tempRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
}

if (!response.hasErrors()) {
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);
moveTempRowsToActualBuffer(tempRowCount);
rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName, RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
}

this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,8 @@ public void testOnErrorAbortFailures() {

@Test
public void testOnErrorAbortSkipBatch() {
AbstractRowBuffer<?> innerBuffer = createTestBuffer(OpenChannelRequest.OnErrorOption.ABORT);
AbstractRowBuffer<?> innerBuffer =
createTestBuffer(OpenChannelRequest.OnErrorOption.SKIP_BATCH);

ColumnMetadata colDecimal = new ColumnMetadata();
colDecimal.setName("COLDECIMAL");
Expand Down Expand Up @@ -1433,37 +1434,32 @@ public void testOnErrorAbortSkipBatch() {

Map<String, Object> row2 = new HashMap<>();
row2.put("COLDECIMAL", 2);
response = innerBuffer.insertRows(Collections.singletonList(row2), "2");
Assert.assertFalse(response.hasErrors());
Map<String, Object> row3 = new HashMap<>();
row3.put("COLDECIMAL", true);

Assert.assertEquals(2, innerBuffer.bufferedRowCount);
response = innerBuffer.insertRows(Arrays.asList(row2, row3), "3");
Assert.assertTrue(response.hasErrors());

Assert.assertEquals(1, innerBuffer.bufferedRowCount);
Assert.assertEquals(0, innerBuffer.getTempRowCount());
Assert.assertEquals(
2, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMaxIntValue().intValue());
1, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMaxIntValue().intValue());
Assert.assertEquals(
1, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMinIntValue().intValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMaxIntValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMinIntValue());

Map<String, Object> row3 = new HashMap<>();
row3.put("COLDECIMAL", true);
try {
innerBuffer.insertRows(Collections.singletonList(row3), "3");
} catch (SFException e) {
Assert.assertEquals(ErrorCode.INVALID_FORMAT_ROW.getMessageCode(), e.getVendorCode());
}

Assert.assertEquals(2, innerBuffer.bufferedRowCount);
Assert.assertEquals(1, innerBuffer.bufferedRowCount);
Assert.assertEquals(0, innerBuffer.getTempRowCount());
Assert.assertEquals(
2, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMaxIntValue().intValue());
1, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMaxIntValue().intValue());
Assert.assertEquals(
1, innerBuffer.statsMap.get("COLDECIMAL").getCurrentMinIntValue().intValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMaxIntValue());
Assert.assertNull(innerBuffer.tempStatsMap.get("COLDECIMAL").getCurrentMinIntValue());

row3.put("COLDECIMAL", 3);
response = innerBuffer.insertRows(Collections.singletonList(row3), "3");
response = innerBuffer.insertRows(Arrays.asList(row2, row3), "3");
Assert.assertFalse(response.hasErrors());
Assert.assertEquals(3, innerBuffer.bufferedRowCount);
Assert.assertEquals(0, innerBuffer.getTempRowCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,64 @@ public void testAbortOnErrorOption() throws Exception {
Assert.fail("Row sequencer not updated before timeout");
}

@Test
public void testAbortOnErrorSkipBatch() throws Exception {
String onErrorOptionTable = "skip_batch_on_error_option";
jdbcConnection
.createStatement()
.execute(String.format("create or replace table %s (c1 int);", onErrorOptionTable));

OpenChannelRequest request =
OpenChannelRequest.builder("CHANNEL")
.setDBName(testDb)
.setSchemaName(TEST_SCHEMA)
.setTableName(onErrorOptionTable)
.setOnErrorOption(OpenChannelRequest.OnErrorOption.SKIP_BATCH)
.build();

// Open a streaming ingest channel from the given client
SnowflakeStreamingIngestChannel channel = client.openChannel(request);
Map<String, Object> row1 = new HashMap<>();
row1.put("c1", 1);
channel.insertRow(row1, "1");
Map<String, Object> row2 = new HashMap<>();
row2.put("c1", 2);
channel.insertRow(row2, "2");
Map<String, Object> row3 = new HashMap<>();
row3.put("c1", "a");

verifyInsertValidationResponse(channel.insertRow(row1, "1"));

InsertValidationResponse response = channel.insertRows(Arrays.asList(row1, row2, row3), "3");
Assert.assertTrue(response.hasErrors());

Map<String, Object> row4 = new HashMap<>();
row4.put("c1", 4);
verifyInsertValidationResponse(channel.insertRow(row4, "4"));

// Close the channel after insertion
channel.close().get();

for (int i = 1; i < 15; i++) {
if (channel.getLatestCommittedOffsetToken() != null
&& channel.getLatestCommittedOffsetToken().equals("4")) {
ResultSet result =
jdbcConnection
.createStatement()
.executeQuery(
String.format(
"select count(c1), min(c1), max(c1) from %s.%s.%s",
testDb, TEST_SCHEMA, onErrorOptionTable));
result.next();
Assert.assertEquals("1", result.getString(1));
Assert.assertEquals("4", result.getString(2));
return;
}
Thread.sleep(1000);
}
Assert.fail("Row sequencer not updated before timeout");
}

@Test
public void testChannelClose() throws Exception {
OpenChannelRequest request1 =
Expand Down

0 comments on commit 825dd23

Please sign in to comment.