Skip to content

Commit

Permalink
pw_rpc: Merge PendingRpc and RpcIds
Browse files Browse the repository at this point in the history
- Define __hash__ and __eq__ for PendingRpc so it can serve the role of
  both classes.
- Restore the __str__ implementation for PendingRpc so logs include the
  service/method name instead of just IDs.
- Instead of having packet encoding functions support plain tuples as
  they originally did, add a helper for creating a test-only PendingRpc.

Change-Id: I5f706c43bc58d2b1fefdd962cbb9619b82f76794
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/228952
Reviewed-by: Alexei Frolov <[email protected]>
Lint: Lint 🤖 <[email protected]>
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Commit-Queue: Auto-Submit <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Aug 15, 2024
1 parent 19cb9e4 commit 35c221f
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 97 deletions.
42 changes: 22 additions & 20 deletions pw_log_rpc/py/rpc_log_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,55 +23,56 @@
from pw_log.log_decoder import Log, LogStreamDecoder
from pw_log.proto import log_pb2
from pw_log_rpc.rpc_log_stream import LogStreamHandler
from pw_rpc import callback_client, client, packets
from pw_rpc import callback_client, client
from pw_rpc.descriptors import fake_pending_rpc, PendingRpc
from pw_rpc.internal import packet_pb2
from pw_status import Status

_LOG = logging.getLogger(__name__)


def _encode_server_stream_packet(
rpc: packets.RpcIds, payload: message.Message
rpc: PendingRpc, payload: message.Message
) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_STREAM,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
payload=payload.SerializeToString(),
).SerializeToString()


def _encode_cancel(rpc: packets.RpcIds) -> bytes:
def _encode_cancel(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_ERROR,
status=Status.CANCELLED.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


def _encode_error(rpc: packets.RpcIds) -> bytes:
def _encode_error(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.SERVER_ERROR,
status=Status.UNKNOWN.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


def _encode_completed(rpc: packets.RpcIds, status: Status) -> bytes:
def _encode_completed(rpc: PendingRpc, status: Status) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
status=status.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()

Expand Down Expand Up @@ -121,14 +122,15 @@ def decoded_log_handler(log: Log) -> None:
self.client.channel(self._channel_id).rpcs, log_decoder
)

def _get_rpc_ids(self) -> packets.RpcIds:
def _get_rpc_ids(self) -> PendingRpc:
service = next(iter(self.client.services))
method = next(iter(service.methods))

# To handle unrequested log streams, packets' call Ids are set to
# kOpenCallId.
call_id = client.OPEN_CALL_ID
return packets.RpcIds(self._channel_id, service.id, method.id, call_id)
return fake_pending_rpc(
self._channel_id, service.id, method.id, client.OPEN_CALL_ID
)

def test_listen_to_logs_subsequent_calls(self):
"""Test a stream of RPC Logs."""
Expand Down
28 changes: 1 addition & 27 deletions pw_rpc/py/pw_rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pw_status import Status

from pw_rpc import descriptors, packets
from pw_rpc.descriptors import Channel, Service, Method
from pw_rpc.descriptors import Channel, Service, Method, PendingRpc
from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket

_LOG = logging.getLogger(__package__)
Expand All @@ -47,32 +47,6 @@ class Error(Exception):
"""Error from incorrectly using the RPC client classes."""


class PendingRpc(packets.RpcIds):
"""Uniquely identifies an RPC call.
Attributes:
channel: Channel
service: Service
method: Method
channel_id: int
service_id: int
method_id: int
call_id: int
"""

def __init__(
self,
channel: Channel,
service: Service,
method: Method,
call_id: int,
) -> None:
super().__init__(channel.id, service.id, method.id, call_id)
self.channel = channel
self.service = service
self.method = method


class _PendingRpcMetadata:
def __init__(self, context: object):
self.context = context
Expand Down
41 changes: 41 additions & 0 deletions pw_rpc/py/pw_rpc/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,44 @@ def get_method(service_accessor: ServiceAccessor, name: str):
service = service.methods

return service[method_name]


@dataclass(frozen=True)
class PendingRpc:
"""Uniquely identifies an RPC call."""

channel: Channel
service: Service
method: Method
call_id: int

def __eq__(self, other: Any) -> bool:
if isinstance(other, PendingRpc):
return self._ids() == other._ids()

return NotImplemented

def __hash__(self) -> int:
return hash(self._ids())

def _ids(self) -> tuple[int, int, int, int]:
return self.channel.id, self.service.id, self.method.id, self.call_id

def __str__(self) -> str:
return (
f'PendingRpc(channel={self.channel.id}, method={self.method}, '
f'call_id={self.call_id})'
)


def fake_pending_rpc(
channel_id: int, service_id: int, method_id: int, call_id: int
) -> PendingRpc:
"""Creates a fake PendingRpc for testing: ONLY the *_id properties work!"""
service = Service(None, service_id, None) # type: ignore[arg-type]
return PendingRpc(
Channel(channel_id, lambda _: None),
service,
Method(None, service, method_id, False, False, None, None), # type: ignore[arg-type] # pylint: disable=line-too-long
call_id,
)
53 changes: 21 additions & 32 deletions pw_rpc/py/pw_rpc/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# the License.
"""Functions for working with pw_rpc packets."""

import dataclasses

from google.protobuf import message
from pw_status import Status

from pw_rpc.descriptors import PendingRpc
from pw_rpc.internal import packet_pb2


Expand All @@ -33,46 +32,36 @@ def decode_payload(packet, payload_type):
return payload


@dataclasses.dataclass(eq=True, frozen=True)
class RpcIds:
"""Integer IDs that uniquely identify a remote procedure call."""

channel_id: int
service_id: int
method_id: int
call_id: int


def encode_request(rpc: RpcIds, request: message.Message | None) -> bytes:
def encode_request(rpc: PendingRpc, request: message.Message | None) -> bytes:
payload = request.SerializeToString() if request is not None else bytes()

return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.REQUEST,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
payload=payload,
).SerializeToString()


