Skip to content

Commit

Permalink
pw_rpc: Restore RpcIds for testing; move packet encoding to packets.py
Browse files Browse the repository at this point in the history
- Restore the RpcIds class for backwards compatibility and use it in
  tests.
- Consolidate packet encoding functions in packets.py.

Change-Id: I94fdd967951b61abfff315438ad637cc368c5a54
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/230471
Lint: Lint 🤖 <[email protected]>
Reviewed-by: Alexei Frolov <[email protected]>
Commit-Queue: Auto-Submit <[email protected]>
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Presubmit-Verified: CQ Bot Account <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Aug 19, 2024
1 parent 269b600 commit 829519b
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 123 deletions.
86 changes: 22 additions & 64 deletions pw_log_rpc/py/rpc_log_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,17 @@
"""RPC log stream handler tests."""

from dataclasses import dataclass
import logging
from typing import Any, Callable
from unittest import TestCase, main, mock

from google.protobuf import message
from pw_log.log_decoder import Log, LogStreamDecoder
from pw_log.proto import log_pb2
from pw_log_rpc.rpc_log_stream import LogStreamHandler
from pw_rpc import callback_client, client
from pw_rpc.descriptors import fake_pending_rpc, PendingRpc
from pw_rpc.internal import packet_pb2
from pw_rpc.descriptors import RpcIds
from pw_rpc import packets
from pw_status import Status

_LOG = logging.getLogger(__name__)


def _encode_server_stream_packet(
rpc: PendingRpc, payload: message.Message
) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_STREAM,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
payload=payload.SerializeToString(),
).SerializeToString()


