Skip to content

Commit

Permalink
create patch for allgather fn for deepspeed stage 3 (microsoft#17855)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Patch for All gather fn for Deepspeed Stage 3 changes


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
ajindal1 authored and kleiti committed Mar 22, 2024
1 parent c435429 commit 528d4ef
Showing 1 changed file with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import inspect
import warnings
from collections import OrderedDict
from datetime import timedelta
from types import CodeType, FunctionType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import onnx
import torch
Expand Down Expand Up @@ -56,6 +57,52 @@ def _setup_zero_stage3_ort_compatible_hooks(self):
return _setup_zero_stage3_ort_compatible_hooks


# Creating this dummy class because several functions would not be available during export step
class DummyWork(torch.distributed.distributed_c10d.Work):
def is_completed(self) -> bool:
return True

def is_success(self) -> bool:
return True

def exception(self) -> Any:
return None

def wait(self, timeout: timedelta = timedelta) -> bool:
return True

def source_rank(self) -> int:
return 0

def _source_rank(self) -> int:
return 0

def result(self) -> List[torch.Tensor]:
return []

def synchronize(self):
pass


def _get_ort_compatible_allgather_fn():
from deepspeed.utils import get_caller_func

original_allgather_fn = deepspeed.comm.allgather_fn
output_get_caller_func = get_caller_func()

# For Monkey patching the original function
# Original code https://github.com/microsoft/DeepSpeed/blob/604d701e35548e5407b017c088bdc3760832c9e0/deepspeed/comm/comm.py#L315
def _ort_compatible_allgather_fn_zero_stage3(
output_tensor, input_tensor, group=None, async_op=False, debug=output_get_caller_func
):
if torch.onnx.is_in_onnx_export():
return DummyWork()

return original_allgather_fn(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)

return _ort_compatible_allgather_fn_zero_stage3


# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104
# In the original logic, if bias is None, after export to ONNX, None becomes a constant, so backward op complains
# output count more than needed.
Expand Down Expand Up @@ -87,6 +134,7 @@ def collect_code(self, function: Callable):
try:
# Have to import below explicitly, otherwise it complains about _apply_to_tensors_only not found.
# The hooks reference functions or classes in that file.
import deepspeed
from deepspeed.runtime.zero.parameter_offload import * # noqa: F403
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload, _apply_to_tensors_only # noqa: F401
from deepspeed.utils import instrument_w_nvtx # noqa: F401
Expand All @@ -109,6 +157,10 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", sta
debug, stats_output_dir, stats_overwrite
)

# This function will overwrite the original allgather_fn in deepspeed comm to make it ort compatible.
# Only need to define it once
deepspeed.comm.allgather_fn = _get_ort_compatible_allgather_fn()

from deepspeed.runtime.zero.linear import zero3_linear_wrap

if torch.nn.functional.linear is zero3_linear_wrap:
Expand Down

0 comments on commit 528d4ef

Please sign in to comment.