Skip to content

Commit

Permalink
Optimize zero3 fetch params using all_reduce (microsoft#5420)
Browse files Browse the repository at this point in the history
* Use all_reduce instead of all_gather to fetch module parameters. This
improves performance by reducing the overhead of concatenation and
slicing, which are no longer required.
* Instead, all tensors views are created prior to the collective
(all_reduce), so upon its completion only the parameter status is
updated.
* The behavior is enabled via a new boolean flag under the section
"zero_optimization": { "stage3_use_all_reduce_for_fetch_params": true }
* By default the optimization is not enabled.

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored and sfc-gh-reyazda committed Jun 10, 2024
1 parent 2c0dcac commit f53895f
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 63 deletions.
7 changes: 7 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"stage": [0|1|2],
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"stage3_use_all_reduce_for_fetch_params": [true|false],
"allgather_partitions": [true|false],
"use_multi_rank_bucket_allreduce": [true|false],
"allgather_bucket_size": 500000000,
Expand Down Expand Up @@ -234,6 +235,12 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
this option is enabled and then saves the fp16 model weights.
"""

use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params")
"""
Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing
the overhead of concatenation and slicing on the host.
"""

stage3_gather_fp16_weights_on_model_save: bool = Field(False,
deprecated=True,
new_param="gather_16bit_weights_on_model_save")
Expand Down
181 changes: 118 additions & 63 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from deepspeed.utils import groups
import deepspeed
from ..utils import see_memory_usage
from ..utils import see_memory_usage, get_only_unique_item
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -716,6 +716,31 @@ def wait(self) -> None:
handle.wait()


class AllReduceCoalescedHandle:

def __init__(self, handle, params: List[Parameter]) -> None:
self.handle = handle
self.params = params
self.complete = False

for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
if self.complete:
return

instrument_w_nvtx(self.handle.wait)()

for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
param.ds_status = ZeroParamStatus.AVAILABLE

self.complete = True


class QuantizationInfo:
# a placeholder object to store all quant related vars used in handles
def __init__(self) -> None:
Expand Down Expand Up @@ -1003,6 +1028,11 @@ def __init__(
if not self.use_all_gather_into_tensor:
logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}")

self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig,
"use_all_reduce_for_fetch_params")
if _ds_config is not None:
self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params

def _update_persist_config(self, ds_config):
Init.apply_param_persistence = True
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Expand Down Expand Up @@ -1250,75 +1280,99 @@ def all_gather_coalesced(params: Iterable[Parameter],
return AllGatherHandle(handle, param, quantization=quant_info)

else:
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))
if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor:
# Use all_reduce instead of all_gather to fetch the module params
flat_buffer_size = sum(p.ds_numel_aligned for p in params)
flat_tensor = torch.zeros(flat_buffer_size,
dtype=get_only_unique_item(p.ds_tensor.dtype for p in params),
device=get_accelerator().current_device_name(),
requires_grad=False)
start_param = 0
for param in params:
param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape)
start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank()
flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor)

return MultipleAllGatherHandles(handles)
start_param += param.ds_numel

handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True)

return AllReduceCoalescedHandle(handle=handle, params=params)
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(
_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
return MultipleAllGatherHandles(handles)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
])
scales = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params
]))
else:
if hasattr(params[0].ds_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)(
[p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
scales = instrument_w_nvtx(torch.cat)([
p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params
])
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups
for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name())
for p in params
])
scales = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.to(get_accelerator().current_device_name())
for p in params
]))
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()
quant_info.quantized_param = flat_tensor
quant_info.backend = self.quantizer_module
quant_info.quant_handle = quant_handle
quant_info.scale_buffer = quant_scale_buffer
quant_info.partition_sz = partition_sz
quant_info.world_size = world_size
return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=None,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
quantization=quant_info,
)
if hasattr(params[0].ds_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)(
[p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
scales = instrument_w_nvtx(torch.cat)([
p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()
quant_info.quantized_param = flat_tensor
quant_info.backend = self.quantizer_module
quant_info.quant_handle = quant_handle
quant_info.scale_buffer = quant_scale_buffer
quant_info.partition_sz = partition_sz
quant_info.world_size = world_size
return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=None,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
cls = param
Expand Down Expand Up @@ -1554,6 +1608,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
param.ds_tensor.ds_numel = partition_size
param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
param.ds_tensor.final_location = final_location
param.ds_numel_aligned = tensor_size

start = partition_size * self.get_partition_rank()
end = start + partition_size
Expand Down

0 comments on commit f53895f

Please sign in to comment.