From 35c221facea3abda2ad8375672e9da541eec417f Mon Sep 17 00:00:00 2001 From: Wyatt Hepler Date: Thu, 15 Aug 2024 21:22:50 +0000 Subject: [PATCH] pw_rpc: Merge PendingRpc and RpcIds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Define __hash__ and __eq__ for PendingRpc so it can serve the role of both classes. - Restore the __str__ implementation for PendingRpc so logs include the service/method name instead of just IDs. - Instead of having packet encoding functions support plain tuples as they originally did, add a helper for creating a test-only PendingRpc. Change-Id: I5f706c43bc58d2b1fefdd962cbb9619b82f76794 Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/228952 Reviewed-by: Alexei Frolov Lint: Lint 🤖 Pigweed-Auto-Submit: Wyatt Hepler Commit-Queue: Auto-Submit --- pw_log_rpc/py/rpc_log_stream_test.py | 42 +++++++++++----------- pw_rpc/py/pw_rpc/client.py | 28 +-------------- pw_rpc/py/pw_rpc/descriptors.py | 41 +++++++++++++++++++++ pw_rpc/py/pw_rpc/packets.py | 53 +++++++++++----------------- pw_rpc/py/tests/client_test.py | 13 ++++--- pw_rpc/py/tests/packets_test.py | 24 +++++++------ 6 files changed, 104 insertions(+), 97 deletions(-) diff --git a/pw_log_rpc/py/rpc_log_stream_test.py b/pw_log_rpc/py/rpc_log_stream_test.py index d21919c8bc..05b58b8535 100644 --- a/pw_log_rpc/py/rpc_log_stream_test.py +++ b/pw_log_rpc/py/rpc_log_stream_test.py @@ -23,7 +23,8 @@ from pw_log.log_decoder import Log, LogStreamDecoder from pw_log.proto import log_pb2 from pw_log_rpc.rpc_log_stream import LogStreamHandler -from pw_rpc import callback_client, client, packets +from pw_rpc import callback_client, client +from pw_rpc.descriptors import fake_pending_rpc, PendingRpc from pw_rpc.internal import packet_pb2 from pw_status import Status @@ -31,47 +32,47 @@ def _encode_server_stream_packet( - rpc: packets.RpcIds, payload: message.Message + rpc: PendingRpc, payload: message.Message ) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_STREAM, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, payload=payload.SerializeToString(), ).SerializeToString() -def _encode_cancel(rpc: packets.RpcIds) -> bytes: +def _encode_cancel(rpc: PendingRpc) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_ERROR, status=Status.CANCELLED.value, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, ).SerializeToString() -def _encode_error(rpc: packets.RpcIds) -> bytes: +def _encode_error(rpc: PendingRpc) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_ERROR, status=Status.UNKNOWN.value, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, ).SerializeToString() -def _encode_completed(rpc: packets.RpcIds, status: Status) -> bytes: +def _encode_completed(rpc: PendingRpc, status: Status) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, status=status.value, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, ).SerializeToString() @@ -121,14 +122,15 @@ def decoded_log_handler(log: Log) -> None: self.client.channel(self._channel_id).rpcs, log_decoder ) - def _get_rpc_ids(self) -> packets.RpcIds: + def _get_rpc_ids(self) -> PendingRpc: service = next(iter(self.client.services)) method = next(iter(service.methods)) # To handle unrequested log streams, packets' call Ids are set to # kOpenCallId. - call_id = client.OPEN_CALL_ID - return packets.RpcIds(self._channel_id, service.id, method.id, call_id) + return fake_pending_rpc( + self._channel_id, service.id, method.id, client.OPEN_CALL_ID + ) def test_listen_to_logs_subsequent_calls(self): """Test a stream of RPC Logs.""" diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py index 51cb93802a..93b73fbc42 100644 --- a/pw_rpc/py/pw_rpc/client.py +++ b/pw_rpc/py/pw_rpc/client.py @@ -30,7 +30,7 @@ from pw_status import Status from pw_rpc import descriptors, packets -from pw_rpc.descriptors import Channel, Service, Method +from pw_rpc.descriptors import Channel, Service, Method, PendingRpc from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket _LOG = logging.getLogger(__package__) @@ -47,32 +47,6 @@ class Error(Exception): """Error from incorrectly using the RPC client classes.""" -class PendingRpc(packets.RpcIds): - """Uniquely identifies an RPC call. - - Attributes: - channel: Channel - service: Service - method: Method - channel_id: int - service_id: int - method_id: int - call_id: int - """ - - def __init__( - self, - channel: Channel, - service: Service, - method: Method, - call_id: int, - ) -> None: - super().__init__(channel.id, service.id, method.id, call_id) - self.channel = channel - self.service = service - self.method = method - - class _PendingRpcMetadata: def __init__(self, context: object): self.context = context diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py index 72f007d57d..8030115dfe 100644 --- a/pw_rpc/py/pw_rpc/descriptors.py +++ b/pw_rpc/py/pw_rpc/descriptors.py @@ -456,3 +456,44 @@ def get_method(service_accessor: ServiceAccessor, name: str): service = service.methods return service[method_name] + + +@dataclass(frozen=True) +class PendingRpc: + """Uniquely identifies an RPC call.""" + + channel: Channel + service: Service + method: Method + call_id: int + + def __eq__(self, other: Any) -> bool: + if isinstance(other, PendingRpc): + return self._ids() == other._ids() + + return NotImplemented + + def __hash__(self) -> int: + return hash(self._ids()) + + def _ids(self) -> tuple[int, int, int, int]: + return self.channel.id, self.service.id, self.method.id, self.call_id + + def __str__(self) -> str: + return ( + f'PendingRpc(channel={self.channel.id}, method={self.method}, ' + f'call_id={self.call_id})' + ) + + +def fake_pending_rpc( + channel_id: int, service_id: int, method_id: int, call_id: int +) -> PendingRpc: + """Creates a fake PendingRpc for testing: ONLY the *_id properties work!""" + service = Service(None, service_id, None) # type: ignore[arg-type] + return PendingRpc( + Channel(channel_id, lambda _: None), + service, + Method(None, service, method_id, False, False, None, None), # type: ignore[arg-type] # pylint: disable=line-too-long + call_id, + ) diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py index 2feb3fd5c6..5aa8798f65 100644 --- a/pw_rpc/py/pw_rpc/packets.py +++ b/pw_rpc/py/pw_rpc/packets.py @@ -13,11 +13,10 @@ # the License. """Functions for working with pw_rpc packets.""" -import dataclasses - from google.protobuf import message from pw_status import Status +from pw_rpc.descriptors import PendingRpc from pw_rpc.internal import packet_pb2 @@ -33,46 +32,36 @@ def decode_payload(packet, payload_type): return payload -@dataclasses.dataclass(eq=True, frozen=True) -class RpcIds: - """Integer IDs that uniquely identify a remote procedure call.""" - - channel_id: int - service_id: int - method_id: int - call_id: int - - -def encode_request(rpc: RpcIds, request: message.Message | None) -> bytes: +def encode_request(rpc: PendingRpc, request: message.Message | None) -> bytes: payload = request.SerializeToString() if request is not None else bytes() return packet_pb2.RpcPacket( type=packet_pb2.PacketType.REQUEST, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, payload=payload, ).SerializeToString() -def encode_response(rpc: RpcIds, response: message.Message) -> bytes: +def encode_response(rpc: PendingRpc, response: message.Message) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, payload=response.SerializeToString(), ).SerializeToString() -def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes: +def encode_client_stream(rpc: PendingRpc, request: message.Message) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CLIENT_STREAM, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, payload=request.SerializeToString(), ).SerializeToString() @@ -89,23 +78,23 @@ def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: ).SerializeToString() -def encode_cancel(rpc: RpcIds) -> bytes: +def encode_cancel(rpc: PendingRpc) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CLIENT_ERROR, status=Status.CANCELLED.value, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, ).SerializeToString() -def encode_client_stream_end(rpc: RpcIds) -> bytes: +def encode_client_stream_end(rpc: PendingRpc) -> bytes: return packet_pb2.RpcPacket( type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, - channel_id=rpc.channel_id, - service_id=rpc.service_id, - method_id=rpc.method_id, + channel_id=rpc.channel.id, + service_id=rpc.service.id, + method_id=rpc.method.id, call_id=rpc.call_id, ).SerializeToString() diff --git a/pw_rpc/py/tests/client_test.py b/pw_rpc/py/tests/client_test.py index ced61a6126..b2b111348c 100755 --- a/pw_rpc/py/tests/client_test.py +++ b/pw_rpc/py/tests/client_test.py @@ -23,8 +23,7 @@ from pw_rpc import callback_client, client, packets import pw_rpc.ids from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket - -RpcIds = packets.RpcIds +from pw_rpc.descriptors import fake_pending_rpc TEST_PROTO_1 = """\ syntax = "proto3"; @@ -290,7 +289,7 @@ def test_process_packet_unrecognized_channel(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - RpcIds( + fake_pending_rpc( SOME_CHANNEL_ID, SOME_SERVICE_ID, SOME_METHOD_ID, @@ -306,7 +305,7 @@ def test_process_packet_unrecognized_service(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - RpcIds( + fake_pending_rpc( CLIENT_FIRST_CHANNEL_ID, SOME_SERVICE_ID, SOME_METHOD_ID, @@ -336,7 +335,7 @@ def test_process_packet_unrecognized_method(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - RpcIds( + fake_pending_rpc( CLIENT_FIRST_CHANNEL_ID, service.id, SOME_METHOD_ID, @@ -367,7 +366,7 @@ def test_process_packet_non_pending_method(self) -> None: self.assertIs( self._client.process_packet( packets.encode_response( - RpcIds( + fake_pending_rpc( CLIENT_FIRST_CHANNEL_ID, service.id, method.id, @@ -417,7 +416,7 @@ def response_callback( self.assertIs( self._client.process_packet( packets.encode_response( - RpcIds( + fake_pending_rpc( CLIENT_FIRST_CHANNEL_ID, method.service.id, method.id, diff --git a/pw_rpc/py/tests/packets_test.py b/pw_rpc/py/tests/packets_test.py index 3edded35ea..c7d53d6e5d 100755 --- a/pw_rpc/py/tests/packets_test.py +++ b/pw_rpc/py/tests/packets_test.py @@ -20,22 +20,24 @@ from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket from pw_rpc import packets +from pw_rpc.descriptors import fake_pending_rpc + +_TEST_IDS = fake_pending_rpc(1, 2, 3, 4) -_TEST_IDS = packets.RpcIds(1, 2, 3, 4) _TEST_STATUS = 321 _TEST_REQUEST = RpcPacket( type=PacketType.REQUEST, - channel_id=_TEST_IDS.channel_id, - service_id=_TEST_IDS.service_id, - method_id=_TEST_IDS.method_id, + channel_id=_TEST_IDS.channel.id, + service_id=_TEST_IDS.service.id, + method_id=_TEST_IDS.method.id, call_id=_TEST_IDS.call_id, payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), ) _TEST_RESPONSE = RpcPacket( type=PacketType.RESPONSE, - channel_id=_TEST_IDS.channel_id, - service_id=_TEST_IDS.service_id, - method_id=_TEST_IDS.method_id, + channel_id=_TEST_IDS.channel.id, + service_id=_TEST_IDS.service.id, + method_id=_TEST_IDS.method.id, call_id=_TEST_IDS.call_id, payload=RpcPacket(status=_TEST_STATUS).SerializeToString(), ) @@ -61,7 +63,7 @@ def test_encode_response(self): self.assertEqual(_TEST_RESPONSE, packet) def test_encode_cancel(self): - data = packets.encode_cancel(packets.RpcIds(9, 8, 7, 6)) + data = packets.encode_cancel(fake_pending_rpc(9, 8, 7, 6)) packet = RpcPacket() packet.ParseFromString(data) @@ -88,9 +90,9 @@ def test_encode_client_error(self): packet, RpcPacket( type=PacketType.CLIENT_ERROR, - channel_id=_TEST_IDS.channel_id, - service_id=_TEST_IDS.service_id, - method_id=_TEST_IDS.method_id, + channel_id=_TEST_IDS.channel.id, + service_id=_TEST_IDS.service.id, + method_id=_TEST_IDS.method.id, call_id=_TEST_IDS.call_id, status=Status.NOT_FOUND.value, ),