Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
xtinkt committed Jul 5, 2024
1 parent 269028d commit 9aecb3f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
23 changes: 17 additions & 6 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
break # this message means "done sending"

def step(
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
step_id: str, start_from_position: int
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand Down Expand Up @@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession":
return self

def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:

if start_from_position is not None:
Expand Down Expand Up @@ -317,8 +325,11 @@ def step(

server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
step_id=step_id, start_from_position=start_from_position
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)

server_idx += 1
Expand Down
4 changes: 3 additions & 1 deletion src/petals/server/block_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ async def iterate_rpc_inference(
async for request, step_metadata in input_iterator:
if "start_from_position" in step_metadata:
start_from_position = step_metadata["start_from_position"]
assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}"
assert (
prefix_length >= start_from_position,
), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
prefix_length = start_from_position

flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
Expand Down

0 comments on commit 9aecb3f

Please sign in to comment.