Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzhang committed Oct 12, 2023
1 parent b269cb9 commit 123b972
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,132 @@ public int getOrdinal() {
}
}

/** Insert rows function strategy for ON_ERROR=CONTINUE */
public class ContinueIngestionStrategy<T> implements IngestionStrategy<T> {
@Override
public InsertValidationResponse insertRows(
AbstractRowBuffer<T> rowBuffer, Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
int rowIndex = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, rowIndex);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
rowsSizeInBytes +=
addRow(
row, rowBuffer.bufferedRowCount, rowBuffer.statsMap, inputColumnNames, rowIndex);
rowBuffer.bufferedRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(rowsSizeInBytes);
rowIndex++;
if (rowBuffer.bufferedRowCount == Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
}
checkBatchSizeRecommendedMaximum(rowsSizeInBytes);
rowBuffer.channelState.setOffsetToken(offsetToken);
rowBuffer.bufferSize += rowsSizeInBytes;
rowBuffer.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}
}

/** Insert rows function strategy for ON_ERROR=ABORT */
public class AbortIngestionStrategy<T> implements IngestionStrategy<T> {
@Override
public InsertValidationResponse insertRows(
AbstractRowBuffer<T> rowBuffer, Iterable<Map<String, Object>> rows, String offsetToken) {
// If the on_error option is ABORT, simply throw the first exception
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
Set<String> inputColumnNames = verifyInputColumns(row, null, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, rowBuffer.tempStatsMap, inputColumnNames, tempRowCount);
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
tempRowCount++;
}
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);

moveTempRowsToActualBuffer(tempRowCount);

rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) rowBuffer.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
rowBuffer.bufferedRowCount += tempRowCount;
rowBuffer.statsMap.forEach(
(colName, stats) ->
rowBuffer.statsMap.put(
colName,
RowBufferStats.getCombinedStats(stats, rowBuffer.tempStatsMap.get(colName))));
rowBuffer.channelState.setOffsetToken(offsetToken);
rowBuffer.bufferSize += rowsSizeInBytes;
rowBuffer.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}
}

/** Insert rows function strategy for ON_ERROR=SKIP_BATCH */
public class SkipBatchIngestionStrategy<T> implements IngestionStrategy<T> {
@Override
public InsertValidationResponse insertRows(
AbstractRowBuffer<T> rowBuffer, Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, tempRowCount);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, rowBuffer.tempStatsMap, inputColumnNames, tempRowCount);
tempRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
}

if (!response.hasErrors()) {
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);
moveTempRowsToActualBuffer(tempRowCount);
rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) rowBuffer.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
rowBuffer.bufferedRowCount += tempRowCount;
rowBuffer.statsMap.forEach(
(colName, stats) ->
rowBuffer.statsMap.put(
colName,
RowBufferStats.getCombinedStats(stats, rowBuffer.tempStatsMap.get(colName))));
rowBuffer.channelState.setOffsetToken(offsetToken);
}
rowBuffer.bufferSize += rowsSizeInBytes;
rowBuffer.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}
}

// Map the column name to the stats
@VisibleForTesting Map<String, RowBufferStats> statsMap;

Expand Down Expand Up @@ -304,13 +430,8 @@ public InsertValidationResponse insertRows(
this.flushLock.lock();
try {
this.channelState.updateInsertStats(System.currentTimeMillis(), this.bufferedRowCount);
if (onErrorOption == OpenChannelRequest.OnErrorOption.CONTINUE) {
response = insertRowsHelperForContinue(rows, offsetToken);
} else if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT) {
response = insertRowsHelperForAbort(rows, offsetToken);
} else if (onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) {
response = insertRowsHelperForSkipBatch(rows, offsetToken);
}
IngestionStrategy<T> ingestionStrategy = createIngestionStrategy(onErrorOption);
response = ingestionStrategy.insertRows(this, rows, offsetToken);
} finally {
this.tempStatsMap.values().forEach(RowBufferStats::reset);
clearTempRows();
Expand All @@ -319,121 +440,6 @@ public InsertValidationResponse insertRows(
return response;
}

/** Insert rows function helper for ON_ERROR=CONTINUE */
private InsertValidationResponse insertRowsHelperForContinue(
Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
int rowIndex = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, rowIndex);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, rowIndex);
rowsSizeInBytes +=
addRow(row, this.bufferedRowCount, this.statsMap, inputColumnNames, rowIndex);
this.bufferedRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(rowsSizeInBytes);
rowIndex++;
if (this.bufferedRowCount == Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
}
checkBatchSizeRecommendedMaximum(rowsSizeInBytes);
this.channelState.setOffsetToken(offsetToken);
this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

/** Insert rows function helper for ON_ERROR=ABORT */
private InsertValidationResponse insertRowsHelperForAbort(
Iterable<Map<String, Object>> rows, String offsetToken) {
// If the on_error option is ABORT, simply throw the first exception
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
Set<String> inputColumnNames = verifyInputColumns(row, null, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
tempRowCount++;
}
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);

