Skip to content

Commit

Permalink
PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait (#…
Browse files Browse the repository at this point in the history
…1875)

(cherry picked from commit 821811e)
  • Loading branch information
ShaneHarvey authored and blink1073 committed Oct 1, 2024
1 parent d712bc1 commit 6a7fae1
Show file tree
Hide file tree
Showing 14 changed files with 693 additions and 48 deletions.
7 changes: 4 additions & 3 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,8 @@ def __init__(
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _ALock(_create_lock())
_lock = _create_lock()
self.lock = _ALock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -1018,15 +1019,15 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self.size_cond = _ACondition(threading.Condition(_lock))
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand Down
5 changes: 3 additions & 2 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _ALock(_create_lock())
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
Expand Down
147 changes: 126 additions & 21 deletions pymongo/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
from __future__ import annotations

import asyncio
import collections
import os
import threading
import time
import weakref
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypeVar

_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")

# References to instances of _create_lock
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()

_T = TypeVar("_T")


def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
Expand All @@ -43,7 +46,14 @@ def _release_locks() -> None:
lock.release()


# Needed only for synchro.py compat.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock


class _ALock:
__slots__ = ("_lock",)

def __init__(self, lock: threading.Lock) -> None:
self._lock = lock

Expand Down Expand Up @@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()


def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)


class _ACondition:
__slots__ = ("_condition", "_waiters")

def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()

async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
Expand All @@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
await asyncio.sleep(0)

async def wait(self, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait(0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False

async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait_for(predicate, 0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e

self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise

async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result

def notify(self, n: int = 1) -> None:
self._condition.notify(n)
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break

if fut.done():
continue

try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue

idx += 1

for waiter in to_remove:
self._waiters.remove(waiter)

def notify_all(self) -> None:
self._condition.notify_all()
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))

def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]

def release(self) -> None:
self._condition.release()
Expand Down
9 changes: 5 additions & 4 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
Expand Down Expand Up @@ -988,7 +988,8 @@ def __init__(
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _create_lock()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -1014,15 +1015,15 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self.size_cond = threading.Condition(_lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand Down
7 changes: 4 additions & 3 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
Expand Down Expand Up @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
Expand Down
4 changes: 3 additions & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,9 @@ async def test_reconnect(self):

# But it can reconnect.
c.revive_host("a:1")
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
await (await c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(await c.address, ("a", 1))

async def _test_network_error(self, operation_callback):
Expand Down
5 changes: 4 additions & 1 deletion test/asynchronous/test_client_bulk_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from unittest.mock import patch

import pymongo
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
Expand Down Expand Up @@ -597,7 +598,9 @@ async def test_timeout_in_multi_batch_bulk_write(self):
timeoutMS=2000,
w="majority",
)
await client.admin.command("ping") # Init the client first.
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(10):
await client.admin.command("ping")
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, NetworkTimeout)
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ async def test_to_list_length(self):
async def test_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
Expand Down Expand Up @@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self):
async def test_command_cursor_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
Expand Down
Loading

0 comments on commit 6a7fae1

Please sign in to comment.