Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a weakref finalizer to stop notifier thread #143

Open
wants to merge 3 commits into
base: branch-0.36
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,24 +566,27 @@ def address(self):
return f"{self.prefix}{self.ip}:{self.port}"

async def start(self):
async def serve_forever(client_ep):
ucx = self.comm_class(
async def serve_forever(client_ep, *, selfref):
ucx = selfref().comm_class(
client_ep,
local_addr=self.address,
peer_addr=self.address,
deserialize=self.deserialize,
local_addr=selfref().address,
peer_addr=selfref().address,
deserialize=selfref().deserialize,
)
ucx.allow_offload = self.allow_offload
ucx.allow_offload = selfref().allow_offload
try:
await self.on_connection(ucx)
await selfref().on_connection(ucx)
except CommClosedError:
logger.debug("Connection closed before handshake completed")
return
if self.comm_handler:
await self.comm_handler(ucx)
if selfref().comm_handler:
await selfref().comm_handler(ucx)

init_once()
self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port)
self.ucxx_server = ucxx.create_listener(
functools.partial(serve_forever, selfref=weakref.ref(self)),
port=self._input_port,
)

def stop(self):
self.ucxx_server = None
Expand Down
46 changes: 28 additions & 18 deletions python/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: BSD-3-Clause

import functools
import logging
import os
import threading
Expand Down Expand Up @@ -144,37 +145,46 @@ def start_notifier_thread(self):
name="UCX-Py Async Notifier Thread",
)
self.notifier_thread.start()
weakref.finalize(
self,
functools.partial(
self.stop_notifier_thread,
self.notifier_thread_q,
self.notifier_thread,
),
)
else:
logger.debug(
"UCXX not compiled with UCXX_ENABLE_PYTHON, disabling notifier thread"
)

def stop_notifier_thread(self):
# Must be a staticmethod so that it can be used in a weakref
# finalizer on the application context
@staticmethod
def stop_notifier_thread(queue, thread):
"""
Stop Python future notifier thread

Stop the notifier thread if context is running with Python future
notification enabled via `UCXPY_ENABLE_PYTHON_FUTURE=1` or
`ucxx.init(..., enable_python_future=True)`.

.. warning:: When the notifier thread is enabled it may be necessary to
explicitly call this method before shutting down the process or
or application, otherwise it may block indefinitely waiting for
the thread to terminate. Executing `ucxx.reset()` will also run
this method, so it's not necessary to have both.
The application context arranges to call this function
automatically in a weakref finalizer when it goes out of
scope. If using the global application context, `ucxx.reset()`
will drop the reference and cause the notifier thread to be
stopped. For a user-maintained context, one must just ensure
that the reference is dropped.
"""
if self.notifier_thread_q and self.notifier_thread:
self.notifier_thread_q.put("shutdown")
while True:
# Having a timeout is required. During the notifier thread shutdown
# it may require the GIL, which will cause a deadlock with the `join()`
# call otherwise.
self.notifier_thread.join(timeout=0.01)
if not self.notifier_thread.is_alive():
break
logger.debug("Notifier thread stopped")
else:
logger.debug("Notifier thread not running")
queue.put("shutdown")
while True:
# Having a timeout is required. During the notifier thread shutdown
# it may require the GIL, which will cause a deadlock with the `join()`
# call otherwise.
thread.join(timeout=0.01)
if not thread.is_alive():
break
logger.debug("Notifier thread stopped")

def create_listener(
self,
Expand Down
11 changes: 7 additions & 4 deletions python/ucxx/_lib_async/tests/test_custom_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import asyncio
import functools
import pickle
import weakref

import numpy as np
import pytest
Expand Down Expand Up @@ -98,11 +100,12 @@ def __init__(self):
self.comm = None

def start(self):
async def serve_forever(ep):
ucx = UCX(ep)
self.comm = ucx
async def serve_forever(ep, *, selfref):
selfref().comm = UCX(ep)

self.ucxx_server = ucxx.create_listener(serve_forever)
self.ucxx_server = ucxx.create_listener(
functools.partial(serve_forever, selfref=weakref.ref(self))
)

uu = UCXListener()
uu.start()
Expand Down
7 changes: 3 additions & 4 deletions python/ucxx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def reset():

The library is initiated at next API call.
"""
stop_notifier_thread()
global _ctx
if _ctx is not None:
weakref_ctx = weakref.ref(_ctx)
Expand All @@ -104,7 +103,7 @@ def reset():
msg = (
"Trying to reset UCX but not all Endpoints and/or Listeners "
"are closed(). The following objects are still referencing "
"ApplicationContext: "
f"the global ApplicationContext {weakref_ctx()}: "
)
for o in gc.get_referrers(weakref_ctx()):
msg += "\n %s" % str(o)
Expand All @@ -113,8 +112,8 @@ def reset():

def stop_notifier_thread():
global _ctx
if _ctx:
_ctx.stop_notifier_thread()
if _ctx and _ctx.notifier_thread is not None:
_ctx.stop_notifier_thread(_ctx.notifier_thread_q, _ctx.notifier_thread)
else:
logger.debug("UCX is not initialized.")

Expand Down
Loading