diff --git a/airbyte-config/models/src/main/resources/types/ReplicationAttemptSummary.yaml b/airbyte-config/models/src/main/resources/types/ReplicationAttemptSummary.yaml index 90356db63cbe..fc3eee68b54b 100644 --- a/airbyte-config/models/src/main/resources/types/ReplicationAttemptSummary.yaml +++ b/airbyte-config/models/src/main/resources/types/ReplicationAttemptSummary.yaml @@ -9,17 +9,25 @@ required: - bytesSynced - startTime - endTime + - totalStats + - streamStats additionalProperties: false properties: status: "$ref": ReplicationStatus.yaml - recordsSynced: + recordsSynced: # TODO (parker) remove in favor of totalRecordsEmitted type: integer minValue: 0 - bytesSynced: + bytesSynced: # TODO (parker) remove in favor of totalBytesEmitted type: integer minValue: 0 startTime: type: integer endTime: type: integer + totalStats: + "$ref": SyncStats.yaml + streamStats: + type: array + items: + "$ref": StreamSyncStats.yaml diff --git a/airbyte-config/models/src/main/resources/types/StandardSyncSummary.yaml b/airbyte-config/models/src/main/resources/types/StandardSyncSummary.yaml index 49d84d49e0a1..a305f7fc4410 100644 --- a/airbyte-config/models/src/main/resources/types/StandardSyncSummary.yaml +++ b/airbyte-config/models/src/main/resources/types/StandardSyncSummary.yaml @@ -12,17 +12,25 @@ required: - bytesSynced - startTime - endTime + - totalStats + - streamStats additionalProperties: false properties: status: "$ref": ReplicationStatus.yaml - recordsSynced: + recordsSynced: # TODO (parker) remove in favor of totalRecordsEmitted type: integer minValue: 0 - bytesSynced: + bytesSynced: # TODO (parker) remove in favor of totalBytesEmitted type: integer minValue: 0 startTime: type: integer endTime: type: integer + totalStats: + "$ref": SyncStats.yaml + streamStats: + type: array + items: + "$ref": StreamSyncStats.yaml diff --git a/airbyte-config/models/src/main/resources/types/StreamSyncStats.yaml b/airbyte-config/models/src/main/resources/types/StreamSyncStats.yaml new file mode 100644 index 000000000000..c20003f72c5d --- /dev/null +++ b/airbyte-config/models/src/main/resources/types/StreamSyncStats.yaml @@ -0,0 +1,15 @@ +--- +"$schema": http://json-schema.org/draft-07/schema# +"$id": https://github.com/airbytehq/airbyte/blob/master/airbyte-config/models/src/main/resources/types/StreamSyncStats.yaml +title: StreamSyncStats +description: Sync stats for a particular stream. +type: object +required: + - streamName + - stats +additionalProperties: false +properties: + streamName: + type: string + stats: + "$ref": SyncStats.yaml diff --git a/airbyte-config/models/src/main/resources/types/SyncStats.yaml b/airbyte-config/models/src/main/resources/types/SyncStats.yaml new file mode 100644 index 000000000000..5c38885e6dc2 --- /dev/null +++ b/airbyte-config/models/src/main/resources/types/SyncStats.yaml @@ -0,0 +1,19 @@ +--- +"$schema": http://json-schema.org/draft-07/schema# +"$id": https://github.com/airbytehq/airbyte/blob/master/airbyte-config/models/src/main/resources/types/SyncStats.yaml +title: SyncStats +description: sync stats. +type: object +required: + - recordsEmitted + - bytesEmitted +additionalProperties: false +properties: + recordsEmitted: + type: integer + bytesEmitted: + type: integer + stateMessagesEmitted: # TODO make required once per-stream state messages are supported in V2 + type: integer + recordsCommitted: + type: integer # if unset, committed records could not be computed diff --git a/airbyte-container-orchestrator/src/main/java/io/airbyte/container_orchestrator/ReplicationJobOrchestrator.java b/airbyte-container-orchestrator/src/main/java/io/airbyte/container_orchestrator/ReplicationJobOrchestrator.java index ea8953e74976..cbcc4338c056 100644 --- a/airbyte-container-orchestrator/src/main/java/io/airbyte/container_orchestrator/ReplicationJobOrchestrator.java +++ b/airbyte-container-orchestrator/src/main/java/io/airbyte/container_orchestrator/ReplicationJobOrchestrator.java @@ -91,7 +91,6 @@ public void runJob() throws Exception { airbyteSource, new NamespacingMapper(syncInput.getNamespaceDefinition(), syncInput.getNamespaceFormat(), syncInput.getPrefix()), new DefaultAirbyteDestination(workerConfigs, destinationLauncher), - new AirbyteMessageTracker(), new AirbyteMessageTracker()); log.info("Running replication worker..."); diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/DefaultReplicationWorker.java b/airbyte-workers/src/main/java/io/airbyte/workers/DefaultReplicationWorker.java index 595681649885..752526699bf2 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/DefaultReplicationWorker.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/DefaultReplicationWorker.java @@ -9,6 +9,8 @@ import io.airbyte.config.StandardSyncInput; import io.airbyte.config.StandardSyncSummary.ReplicationStatus; import io.airbyte.config.State; +import io.airbyte.config.StreamSyncStats; +import io.airbyte.config.SyncStats; import io.airbyte.config.WorkerDestinationConfig; import io.airbyte.config.WorkerSourceConfig; import io.airbyte.protocol.models.AirbyteMessage; @@ -17,6 +19,7 @@ import io.airbyte.workers.protocols.airbyte.AirbyteSource; import io.airbyte.workers.protocols.airbyte.MessageTracker; import java.nio.file.Path; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -55,8 +58,7 @@ public class DefaultReplicationWorker implements ReplicationWorker { private final AirbyteSource source; private final AirbyteMapper mapper; private final AirbyteDestination destination; - private final MessageTracker sourceMessageTracker; - private final MessageTracker destinationMessageTracker; + private final MessageTracker messageTracker; private final ExecutorService executors; private final AtomicBoolean cancelled; @@ -67,15 +69,13 @@ public DefaultReplicationWorker(final String jobId, final AirbyteSource source, final AirbyteMapper mapper, final AirbyteDestination destination, - final MessageTracker sourceMessageTracker, - final MessageTracker destinationMessageTracker) { + final MessageTracker messageTracker) { this.jobId = jobId; this.attempt = attempt; this.source = source; this.mapper = mapper; this.destination = destination; - this.sourceMessageTracker = sourceMessageTracker; - this.destinationMessageTracker = destinationMessageTracker; + this.messageTracker = messageTracker; this.executors = Executors.newFixedThreadPool(2); this.cancelled = new AtomicBoolean(false); @@ -120,11 +120,11 @@ public ReplicationOutput run(final StandardSyncInput syncInput, final Path jobRo source.start(sourceConfig, jobRoot); final CompletableFuture destinationOutputThreadFuture = CompletableFuture.runAsync( - getDestinationOutputRunnable(destination, cancelled, destinationMessageTracker, mdc), + getDestinationOutputRunnable(destination, cancelled, messageTracker, mdc), executors); final CompletableFuture replicationThreadFuture = CompletableFuture.runAsync( - getReplicationRunnable(source, destination, cancelled, mapper, sourceMessageTracker, mdc), + getReplicationRunnable(source, destination, cancelled, mapper, messageTracker, mdc), executors); LOGGER.info("Waiting for source and destination threads to complete."); @@ -155,10 +155,45 @@ else if (hasFailed.get()) { outputStatus = ReplicationStatus.COMPLETED; } + final SyncStats totalSyncStats = new SyncStats() + .withRecordsEmitted(messageTracker.getTotalRecordsEmitted()) + .withBytesEmitted(messageTracker.getTotalBytesEmitted()) + .withStateMessagesEmitted(messageTracker.getTotalStateMessagesEmitted()); + + if (outputStatus == ReplicationStatus.COMPLETED) { + totalSyncStats.setRecordsCommitted(totalSyncStats.getRecordsEmitted()); + } else if (messageTracker.getTotalRecordsCommitted().isPresent()) { + totalSyncStats.setRecordsCommitted(messageTracker.getTotalRecordsCommitted().get()); + } else { + LOGGER.warn("Could not reliably determine committed record counts, committed record stats will be set to null"); + totalSyncStats.setRecordsCommitted(null); + } + + // assume every stream with stats is in streamToEmittedRecords map + final List streamSyncStats = messageTracker.getStreamToEmittedRecords().keySet().stream().map(stream -> { + final SyncStats syncStats = new SyncStats() + .withRecordsEmitted(messageTracker.getStreamToEmittedRecords().get(stream)) + .withBytesEmitted(messageTracker.getStreamToEmittedBytes().get(stream)) + .withStateMessagesEmitted(null); // TODO (parker) populate per-stream state messages emitted once supported in V2 + + if (outputStatus == ReplicationStatus.COMPLETED) { + syncStats.setRecordsCommitted(messageTracker.getStreamToEmittedRecords().get(stream)); + } else if (messageTracker.getStreamToCommittedRecords().isPresent()) { + syncStats.setRecordsCommitted(messageTracker.getStreamToCommittedRecords().get().get(stream)); + } else { + syncStats.setRecordsCommitted(null); + } + return new StreamSyncStats() + .withStreamName(stream) + .withStats(syncStats); + }).collect(Collectors.toList()); + final ReplicationAttemptSummary summary = new ReplicationAttemptSummary() .withStatus(outputStatus) - .withRecordsSynced(sourceMessageTracker.getRecordCount()) - .withBytesSynced(sourceMessageTracker.getBytesCount()) + .withRecordsSynced(messageTracker.getTotalRecordsEmitted()) // TODO (parker) remove in favor of totalRecordsEmitted + .withBytesSynced(messageTracker.getTotalBytesEmitted()) // TODO (parker) remove in favor of totalBytesEmitted + .withTotalStats(totalSyncStats) + .withStreamStats(streamSyncStats) .withStartTime(startTime) .withEndTime(System.currentTimeMillis()); @@ -168,15 +203,15 @@ else if (hasFailed.get()) { .withReplicationAttemptSummary(summary) .withOutputCatalog(destinationConfig.getCatalog()); - if (sourceMessageTracker.getOutputState().isPresent()) { + if (messageTracker.getSourceOutputState().isPresent()) { LOGGER.info("Source output at least one state message"); } else { LOGGER.info("Source did not output any state messages"); } - if (destinationMessageTracker.getOutputState().isPresent()) { - LOGGER.info("State capture: Updated state to: {}", destinationMessageTracker.getOutputState()); - final State state = destinationMessageTracker.getOutputState().get(); + if (messageTracker.getDestinationOutputState().isPresent()) { + LOGGER.info("State capture: Updated state to: {}", messageTracker.getDestinationOutputState()); + final State state = messageTracker.getDestinationOutputState().get(); output.withState(state); } else if (syncInput.getState() != null) { LOGGER.warn("State capture: No new state, falling back on input state: {}", syncInput.getState()); @@ -196,7 +231,7 @@ private static Runnable getReplicationRunnable(final AirbyteSource source, final AirbyteDestination destination, final AtomicBoolean cancelled, final AirbyteMapper mapper, - final MessageTracker sourceMessageTracker, + final MessageTracker messageTracker, final Map mdc) { return () -> { MDC.setContextMap(mdc); @@ -208,7 +243,7 @@ private static Runnable getReplicationRunnable(final AirbyteSource source, if (messageOptional.isPresent()) { final AirbyteMessage message = mapper.mapMessage(messageOptional.get()); - sourceMessageTracker.accept(message); + messageTracker.acceptFromSource(message); destination.accept(message); recordsRead += 1; @@ -235,7 +270,7 @@ private static Runnable getReplicationRunnable(final AirbyteSource source, private static Runnable getDestinationOutputRunnable(final AirbyteDestination destination, final AtomicBoolean cancelled, - final MessageTracker destinationMessageTracker, + final MessageTracker messageTracker, final Map mdc) { return () -> { MDC.setContextMap(mdc); @@ -245,7 +280,7 @@ private static Runnable getDestinationOutputRunnable(final AirbyteDestination de final Optional messageOptional = destination.attemptRead(); if (messageOptional.isPresent()) { LOGGER.info("state in DefaultReplicationWorker from Destination: {}", messageOptional.get()); - destinationMessageTracker.accept(messageOptional.get()); + messageTracker.acceptFromDestination(messageOptional.get()); } } if (!cancelled.get() && destination.getExitValue() != 0) { diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTracker.java b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTracker.java index 8cbf4be5308b..4d4f93ed6f0a 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTracker.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTracker.java @@ -4,51 +4,232 @@ package io.airbyte.workers.protocols.airbyte; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Charsets; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; import io.airbyte.commons.json.Jsons; import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteMessage; +import io.airbyte.protocol.models.AirbyteRecordMessage; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.workers.protocols.airbyte.StateDeltaTracker.StateDeltaTrackerException; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +@Slf4j public class AirbyteMessageTracker implements MessageTracker { - private final AtomicLong recordCount; - private final AtomicLong numBytes; - private final AtomicReference outputState; + private static final long STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES = 20L * 1024L * 1024L; // 20 MiB, ~10% of default cloud worker memory + + private final AtomicReference sourceOutputState; + private final AtomicReference destinationOutputState; + private final AtomicLong totalEmittedStateMessages; + private final Map streamToRunningCount; + private final HashFunction hashFunction; + private final BiMap streamNameToIndex; + private final Map streamToTotalBytesEmitted; + private final Map streamToTotalRecordsEmitted; + private final StateDeltaTracker stateDeltaTracker; + + private short nextStreamIndex; + + /** + * If the StateDeltaTracker throws an exception, this flag is set to true and committed counts are + * not returned. + */ + private boolean unreliableCommittedCounts; public AirbyteMessageTracker() { - this.recordCount = new AtomicLong(); - this.numBytes = new AtomicLong(); - this.outputState = new AtomicReference<>(); + this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES)); + } + + @VisibleForTesting + protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker) { + this.sourceOutputState = new AtomicReference<>(); + this.destinationOutputState = new AtomicReference<>(); + this.totalEmittedStateMessages = new AtomicLong(0L); + this.streamToRunningCount = new HashMap<>(); + this.streamNameToIndex = HashBiMap.create(); + this.hashFunction = Hashing.murmur3_32_fixed(); + this.streamToTotalBytesEmitted = new HashMap<>(); + this.streamToTotalRecordsEmitted = new HashMap<>(); + this.stateDeltaTracker = stateDeltaTracker; + this.nextStreamIndex = 0; + this.unreliableCommittedCounts = false; } @Override - public void accept(final AirbyteMessage message) { - if (message.getType() == AirbyteMessage.Type.RECORD) { - recordCount.incrementAndGet(); - // todo (cgardens) - pretty wasteful to do an extra serialization just to get size. - numBytes.addAndGet(Jsons.serialize(message.getRecord().getData()).getBytes(Charsets.UTF_8).length); + public void acceptFromSource(final AirbyteMessage message) { + switch (message.getType()) { + case RECORD -> handleSourceEmittedRecord(message.getRecord()); + case STATE -> handleSourceEmittedState(message.getState()); + default -> log.warn("Invalid message type for message: {}", message); + } + } + + @Override + public void acceptFromDestination(final AirbyteMessage message) { + switch (message.getType()) { + case STATE -> handleDestinationEmittedState(message.getState()); + default -> log.warn("Invalid message type for message: {}", message); + } + } + + /** + * When a source emits a record, increment the running record count, the total record count, and the + * total byte count for the record's stream. + */ + private void handleSourceEmittedRecord(final AirbyteRecordMessage recordMessage) { + final short streamIndex = getStreamIndex(recordMessage.getStream()); + + final long currentRunningCount = streamToRunningCount.getOrDefault(streamIndex, 0L); + streamToRunningCount.put(streamIndex, currentRunningCount + 1); + + final long currentTotalCount = streamToTotalRecordsEmitted.getOrDefault(streamIndex, 0L); + streamToTotalRecordsEmitted.put(streamIndex, currentTotalCount + 1); + + // todo (cgardens) - pretty wasteful to do an extra serialization just to get size. + final int numBytes = Jsons.serialize(recordMessage.getData()).getBytes(Charsets.UTF_8).length; + final long currentTotalStreamBytes = streamToTotalBytesEmitted.getOrDefault(streamIndex, 0L); + streamToTotalBytesEmitted.put(streamIndex, currentTotalStreamBytes + numBytes); + } + + /** + * When a source emits a state, persist the current running count per stream to the + * {@link StateDeltaTracker}. Then, reset the running count per stream so that new counts can start + * recording for the next state. Also add the state to list so that state order is tracked + * correctly. + */ + private void handleSourceEmittedState(final AirbyteStateMessage stateMessage) { + sourceOutputState.set(new State().withState(stateMessage.getData())); + totalEmittedStateMessages.incrementAndGet(); + final int stateHash = getStateHashCode(stateMessage); + try { + if (!unreliableCommittedCounts) { + stateDeltaTracker.addState(stateHash, streamToRunningCount); + } + } catch (final StateDeltaTrackerException e) { + log.error(e.getMessage(), e); + unreliableCommittedCounts = true; + } + streamToRunningCount.clear(); + } + + /** + * When a destination emits a state, mark all uncommitted states up to and including this state as + * committed in the {@link StateDeltaTracker}. Also record this state as the last committed state. + */ + private void handleDestinationEmittedState(final AirbyteStateMessage stateMessage) { + destinationOutputState.set(new State().withState(stateMessage.getData())); + try { + if (!unreliableCommittedCounts) { + stateDeltaTracker.commitStateHash(getStateHashCode(stateMessage)); + } + } catch (final StateDeltaTrackerException e) { + log.error(e.getMessage(), e); + unreliableCommittedCounts = true; + } + } + + private short getStreamIndex(final String streamName) { + if (!streamNameToIndex.containsKey(streamName)) { + streamNameToIndex.put(streamName, nextStreamIndex); + nextStreamIndex++; } - if (message.getType() == AirbyteMessage.Type.STATE) { - outputState.set(new State().withState(message.getState().getData())); + return streamNameToIndex.get(streamName); + } + + private int getStateHashCode(final AirbyteStateMessage stateMessage) { + return hashFunction.hashBytes(Jsons.serialize(stateMessage.getData()).getBytes(Charsets.UTF_8)).hashCode(); + } + + @Override + public Optional getSourceOutputState() { + return Optional.ofNullable(sourceOutputState.get()); + } + + @Override + public Optional getDestinationOutputState() { + return Optional.ofNullable(destinationOutputState.get()); + } + + /** + * Fetch committed stream index to record count from the {@link StateDeltaTracker}. Then, swap out + * stream indices for stream names. If the delta tracker has exceeded its capacity, return empty + * because committed record counts cannot be reliably computed. + */ + @Override + public Optional> getStreamToCommittedRecords() { + if (unreliableCommittedCounts) { + return Optional.empty(); } + final Map streamIndexToCommittedRecordCount = stateDeltaTracker.getStreamToCommittedRecords(); + return Optional.of( + streamIndexToCommittedRecordCount.entrySet().stream().collect( + Collectors.toMap( + entry -> streamNameToIndex.inverse().get(entry.getKey()), + Map.Entry::getValue))); + } + + /** + * Swap out stream indices for stream names and return total records emitted by stream. + */ + @Override + public Map getStreamToEmittedRecords() { + return streamToTotalRecordsEmitted.entrySet().stream().collect(Collectors.toMap( + entry -> streamNameToIndex.inverse().get(entry.getKey()), + Map.Entry::getValue)); } + /** + * Swap out stream indices for stream names and return total bytes emitted by stream. + */ @Override - public long getRecordCount() { - return recordCount.get(); + public Map getStreamToEmittedBytes() { + return streamToTotalBytesEmitted.entrySet().stream().collect(Collectors.toMap( + entry -> streamNameToIndex.inverse().get(entry.getKey()), + Map.Entry::getValue)); } + /** + * Compute sum of emitted record counts across all streams. + */ @Override - public long getBytesCount() { - return numBytes.get(); + public long getTotalRecordsEmitted() { + return streamToTotalRecordsEmitted.values().stream().reduce(0L, Long::sum); + } + + /** + * Compute sum of emitted bytes across all streams. + */ + @Override + public long getTotalBytesEmitted() { + return streamToTotalBytesEmitted.values().stream().reduce(0L, Long::sum); + } + + /** + * Compute sum of committed record counts across all streams. If the delta tracker has exceeded its + * capacity, return empty because committed record counts cannot be reliably computed. + */ + @Override + public Optional getTotalRecordsCommitted() { + if (unreliableCommittedCounts) { + return Optional.empty(); + } + return Optional.of(stateDeltaTracker.getStreamToCommittedRecords().values().stream().reduce(0L, Long::sum)); } @Override - public Optional getOutputState() { - return Optional.ofNullable(outputState.get()); + public Long getTotalStateMessagesEmitted() { + return totalEmittedStateMessages.get(); } } diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/MessageTracker.java b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/MessageTracker.java index 5213abf76a99..9e0a770e60cf 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/MessageTracker.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/MessageTracker.java @@ -6,42 +6,100 @@ import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteMessage; +import java.util.Map; import java.util.Optional; -import java.util.function.Consumer; /** * Interface to handle extracting metadata from the stream of data flowing from a Source to a * Destination. */ -public interface MessageTracker extends Consumer { +public interface MessageTracker { /** - * Accepts an AirbyteMessage and tracks any metadata about it that is required by the Platform. + * Accepts an AirbyteMessage emitted from a source and tracks any metadata about it that is required + * by the Platform. * * @param message message to derive metadata from. */ - @Override - void accept(AirbyteMessage message); + void acceptFromSource(AirbyteMessage message); /** - * Gets the records replicated. + * Accepts an AirbyteMessage emitted from a destination and tracks any metadata about it that is + * required by the Platform. * - * @return total records that passed from Source to Destination. + * @param message message to derive metadata from. + */ + void acceptFromDestination(AirbyteMessage message); + + /** + * Get the current source state of the stream. + * + * @return returns the last StateMessage that was accepted from the source. If no StateMessage was + * accepted, empty. + */ + Optional getSourceOutputState(); + + /** + * Get the current destination state of the stream. + * + * @return returns the last StateMessage that was accepted from the destination. If no StateMessage + * was accepted, empty. + */ + Optional getDestinationOutputState(); + + /** + * Get the per-stream committed record count. + * + * @return returns a map of committed record count by stream name. If committed record counts cannot + * be computed, empty. + */ + Optional> getStreamToCommittedRecords(); + + /** + * Get the per-stream emitted record count. This includes messages that were emitted by the source, + * but never committed by the destination. + * + * @return returns a map of emitted record count by stream name. + */ + Map getStreamToEmittedRecords(); + + /** + * Get the per-stream emitted byte count. This includes messages that were emitted by the source, + * but never committed by the destination. + * + * @return returns a map of emitted record count by stream name. + */ + Map getStreamToEmittedBytes(); + + /** + * Get the overall emitted record count. This includes messages that were emitted by the source, but + * never committed by the destination. + * + * @return returns the total count of emitted records across all streams. + */ + long getTotalRecordsEmitted(); + + /** + * Get the overall emitted bytes. This includes messages that were emitted by the source, but never + * committed by the destination. + * + * @return returns the total emitted bytes across all streams. */ - long getRecordCount(); + long getTotalBytesEmitted(); /** - * Gets the bytes replicated. + * Get the overall committed record count. * - * @return total bytes that passed from Source to Destination. + * @return returns the total count of committed records across all streams. If total committed + * record count cannot be computed, empty. */ - long getBytesCount(); + Optional getTotalRecordsCommitted(); /** - * Get the current state of the stream. + * Get the overall emitted state message count. * - * @return returns the last StateMessage that was accepted. If no StateMessage was accepted, empty. + * @return returns the total count of emitted state messages. */ - Optional getOutputState(); + Long getTotalStateMessagesEmitted(); } diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/StateDeltaTracker.java b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/StateDeltaTracker.java new file mode 100644 index 000000000000..93963fd38e01 --- /dev/null +++ b/airbyte-workers/src/main/java/io/airbyte/workers/protocols/airbyte/StateDeltaTracker.java @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2021 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.protocols.airbyte; + +import com.google.common.annotations.VisibleForTesting; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; + +/** + * This class tracks "deltas" between states in compact {@code byte[]}s with the following schema: + * + *
+ *  [(state hash),(stream index),(record count)...] with the last two elements repeating per stream in the delta.
+ * 
+ *