def _encode_cancel(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_ERROR,
status=Status.CANCELLED.value,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


def _encode_error(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_ERROR,
status=Status.UNKNOWN.value,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


def _encode_completed(rpc: PendingRpc, status: Status) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
status=status.value,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


class _CallableWithCounter:
"""Wraps a function and counts how many time it was called."""
Expand Down Expand Up @@ -122,13 +72,13 @@ def decoded_log_handler(log: Log) -> None:
self.client.channel(self._channel_id).rpcs, log_decoder
)

def _get_rpc_ids(self) -> PendingRpc:
def _get_rpc_ids(self) -> RpcIds:
service = next(iter(self.client.services))
method = next(iter(service.methods))

# To handle unrequested log streams, packets' call Ids are set to
# kOpenCallId.
return fake_pending_rpc(
return RpcIds(
self._channel_id, service.id, method.id, client.OPEN_CALL_ID
)

Expand All @@ -140,7 +90,7 @@ def test_listen_to_logs_subsequent_calls(self):

self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=0,
Expand All @@ -162,7 +112,7 @@ def test_listen_to_logs_subsequent_calls(self):
# A subsequent RPC packet should be handled successfully.
self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=2,
Expand Down Expand Up @@ -195,7 +145,7 @@ def test_log_stream_cancelled(self):
# Send logs prior to cancellation.
self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=0,
Expand All @@ -209,7 +159,11 @@ def test_log_stream_cancelled(self):
Status.OK,
)
self.assertIs(
self.client.process_packet(_encode_cancel(self._get_rpc_ids())),
self.client.process_packet(
packets.encode_server_error(
self._get_rpc_ids(), Status.CANCELLED
)
),
Status.OK,
)
self.log_stream_handler.handle_log_stream_error.assert_called_once_with(
Expand Down Expand Up @@ -239,7 +193,7 @@ def test_log_stream_error_stream_restarted(self):
# Send logs prior to cancellation.
self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=0,
Expand All @@ -253,7 +207,9 @@ def test_log_stream_error_stream_restarted(self):
Status.OK,
)
self.assertIs(
self.client.process_packet(_encode_error(self._get_rpc_ids())),
self.client.process_packet(
packets.encode_server_error(self._get_rpc_ids(), Status.UNKNOWN)
),
Status.OK,
)

Expand Down Expand Up @@ -283,7 +239,7 @@ def test_log_stream_completed_ok_stream_restarted(self):
# Send logs prior to cancellation.
self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=0,
Expand All @@ -298,7 +254,7 @@ def test_log_stream_completed_ok_stream_restarted(self):
)
self.assertIs(
self.client.process_packet(
_encode_completed(self._get_rpc_ids(), Status.OK)
packets.encode_response(self._get_rpc_ids(), status=Status.OK)
),
Status.OK,
)
Expand Down Expand Up @@ -327,7 +283,7 @@ def test_log_stream_completed_with_error_stream_restarted(self):
# Send logs prior to cancellation.
self.assertIs(
self.client.process_packet(
_encode_server_stream_packet(
packets.encode_server_stream(
self._get_rpc_ids(),
log_pb2.LogEntries(
first_entry_sequence_id=0,
Expand All @@ -342,7 +298,9 @@ def test_log_stream_completed_with_error_stream_restarted(self):
)
self.assertIs(
self.client.process_packet(
_encode_completed(self._get_rpc_ids(), Status.UNKNOWN)
packets.encode_response(
self._get_rpc_ids(), status=Status.UNKNOWN
)
),
Status.OK,
)
Expand Down
47 changes: 29 additions & 18 deletions pw_rpc/py/pw_rpc/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,26 +458,50 @@ def get_method(service_accessor: ServiceAccessor, name: str):
return service[method_name]


@dataclass(frozen=True)
class RpcIds:
"""Integer IDs that uniquely identify a remote procedure call."""

channel_id: int
service_id: int
method_id: int
call_id: int


@dataclass(frozen=True)
class PendingRpc:
"""Uniquely identifies an RPC call."""
"""Tracks an active RPC call."""

channel: Channel
service: Service
method: Method
call_id: int

@property
def channel_id(self) -> int:
return self.channel.id

@property
def service_id(self) -> int:
return self.service.id

@property
def method_id(self) -> int:
return self.method.id

def __eq__(self, other: Any) -> bool:
if isinstance(other, PendingRpc):
return self._ids() == other._ids()
return self.ids() == other.ids()

return NotImplemented

def __hash__(self) -> int:
return hash(self._ids())
return hash(self.ids())

def _ids(self) -> tuple[int, int, int, int]:
return self.channel.id, self.service.id, self.method.id, self.call_id
def ids(self) -> RpcIds:
return RpcIds(
self.channel.id, self.service.id, self.method.id, self.call_id
)

def __str__(self) -> str:
return (
Expand All @@ -491,16 +515,3 @@ def matches_channel_service_method(self, other: PendingRpc) -> bool:
and self.service.id == other.service.id
and self.method.id == other.method.id
)


def fake_pending_rpc(
channel_id: int, service_id: int, method_id: int, call_id: int
) -> PendingRpc:
"""Creates a fake PendingRpc for testing: ONLY the *_id properties work!"""
service = Service(None, service_id, None) # type: ignore[arg-type]
return PendingRpc(
Channel(channel_id, lambda _: None),
service,
Method(None, service, method_id, False, False, None, None), # type: ignore[arg-type] # pylint: disable=line-too-long
call_id,
)
76 changes: 54 additions & 22 deletions pw_rpc/py/pw_rpc/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from google.protobuf import message
from pw_status import Status

from pw_rpc.descriptors import PendingRpc
from pw_rpc.descriptors import RpcIds, PendingRpc
from pw_rpc.internal import packet_pb2


Expand All @@ -32,36 +32,45 @@ def decode_payload(packet, payload_type):
return payload


def encode_request(rpc: PendingRpc, request: message.Message | None) -> bytes:
def encode_request(
rpc: PendingRpc | RpcIds, request: message.Message | None
) -> bytes:
payload = request.SerializeToString() if request is not None else bytes()

return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.REQUEST,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=payload,
).SerializeToString()


def encode_response(rpc: PendingRpc, response: message.Message) -> bytes:
def encode_response(
rpc: PendingRpc | RpcIds,
response: message.Message | None = None,
status: Status = Status.OK,
) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=response.SerializeToString(),
payload=b'' if response is None else response.SerializeToString(),
status=status.value,
).SerializeToString()


def encode_client_stream(rpc: PendingRpc, request: message.Message) -> bytes:
def encode_client_stream(
rpc: PendingRpc | RpcIds, request: message.Message
) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_STREAM,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=request.SerializeToString(),
).SerializeToString()
Expand All @@ -78,23 +87,46 @@ def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes:
).SerializeToString()


def encode_cancel(rpc: PendingRpc) -> bytes:
def encode_cancel(rpc: PendingRpc | RpcIds) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_ERROR,
status=Status.CANCELLED.value,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
).SerializeToString()


def encode_client_stream_end(rpc: PendingRpc) -> bytes:
def encode_client_stream_end(rpc: PendingRpc | RpcIds) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
).SerializeToString()


def encode_server_stream(rpc: RpcIds, payload: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_STREAM,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=payload.SerializeToString(),
).SerializeToString()


def encode_server_error(rpc: RpcIds, status: Status) -> bytes:
assert not status.ok()
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_ERROR,
status=status.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
).SerializeToString()

Expand Down
Loading

0 comments on commit 829519b

Please sign in to comment.