Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize zero3 fetch params using all_reduce #5420

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"stage": [0|1|2],
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"stage3_use_all_reduce_for_fetch_params": [true|false],
"allgather_partitions": [true|false],
"use_multi_rank_bucket_allreduce": [true|false],
"allgather_bucket_size": 500000000,
Expand Down Expand Up @@ -234,6 +235,12 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
this option is enabled and then saves the fp16 model weights.
"""

use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params")
"""
Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing
the overhead of concatenation and slicing on the host.
"""

stage3_gather_fp16_weights_on_model_save: bool = Field(False,
deprecated=True,
new_param="gather_16bit_weights_on_model_save")
Expand Down
181 changes: 118 additions & 63 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from deepspeed.utils import groups
import deepspeed
from ..utils import see_memory_usage
from ..utils import see_memory_usage, get_only_unique_item
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -716,6 +716,31 @@ def wait(self) -> None:
handle.wait()


class AllReduceCoalescedHandle:

def __init__(self, handle, params: List[Parameter]) -> None:
self.handle = handle
self.params = params
self.complete = False

for param in self.params:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to not be available")

@instrument_w_nvtx
def wait(self) -> None:
if self.complete:
return

instrument_w_nvtx(self.handle.wait)()

for param in self.params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
param.ds_status = ZeroParamStatus.AVAILABLE

self.complete = True


class QuantizationInfo:
# a placeholder object to store all quant related vars used in handles
def __init__(self) -> None:
Expand Down Expand Up @@ -1003,6 +1028,11 @@ def __init__(
if not self.use_all_gather_into_tensor:
logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}")

self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig,
"use_all_reduce_for_fetch_params")
if _ds_config is not None:
self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params

def _update_persist_config(self, ds_config):
Init.apply_param_persistence = True
Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold
Expand Down Expand Up @@ -1250,75 +1280,99 @@ def all_gather_coalesced(params: Iterable[Parameter],
return AllGatherHandle(handle, param, quantization=quant_info)

else:
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))
if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor:
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
# Use all_reduce instead of all_gather to fetch the module params
flat_buffer_size = sum(p.ds_numel_aligned for p in params)
flat_tensor = torch.zeros(flat_buffer_size,
dtype=get_only_unique_item(p.ds_tensor.dtype for p in params),
device=get_accelerator().current_device_name(),
requires_grad=False)
start_param = 0
for param in params:
param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape)
start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank()
flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor)

return MultipleAllGatherHandles(handles)
start_param += param.ds_numel

handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True)

return AllReduceCoalescedHandle(handle=handle, params=params)
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(
_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)
return MultipleAllGatherHandles(handles)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
])
scales = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params
]))
else:
if hasattr(params[0].ds_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)(
[p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
scales = instrument_w_nvtx(torch.cat)([
p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params
])
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups
for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name())
for p in params
])
scales = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.to(get_accelerator().current_device_name())
for p in params
]))
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()
quant_info.quantized_param = flat_tensor
quant_info.backend = self.quantizer_module
quant_info.quant_handle = quant_handle
quant_info.scale_buffer = quant_scale_buffer
quant_info.partition_sz = partition_sz
quant_info.world_size = world_size
return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=None,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
quantization=quant_info,
)
if hasattr(params[0].ds_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)(
[p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
scales = instrument_w_nvtx(torch.cat)([
p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
for p in params
])
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()
quant_info.quantized_param = flat_tensor
quant_info.backend = self.quantizer_module
quant_info.quant_handle = quant_handle
quant_info.scale_buffer = quant_scale_buffer
quant_info.partition_sz = partition_sz
quant_info.world_size = world_size
return AllGatherCoalescedHandle(
allgather_handle=handle,
params=params,
partitions=None,
world_size=world_size,
use_secondary_tensor=use_secondary_tensor,
quantization=quant_info,
)

def partition(param_list=None, hierarchy=0, has_been_updated=False):
cls = param
Expand Down Expand Up @@ -1554,6 +1608,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
param.ds_tensor.ds_numel = partition_size
param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
param.ds_tensor.final_location = final_location
param.ds_numel_aligned = tensor_size

start = partition_size * self.get_partition_rank()
end = start + partition_size
Expand Down
Loading