Skip to content

Commit

Permalink
Merge branch 'master' into baichuan_support
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Dec 18, 2023
2 parents acf40e7 + 83fa673 commit 2b724d9
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 53 deletions.
62 changes: 31 additions & 31 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,50 @@


# top-level repo folders
/.github/ @jeffra @mrwyattii @loadams
/azure/ @jeffra @awan-10
/benchmarks/ @jeffra @awan-10 @mrwyattii @molly-smith
/bin/ @jeffra
/csrc/ @RezaYazdaniAminabadi @awan-10 @jeffra @cmikeh2 @arashb
/deepspeed/ @jeffra
/docker/ @jeffra @awan-10
/docs/ @jeffra @mrwyattii
/examples/ @jeffra @awan-10 @mrwyattii
/op_builder/ @jeffra @RezaYazdaniAminabadi @cmikeh2
/release/ @jeffra @mrwyattii
/requirements/ @jeffra @mrwyattii
/scripts/ @jeffra @awan-10
/tests/ @jeffra @mrwyattii @tjruwase
/.github/ @mrwyattii @loadams
/azure/ @mrwyattii @awan-10
/benchmarks/ @awan-10 @mrwyattii
/bin/ @mrwyattii
/csrc/ @awan-10 @mrwyattii @cmikeh2 @arashb
/deepspeed/ @mrwyattii
/docker/ @mrwyattii @awan-10
/docs/ @mrwyattii
/examples/ @awan-10 @mrwyattii
/op_builder/ @mrwyattii @cmikeh2
/release/ @loadams @mrwyattii
/requirements/ @loadams @mrwyattii
/scripts/ @mrwyattii @awan-10
/tests/ @mrwyattii @tjruwase @loadams

# deepspeed
/deepspeed/autotuning/ @cli99
/deepspeed/autotuning/ @mrwyattii
/deepspeed/checkpoint/ @tjruwase
/deepspeed/comm/ @awan-10
/deepspeed/compression/ @yaozhewei @minjiaz @xiaoxiawu-microsoft @conglongli
/deepspeed/elasticity/ @jeffra @awan-10
/deepspeed/launcher/ @jeffra @awan-10
/deepspeed/module_inject/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/compression/ @minjiaz @xiaoxiawu-microsoft @conglongli
/deepspeed/elasticity/ @mrwyattii @awan-10
/deepspeed/launcher/ @mrwyattii @awan-10
/deepspeed/module_inject/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/moe/ @awan-10
/deepspeed/monitor/ @awan-10 @jeffra
/deepspeed/nebula/ @tjruwase @jeffra
/deepspeed/ops/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/monitor/ @awan-10 @mrwyattii
/deepspeed/nebula/ @tjruwase @mrwyattii
/deepspeed/ops/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/pipe/ @ShadenSmith @duli2012
/deepspeed/profiling/ @cli99
/deepspeed/utils/ @jeffra @tjruwase @awan-10
/deepspeed/profiling/ @ShijieZZZZ
/deepspeed/utils/ @mrwyattii @tjruwase @awan-10

# inference
/deepspeed/inference/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/model_implementations/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/inference/ @mrwyattii @awan-10 @cmikeh2 @arashb
/deepspeed/model_implementations/ @mrwyattii @awan-10 @cmikeh2 @arashb

# training
/deepspeed/runtime/ @jeffra @tjruwase
/deepspeed/runtime/activation_checkpointing/ @jeffra @tjruwase
/deepspeed/runtime/checkpoint_engine/ @tjruwase @jeffra
/deepspeed/runtime/ @mrwyattii @tjruwase
/deepspeed/runtime/activation_checkpointing/ @mrwyattii @tjruwase
/deepspeed/runtime/checkpoint_engine/ @tjruwase @mrwyattii
/deepspeed/runtime/comm/ @awan-10
/deepspeed/runtime/compression/ @awan-10 @conglongli
/deepspeed/runtime/data_pipeline/ @conglongli
/deepspeed/runtime/fp16/ @jeffra @tjruwase
/deepspeed/runtime/fp16/ @mrwyattii @tjruwase
/deepspeed/runtime/fp16/onebit/ @conglongli @awan-10
/deepspeed/runtime/pipe/ @ShadenSmith @duli2012
/deepspeed/runtime/swap_tensor/ @tjruwase @mrwyattii
/deepspeed/runtime/zero/ @jeffra @tjruwase @samyam @mrwyattii
/deepspeed/runtime/zero/ @tjruwase @mrwyattii
12 changes: 6 additions & 6 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,6 @@ def __init__(self, model, config):
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"