def encode_response(rpc: RpcIds, response: message.Message) -> bytes:
def encode_response(rpc: PendingRpc, response: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
payload=response.SerializeToString(),
).SerializeToString()


def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes:
def encode_client_stream(rpc: PendingRpc, request: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_STREAM,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
payload=request.SerializeToString(),
).SerializeToString()
Expand All @@ -89,23 +78,23 @@ def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes:
).SerializeToString()


def encode_cancel(rpc: RpcIds) -> bytes:
def encode_cancel(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_ERROR,
status=Status.CANCELLED.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()


def encode_client_stream_end(rpc: RpcIds) -> bytes:
def encode_client_stream_end(rpc: PendingRpc) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
channel_id=rpc.channel.id,
service_id=rpc.service.id,
method_id=rpc.method.id,
call_id=rpc.call_id,
).SerializeToString()

Expand Down
13 changes: 6 additions & 7 deletions pw_rpc/py/tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from pw_rpc import callback_client, client, packets
import pw_rpc.ids
from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket

RpcIds = packets.RpcIds
from pw_rpc.descriptors import fake_pending_rpc

TEST_PROTO_1 = """\
syntax = "proto3";
Expand Down Expand Up @@ -290,7 +289,7 @@ def test_process_packet_unrecognized_channel(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
RpcIds(
fake_pending_rpc(
SOME_CHANNEL_ID,
SOME_SERVICE_ID,
SOME_METHOD_ID,
Expand All @@ -306,7 +305,7 @@ def test_process_packet_unrecognized_service(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
RpcIds(
fake_pending_rpc(
CLIENT_FIRST_CHANNEL_ID,
SOME_SERVICE_ID,
SOME_METHOD_ID,
Expand Down Expand Up @@ -336,7 +335,7 @@ def test_process_packet_unrecognized_method(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
RpcIds(
fake_pending_rpc(
CLIENT_FIRST_CHANNEL_ID,
service.id,
SOME_METHOD_ID,
Expand Down Expand Up @@ -367,7 +366,7 @@ def test_process_packet_non_pending_method(self) -> None:
self.assertIs(
self._client.process_packet(
packets.encode_response(
RpcIds(
fake_pending_rpc(
CLIENT_FIRST_CHANNEL_ID,
service.id,
method.id,
Expand Down Expand Up @@ -417,7 +416,7 @@ def response_callback(
self.assertIs(
self._client.process_packet(
packets.encode_response(
RpcIds(
fake_pending_rpc(
CLIENT_FIRST_CHANNEL_ID,
method.service.id,
method.id,
Expand Down
24 changes: 13 additions & 11 deletions pw_rpc/py/tests/packets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,24 @@

from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
from pw_rpc import packets
from pw_rpc.descriptors import fake_pending_rpc

_TEST_IDS = fake_pending_rpc(1, 2, 3, 4)

_TEST_IDS = packets.RpcIds(1, 2, 3, 4)
_TEST_STATUS = 321
_TEST_REQUEST = RpcPacket(
type=PacketType.REQUEST,
channel_id=_TEST_IDS.channel_id,
service_id=_TEST_IDS.service_id,
method_id=_TEST_IDS.method_id,
channel_id=_TEST_IDS.channel.id,
service_id=_TEST_IDS.service.id,
method_id=_TEST_IDS.method.id,
call_id=_TEST_IDS.call_id,
payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
)
_TEST_RESPONSE = RpcPacket(
type=PacketType.RESPONSE,
channel_id=_TEST_IDS.channel_id,
service_id=_TEST_IDS.service_id,
method_id=_TEST_IDS.method_id,
channel_id=_TEST_IDS.channel.id,
service_id=_TEST_IDS.service.id,
method_id=_TEST_IDS.method.id,
call_id=_TEST_IDS.call_id,
payload=RpcPacket(status=_TEST_STATUS).SerializeToString(),
)
Expand All @@ -61,7 +63,7 @@ def test_encode_response(self):
self.assertEqual(_TEST_RESPONSE, packet)

def test_encode_cancel(self):
data = packets.encode_cancel(packets.RpcIds(9, 8, 7, 6))
data = packets.encode_cancel(fake_pending_rpc(9, 8, 7, 6))

packet = RpcPacket()
packet.ParseFromString(data)
Expand All @@ -88,9 +90,9 @@ def test_encode_client_error(self):
packet,
RpcPacket(
type=PacketType.CLIENT_ERROR,
channel_id=_TEST_IDS.channel_id,
service_id=_TEST_IDS.service_id,
method_id=_TEST_IDS.method_id,
channel_id=_TEST_IDS.channel.id,
service_id=_TEST_IDS.service.id,
method_id=_TEST_IDS.method.id,
call_id=_TEST_IDS.call_id,
status=Status.NOT_FOUND.value,
),
Expand Down

0 comments on commit 35c221f

Please sign in to comment.