Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move EZSP send lock from EZSP to individual protocol handlers #649

Merged
merged 20 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import binascii
from collections.abc import Coroutine
import contextlib
import dataclasses
import enum
import logging
Expand Down Expand Up @@ -62,7 +63,7 @@ class Reserved(enum.IntEnum):
# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
# due to timeouts.
ACK_TIMEOUTS = 4
ACK_TIMEOUTS = 5


def generate_random_sequence(length: int) -> bytes:
Expand Down Expand Up @@ -368,14 +369,26 @@ def connection_made(self, transport):
self._ezsp_protocol.connection_made(self)

def connection_lost(self, exc):
self._transport = None
self._cancel_pending_data_frames()
self._ezsp_protocol.connection_lost(exc)

def eof_received(self):
self._ezsp_protocol.eof_received()

def _cancel_pending_data_frames(
self, exc: BaseException = RuntimeError("Connection has been closed")
):
for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

def close(self):
self._cancel_pending_data_frames()

if self._transport is not None:
self._transport.close()
self._transport = None

@staticmethod
def _stuff_bytes(data: bytes) -> bytes:
Expand All @@ -399,7 +412,9 @@ def _unstuff_bytes(data: bytes) -> bytes:
for c in data:
if escaped:
byte = c ^ 0b00100000
assert byte in RESERVED_BYTES
if byte not in RESERVED_BYTES:
raise ParsingError(f"Invalid escaped byte: 0x{byte:02X}")

out.append(byte)
escaped = False
elif c == Reserved.ESCAPE:
Expand All @@ -417,7 +432,7 @@ def data_received(self, data: bytes) -> None:
_LOGGER.debug(
"Truncating buffer to %s bytes, it is growing too fast", MAX_BUFFER_SIZE
)
self._buffer = self._buffer[:MAX_BUFFER_SIZE]
self._buffer = self._buffer[-MAX_BUFFER_SIZE:]

while self._buffer:
if self._discarding_until_next_flag:
Expand Down Expand Up @@ -447,14 +462,19 @@ def data_received(self, data: bytes) -> None:
if not frame_bytes:
continue

data = self._unstuff_bytes(frame_bytes)

try:
data = self._unstuff_bytes(frame_bytes)
frame = parse_frame(data)
except Exception:
_LOGGER.debug(
"Failed to parse frame %r", frame_bytes, exc_info=True
)

with contextlib.suppress(NcpFailure):
self._write_frame(
NakFrame(res=0, ncp_ready=0, ack_num=self._rx_seq),
prefix=(Reserved.CANCEL,),
)
else:
self.frame_received(frame)
elif reserved_byte == Reserved.CANCEL:
Expand All @@ -479,7 +499,7 @@ def data_received(self, data: bytes) -> None:
f"Unexpected reserved byte found: 0x{reserved_byte:02X}"
) # pragma: no cover

def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
def _handle_ack(self, frame: DataFrame | AckFrame | NakFrame) -> None:
# Note that ackNum is the number of the next frame the receiver expects and it
# is one greater than the last frame received.
for ack_num_offset in range(-TX_K, 0):
Expand All @@ -494,14 +514,19 @@ def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
def frame_received(self, frame: AshFrame) -> None:
_LOGGER.debug("Received frame %r", frame)

# If a frame has ACK information (DATA, ACK, or NAK), it should be used even if
# the frame is out of sequence or invalid
if isinstance(frame, DataFrame):
self._handle_ack(frame)
self.data_frame_received(frame)
elif isinstance(frame, RStackFrame):
self.rstack_frame_received(frame)
elif isinstance(frame, AckFrame):
self._handle_ack(frame)
self.ack_frame_received(frame)
elif isinstance(frame, NakFrame):
self._handle_ack(frame)
self.nak_frame_received(frame)
elif isinstance(frame, RStackFrame):
self.rstack_frame_received(frame)
elif isinstance(frame, RstFrame):
self.rst_frame_received(frame)
elif isinstance(frame, ErrorFrame):
Expand All @@ -513,7 +538,6 @@ def data_frame_received(self, frame: DataFrame) -> None:
# The Host may not piggyback acknowledgments and should promptly send an ACK
# frame when it receives a DATA frame.
if frame.frm_num == self._rx_seq:
self._handle_ack(frame)
self._rx_seq = (frame.frm_num + 1) % 8
self._write_frame(AckFrame(res=0, ncp_ready=0, ack_num=self._rx_seq))

