Skip to content

Commit

Permalink
pw_rpc: Avoid recompiling protos for every test
Browse files Browse the repository at this point in the history
Tests using python_protos.Library.from_strings() write a .proto file to
a temp directory, invoke protoc, and import the generated Python for
each test. Since the protos do not change, only do this once.

Bug: b/360184800
Change-Id: I9b31e03f21bc47e579c745d1be57923733ad1ed8
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/230135
Lint: Lint 🤖 <[email protected]>
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Reviewed-by: Armando Montanez <[email protected]>
Commit-Queue: Auto-Submit <[email protected]>
Presubmit-Verified: CQ Bot Account <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Aug 22, 2024
1 parent a75b716 commit 9175df7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
14 changes: 7 additions & 7 deletions pw_rpc/py/tests/callback_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
}
"""

PROTOS = python_protos.Library.from_strings(TEST_PROTO_1)
CLIENT_CHANNEL_ID: int = 489


Expand All @@ -63,13 +64,12 @@ class _CallbackClientImplTestBase(unittest.TestCase):
"""Supports writing tests that require responses from an RPC server."""

def setUp(self) -> None:
self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
self._request = self._protos.packages.pw.test1.SomeMessage
self._request = PROTOS.packages.pw.test1.SomeMessage

self._client = client.Client.from_modules(
callback_client.Impl(),
[client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
self._protos.modules(),
PROTOS.modules(),
)
self._service = self._client.channel(
CLIENT_CHANNEL_ID
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_default_timeouts_set_for_all_rpcs(self) -> None:
rpc_client = client.Client.from_modules(
callback_client.Impl(99, 100),
[client.Channel(CLIENT_CHANNEL_ID, lambda *a, **b: None)],
self._protos.modules(),
PROTOS.modules(),
)
rpcs = rpc_client.channel(CLIENT_CHANNEL_ID).rpcs

Expand Down Expand Up @@ -540,7 +540,7 @@ def test_nonblocking_exception_in_callback(self) -> None:
self.assertEqual(context.exception.__cause__, exception)

def test_unary_response(self) -> None:
proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123)
self.assertEqual(
repr(callback_client.UnaryResponse(Status.ABORTED, proto)),
'(Status.ABORTED, pw.test1.SomeMessage(magic_number=123))',
Expand All @@ -556,7 +556,7 @@ def test_on_call_hook(self) -> None:
self._client = client.Client.from_modules(
callback_client.Impl(on_call_hook=hook_function),
[client.Channel(CLIENT_CHANNEL_ID, self._handle_packet)],
self._protos.modules(),
PROTOS.modules(),
)

self._service = self._client.channel(
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def test_max_responses(self) -> None:
self.assertEqual(result.responses, list(call.responses))

def test_stream_response(self) -> None:
proto = self._protos.packages.pw.test1.SomeMessage(magic_number=123)
proto = PROTOS.packages.pw.test1.SomeMessage(magic_number=123)
self.assertEqual(
repr(callback_client.StreamResponse(Status.ABORTED, [proto] * 2)),
'(Status.ABORTED, [pw.test1.SomeMessage(magic_number=123), '
Expand Down
19 changes: 8 additions & 11 deletions pw_rpc/py/tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
}
"""

PROTOS = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])

SOME_CHANNEL_ID: int = 237
SOME_SERVICE_ID: int = 193
SOME_METHOD_ID: int = 769
Expand All @@ -83,10 +85,6 @@
CLIENT_SECOND_CHANNEL_ID: int = 474


def create_protos() -> Any:
return python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])


def create_client(
proto_modules: Any,
first_channel_output_fn: Callable[[bytes], Any] = lambda _: None,
Expand All @@ -105,7 +103,7 @@ class ChannelClientTest(unittest.TestCase):
"""Tests the ChannelClient."""

def setUp(self) -> None:
client_instance = create_client(create_protos().modules())
client_instance = create_client(PROTOS.modules())
self._channel_client: client.ChannelClient = client_instance.channel(
CLIENT_FIRST_CHANNEL_ID
)
Expand Down Expand Up @@ -203,8 +201,7 @@ class ClientTest(unittest.TestCase):

def setUp(self) -> None:
self._last_packet_sent_bytes: bytes | None = None
self._protos = create_protos()
self._client = create_client(self._protos.modules(), self._save_packet)
self._client = create_client(PROTOS.modules(), self._save_packet)

def _save_packet(self, packet) -> None:
self._last_packet_sent_bytes = packet
Expand Down Expand Up @@ -295,7 +292,7 @@ def test_process_packet_unrecognized_channel(self) -> None:
SOME_METHOD_ID,
SOME_CALL_ID,
),
self._protos.packages.pw.test2.Request(),
PROTOS.packages.pw.test2.Request(),
)
),
Status.NOT_FOUND,
Expand All @@ -311,7 +308,7 @@ def test_process_packet_unrecognized_service(self) -> None:
SOME_METHOD_ID,
SOME_CALL_ID,
),
self._protos.packages.pw.test2.Request(),
PROTOS.packages.pw.test2.Request(),
)
),
Status.OK,
Expand Down Expand Up @@ -341,7 +338,7 @@ def test_process_packet_unrecognized_method(self) -> None:
SOME_METHOD_ID,
SOME_CALL_ID,
),
self._protos.packages.pw.test2.Request(),
PROTOS.packages.pw.test2.Request(),
)
),
Status.OK,
Expand Down Expand Up @@ -372,7 +369,7 @@ def test_process_packet_non_pending_method(self) -> None:
method.id,
SOME_CALL_ID,
),
self._protos.packages.pw.test2.Request(),
PROTOS.packages.pw.test2.Request(),
)
),
Status.OK,
Expand Down

0 comments on commit 9175df7

Please sign in to comment.