Skip to content

Commit

Permalink
Add arctic model support by adding w2 to all_reduce
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Huang <[email protected]>
  • Loading branch information
pi314ever committed Dec 11, 2024
1 parent 074d5c6 commit 5b72093
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer):
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand Down

0 comments on commit 5b72093

Please sign in to comment.