Skip to content

Commit

Permalink
Merge pull request #148 from stealthrocket/no-double-pickle
Browse files Browse the repository at this point in the history
Don't pickle coroutine state twice
  • Loading branch information
chriso authored Apr 9, 2024
2 parents c9cdf68 + e9abbe4 commit 54c509e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
13 changes: 4 additions & 9 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ def __init__(self, req: function_pb.RunRequest):
else:
self._input = _pb_any_unpack(req.input)
else:
state_bytes = req.poll_result.coroutine_state
if len(state_bytes) > 0:
self._coroutine_state = pickle.loads(state_bytes)
else:
self._coroutine_state = None
self._coroutine_state = req.poll_result.coroutine_state
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -143,7 +139,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
def from_poll_results(
cls,
function: str,
coroutine_state: Any,
coroutine_state: Optional[bytes],
call_results: List[CallResult],
error: Optional[Error] = None,
):
Expand Down Expand Up @@ -220,7 +216,7 @@ def exit(
@classmethod
def poll(
cls,
state: Any,
coroutine_state: Optional[bytes] = None,
calls: Optional[List[Call]] = None,
min_results: int = 1,
max_results: int = 10,
Expand All @@ -229,14 +225,13 @@ def poll(
"""Suspend the function with a set of Calls, instructing the
orchestrator to resume the function with the provided state when
call results are ready."""
state_bytes = pickle.dumps(state)
max_wait = (
duration_pb2.Duration(seconds=max_wait_seconds)
if max_wait_seconds is not None
else None
)
poll = poll_pb.Poll(
coroutine_state=state_bytes,
coroutine_state=coroutine_state,
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def _run(self, input: Input) -> Output:
len(serialized_state),
)
return Output.poll(
state=serialized_state,
coroutine_state=serialized_state,
calls=pending_calls,
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
Expand Down
18 changes: 11 additions & 7 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import os
import pickle
import struct
import unittest
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -282,7 +283,7 @@ def test_error_on_access_input_in_second_call(self):
@self.dispatch.primitive_function
def my_function(input: Input) -> Output:
if input.is_first_call:
return Output.poll(state=42)
return Output.poll(coroutine_state=b"42")
try:
print(input.input)
except ValueError:
Expand All @@ -294,7 +295,7 @@ def my_function(input: Input) -> Output:
return Output.value("not reached")

resp = self.execute(my_function, input="cool stuff")
self.assertEqual(42, pickle.loads(resp.poll.coroutine_state))
self.assertEqual(b"42", resp.poll.coroutine_state)

resp = self.execute(my_function, state=resp.poll.coroutine_state)
self.assertEqual("ValueError", resp.exit.result.error.type)
Expand Down Expand Up @@ -337,11 +338,12 @@ def coroutine3(input: Input) -> Output:
if input.is_first_call:
counter = input.input
else:
counter = input.coroutine_state
(counter,) = struct.unpack("@i", input.coroutine_state)
counter -= 1
if counter <= 0:
return Output.value("done")
return Output.poll(state=counter)
coroutine_state = struct.pack("@i", counter)
return Output.poll(coroutine_state=coroutine_state)

# first call
resp = self.execute(coroutine3, input=4)
Expand Down Expand Up @@ -375,9 +377,10 @@ def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.poll(
state=text, calls=[coro_compute_len._build_primitive_call(text)]
coroutine_state=text.encode(),
calls=[coro_compute_len._build_primitive_call(text)],
)
text = input.coroutine_state
text = input.coroutine_state.decode()
length = input.call_results[0].output
return Output.value(f"length={length} text='{text}'")

Expand Down Expand Up @@ -415,7 +418,8 @@ def coroutine_main(input: Input) -> Output:
if input.is_first_call:
text: str = input.input
return Output.poll(
state=text, calls=[coro_compute_len._build_primitive_call(text)]
coroutine_state=text.encode(),
calls=[coro_compute_len._build_primitive_call(text)],
)
error = input.call_results[0].error
if error is not None:
Expand Down

0 comments on commit 54c509e

Please sign in to comment.