Skip to content

Commit

Permalink
fix(ingest/partition-executor): Fix deadlock by recomputing ready ite…
Browse files Browse the repository at this point in the history
  • Loading branch information
asikowitz authored Nov 14, 2024
1 parent 383a70a commit 5ff6295
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
1 change: 1 addition & 0 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@

test_api_requirements = {
"pytest>=6.2.2",
"pytest-timeout",
# Missing numpy requirement in 8.0.0
"deepdiff!=8.0.0",
"PyYAML",
Expand Down
25 changes: 16 additions & 9 deletions metadata-ingestion/src/datahub/utilities/partition_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def __init__(
process_batch: Callable[[List], None],
max_per_batch: int = 100,
min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL,
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
read_from_pending_interval: timedelta = timedelta(seconds=3),
) -> None:
"""Similar to PartitionExecutor, but with batching.
Expand All @@ -262,8 +267,10 @@ def __init__(
self.max_per_batch = max_per_batch
self.process_batch = process_batch
self.min_process_interval = min_process_interval
self.read_from_pending_interval = read_from_pending_interval
assert self.max_workers > 1

self.state_lock = threading.Lock()
self._executor = ThreadPoolExecutor(
# We add one here to account for the clearinghouse worker thread.
max_workers=max_workers + 1,
Expand Down Expand Up @@ -362,12 +369,8 @@ def _build_batch() -> List[_BatchPartitionWorkItem]:
if not blocking:
next_item = self._pending.get_nowait()
else:
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
next_item = self._pending.get(
timeout=3, # seconds
timeout=self.read_from_pending_interval.total_seconds(),
)

if next_item is None: # None is the shutdown signal
Expand All @@ -379,6 +382,9 @@ def _build_batch() -> List[_BatchPartitionWorkItem]:
pending_key_completion.append(next_item)
else:
next_batch.append(next_item)

if not next_batch:
next_batch = _find_ready_items()
except queue.Empty:
if not blocking:
break
Expand Down Expand Up @@ -452,10 +458,11 @@ def _ensure_clearinghouse_started(self) -> None:
f"{self.__class__.__name__} is shutting down; cannot submit new work items."
)

# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)
with self.state_lock:
# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)

def submit(
self,
Expand Down
17 changes: 14 additions & 3 deletions metadata-ingestion/tests/unit/utilities/test_partition_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import math
import time
from concurrent.futures import Future

import pytest
from pydantic.schema import timedelta

from datahub.utilities.partition_executor import (
BatchPartitionExecutor,
PartitionExecutor,
Expand Down Expand Up @@ -129,23 +133,30 @@ def process_batch(batch):
}


@pytest.mark.timeout(10)
def test_batch_partition_executor_max_batch_size():
n = 20 # Exceed max_pending to test for deadlocks when max_pending exceeded
batches_processed = []

def process_batch(batch):
batches_processed.append(batch)
time.sleep(0.1) # Simulate batch processing time

with BatchPartitionExecutor(
max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2
max_workers=5,
max_pending=10,
process_batch=process_batch,
max_per_batch=2,
min_process_interval=timedelta(seconds=1),
read_from_pending_interval=timedelta(seconds=1),
) as executor:
# Submit more tasks than the max_per_batch to test batching limits.
for i in range(5):
for i in range(n):
executor.submit("key3", "key3", f"task{i}")

# Check the batches.
logger.info(f"batches_processed: {batches_processed}")
assert len(batches_processed) == 3
assert len(batches_processed) == math.ceil(n / 2), "Incorrect number of batches"
for batch in batches_processed:
assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"

Expand Down

0 comments on commit 5ff6295

Please sign in to comment.