+ * This class also maintains a {@code Set} of {@code committedStateHashes} so that it can accumulate + * both committed and total record counts per stream. + *

+ * The StateDeltaTracker is initialized with a memory limit. If this memory limit is exceeded, new + * states deltas will not be added and per-stream record counts will not be able to be computed. + * This is to prevent OutOfMemoryErrors from crashing the sync. + */ +@Slf4j +public class StateDeltaTracker { + + private static final int STATE_HASH_BYTES = Integer.BYTES; + private static final int STREAM_INDEX_BYTES = Short.BYTES; + private static final int RECORD_COUNT_BYTES = Long.BYTES; + private static final int BYTES_PER_STREAM = STREAM_INDEX_BYTES + RECORD_COUNT_BYTES; + + private final Set committedStateHashes; + private final Map streamToCommittedRecords; + + /** + * Every time a state is added, a new byte[] containing the state hash and per-stream delta will be + * added to this list. Every time a state is committed, state deltas up to the committed state are + * removed from the head of the list and aggregated into the committed count map. The source thread + * adds while the destination thread removes, so synchronization is necessary to provide + * thread-safety. + */ + @VisibleForTesting + protected final List stateDeltas; + + @VisibleForTesting + protected long remainingCapacity; + @VisibleForTesting + protected boolean capacityExceeded; + + public StateDeltaTracker(final long memoryLimitBytes) { + this.committedStateHashes = new HashSet<>(); + this.streamToCommittedRecords = new HashMap<>(); + this.stateDeltas = new ArrayList<>(); + this.remainingCapacity = memoryLimitBytes; + this.capacityExceeded = false; + } + + /** + * Converts the given state hash and per-stream record count map into a {@code byte[]} and stores + * it. + *

+ * This method leverages a synchronized block to provide thread safety between the source thread + * calling addState while the destination thread calls commitStateHash. + * + * @throws StateDeltaTrackerException thrown when the memory footprint of stateDeltas exceeds + * available capacity. + */ + public void addState(final int stateHash, final Map streamIndexToRecordCount) throws StateDeltaTrackerException { + synchronized (this) { + final int size = STATE_HASH_BYTES + (streamIndexToRecordCount.size() * BYTES_PER_STREAM); + + if (capacityExceeded || remainingCapacity < size) { + capacityExceeded = true; + throw new StateDeltaTrackerException("Memory capacity is exceeded for StateDeltaTracker."); + } + + final ByteBuffer delta = ByteBuffer.allocate(size); + + delta.putInt(stateHash); + + for (final Map.Entry entry : streamIndexToRecordCount.entrySet()) { + delta.putShort(entry.getKey()); + delta.putLong(entry.getValue()); + } + + stateDeltas.add(delta.array()); + remainingCapacity -= delta.array().length; + } + } + + /** + * Mark the given {@code stateHash} as committed. + *

+ * This method leverages a synchronized block to provide thread safety between the source thread + * calling addState while the destination thread calls commitStateHash. + * + * @throws StateDeltaTrackerException thrown when committed counts can no longer be reliably + * computed. + */ + public void commitStateHash(final int stateHash) throws StateDeltaTrackerException { + synchronized (this) { + if (capacityExceeded) { + throw new StateDeltaTrackerException("Memory capacity exceeded for StateDeltaTracker, so states cannot be reliably committed"); + } + if (committedStateHashes.contains(stateHash)) { + throw new StateDeltaTrackerException( + String.format("State hash %d was already committed, likely indicating a state hash collision", stateHash)); + } + + committedStateHashes.add(stateHash); + int currStateHash; + do { + if (stateDeltas.isEmpty()) { + throw new StateDeltaTrackerException(String.format("Delta was not stored for state hash %d", stateHash)); + } + // as deltas are removed and aggregated into committed count map, reclaim capacity + final ByteBuffer currDelta = ByteBuffer.wrap(stateDeltas.remove(0)); + remainingCapacity += currDelta.capacity(); + + currStateHash = currDelta.getInt(); + + final int numStreams = (currDelta.capacity() - STATE_HASH_BYTES) / BYTES_PER_STREAM; + for (int i = 0; i < numStreams; i++) { + final short streamIndex = currDelta.getShort(); + final long recordCount = currDelta.getLong(); + + // aggregate delta into committed count map + final long committedRecordCount = streamToCommittedRecords.getOrDefault(streamIndex, 0L); + streamToCommittedRecords.put(streamIndex, committedRecordCount + recordCount); + } + } while (currStateHash != stateHash); // repeat until each delta up to the committed state is aggregated + } + } + + public Map getStreamToCommittedRecords() { + return streamToCommittedRecords; + } + + /** + * Thrown when the StateDeltaTracker encounters an issue that prevents it from reliably computing + * committed record deltas. + */ + public static class StateDeltaTrackerException extends Exception { + + public StateDeltaTrackerException(final String message) { + super(message); + } + + } + +} diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/ReplicationActivityImpl.java b/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/ReplicationActivityImpl.java index bb4aae0e1c84..8eb3533b4e1d 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/ReplicationActivityImpl.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/temporal/sync/ReplicationActivityImpl.java @@ -117,7 +117,7 @@ public StandardSyncOutput replicate(final JobRunConfig jobRunConfig, return fullSyncInput; }; - CheckedSupplier, Exception> workerFactory; + final CheckedSupplier, Exception> workerFactory; if (containerOrchestratorEnabled) { workerFactory = getContainerLauncherWorkerFactory(sourceLauncherConfig, destinationLauncherConfig, jobRunConfig, syncInput); @@ -156,6 +156,8 @@ private static StandardSyncOutput reduceReplicationOutput(final ReplicationOutpu syncSummary.setStartTime(output.getReplicationAttemptSummary().getStartTime()); syncSummary.setEndTime(output.getReplicationAttemptSummary().getEndTime()); syncSummary.setStatus(output.getReplicationAttemptSummary().getStatus()); + syncSummary.setTotalStats(output.getReplicationAttemptSummary().getTotalStats()); + syncSummary.setStreamStats(output.getReplicationAttemptSummary().getStreamStats()); final StandardSyncOutput standardSyncOutput = new StandardSyncOutput(); standardSyncOutput.setState(output.getState()); @@ -195,7 +197,6 @@ private CheckedSupplier, Exception> airbyteSource, new NamespacingMapper(syncInput.getNamespaceDefinition(), syncInput.getNamespaceFormat(), syncInput.getPrefix()), new DefaultAirbyteDestination(workerConfigs, destinationLauncher), - new AirbyteMessageTracker(), new AirbyteMessageTracker()); }; } diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/DefaultReplicationWorkerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/DefaultReplicationWorkerTest.java index a0f53ce0f005..f62e7b4d848e 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/DefaultReplicationWorkerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/DefaultReplicationWorkerTest.java @@ -29,6 +29,8 @@ import io.airbyte.config.StandardSyncInput; import io.airbyte.config.StandardSyncSummary.ReplicationStatus; import io.airbyte.config.State; +import io.airbyte.config.StreamSyncStats; +import io.airbyte.config.SyncStats; import io.airbyte.config.WorkerDestinationConfig; import io.airbyte.config.WorkerSourceConfig; import io.airbyte.config.helpers.LogClientSingleton; @@ -44,6 +46,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Duration; +import java.util.Collections; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; @@ -78,8 +82,7 @@ class DefaultReplicationWorkerTest { private StandardSyncInput syncInput; private WorkerSourceConfig sourceConfig; private WorkerDestinationConfig destinationConfig; - private AirbyteMessageTracker sourceMessageTracker; - private AirbyteMessageTracker destinationMessageTracker; + private AirbyteMessageTracker messageTracker; @SuppressWarnings("unchecked") @BeforeEach @@ -97,8 +100,7 @@ void setup() throws Exception { source = mock(AirbyteSource.class); mapper = mock(NamespacingMapper.class); destination = mock(AirbyteDestination.class); - sourceMessageTracker = mock(AirbyteMessageTracker.class); - destinationMessageTracker = mock(AirbyteMessageTracker.class); + messageTracker = mock(AirbyteMessageTracker.class); when(source.isFinished()).thenReturn(false, false, false, true); when(destination.isFinished()).thenReturn(false, false, false, true); @@ -121,8 +123,7 @@ void test() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); worker.run(syncInput, jobRoot); @@ -144,8 +145,7 @@ void testSourceNonZeroExitValue() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput output = worker.run(syncInput, jobRoot); assertEquals(ReplicationStatus.FAILED, output.getReplicationAttemptSummary().getStatus()); @@ -161,8 +161,7 @@ void testDestinationNonZeroExitValue() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput output = worker.run(syncInput, jobRoot); assertEquals(ReplicationStatus.FAILED, output.getReplicationAttemptSummary().getStatus()); @@ -181,8 +180,7 @@ void testLoggingInThreads() throws IOException, WorkerException { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); worker.run(syncInput, jobRoot); @@ -213,7 +211,7 @@ void testLogMaskRegex() throws IOException { void testCancellation() throws InterruptedException { final AtomicReference output = new AtomicReference<>(); when(source.isFinished()).thenReturn(false); - when(destinationMessageTracker.getOutputState()).thenReturn(Optional.of(new State().withState(STATE_MESSAGE.getState().getData()))); + when(messageTracker.getDestinationOutputState()).thenReturn(Optional.of(new State().withState(STATE_MESSAGE.getState().getData()))); final ReplicationWorker worker = new DefaultReplicationWorker( JOB_ID, @@ -221,8 +219,7 @@ void testCancellation() throws InterruptedException { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final Thread workerThread = new Thread(() -> { try { @@ -235,7 +232,7 @@ void testCancellation() throws InterruptedException { workerThread.start(); // verify the worker is actually running before we kill it. - while (Mockito.mockingDetails(sourceMessageTracker).getInvocations().size() < 5) { + while (Mockito.mockingDetails(messageTracker).getInvocations().size() < 5) { LOGGER.info("waiting for worker to start running"); sleep(100); } @@ -249,9 +246,12 @@ void testCancellation() throws InterruptedException { @Test void testPopulatesOutputOnSuccess() throws WorkerException { final JsonNode expectedState = Jsons.jsonNode(ImmutableMap.of("updated_at", 10L)); - when(sourceMessageTracker.getRecordCount()).thenReturn(12L); - when(sourceMessageTracker.getBytesCount()).thenReturn(100L); - when(destinationMessageTracker.getOutputState()).thenReturn(Optional.of(new State().withState(expectedState))); + when(messageTracker.getDestinationOutputState()).thenReturn(Optional.of(new State().withState(expectedState))); + when(messageTracker.getTotalRecordsEmitted()).thenReturn(12L); + when(messageTracker.getTotalBytesEmitted()).thenReturn(100L); + when(messageTracker.getTotalStateMessagesEmitted()).thenReturn(3L); + when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap("stream1", 100L)); + when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap("stream1", 12L)); final ReplicationWorker worker = new DefaultReplicationWorker( JOB_ID, @@ -259,15 +259,27 @@ void testPopulatesOutputOnSuccess() throws WorkerException { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput actual = worker.run(syncInput, jobRoot); final ReplicationOutput replicationOutput = new ReplicationOutput() .withReplicationAttemptSummary(new ReplicationAttemptSummary() .withRecordsSynced(12L) .withBytesSynced(100L) - .withStatus(ReplicationStatus.COMPLETED)) + .withStatus(ReplicationStatus.COMPLETED) + .withTotalStats(new SyncStats() + .withRecordsEmitted(12L) + .withBytesEmitted(100L) + .withStateMessagesEmitted(3L) + .withRecordsCommitted(12L)) // since success, should use emitted count + .withStreamStats(Collections.singletonList( + new StreamSyncStats() + .withStreamName("stream1") + .withStats(new SyncStats() + .withBytesEmitted(100L) + .withRecordsEmitted(12L) + .withRecordsCommitted(12L) // since success, should use emitted count + .withStateMessagesEmitted(null))))) .withOutputCatalog(syncInput.getCatalog()) .withState(new State().withState(expectedState)); @@ -291,7 +303,7 @@ void testPopulatesOutputOnSuccess() throws WorkerException { @Test void testPopulatesStateOnFailureIfAvailable() throws Exception { doThrow(new IllegalStateException("induced exception")).when(source).close(); - when(destinationMessageTracker.getOutputState()).thenReturn(Optional.of(new State().withState(STATE_MESSAGE.getState().getData()))); + when(messageTracker.getDestinationOutputState()).thenReturn(Optional.of(new State().withState(STATE_MESSAGE.getState().getData()))); final ReplicationWorker worker = new DefaultReplicationWorker( JOB_ID, @@ -299,8 +311,7 @@ void testPopulatesStateOnFailureIfAvailable() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput actual = worker.run(syncInput, jobRoot); assertNotNull(actual); @@ -317,8 +328,7 @@ void testRetainsStateOnFailureIfNewStateNotAvailable() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput actual = worker.run(syncInput, jobRoot); @@ -326,6 +336,45 @@ void testRetainsStateOnFailureIfNewStateNotAvailable() throws Exception { assertEquals(syncInput.getState().getState(), actual.getState().getState()); } + @Test + void testPopulatesStatsOnFailureIfAvailable() throws Exception { + doThrow(new IllegalStateException("induced exception")).when(source).close(); + when(messageTracker.getTotalRecordsEmitted()).thenReturn(12L); + when(messageTracker.getTotalBytesEmitted()).thenReturn(100L); + when(messageTracker.getTotalRecordsCommitted()).thenReturn(Optional.of(6L)); + when(messageTracker.getTotalStateMessagesEmitted()).thenReturn(3L); + when(messageTracker.getStreamToEmittedBytes()).thenReturn(Collections.singletonMap("stream1", 100L)); + when(messageTracker.getStreamToEmittedRecords()).thenReturn(Collections.singletonMap("stream1", 12L)); + when(messageTracker.getStreamToCommittedRecords()).thenReturn(Optional.of(Collections.singletonMap("stream1", 6L))); + + final ReplicationWorker worker = new DefaultReplicationWorker( + JOB_ID, + JOB_ATTEMPT, + source, + mapper, + destination, + messageTracker); + + final ReplicationOutput actual = worker.run(syncInput, jobRoot); + final SyncStats expectedTotalStats = new SyncStats() + .withRecordsEmitted(12L) + .withBytesEmitted(100L) + .withStateMessagesEmitted(3L) + .withRecordsCommitted(6L); + final List expectedStreamStats = Collections.singletonList( + new StreamSyncStats() + .withStreamName("stream1") + .withStats(new SyncStats() + .withBytesEmitted(100L) + .withRecordsEmitted(12L) + .withRecordsCommitted(6L) + .withStateMessagesEmitted(null))); + + assertNotNull(actual); + assertEquals(expectedTotalStats, actual.getReplicationAttemptSummary().getTotalStats()); + assertEquals(expectedStreamStats, actual.getReplicationAttemptSummary().getStreamStats()); + } + @Test void testDoesNotPopulatesStateOnFailureIfNotAvailable() throws Exception { final StandardSyncInput syncInputWithoutState = Jsons.clone(syncInput); @@ -339,8 +388,7 @@ void testDoesNotPopulatesStateOnFailureIfNotAvailable() throws Exception { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); final ReplicationOutput actual = worker.run(syncInputWithoutState, jobRoot); @@ -350,7 +398,7 @@ void testDoesNotPopulatesStateOnFailureIfNotAvailable() throws Exception { @Test void testDoesNotPopulateOnIrrecoverableFailure() { - doThrow(new IllegalStateException("induced exception")).when(sourceMessageTracker).getRecordCount(); + doThrow(new IllegalStateException("induced exception")).when(messageTracker).getTotalRecordsEmitted(); final ReplicationWorker worker = new DefaultReplicationWorker( JOB_ID, @@ -358,8 +406,7 @@ void testDoesNotPopulateOnIrrecoverableFailure() { source, mapper, destination, - sourceMessageTracker, - destinationMessageTracker); + messageTracker); assertThrows(WorkerException.class, () -> worker.run(syncInput, jobRoot)); } diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTrackerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTrackerTest.java index 030e4c403e33..8634c6eedbdb 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTrackerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/AirbyteMessageTrackerTest.java @@ -7,58 +7,275 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import com.fasterxml.jackson.databind.JsonNode; -import com.google.common.collect.ImmutableMap; +import com.google.common.base.Charsets; import io.airbyte.commons.json.Jsons; import io.airbyte.config.State; import io.airbyte.protocol.models.AirbyteMessage; import io.airbyte.protocol.models.AirbyteRecordMessage; import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.workers.protocols.airbyte.StateDeltaTracker.StateDeltaTrackerException; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.testcontainers.shaded.com.google.common.base.Charsets; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +@ExtendWith(MockitoExtension.class) class AirbyteMessageTrackerTest { + private static final String STREAM_1 = "stream1"; + private static final String STREAM_2 = "stream2"; + private static final String STREAM_3 = "stream3"; + + private AirbyteMessageTracker messageTracker; + + @Mock + private StateDeltaTracker mStateDeltaTracker; + + @BeforeEach + public void setup() { + this.messageTracker = new AirbyteMessageTracker(mStateDeltaTracker); + } + @Test - public void testIncrementsWhenRecord() { - final AirbyteMessage message = new AirbyteMessage() - .withType(AirbyteMessage.Type.RECORD) - .withRecord(new AirbyteRecordMessage().withData(Jsons.jsonNode(ImmutableMap.of("name", "rudolph")))); + public void testGetTotalRecordsStatesAndBytesEmitted() { + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 123); + final AirbyteMessage s1 = createStateMessage(1); + final AirbyteMessage s2 = createStateMessage(2); - final AirbyteMessageTracker messageTracker = new AirbyteMessageTracker(); - messageTracker.accept(message); - messageTracker.accept(message); - messageTracker.accept(message); + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromSource(s2); - assertEquals(3, messageTracker.getRecordCount()); - assertEquals(3 * Jsons.serialize(message.getRecord().getData()).getBytes(Charsets.UTF_8).length, messageTracker.getBytesCount()); + assertEquals(3, messageTracker.getTotalRecordsEmitted()); + assertEquals(3 * Jsons.serialize(r1.getRecord().getData()).getBytes(Charsets.UTF_8).length, messageTracker.getTotalBytesEmitted()); + assertEquals(2, messageTracker.getTotalStateMessagesEmitted()); } @Test - public void testRetainsLatestState() { - final JsonNode oldStateValue = Jsons.jsonNode(ImmutableMap.builder().put("lastSync", "1598900000").build()); - final AirbyteMessage oldStateMessage = new AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(new AirbyteStateMessage().withData(oldStateValue)); + public void testRetainsLatestSourceAndDestinationState() { + final int s1Value = 111; + final int s2Value = 222; + final int s3Value = 333; + final AirbyteMessage s1 = createStateMessage(s1Value); + final AirbyteMessage s2 = createStateMessage(s2Value); + final AirbyteMessage s3 = createStateMessage(s3Value); - final JsonNode newStateValue = Jsons.jsonNode(ImmutableMap.builder().put("lastSync", "1598993526").build()); - final AirbyteMessage newStateMessage = new AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(new AirbyteStateMessage().withData(newStateValue)); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromSource(s2); + messageTracker.acceptFromSource(s3); + messageTracker.acceptFromDestination(s1); + messageTracker.acceptFromDestination(s2); - final AirbyteMessageTracker messageTracker = new AirbyteMessageTracker(); - messageTracker.accept(oldStateMessage); - messageTracker.accept(oldStateMessage); - messageTracker.accept(newStateMessage); + assertTrue(messageTracker.getSourceOutputState().isPresent()); + assertEquals(new State().withState(Jsons.jsonNode(s3Value)), messageTracker.getSourceOutputState().get()); - assertTrue(messageTracker.getOutputState().isPresent()); - assertEquals(new State().withState(newStateValue), messageTracker.getOutputState().get()); + assertTrue(messageTracker.getDestinationOutputState().isPresent()); + assertEquals(new State().withState(Jsons.jsonNode(s2Value)), messageTracker.getDestinationOutputState().get()); } @Test public void testReturnEmptyStateIfNoneEverAccepted() { - final AirbyteMessageTracker MessageTracker = new AirbyteMessageTracker(); - assertTrue(MessageTracker.getOutputState().isEmpty()); + assertTrue(messageTracker.getSourceOutputState().isEmpty()); + assertTrue(messageTracker.getDestinationOutputState().isEmpty()); + } + + @Test + public void testEmittedRecordsByStream() { + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage r2 = createRecordMessage(STREAM_2, 2); + final AirbyteMessage r3 = createRecordMessage(STREAM_3, 3); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(r3); + messageTracker.acceptFromSource(r3); + messageTracker.acceptFromSource(r3); + + final Map expected = new HashMap<>(); + expected.put(STREAM_1, 1L); + expected.put(STREAM_2, 2L); + expected.put(STREAM_3, 3L); + + assertEquals(expected, messageTracker.getStreamToEmittedRecords()); + } + + @Test + public void testEmittedBytesByStream() { + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage r2 = createRecordMessage(STREAM_2, 2); + final AirbyteMessage r3 = createRecordMessage(STREAM_3, 3); + + final long r1Bytes = Jsons.serialize(r1.getRecord().getData()).getBytes(Charsets.UTF_8).length; + final long r2Bytes = Jsons.serialize(r2.getRecord().getData()).getBytes(Charsets.UTF_8).length; + final long r3Bytes = Jsons.serialize(r3.getRecord().getData()).getBytes(Charsets.UTF_8).length; + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(r3); + messageTracker.acceptFromSource(r3); + messageTracker.acceptFromSource(r3); + + final Map expected = new HashMap<>(); + expected.put(STREAM_1, r1Bytes); + expected.put(STREAM_2, r2Bytes * 2); + expected.put(STREAM_3, r3Bytes * 3); + + assertEquals(expected, messageTracker.getStreamToEmittedBytes()); + } + + @Test + public void testGetCommittedRecordsByStream() { + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage r2 = createRecordMessage(STREAM_2, 2); + final AirbyteMessage r3 = createRecordMessage(STREAM_3, 3); + final AirbyteMessage s1 = createStateMessage(1); + final AirbyteMessage s2 = createStateMessage(2); + + messageTracker.acceptFromSource(r1); // should make stream 1 index 0 + messageTracker.acceptFromSource(r2); // should make stream 2 index 1 + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(s1); // emit state 1 + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromDestination(s1); // commit state 1 + messageTracker.acceptFromSource(r3); // should make stream 3 index 2 + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s2); // emit state 2 + + final Map countsByIndex = new HashMap<>(); + final Map expected = new HashMap<>(); + Mockito.when(mStateDeltaTracker.getStreamToCommittedRecords()).thenReturn(countsByIndex); + + countsByIndex.put((short) 0, 1L); + countsByIndex.put((short) 1, 2L); + // result only contains counts up to state 1 + expected.put(STREAM_1, 1L); + expected.put(STREAM_2, 2L); + assertEquals(expected, messageTracker.getStreamToCommittedRecords().get()); + + countsByIndex.clear(); + expected.clear(); + messageTracker.acceptFromDestination(s2); // now commit state 2 + countsByIndex.put((short) 0, 3L); + countsByIndex.put((short) 1, 3L); + countsByIndex.put((short) 2, 1L); + // result updated with counts between state 1 and state 2 + expected.put(STREAM_1, 3L); + expected.put(STREAM_2, 3L); + expected.put(STREAM_3, 1L); + assertEquals(expected, messageTracker.getStreamToCommittedRecords().get()); + } + + @Test + public void testGetCommittedRecordsByStream_emptyWhenAddStateThrowsException() throws Exception { + Mockito.doThrow(new StateDeltaTrackerException("induced exception")).when(mStateDeltaTracker).addState(Mockito.anyInt(), Mockito.anyMap()); + + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage s1 = createStateMessage(1); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromDestination(s1); + + assertTrue(messageTracker.getStreamToCommittedRecords().isEmpty()); + } + + @Test + public void testGetCommittedRecordsByStream_emptyWhenCommitStateHashThrowsException() throws Exception { + Mockito.doThrow(new StateDeltaTrackerException("induced exception")).when(mStateDeltaTracker).commitStateHash(Mockito.anyInt()); + + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage s1 = createStateMessage(1); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromDestination(s1); + + assertTrue(messageTracker.getStreamToCommittedRecords().isEmpty()); + } + + @Test + public void testTotalRecordsCommitted() { + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage r2 = createRecordMessage(STREAM_2, 2); + final AirbyteMessage r3 = createRecordMessage(STREAM_3, 3); + final AirbyteMessage s1 = createStateMessage(1); + final AirbyteMessage s2 = createStateMessage(2); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromSource(s1); // emit state 1 + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(r2); + messageTracker.acceptFromDestination(s1); // commit state 1 + messageTracker.acceptFromSource(r3); + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s2); // emit state 2 + + final Map countsByIndex = new HashMap<>(); + Mockito.when(mStateDeltaTracker.getStreamToCommittedRecords()).thenReturn(countsByIndex); + + countsByIndex.put((short) 0, 1L); + countsByIndex.put((short) 1, 2L); + // result only contains counts up to state 1 + assertEquals(3L, messageTracker.getTotalRecordsCommitted().get()); + + countsByIndex.clear(); + messageTracker.acceptFromDestination(s2); // now commit state 2 + countsByIndex.put((short) 0, 3L); + countsByIndex.put((short) 1, 3L); + countsByIndex.put((short) 2, 1L); + // result updated with counts between state 1 and state 2 + assertEquals(7L, messageTracker.getTotalRecordsCommitted().get()); + } + + @Test + public void testGetTotalRecordsCommitted_emptyWhenAddStateThrowsException() throws Exception { + Mockito.doThrow(new StateDeltaTrackerException("induced exception")).when(mStateDeltaTracker).addState(Mockito.anyInt(), Mockito.anyMap()); + + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage s1 = createStateMessage(1); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromDestination(s1); + + assertTrue(messageTracker.getTotalRecordsCommitted().isEmpty()); + } + + @Test + public void testGetTotalRecordsCommitted_emptyWhenCommitStateHashThrowsException() throws Exception { + Mockito.doThrow(new StateDeltaTrackerException("induced exception")).when(mStateDeltaTracker).commitStateHash(Mockito.anyInt()); + + final AirbyteMessage r1 = createRecordMessage(STREAM_1, 1); + final AirbyteMessage s1 = createStateMessage(1); + + messageTracker.acceptFromSource(r1); + messageTracker.acceptFromSource(s1); + messageTracker.acceptFromDestination(s1); + + assertTrue(messageTracker.getTotalRecordsCommitted().isEmpty()); + } + + private AirbyteMessage createRecordMessage(final String streamName, final int recordData) { + return new AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord(new AirbyteRecordMessage().withStream(streamName).withData(Jsons.jsonNode(recordData))); + } + + private AirbyteMessage createStateMessage(final int stateData) { + return new AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState(new AirbyteStateMessage().withData(Jsons.jsonNode(stateData))); } } diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/StateDeltaTrackerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/StateDeltaTrackerTest.java new file mode 100644 index 000000000000..f7a50d038bc0 --- /dev/null +++ b/airbyte-workers/src/test/java/io/airbyte/workers/protocols/airbyte/StateDeltaTrackerTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.protocols.airbyte; + +import io.airbyte.workers.protocols.airbyte.StateDeltaTracker.StateDeltaTrackerException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class StateDeltaTrackerTest { + + private static final int STATE_1_HASH = 1; + private static final int STATE_2_HASH = 2; + private static final int STATE_3_HASH = Integer.MAX_VALUE; + private static final int NEVER_ADDED_STATE_HASH = 20; + + private static final short STREAM_INDEX_1 = (short) 111; + private static final short STREAM_INDEX_2 = (short) 222; + private static final short STREAM_INDEX_3 = (short) 333; + private static final short STREAM_INDEX_4 = Short.MAX_VALUE; + + private static final long STATE_1_STREAM_1_COUNT = 11L; + private static final long STATE_1_STREAM_2_COUNT = 12L; + + private static final long STATE_2_STREAM_1_COUNT = 21L; + private static final long STATE_2_STREAM_3_COUNT = 23L; + + private static final long STATE_3_STREAM_3_COUNT = 33L; + private static final long STATE_3_STREAM_4_COUNT = 34L; + + // enough capacity for above 3 states, which are each 24 bytes (8 byte hash + two 10 byte stream + // counts + private static final long INITIAL_DELTA_MEMORY_CAPACITY = 72L; + + private StateDeltaTracker stateDeltaTracker; + + @BeforeEach + public void setup() throws Exception { + final Map state1Counts = new HashMap<>(); + state1Counts.put(STREAM_INDEX_1, STATE_1_STREAM_1_COUNT); + state1Counts.put(STREAM_INDEX_2, STATE_1_STREAM_2_COUNT); + + final Map state2Counts = new HashMap<>(); + state2Counts.put(STREAM_INDEX_1, STATE_2_STREAM_1_COUNT); + state2Counts.put(STREAM_INDEX_3, STATE_2_STREAM_3_COUNT); + + final Map state3Counts = new HashMap<>(); + state3Counts.put(STREAM_INDEX_3, STATE_3_STREAM_3_COUNT); + state3Counts.put(STREAM_INDEX_4, STATE_3_STREAM_4_COUNT); + + stateDeltaTracker = new StateDeltaTracker(INITIAL_DELTA_MEMORY_CAPACITY); + stateDeltaTracker.addState(STATE_1_HASH, state1Counts); + stateDeltaTracker.addState(STATE_2_HASH, state2Counts); + stateDeltaTracker.addState(STATE_3_HASH, state3Counts); + } + + @Test + public void testAddState_throwsExceptionWhenCapacityExceeded() { + Assertions.assertThrows(StateDeltaTrackerException.class, () -> stateDeltaTracker.addState(4, Collections.singletonMap((short) 444, 44L))); + Assertions.assertTrue(stateDeltaTracker.capacityExceeded); + } + + @Test + public void testCommitStateHash_throwsExceptionWhenStateHashConflict() throws Exception { + stateDeltaTracker.commitStateHash(STATE_1_HASH); + stateDeltaTracker.commitStateHash(STATE_2_HASH); + + Assertions.assertThrows(StateDeltaTrackerException.class, () -> stateDeltaTracker.commitStateHash(STATE_1_HASH)); + } + + @Test + public void testCommitStateHash_throwsExceptionIfCapacityExceededEarlier() { + stateDeltaTracker.capacityExceeded = true; + Assertions.assertThrows(StateDeltaTrackerException.class, () -> stateDeltaTracker.commitStateHash(STATE_1_HASH)); + } + + @Test + public void testCommitStateHash_throwsExceptionIfCommitStateHashCalledBeforeAddingState() { + Assertions.assertThrows(StateDeltaTrackerException.class, () -> stateDeltaTracker.commitStateHash(NEVER_ADDED_STATE_HASH)); + } + + @Test + public void testGetCommittedRecordsByStream() throws Exception { + // before anything is committed, returned map should be empty and deltas should contain three states + final Map expected = new HashMap<>(); + Assertions.assertEquals(expected, stateDeltaTracker.getStreamToCommittedRecords()); + Assertions.assertEquals(3, stateDeltaTracker.stateDeltas.size()); + + stateDeltaTracker.commitStateHash(STATE_1_HASH); + expected.put(STREAM_INDEX_1, STATE_1_STREAM_1_COUNT); + expected.put(STREAM_INDEX_2, STATE_1_STREAM_2_COUNT); + Assertions.assertEquals(expected, stateDeltaTracker.getStreamToCommittedRecords()); + Assertions.assertEquals(2, stateDeltaTracker.stateDeltas.size()); + expected.clear(); + + stateDeltaTracker.commitStateHash(STATE_2_HASH); + expected.put(STREAM_INDEX_1, STATE_1_STREAM_1_COUNT + STATE_2_STREAM_1_COUNT); + expected.put(STREAM_INDEX_2, STATE_1_STREAM_2_COUNT); + expected.put(STREAM_INDEX_3, STATE_2_STREAM_3_COUNT); + Assertions.assertEquals(expected, stateDeltaTracker.getStreamToCommittedRecords()); + Assertions.assertEquals(1, stateDeltaTracker.stateDeltas.size()); + expected.clear(); + + stateDeltaTracker.commitStateHash(STATE_3_HASH); + expected.put(STREAM_INDEX_1, STATE_1_STREAM_1_COUNT + STATE_2_STREAM_1_COUNT); + expected.put(STREAM_INDEX_2, STATE_1_STREAM_2_COUNT); + expected.put(STREAM_INDEX_3, STATE_2_STREAM_3_COUNT + STATE_3_STREAM_3_COUNT); + expected.put(STREAM_INDEX_4, STATE_3_STREAM_4_COUNT); + Assertions.assertEquals(expected, stateDeltaTracker.getStreamToCommittedRecords()); + + // since all states are committed, capacity should be freed and the delta queue should be empty + Assertions.assertEquals(INITIAL_DELTA_MEMORY_CAPACITY, stateDeltaTracker.remainingCapacity); + Assertions.assertEquals(0, stateDeltaTracker.stateDeltas.size()); + } + +}