# Check if model passed to engine is loaded w/ meta tensors, in which case
# kernel injection must be enabled.
# NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type
self.model_meta_device = self.module.device.type == 'meta' if hasattr(self.module, "device") else False

# convert model to intended dtype
if config.dtype:
self._convert_to_dtype(config)
Expand Down Expand Up @@ -170,7 +165,12 @@ def __init__(self, model, config):
self._apply_injection_policy(config, client_module)

device = get_accelerator().current_device_name()
self.module.to(device)
# NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
else:
self.module.to(device)

if config.tensor_parallel.tp_size > 1:
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
Expand Down
34 changes: 23 additions & 11 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def move(tensor, device):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)


class ReplaceWithTensorSlicing:

def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
Expand Down Expand Up @@ -318,17 +328,18 @@ def _replace(self, child, name, conv_linear_layer):
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
dim=1)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
child.bias.to(get_accelerator().current_device_name())), self.mp_group)
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand All @@ -340,30 +351,31 @@ def _replace(self, child, name, conv_linear_layer):
#for detecting fused type
module_str = str(self.module).strip()
#The copy is a regular copy, The shape of dst and src is the same
data_dc = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index)
data_dc = move(
prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())

bias_data_dc = None if child.bias is None else prepare_tp_fused_qkvw(
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(module_str, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
dim=1 if self.conv_linear_layer else 0)
data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach()
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
bias_data_dc = None

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc.to(get_accelerator().current_device_name()), requires_grad=False), \
bias=bias_data_dc)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
Expand Down
25 changes: 21 additions & 4 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

self.first_output_send = True
self.first_gradient_send = True
self.pipe_partition_input_meta_cache = None
self.pipe_partition_output_meta_cache = None
self.pipe_partition_grad_meta_cache = None
self.grad_partition_grad_layer_meta_cache = None

#stores the loss for the current micro batch being processed
self.loss = torch.tensor(0.0).to(self.device)
Expand Down Expand Up @@ -309,6 +313,11 @@ def reset_activation_shape(self):
self.grad_layer = None
self.meta_buffer = None

self.pipe_partition_input_meta_cache = None
self.pipe_partition_output_meta_cache = None
self.pipe_partition_grad_meta_cache = None
self.grad_partition_grad_layer_meta_cache = None

def train_batch(self, data_iter=None):
"""Progress the pipeline to train the next batch of data. The engine will ingest
``self.train_batch_size()`` total samples collectively across all workers.
Expand Down Expand Up @@ -641,7 +650,9 @@ def _exec_forward_pass(self, buffer_id):

# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(meta=inputs[0],
if self.pipe_partition_input_meta_cache is None:
self.pipe_partition_input_meta_cache = inputs[0].to('cpu')
part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache,
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())

Expand Down Expand Up @@ -732,7 +743,9 @@ def _exec_backward_pass(self, buffer_id):
# careful to also restore the computational graph of the tensors we partitioned.
if self.is_pipe_partitioned:
if self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(meta=outputs[0],
if self.pipe_partition_output_meta_cache is None:
self.pipe_partition_output_meta_cache = outputs[0].to('cpu')
part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache,
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
Expand All @@ -745,7 +758,9 @@ def _exec_backward_pass(self, buffer_id):
grad_tensors = self.grad_layer
if self.is_grad_partitioned:
#print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0],
if self.grad_partition_grad_layer_meta_cache is None:
self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu')
part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache,
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = (part_grad.full(), *grad_tensors[2:])
Expand Down Expand Up @@ -1088,7 +1103,9 @@ def _exec_recv_grads(self, buffer_id):
# XXX these shapes are hardcoded for Megatron
# Restore partitioned output if it was partitioned and we are sending full gradients
if self.is_pipe_partitioned and not self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(meta=outputs[0],
if self.pipe_partition_grad_meta_cache is None:
self.pipe_partition_grad_meta_cache = outputs[0].to('cpu')
part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache,
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.12.5
0.12.6

0 comments on commit 2b724d9

Please sign in to comment.