Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingest/partition-executor): Fix deadlock by recomputing ready items #11853

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@

test_api_requirements = {
"pytest>=6.2.2",
"pytest-timeout",
# Missing numpy requirement in 8.0.0
"deepdiff!=8.0.0",
"PyYAML",
Expand Down
25 changes: 16 additions & 9 deletions metadata-ingestion/src/datahub/utilities/partition_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def __init__(
process_batch: Callable[[List], None],
max_per_batch: int = 100,
min_process_interval: timedelta = _DEFAULT_BATCHER_MIN_PROCESS_INTERVAL,
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
read_from_pending_interval: timedelta = timedelta(seconds=3),
) -> None:
"""Similar to PartitionExecutor, but with batching.

Expand All @@ -262,8 +267,10 @@ def __init__(
self.max_per_batch = max_per_batch
self.process_batch = process_batch
self.min_process_interval = min_process_interval
self.read_from_pending_interval = read_from_pending_interval
assert self.max_workers > 1

self.state_lock = threading.Lock()
self._executor = ThreadPoolExecutor(
# We add one here to account for the clearinghouse worker thread.
max_workers=max_workers + 1,
Expand Down Expand Up @@ -362,12 +369,8 @@ def _build_batch() -> List[_BatchPartitionWorkItem]:
if not blocking:
next_item = self._pending.get_nowait()
else:
# Why 3 seconds? It's somewhat arbitrary.
# We don't want it to be too high, since then liveness suffers,
# particularly during a dirty shutdown. If it's too low, then we'll
# waste CPU cycles rechecking the timer, only to call get again.
next_item = self._pending.get(
timeout=3, # seconds
timeout=self.read_from_pending_interval.total_seconds(),
)

if next_item is None: # None is the shutdown signal
Expand All @@ -379,6 +382,9 @@ def _build_batch() -> List[_BatchPartitionWorkItem]:
pending_key_completion.append(next_item)
else:
next_batch.append(next_item)

if not next_batch:
next_batch = _find_ready_items()
except queue.Empty:
if not blocking:
break
Expand Down Expand Up @@ -452,10 +458,11 @@ def _ensure_clearinghouse_started(self) -> None:
f"{self.__class__.__name__} is shutting down; cannot submit new work items."
)

# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)
with self.state_lock:
# Lazily start the clearinghouse worker.
if not self._clearinghouse_started:
self._clearinghouse_started = True
self._executor.submit(self._clearinghouse_worker)

def submit(
self,
Expand Down
17 changes: 14 additions & 3 deletions metadata-ingestion/tests/unit/utilities/test_partition_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import math
import time
from concurrent.futures import Future

import pytest
from pydantic.schema import timedelta

from datahub.utilities.partition_executor import (
BatchPartitionExecutor,
PartitionExecutor,
Expand Down Expand Up @@ -129,23 +133,30 @@ def process_batch(batch):
}


@pytest.mark.timeout(10)
def test_batch_partition_executor_max_batch_size():
n = 20 # Exceed max_pending to test for deadlocks when max_pending exceeded
batches_processed = []

def process_batch(batch):
batches_processed.append(batch)
time.sleep(0.1) # Simulate batch processing time

with BatchPartitionExecutor(
max_workers=5, max_pending=20, process_batch=process_batch, max_per_batch=2
max_workers=5,
max_pending=10,
process_batch=process_batch,
max_per_batch=2,
min_process_interval=timedelta(seconds=1),
read_from_pending_interval=timedelta(seconds=1),
) as executor:
# Submit more tasks than the max_per_batch to test batching limits.
for i in range(5):
for i in range(n):
executor.submit("key3", "key3", f"task{i}")

# Check the batches.
logger.info(f"batches_processed: {batches_processed}")
assert len(batches_processed) == 3
assert len(batches_processed) == math.ceil(n / 2), "Incorrect number of batches"
for batch in batches_processed:
assert len(batch) <= 2, "Batch size exceeded max_per_batch limit"

Expand Down
Loading