-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward arbitrary kwargs to remote blocks #467
base: main
Are you sure you want to change the base?
Conversation
note 2 self: old client runs backward with inputs that do not require_grad, we must support that! |
note 2self: on wake up, do
|
@@ -141,7 +145,7 @@ async def sequential_backward( | |||
try: | |||
if attempt_no >= 1: | |||
_, backup_inputs, backup_sequences = await sequential_forward( | |||
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end | |||
sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subjective matter: sequence_manager is the first parameter to most internal functions; can rollback if the reviewer disagrees.
value = value[:, offset : offset + max_chunk_length] | ||
kwargs_chunk[key] = value | ||
return kwargs_chunk | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this is a potential problem; not all tensors where shape[-2] == seq_len can be time-sliced.
Counter-example: a LoRA adapter might accidentally have it's rank equal to sequence length
@@ -227,15 +222,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): | |||
""" | |||
|
|||
@staticmethod | |||
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): | |||
def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor): | |||
# TODO add kwargs here; figure out a way to split kwargs across servers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
problem: how do we split args/kwargs into sub-batches?
# Conflicts: # src/petals/__init__.py # src/petals/client/inference_session.py
@justheuristic solemnly swears to
|
NB: this pull request makes several drastic changes to the backend, block_functions and pools. It might be better if I walk you through before the review. On a related note, if it interferes with long-term plans for the codebase, please raise a concern - i'm happy to rollback any detrimetnal changes.
Why this exists:
and expect that the outputs are the same
output_with_lora = internal_model_interface.forward(inputs, **lora_adapters)
output = internal_model_interface.forward(inputs, layer_past=make_method_dependent_tensors())
output_with_lora = internal_model_interface.forward(inputs, **ia3_state_dict)
What does this PR contain
New functionality
Internal codebase changes:
RemoteSequenceManager.get_request_metadata now always accepts (server_id, protocol, block_uids, args, kwargs) in that order
client-side code: packing args/kwargs and forming metadata was moved from sequential_autograd to remote_forward_backward
Task size is now specified explicitly in block_functions
Task and PrioritizedTaskPool support kwargs
, and therefore, this pull request does not make server-side batching any more complicated than it already is
Notable missing functionality
(implementation issue) _RemoteSequentialAutogradFunction can't split sub-batches with kwargs
(implementation issue) InferenceSession only accepts kwargs during it's creation
Tests & sanity checks
Sanity checks:
CI tests