Skip to content

Commit

Permalink
Add trio.open_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Jul 30, 2018
1 parent 3bf8564 commit 045cd60
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 0 deletions.
3 changes: 3 additions & 0 deletions trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from ._sync import *
__all__ += _sync.__all__

from ._channel import *
__all__ += _channel.__all__

from ._threads import *
__all__ += _threads.__all__

Expand Down
202 changes: 202 additions & 0 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from collections import deque, OrderedDict
from math import inf

import attr
from outcome import Error, Value

from . import _core
from ._util import aiter_compat

__all__ = ["open_channel", "EndOfChannel", "BrokenChannelError"]

# TODO:
# - introspection:
# - statistics
# - capacity, usage
# - repr
# - BrokenResourceError?
# - tests
# - docs


class EndOfChannel(Exception):
pass


class BrokenChannelError(Exception):
pass


def open_channel(capacity):
if capacity != inf and not isinstance(capacity, int):
raise TypeError("capacity must be an integer or math.inf")
if capacity < 0:
raise ValueError("capacity must be >= 0")
buf = ChannelBuf(capacity)
return PutChannel(buf), GetChannel(buf)


@attr.s(cmp=False, hash=False)
class ChannelBuf:
capacity = attr.ib()
data = attr.ib(default=attr.Factory(deque))
# counts
put_channels = attr.ib(default=0)
get_channels = attr.ib(default=0)
# {task: value}
put_tasks = attr.ib(default=attr.Factory(OrderedDict))
# {task: None}
get_tasks = attr.ib(default=attr.Factory(OrderedDict))


class PutChannel:
def __init__(self, buf):
self._buf = buf
self.closed = False
self._tasks = set()
self._buf.put_channels += 1

@_core.disable_ki_protection
def put_nowait(self, value):
if self.closed:
raise _core.ClosedResourceError
if not self._buf.get_channels:
raise BrokenChannelError
if self._buf.get_tasks:
assert not self._buf.data
task = next(iter(self._buf.get_tasks))
_core.reschedule(task, Value(value))
elif len(self._buf.data) < self._buf.capacity:
self._buf.data.append(value)
else:
raise _core.WouldBlock

@_core.disable_ki_protection
async def put(self, value):
await _core.checkpoint_if_cancelled()
try:
self.put_nowait(value)
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return

task = _core.current_task()
self._tasks.add(task)
self._buf.put_tasks[task] = value

def abort_fn(_):
self._tasks.remove(task)
del self._buf.put_tasks[task]
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort_fn, always_abort=True)

@_core.disable_ki_protection
def clone(self):
if self.closed:
raise _core.ClosedResourceError
return PutChannel(self._buf)

@_core.disable_ki_protection
def close(self):
if self.closed:
return
self.closed = True
for task in list(self._tasks):
_core.reschedule(task, Error(ClosedResourceError()))
self._buf.put_channels -= 1
if self._buf.put_channels == 0:
assert not self._buf.put_tasks
for task in list(self._buf.get_tasks):
_core.reschedule(task, Error(EndOfChannel()))

def __enter__(self):
return self

def __exit__(self, *args):
self.close()


class GetChannel:
def __init__(self, buf):
self._buf = buf
self.closed = False
self._tasks = set()
self._buf.get_channels += 1

@_core.disable_ki_protection
def get_nowait(self):
if self.closed:
raise _core.ClosedResourceError
buf = self._buf
if buf.put_tasks:
task, value = next(iter(buf.put_tasks.items()))
_core.reschedule(task)
return value
if buf.data:
return buf.data.popleft()
if not buf.put_channels:
raise EndOfChannel
raise _core.WouldBlock

@_core.disable_ki_protection
async def get(self):
await _core.checkpoint_if_cancelled()
try:
return self.get_nowait()
except _core.WouldBlock:
pass
else:
await _core.cancel_shielded_checkpoint()
return

task = _core.current_task()
self._tasks.add(task)
self._buf.get_tasks[task] = None

def abort_fn(_):
self._tasks.remove(task)
del self._buf.get_tasks[task]
return _core.Abort.SUCCEEDED

return await _core.wait_task_rescheduled(abort_fn, always_abort=True)

@_core.disable_ki_protection
def clone(self):
if self.closed:
raise _core.ClosedResourceError
return GetChannel(self._buf)

@_core.disable_ki_protection
def close(self):
if self.closed:
return
self.closed = True
for task in list(self._tasks):
_core.reschedule(task, Error(ClosedResourceError()))
self._buf.get_channels -= 1
if self._buf.get_channels == 0:
assert not self._buf.get_tasks
for task in list(self._buf.put_tasks):
_core.reschedule(task, Error(BrokenChannelError()))
# XX: or if we're losing data, maybe we should raise a
# BrokenChannelError here?
self._buf.data.clear()

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
try:
return await self.get()
except EndOfChannel:
raise StopAsyncIteration

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

0 comments on commit 045cd60

Please sign in to comment.