diff --git a/src/main/java/com/vinted/flink/bigquery/sink/async/AsyncBigQuerySinkWriter.java b/src/main/java/com/vinted/flink/bigquery/sink/async/AsyncBigQuerySinkWriter.java index 4e895a8..1d1b6ac 100644 --- a/src/main/java/com/vinted/flink/bigquery/sink/async/AsyncBigQuerySinkWriter.java +++ b/src/main/java/com/vinted/flink/bigquery/sink/async/AsyncBigQuerySinkWriter.java @@ -26,6 +26,7 @@ import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.stream.Collectors; public class AsyncBigQuerySinkWriter extends AsyncSinkWriter, StreamRequest> { private static final Logger logger = LoggerFactory.getLogger(AsyncSinkWriter.class); @@ -112,32 +113,83 @@ protected final void recreateStreamWriter(String traceId, String streamName, Str @Override protected void submitRequestEntries(List list, Consumer> consumer) { var traceId = UUID.randomUUID().toString(); - var parent = this; - - CompletableFuture.runAsync(() -> { - var counter = new CountDownLatch(list.size()); - var result = new ConcurrentLinkedDeque(); - list.forEach(request -> { - registerAppendMetrics(request); - var writer = streamWriter(traceId, request.getStream(), request.getTable()); - logger.trace("Trace-id {}, Writing rows stream {} to steamWriter for {} writer id {}", traceId, request.getStream(), writer.getStreamName(), writer.getWriterId()); + var requests = list.stream().map(request -> { + registerAppendMetrics(request); + var writer = streamWriter(traceId, request.getStream(), request.getTable()); + logger.trace("Trace-id {}, Writing rows stream {} to steamWriter for {} writer id {}", traceId, request.getStream(), writer.getStreamName(), writer.getWriterId()); + return CompletableFuture.>supplyAsync(() ->{ try { - var apiFuture = writer.append(request.getData()); - ApiFutures.addCallback(apiFuture, new AppendCallBack<>(parent, writer.getWriterId(), traceId, request, result, counter), appendExecutor); + writer.append(request.getData()).get(); + return List.of(); } catch (Throwable t) { logger.error("Trace-id {}, StreamWriter failed to append {}", traceId, t.getMessage()); - counter.countDown(); - getFatalExceptionCons().accept(new AsyncWriterException(traceId, Status.Code.INTERNAL, t)); + var status = Status.fromThrowable(t); + switch (status.getCode()) { + case UNAVAILABLE: { + this.recreateStreamWriter(traceId, request.getStream(), writer.getWriterId(), request.getTable()); + return retry(t, traceId, request); + } + case INVALID_ARGUMENT: + if (t.getMessage().contains("INVALID_ARGUMENT: MessageSize is too large.")) { + Optional.ofNullable(this.metrics.get(request.getStream())).ifPresent(BigQueryStreamMetrics::incSplitCount); + logger.warn("Trace-id {} MessageSize is too large. Splitting batch", traceId); + var data = request.getData().getSerializedRowsList(); + var first = data.subList(0, data.size() / 2); + var second = data.subList(data.size() / 2, data.size()); + try { + return List.of( + new StreamRequest(request.getStream(), request.getTable(), ProtoRows.newBuilder().addAllSerializedRows(first).build(), request.getRetries() - 1), + new StreamRequest(request.getStream(), request.getTable(), ProtoRows.newBuilder().addAllSerializedRows(second).build(), request.getRetries() - 1) + ); + } catch (Throwable e) { + this.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), e)); + return List.of(); + } + } else { + logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); + this.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); + return List.of(); + } + case UNKNOWN: + if (status.getCause() instanceof Exceptions.MaximumRequestCallbackWaitTimeExceededException) { + logger.info("Trace-id {} request timed out: {}", traceId, t.getMessage()); + Optional.ofNullable(this.metrics.get(request.getStream())) + .ifPresent(BigQueryStreamMetrics::incrementTimeoutCount); + this.recreateStreamWriter(traceId, request.getStream(), writer.getWriterId(), request.getTable()); + return retry(t, traceId, request); + } else { + logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); + this.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); + return List.of(); + } + default: + logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); + this.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); + return List.of(); + } + } - }); - try { - counter.await(); - var finalResult = new ArrayList<>(result); - consumer.accept(finalResult); - } catch (InterruptedException e) { - getFatalExceptionCons().accept(new AsyncWriterException(traceId, Status.Code.INTERNAL, e)); - } - }, waitExecutor); + }, appendExecutor); + }).collect(Collectors.toList()); + + CompletableFuture + .allOf(requests.toArray(new CompletableFuture[0])) + .thenApplyAsync(v -> requests.stream().flatMap(s -> s.join().stream()).collect(Collectors.toList()), appendExecutor) + .thenAcceptAsync(consumer, appendExecutor); + + } + + private List retry(Throwable t, String traceId, StreamRequest request) { + var status = Status.fromThrowable(t); + request.setRetries(request.getRetries() - 1); + if (request.getRetries() > 0) { + logger.warn("Trace-id {} Recoverable error {}. Retrying {} ...", traceId, status.getCode(), request.getRetries()); + return List.of(request); + } else { + logger.error("Trace-id {} Recoverable error {}. No more retries left", traceId, status.getCode(), t); + this.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); + return List.of(); + } } @Override @@ -146,89 +198,4 @@ protected long getSizeInBytes(StreamRequest StreamRequest) { } - static class AppendCallBack implements ApiFutureCallback { - private final AsyncBigQuerySinkWriter parent; - private final StreamRequest request; - - private final String writerId; - private final String traceId; - - private final ConcurrentLinkedDeque out; - - private final CountDownLatch counter; - - public AppendCallBack(AsyncBigQuerySinkWriter parent, String writerId, String traceId, StreamRequest request, ConcurrentLinkedDeque out, CountDownLatch counter) { - this.parent = parent; - this.writerId = writerId; - this.traceId = traceId; - this.request = request; - this.out = out; - this.counter = counter; - } - - @Override - public void onSuccess(AppendRowsResponse result) { - counter.countDown(); - } - - - @Override - public void onFailure(Throwable t) { - var status = Status.fromThrowable(t); - switch (status.getCode()) { - case UNAVAILABLE: { - this.parent.recreateStreamWriter(traceId, request.getStream(), writerId, request.getTable()); - retry(t, traceId, request); - break; - } - case INVALID_ARGUMENT: - if (t.getMessage().contains("INVALID_ARGUMENT: MessageSize is too large.")) { - Optional.ofNullable(this.parent.metrics.get(request.getStream())).ifPresent(BigQueryStreamMetrics::incSplitCount); - logger.warn("Trace-id {} MessageSize is too large. Splitting batch", traceId); - var data = request.getData().getSerializedRowsList(); - var first = data.subList(0, data.size() / 2); - var second = data.subList(data.size() / 2, data.size()); - try { - out.add(new StreamRequest(request.getStream(), request.getTable(), ProtoRows.newBuilder().addAllSerializedRows(first).build(), request.getRetries() - 1)); - out.add(new StreamRequest(request.getStream(), request.getTable(), ProtoRows.newBuilder().addAllSerializedRows(second).build(), request.getRetries() - 1)); - } catch (Throwable e) { - this.parent.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), e)); - } - } else { - logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); - this.parent.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); - } - break; - case UNKNOWN: - if (t instanceof Exceptions.MaximumRequestCallbackWaitTimeExceededException || t.getCause() instanceof Exceptions.MaximumRequestCallbackWaitTimeExceededException) { - logger.info("Trace-id {} request timed out: {}", traceId, t.getMessage()); - Optional.ofNullable(this.parent.metrics.get(request.getStream())) - .ifPresent(BigQueryStreamMetrics::incrementTimeoutCount); - this.parent.recreateStreamWriter(traceId, request.getStream(), writerId, request.getTable()); - retry(t, traceId, request); - } else { - logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); - this.parent.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); - } - break; - default: - logger.error("Trace-id {} Received error {} with status {}", traceId, t.getMessage(), status.getCode()); - this.parent.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); - } - - counter.countDown(); - } - - private void retry(Throwable t, String traceId, StreamRequest request) { - var status = Status.fromThrowable(t); - request.setRetries(request.getRetries() - 1); - if (request.getRetries() > 0) { - logger.warn("Trace-id {} Recoverable error {}. Retrying {} ...", traceId, status.getCode(), request.getRetries()); - out.add(request); - } else { - logger.error("Trace-id {} Recoverable error {}. No more retries left", traceId, status.getCode(), t); - this.parent.getFatalExceptionCons().accept(new AsyncWriterException(traceId, status.getCode(), t)); - } - } - } }