Expand All @@ -536,14 +560,10 @@ def rstack_frame_received(self, frame: RStackFrame) -> None:
self._ezsp_protocol.reset_received(frame.reset_code)

def ack_frame_received(self, frame: AckFrame) -> None:
self._handle_ack(frame)
pass

def nak_frame_received(self, frame: NakFrame) -> None:
err = NotAcked(frame=frame)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(err)
self._cancel_pending_data_frames(NotAcked(frame=frame))

def rst_frame_received(self, frame: RstFrame) -> None:
self._ncp_reset_code = None
Expand All @@ -558,12 +578,8 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
self._enter_failed_state(self._ncp_reset_code)

def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
exc = NcpFailure(code=reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

self._ncp_state = NcpState.FAILED
self._cancel_pending_data_frames(NcpFailure(code=reset_code))
self._ezsp_protocol.reset_received(reset_code)

def _write_frame(
Expand All @@ -573,6 +589,9 @@ def _write_frame(
prefix: tuple[Reserved] = (),
suffix: tuple[Reserved] = (Reserved.FLAG,),
) -> None:
if self._transport is None or self._transport.is_closing():
raise NcpFailure("Transport is closed, cannot send frame")

if _LOGGER.isEnabledFor(logging.DEBUG):
prefix_str = "".join([f"{r.name} + " for r in prefix])
suffix_str = "".join([f" + {r.name}" for r in suffix])
Expand Down Expand Up @@ -631,7 +650,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
await ack_future
except NotAcked:
_LOGGER.debug(
"NCP responded with NAK. Retrying (attempt %d)", attempt + 1
"NCP responded with NAK to %r. Retrying (attempt %d)",
frame,
attempt + 1,
)

# For timing purposes, NAK can be treated as an ACK
Expand All @@ -650,9 +671,10 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
raise
except asyncio.TimeoutError:
_LOGGER.debug(
"No ACK received in %0.2fs (attempt %d)",
"No ACK received in %0.2fs (attempt %d) for %r",
self._t_rx_ack,
attempt + 1,
frame,
)
# If a DATA frame acknowledgement is not received within the
# current timeout value, then t_rx_ack is doubled.
Expand Down
23 changes: 1 addition & 22 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from typing import Any, Callable, Generator
import urllib.parse

from zigpy.datastructures import PriorityDynamicBoundedSemaphore

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
else:
Expand Down Expand Up @@ -41,8 +39,6 @@
NETWORK_OPS_TIMEOUT = 10
NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1

MAX_COMMAND_CONCURRENCY = 1


class EZSP:
_BY_VERSION = {
Expand All @@ -66,7 +62,6 @@ def __init__(self, device_config: dict):
self._ezsp_version = v4.EZSPv4.VERSION
self._gw = None
self._protocol = None
self._send_sem = PriorityDynamicBoundedSemaphore(value=MAX_COMMAND_CONCURRENCY)

self._stack_status_listeners: collections.defaultdict[
t.sl_Status, list[asyncio.Future]
Expand Down Expand Up @@ -190,21 +185,6 @@ def close(self):
self._gw.close()
self._gw = None

def _get_command_priority(self, name: str) -> int:
return {
# Deprioritize any commands that send packets
"set_source_route": -1,
"setExtendedTimeout": -1,
"send_unicast": -1,
"send_multicast": -1,
"send_broadcast": -1,
# Prioritize watchdog commands
"nop": 999,
"readCounters": 999,
"readAndClearCounters": 999,
"getValue": 999,
}.get(name, 0)

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
command = getattr(self._protocol, name)

Expand All @@ -217,8 +197,7 @@ async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
)
raise EzspError("EZSP is not running")

async with self._send_sem(priority=self._get_command_priority(name)):
return await command(*args, **kwargs)
return await command(*args, **kwargs)

async def _list_command(
self, name, item_frames, completion_frame, spos, *args, **kwargs
Expand Down
70 changes: 60 additions & 10 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import logging
import sys
import time
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable

import zigpy.state
Expand All @@ -15,6 +16,8 @@
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover

from zigpy.datastructures import PriorityDynamicBoundedSemaphore

from bellows.config import CONF_EZSP_POLICIES
from bellows.exception import InvalidCommandError
import bellows.types as t
Expand All @@ -23,7 +26,9 @@
from bellows.uart import Gateway

LOGGER = logging.getLogger(__name__)
EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2

EZSP_CMD_TIMEOUT = 10
MAX_COMMAND_CONCURRENCY = 1


class ProtocolHandler(abc.ABC):
Expand All @@ -42,6 +47,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
for name, (cmd_id, tx_schema, rx_schema) in self.COMMANDS.items()
}
self.tc_policy = 0
self._send_semaphore = PriorityDynamicBoundedSemaphore(
value=MAX_COMMAND_CONCURRENCY
)

# Cached by `set_extended_timeout` so subsequent calls are a little faster
self._address_table_size: int | None = None
Expand All @@ -65,18 +73,60 @@ def _ezsp_frame_rx(self, data: bytes) -> tuple[int, int, bytes]:
def _ezsp_frame_tx(self, name: str) -> bytes:
"""Serialize the named frame."""

def _get_command_priority(self, name: str) -> int:
return {
# Deprioritize any commands that send packets
"setSourceRoute": -1,
"setExtendedTimeout": -1,
"sendUnicast": -1,
"sendMulticast": -1,
"sendBroadcast": -1,
# Prioritize watchdog commands
"nop": 999,
"readCounters": 999,
"readAndClearCounters": 999,
"getValue": 999,
}.get(name, 0)

async def command(self, name, *args, **kwargs) -> Any:
"""Serialize command and send it."""
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)
data = self._ezsp_frame(name, *args, **kwargs)
cmd_id, _, rx_schema = self.COMMANDS[name]
future = asyncio.get_running_loop().create_future()
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
self._seq = (self._seq + 1) % 256

async with asyncio_timeout(EZSP_CMD_TIMEOUT):
delayed = False
send_time = None

if self._send_semaphore.locked():
delayed = True
send_time = time.monotonic()

LOGGER.debug(
"Send semaphore is locked, delaying before sending %s(%r, %r)",
name,
args,
kwargs,
)

async with self._send_semaphore(priority=self._get_command_priority(name)):
if delayed:
LOGGER.debug(
"Sending command %s: %s %s after %0.2fs delay",
name,
args,
kwargs,
time.monotonic() - send_time,
)
else:
LOGGER.debug("Sending command %s: %s %s", name, args, kwargs)

data = self._ezsp_frame(name, *args, **kwargs)
cmd_id, _, rx_schema = self.COMMANDS[name]

future = asyncio.get_running_loop().create_future()
self._awaiting[self._seq] = (cmd_id, rx_schema, future)
self._seq = (self._seq + 1) % 256

await self._gw.send_data(data)
return await future

async with asyncio_timeout(EZSP_CMD_TIMEOUT):
return await future

async def update_policies(self, policy_config: dict) -> None:
"""Set up the policies for what the NCP should do."""
Expand Down
15 changes: 0 additions & 15 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,6 @@


class Gateway(asyncio.Protocol):
FLAG = b"\x7E" # Marks end of frame
ESCAPE = b"\x7D"
XON = b"\x11" # Resume transmission
XOFF = b"\x13" # Stop transmission
SUBSTITUTE = b"\x18"
CANCEL = b"\x1A" # Terminates a frame in progress
STUFF = 0x20
RANDOMIZE_START = 0x42
RANDOMIZE_SEQ = 0xB8

RESERVED = FLAG + ESCAPE + XON + XOFF + SUBSTITUTE + CANCEL

class Terminator:
pass

def __init__(self, application, connected_future=None, connection_done_future=None):
self._application = application

Expand Down
Loading
Loading