moveTempRowsToActualBuffer(tempRowCount);

rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName, RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

/** Insert rows function helper for ON_ERROR=SKIP_BATCH */
private InsertValidationResponse insertRowsHelperForSkipBatch(
Iterable<Map<String, Object>> rows, String offsetToken) {
InsertValidationResponse response = new InsertValidationResponse();
float rowsSizeInBytes = 0F;
float tempRowsSizeInBytes = 0F;
int tempRowCount = 0;
for (Map<String, Object> row : rows) {
InsertValidationResponse.InsertError error =
new InsertValidationResponse.InsertError(row, tempRowCount);
try {
Set<String> inputColumnNames = verifyInputColumns(row, error, tempRowCount);
tempRowsSizeInBytes +=
addTempRow(row, tempRowCount, this.tempStatsMap, inputColumnNames, tempRowCount);
tempRowCount++;
} catch (SFException e) {
error.setException(e);
response.addError(error);
} catch (Throwable e) {
logger.logWarn("Unexpected error happens during insertRows: {}", e);
error.setException(new SFException(e, ErrorCode.INTERNAL_ERROR, e.getMessage()));
response.addError(error);
}
checkBatchSizeEnforcedMaximum(tempRowsSizeInBytes);
}

if (!response.hasErrors()) {
checkBatchSizeRecommendedMaximum(tempRowsSizeInBytes);
moveTempRowsToActualBuffer(tempRowCount);
rowsSizeInBytes = tempRowsSizeInBytes;
if ((long) this.bufferedRowCount + tempRowCount >= Integer.MAX_VALUE) {
throw new SFException(ErrorCode.INTERNAL_ERROR, "Row count reaches MAX value");
}
this.bufferedRowCount += tempRowCount;
this.statsMap.forEach(
(colName, stats) ->
this.statsMap.put(
colName, RowBufferStats.getCombinedStats(stats, this.tempStatsMap.get(colName))));
this.channelState.setOffsetToken(offsetToken);
}

this.bufferSize += rowsSizeInBytes;
this.rowSizeMetric.accept(rowsSizeInBytes);
return response;
}

/**
* Flush the data in the row buffer by taking the ownership of the old vectors and pass all the
* required info back to the flush service to build the blob
Expand Down Expand Up @@ -645,4 +651,18 @@ private void checkBatchSizeRecommendedMaximum(float batchSizeInBytes) {
INSERT_ROWS_RECOMMENDED_MAX_BATCH_SIZE_IN_BYTES);
}
}

/** Create the ingestion strategy based on the channel on_error option */
IngestionStrategy<T> createIngestionStrategy(OpenChannelRequest.OnErrorOption onErrorOption) {
switch (onErrorOption) {
case CONTINUE:
return new ContinueIngestionStrategy<>();
case ABORT:
return new AbortIngestionStrategy<>();
case SKIP_BATCH:
return new SkipBatchIngestionStrategy<>();
default:
throw new IllegalArgumentException("Unknown on error option: ");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming.internal;

import java.util.Map;
import net.snowflake.ingest.streaming.InsertValidationResponse;

/**
* Interface to a batch of rows into the row buffer based on different on error options
*/
public interface IngestionStrategy<T> {
/**
* Insert a batch of rows into the row buffer
*
* @param rows input row
* @param offsetToken offset token of the latest row in the batch
* @return insert response that possibly contains errors because of insertion failures
*/
InsertValidationResponse insertRows(AbstractRowBuffer<T> rowBuffer, Iterable<Map<String, Object>> rows, String offsetToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -879,12 +879,10 @@ public void testAbortOnErrorSkipBatch() throws Exception {
SnowflakeStreamingIngestChannel channel = client.openChannel(request);
Map<String, Object> row1 = new HashMap<>();
row1.put("c1", 1);
channel.insertRow(row1, "1");
Map<String, Object> row2 = new HashMap<>();
row2.put("c1", 2);
channel.insertRow(row2, "2");
Map<String, Object> row3 = new HashMap<>();
row3.put("c1", "a");
row3.put("c1", "3");

verifyInsertValidationResponse(channel.insertRow(row1, "1"));

Expand All @@ -909,8 +907,9 @@ public void testAbortOnErrorSkipBatch() throws Exception {
"select count(c1), min(c1), max(c1) from %s.%s.%s",
testDb, TEST_SCHEMA, onErrorOptionTable));
result.next();
Assert.assertEquals("1", result.getString(1));
Assert.assertEquals("4", result.getString(2));
Assert.assertEquals("2", result.getString(1));
Assert.assertEquals("1", result.getString(2));
Assert.assertEquals("4", result.getString(3));
return;
}
Thread.sleep(1000);
Expand Down

0 comments on commit 123b972

Please sign in to comment.