Skip to content

Commit

Permalink
Use asyncio for synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusSintonen committed Jun 15, 2024
1 parent e987df2 commit eb72670
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 163 deletions.
15 changes: 10 additions & 5 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
from .._synchronization import AsyncEvent, AsyncThreadLock, async_cancel_shield
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface

Expand Down Expand Up @@ -299,11 +299,16 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]:
return closing_connections

async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> None:
if not closing:
return

# Close connections which have been removed from the pool.
with AsyncShieldCancellation():
async def close() -> None:
for connection in closing:
await connection.aclose()

await async_cancel_shield(close)

async def aclose(self) -> None:
# Explicitly close the connection pool.
# Clears all existing requests and connections.
Expand Down Expand Up @@ -369,9 +374,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
async def aclose(self) -> None:
if not self._closed:
self._closed = True
with AsyncShieldCancellation():
if hasattr(self._stream, "aclose"):
await self._stream.aclose()

if hasattr(self._stream, "aclose"):
await async_cancel_shield(self._stream.aclose)

with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
Expand Down
10 changes: 4 additions & 6 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncShieldCancellation
from .._synchronization import AsyncLock, async_cancel_shield
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -137,9 +137,8 @@ async def handle_async_request(self, request: Request) -> Response:
},
)
except BaseException as exc:
with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
async with Trace("response_closed", logger, request) as trace:
await async_cancel_shield(self._response_closed)
raise exc

# Sending the request...
Expand Down Expand Up @@ -344,8 +343,7 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
await async_cancel_shield(self.aclose)
raise exc

async def aclose(self) -> None:
Expand Down
19 changes: 9 additions & 10 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
from .._synchronization import AsyncLock, AsyncSemaphore, async_cancel_shield
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -108,8 +108,7 @@ async def handle_async_request(self, request: Request) -> Response:
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
except BaseException as exc:
with AsyncShieldCancellation():
await self.aclose()
await async_cancel_shield(self.aclose)
raise exc

self._sent_connection_init = True
Expand Down Expand Up @@ -160,11 +159,12 @@ async def handle_async_request(self, request: Request) -> Response:
"stream_id": stream_id,
},
)
except BaseException as exc: # noqa: PIE786
with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)
except BaseException as exc:
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await async_cancel_shield(
lambda: self._response_closed(stream_id=stream_id)
)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
Expand Down Expand Up @@ -577,8 +577,7 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
await async_cancel_shield(self.aclose)
raise exc

async def aclose(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import typing
from typing import Optional

from .._synchronization import current_async_library
from .._synchronization import current_async_backend
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream


class AutoBackend(AsyncNetworkBackend):
async def _init_backend(self) -> None:
if not (hasattr(self, "_backend")):
backend = current_async_library()
backend = current_async_backend()
if backend == "trio":
from .trio import TrioBackend

Expand Down
15 changes: 10 additions & 5 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import Event, ShieldCancellation, ThreadLock
from .._synchronization import Event, ThreadLock, sync_cancel_shield
from .connection import HTTPConnection
from .interfaces import ConnectionInterface, RequestInterface

Expand Down Expand Up @@ -299,11 +299,16 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]:
return closing_connections

def _close_connections(self, closing: List[ConnectionInterface]) -> None:
if not closing:
return

# Close connections which have been removed from the pool.
with ShieldCancellation():
def close() -> None:
for connection in closing:
connection.close()

sync_cancel_shield(close)

def close(self) -> None:
# Explicitly close the connection pool.
# Clears all existing requests and connections.
Expand Down Expand Up @@ -369,9 +374,9 @@ def __iter__(self) -> Iterator[bytes]:
def close(self) -> None:
if not self._closed:
self._closed = True
with ShieldCancellation():
if hasattr(self._stream, "close"):
self._stream.close()

if hasattr(self._stream, "close"):
sync_cancel_shield(self._stream.close)

with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
Expand Down
10 changes: 4 additions & 6 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, ShieldCancellation
from .._synchronization import Lock, sync_cancel_shield
from .._trace import Trace
from .interfaces import ConnectionInterface

Expand Down Expand Up @@ -137,9 +137,8 @@ def handle_request(self, request: Request) -> Response:
},
)
except BaseException as exc:
with ShieldCancellation():
with Trace("response_closed", logger, request) as trace:
self._response_closed()
with Trace("response_closed", logger, request) as trace:
sync_cancel_shield(self._response_closed)
raise exc

# Sending the request...
Expand Down Expand Up @@ -344,8 +343,7 @@ def __iter__(self) -> Iterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with ShieldCancellation():
self.close()
sync_cancel_shield(self.close)
raise exc

def close(self) -> None:
Expand Down
19 changes: 9 additions & 10 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, Semaphore, ShieldCancellation
from .._synchronization import Lock, Semaphore, sync_cancel_shield
from .._trace import Trace
from .interfaces import ConnectionInterface

Expand Down Expand Up @@ -108,8 +108,7 @@ def handle_request(self, request: Request) -> Response:
with Trace("send_connection_init", logger, request, kwargs):
self._send_connection_init(**kwargs)
except BaseException as exc:
with ShieldCancellation():
self.close()
sync_cancel_shield(self.close)
raise exc

self._sent_connection_init = True
Expand Down Expand Up @@ -160,11 +159,12 @@ def handle_request(self, request: Request) -> Response:
"stream_id": stream_id,
},
)
except BaseException as exc: # noqa: PIE786
with ShieldCancellation():
kwargs = {"stream_id": stream_id}
with Trace("response_closed", logger, request, kwargs):
self._response_closed(stream_id=stream_id)
except BaseException as exc:
kwargs = {"stream_id": stream_id}
with Trace("response_closed", logger, request, kwargs):
sync_cancel_shield(
lambda: self._response_closed(stream_id=stream_id)
)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
Expand Down Expand Up @@ -577,8 +577,7 @@ def __iter__(self) -> typing.Iterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with ShieldCancellation():
self.close()
sync_cancel_shield(self.close)
raise exc

def close(self) -> None:
Expand Down
Loading

0 comments on commit eb72670

Please sign in to comment.