diff --git a/bellows/ash.py b/bellows/ash.py index 66c0e516..1349e36d 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -4,6 +4,7 @@ import asyncio import binascii from collections.abc import Coroutine +import contextlib import dataclasses import enum import logging @@ -62,7 +63,7 @@ class Reserved(enum.IntEnum): # Maximum number of consecutive timeouts allowed while waiting to receive an ACK before # going to the FAILED state. The value 0 prevents the NCP from entering the error state # due to timeouts. -ACK_TIMEOUTS = 4 +ACK_TIMEOUTS = 5 def generate_random_sequence(length: int) -> bytes: @@ -368,14 +369,26 @@ def connection_made(self, transport): self._ezsp_protocol.connection_made(self) def connection_lost(self, exc): + self._transport = None + self._cancel_pending_data_frames() self._ezsp_protocol.connection_lost(exc) def eof_received(self): self._ezsp_protocol.eof_received() + def _cancel_pending_data_frames( + self, exc: BaseException = RuntimeError("Connection has been closed") + ): + for fut in self._pending_data_frames.values(): + if not fut.done(): + fut.set_exception(exc) + def close(self): + self._cancel_pending_data_frames() + if self._transport is not None: self._transport.close() + self._transport = None @staticmethod def _stuff_bytes(data: bytes) -> bytes: @@ -399,7 +412,9 @@ def _unstuff_bytes(data: bytes) -> bytes: for c in data: if escaped: byte = c ^ 0b00100000 - assert byte in RESERVED_BYTES + if byte not in RESERVED_BYTES: + raise ParsingError(f"Invalid escaped byte: 0x{byte:02X}") + out.append(byte) escaped = False elif c == Reserved.ESCAPE: @@ -417,7 +432,7 @@ def data_received(self, data: bytes) -> None: _LOGGER.debug( "Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE ) - self._buffer = self._buffer[:MAX_BUFFER_SIZE] + self._buffer = self._buffer[-MAX_BUFFER_SIZE:] while self._buffer: if self._discarding_until_next_flag: @@ -447,14 +462,19 @@ def data_received(self, data: bytes) -> None: if not frame_bytes: continue - data = self._unstuff_bytes(frame_bytes) - try: + data = self._unstuff_bytes(frame_bytes) frame = parse_frame(data) except Exception: _LOGGER.debug( "Failed to parse frame %r", frame_bytes, exc_info=True ) + + with contextlib.suppress(NcpFailure): + self._write_frame( + NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq), + prefix=(Reserved.CANCEL,), + ) else: self.frame_received(frame) elif reserved_byte == Reserved.CANCEL: @@ -479,7 +499,7 @@ def data_received(self, data: bytes) -> None: f"Unexpected reserved byte found: 0x{reserved_byte:02X}" ) # pragma: no cover - def _handle_ack(self, frame: DataFrame | AckFrame) -> None: + def _handle_ack(self, frame: DataFrame | AckFrame | NakFrame) -> None: # Note that ackNum is the number of the next frame the receiver expects and it # is one greater than the last frame received. for ack_num_offset in range(-TX_K, 0): @@ -494,14 +514,19 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None: def frame_received(self, frame: AshFrame) -> None: _LOGGER.debug("Received frame %r", frame) + # If a frame has ACK information (DATA, ACK, or NAK), it should be used even if + # the frame is out of sequence or invalid if isinstance(frame, DataFrame): + self._handle_ack(frame) self.data_frame_received(frame) - elif isinstance(frame, RStackFrame): - self.rstack_frame_received(frame) elif isinstance(frame, AckFrame): + self._handle_ack(frame) self.ack_frame_received(frame) elif isinstance(frame, NakFrame): + self._handle_ack(frame) self.nak_frame_received(frame) + elif isinstance(frame, RStackFrame): + self.rstack_frame_received(frame) elif isinstance(frame, RstFrame): self.rst_frame_received(frame) elif isinstance(frame, ErrorFrame): @@ -513,7 +538,6 @@ def data_frame_received(self, frame: DataFrame) -> None: # The Host may not piggyback acknowledgments and should promptly send an ACK # frame when it receives a DATA frame. if frame.frm_num == self._rx_seq: - self._handle_ack(frame) self._rx_seq = (frame.frm_num + 1) % 8 self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq)) @@ -536,14 +560,10 @@ def rstack_frame_received(self, frame: RStackFrame) -> None: self._ezsp_protocol.reset_received(frame.reset_code) def ack_frame_received(self, frame: AckFrame) -> None: - self._handle_ack(frame) + pass def nak_frame_received(self, frame: NakFrame) -> None: - err = NotAcked(frame=frame) - - for fut in self._pending_data_frames.values(): - if not fut.done(): - fut.set_exception(err) + self._cancel_pending_data_frames(NotAcked(frame=frame)) def rst_frame_received(self, frame: RstFrame) -> None: self._ncp_reset_code = None @@ -558,12 +578,8 @@ def error_frame_received(self, frame: ErrorFrame) -> None: self._enter_failed_state(self._ncp_reset_code) def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None: - exc = NcpFailure(code=reset_code) - - for fut in self._pending_data_frames.values(): - if not fut.done(): - fut.set_exception(exc) - + self._ncp_state = NcpState.FAILED + self._cancel_pending_data_frames(NcpFailure(code=reset_code)) self._ezsp_protocol.reset_received(reset_code) def _write_frame( @@ -573,6 +589,9 @@ def _write_frame( prefix: tuple[Reserved] = (), suffix: tuple[Reserved] = (Reserved.FLAG,), ) -> None: + if self._transport is None or self._transport.is_closing(): + raise NcpFailure("Transport is closed, cannot send frame") + if _LOGGER.isEnabledFor(logging.DEBUG): prefix_str = "".join([f"{r.name} + " for r in prefix]) suffix_str = "".join([f" + {r.name}" for r in suffix]) @@ -631,7 +650,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None: await ack_future except NotAcked: _LOGGER.debug( - "NCP responded with NAK. Retrying (attempt %d)", attempt + 1 + "NCP responded with NAK to %r. Retrying (attempt %d)", + frame, + attempt + 1, ) # For timing purposes, NAK can be treated as an ACK @@ -650,9 +671,10 @@ async def _send_data_frame(self, frame: AshFrame) -> None: raise except asyncio.TimeoutError: _LOGGER.debug( - "No ACK received in %0.2fs (attempt %d)", + "No ACK received in %0.2fs (attempt %d) for %r", self._t_rx_ack, attempt + 1, + frame, ) # If a DATA frame acknowledgement is not received within the # current timeout value, then t_rx_ack is doubled. diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 273add3b..1ff511d3 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -12,8 +12,6 @@ from typing import Any, Callable, Generator import urllib.parse -from zigpy.datastructures import PriorityDynamicBoundedSemaphore - if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover else: @@ -41,8 +39,6 @@ NETWORK_OPS_TIMEOUT = 10 NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1 -MAX_COMMAND_CONCURRENCY = 1 - class EZSP: _BY_VERSION = { @@ -66,7 +62,6 @@ def __init__(self, device_config: dict): self._ezsp_version = v4.EZSPv4.VERSION self._gw = None self._protocol = None - self._send_sem = PriorityDynamicBoundedSemaphore(value=MAX_COMMAND_CONCURRENCY) self._stack_status_listeners: collections.defaultdict[ t.sl_Status, list[asyncio.Future] @@ -190,21 +185,6 @@ def close(self): self._gw.close() self._gw = None - def _get_command_priority(self, name: str) -> int: - return { - # Deprioritize any commands that send packets - "set_source_route": -1, - "setExtendedTimeout": -1, - "send_unicast": -1, - "send_multicast": -1, - "send_broadcast": -1, - # Prioritize watchdog commands - "nop": 999, - "readCounters": 999, - "readAndClearCounters": 999, - "getValue": 999, - }.get(name, 0) - async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: command = getattr(self._protocol, name) @@ -217,8 +197,7 @@ async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: ) raise EzspError("EZSP is not running") - async with self._send_sem(priority=self._get_command_priority(name)): - return await command(*args, **kwargs) + return await command(*args, **kwargs) async def _list_command( self, name, item_frames, completion_frame, spos, *args, **kwargs diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index f9eca74e..7dedb8a8 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -6,6 +6,7 @@ import functools import logging import sys +import time from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable import zigpy.state @@ -15,6 +16,8 @@ else: from asyncio import timeout as asyncio_timeout # pragma: no cover +from zigpy.datastructures import PriorityDynamicBoundedSemaphore + from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError import bellows.types as t @@ -23,7 +26,9 @@ from bellows.uart import Gateway LOGGER = logging.getLogger(__name__) -EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2 + +EZSP_CMD_TIMEOUT = 10 +MAX_COMMAND_CONCURRENCY = 1 class ProtocolHandler(abc.ABC): @@ -42,6 +47,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: for name, (cmd_id, tx_schema, rx_schema) in self.COMMANDS.items() } self.tc_policy = 0 + self._send_semaphore = PriorityDynamicBoundedSemaphore( + value=MAX_COMMAND_CONCURRENCY + ) # Cached by `set_extended_timeout` so subsequent calls are a little faster self._address_table_size: int | None = None @@ -65,18 +73,60 @@ def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]: def _ezsp_frame_tx(self, name: str) -> bytes: """Serialize the named frame.""" + def _get_command_priority(self, name: str) -> int: + return { + # Deprioritize any commands that send packets + "setSourceRoute": -1, + "setExtendedTimeout": -1, + "sendUnicast": -1, + "sendMulticast": -1, + "sendBroadcast": -1, + # Prioritize watchdog commands + "nop": 999, + "readCounters": 999, + "readAndClearCounters": 999, + "getValue": 999, + }.get(name, 0) + async def command(self, name, *args, **kwargs) -> Any: """Serialize command and send it.""" - LOGGER.debug("Sending command %s: %s %s", name, args, kwargs) - data = self._ezsp_frame(name, *args, **kwargs) - cmd_id, _, rx_schema = self.COMMANDS[name] - future = asyncio.get_running_loop().create_future() - self._awaiting[self._seq] = (cmd_id, rx_schema, future) - self._seq = (self._seq + 1) % 256 - - async with asyncio_timeout(EZSP_CMD_TIMEOUT): + delayed = False + send_time = None + + if self._send_semaphore.locked(): + delayed = True + send_time = time.monotonic() + + LOGGER.debug( + "Send semaphore is locked, delaying before sending %s(%r, %r)", + name, + args, + kwargs, + ) + + async with self._send_semaphore(priority=self._get_command_priority(name)): + if delayed: + LOGGER.debug( + "Sending command %s: %s %s after %0.2fs delay", + name, + args, + kwargs, + time.monotonic() - send_time, + ) + else: + LOGGER.debug("Sending command %s: %s %s", name, args, kwargs) + + data = self._ezsp_frame(name, *args, **kwargs) + cmd_id, _, rx_schema = self.COMMANDS[name] + + future = asyncio.get_running_loop().create_future() + self._awaiting[self._seq] = (cmd_id, rx_schema, future) + self._seq = (self._seq + 1) % 256 + await self._gw.send_data(data) - return await future + + async with asyncio_timeout(EZSP_CMD_TIMEOUT): + return await future async def update_policies(self, policy_config: dict) -> None: """Set up the policies for what the NCP should do.""" diff --git a/bellows/uart.py b/bellows/uart.py index ee2aea08..d48838d5 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -19,21 +19,6 @@ class Gateway(asyncio.Protocol): - FLAG = b"\x7E" # Marks end of frame - ESCAPE = b"\x7D" - XON = b"\x11" # Resume transmission - XOFF = b"\x13" # Stop transmission - SUBSTITUTE = b"\x18" - CANCEL = b"\x1A" # Terminates a frame in progress - STUFF = 0x20 - RANDOMIZE_START = 0x42 - RANDOMIZE_SEQ = 0xB8 - - RESERVED = FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL - - class Terminator: - pass - def __init__(self, application, connected_future=None, connection_done_future=None): self._application = application diff --git a/tests/test_ash.py b/tests/test_ash.py index 9b479c2d..cb7c356a 100644 --- a/tests/test_ash.py +++ b/tests/test_ash.py @@ -11,11 +11,6 @@ import bellows.types as t -@pytest.fixture(autouse=True, scope="function") -def random_seed(): - random.seed(0) - - class AshNcpProtocol(ash.AshProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -86,14 +81,21 @@ def send_reset(self) -> None: class FakeTransport: - def __init__(self, receiver): + def __init__(self, receiver) -> None: self.receiver = receiver - self.paused = False + self.paused: bool = False + self.closing: bool = False def write(self, data: bytes) -> None: if not self.paused: self.receiver.data_received(data) + def close(self) -> None: + self.closing = True + + def is_closing(self) -> bool: + return self.closing + class FakeTransportOneByteAtATime(FakeTransport): def write(self, data: bytes) -> None: @@ -175,6 +177,10 @@ def test_stuffing(): assert ash.AshProtocol._stuff_bytes(b"\x7F") == b"\x7F" assert ash.AshProtocol._unstuff_bytes(b"\x7F") == b"\x7F" + with pytest.raises(ash.ParsingError): + # AB is not a sequence of bytes that can be unescaped + assert ash.AshProtocol._unstuff_bytes(b"\x7D\xAB") + def test_pseudo_random_data_sequence(): assert ash.PSEUDO_RANDOM_DATA_SEQUENCE.startswith(b"\x42\x21\xA8\x54\x2A") @@ -317,6 +323,7 @@ async def test_sequence(): loop = asyncio.get_running_loop() ezsp = MagicMock() transport = MagicMock() + transport.is_closing.return_value = False protocol = ash.AshProtocol(ezsp) protocol._write_frame = MagicMock(wraps=protocol._write_frame) @@ -408,6 +415,7 @@ async def test_ash_protocol_startup(caplog): ezsp = MagicMock() transport = MagicMock() + transport.is_closing.return_value = False protocol = ash.AshProtocol(ezsp) protocol._write_frame = MagicMock(wraps=protocol._write_frame) @@ -493,7 +501,7 @@ async def test_ash_protocol_startup(caplog): ], ) async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: - asyncio.get_running_loop() + random.seed(2) host_ezsp = MagicMock() ncp_ezsp = MagicMock() @@ -549,8 +557,9 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None: send_task = asyncio.create_task(host.send_data(b"ncp NAKing")) await asyncio.sleep(host._t_rx_ack) - # It'll still succeed - await send_task + # The NCP is in a failed state, we can't send it + with pytest.raises(ash.NcpFailure): + await send_task ncp_ezsp.data_received.reset_mock() host_ezsp.data_received.reset_mock() diff --git a/tests/test_ezsp_v4.py b/tests/test_ezsp_v4.py index 1fec0187..fc31d251 100644 --- a/tests/test_ezsp_v4.py +++ b/tests/test_ezsp_v4.py @@ -1,5 +1,6 @@ +import asyncio import logging -from unittest.mock import MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest import zigpy.state @@ -515,3 +516,27 @@ async def test_set_extended_timeout_bad_table_size(ezsp_f) -> None: assert ezsp_f.getConfigurationValue.mock_calls == [ call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) ] + + +async def test_send_concurrency(ezsp_f, caplog) -> None: + async def send_data(data: bytes) -> None: + await asyncio.sleep(0.1) + + rsp_data = bytearray(data) + rsp_data[1] |= 0x80 + + ezsp_f.__call__(rsp_data) + + ezsp_f._gw.send_data = AsyncMock(side_effect=send_data) + + with caplog.at_level(logging.DEBUG): + await asyncio.gather( + ezsp_f.command("nop"), + ezsp_f.command("nop"), + ezsp_f.command("nop"), + ezsp_f.command("nop"), + ) + + # All but the first queue up + assert caplog.text.count("Send semaphore is locked, delaying before sending") == 3 + assert caplog.text.count("s delay") == 3 diff --git a/tests/test_uart.py b/tests/test_uart.py index aadacdf8..fdd404fb 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -188,6 +188,7 @@ def test_eof_received(gw): async def test_connection_lost_reset_error_propagation(monkeypatch): app = MagicMock() transport = MagicMock() + transport.is_closing.return_value = False async def mockconnect(loop, protocol_factory, **kwargs): protocol = protocol_factory()