Skip to content

Commit

Permalink
Merge branch 'master' into gma/add_autotp_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Feb 6, 2024
2 parents 552aa5b + c3cfe96 commit 0961d5f
Show file tree
Hide file tree
Showing 20 changed files with 872 additions and 192 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Install pytorch
run: |
pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -46,5 +46,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="1.12"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="1.12"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2"
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule

Expand Down
24 changes: 24 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import *
from .backend import *
from .comm import *
from ..runtime import compiler
import os

DS_COMM_ALL_GATHER_OFF = False
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.init_process_group(backend, timeout, init_method, rank, world_size)

@classmethod
@compiler.disable
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -128,6 +130,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -150,14 +153,17 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -168,13 +174,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -187,6 +195,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -195,6 +204,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -203,13 +213,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -227,6 +239,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -250,6 +263,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -263,6 +277,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
def all_to_all_single(self,
output,
input,
Expand All @@ -277,40 +292,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
14 changes: 10 additions & 4 deletions deepspeed/module_inject/containers/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ def __init__(self):
super().__init__()
try:
import diffusers
if hasattr(diffusers.models.autoencoders.vae, "AutoencoderKL"):
self._orig_layer_class = diffusers.models.autoencoders.vae.AutoencoderKL
else:
# Diffusers >= 0.12.0 changes location of AutoencoderKL
if hasattr(diffusers.models, "autoencoders"):
# Diffusers >= 0.25.0
# Changes location to 'autoencoders' directory
self._orig_layer_class = diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL
elif hasattr(diffusers.models.vae, "AutoencoderKL"):
# Diffusers < 0.12.0
self._orig_layer_class = diffusers.models.vae.AutoencoderKL
else:
# Diffusers >= 0.12.0 & < 0.25.0
# Changes location of AutoencoderKL
self._orig_layer_class = diffusers.models.autoencoder_kl.AutoencoderKL
except ImportError:
self._orig_layer_class = None

Expand Down
55 changes: 29 additions & 26 deletions deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@

# DeepSpeed Team

import torch
from typing import Optional, Tuple

from deepspeed.utils import log_dist
import torch
from torch import nn
from torch.nn import functional as F

from deepspeed.utils import groups
from .sharded_moe import MOELayer, TopKGate
from deepspeed.utils import groups, log_dist
from .experts import Experts
import typing
from .sharded_moe import MOELayer, TopKGate


class MoE(torch.nn.Module):
class MoE(nn.Module):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
expert (nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
Expand All @@ -34,20 +35,20 @@ class MoE(torch.nn.Module):
"""

def __init__(self,
hidden_size,
expert,
num_experts=1,
ep_size=1,
k=1,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
use_residual=False,
noisy_gate_policy: typing.Optional[str] = None,
hidden_size: int,
expert: nn.Module,
num_experts: int = 1,
ep_size: int = 1,
k: int = 1,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 4,
use_residual: bool = False,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts=True,
use_rts: bool = True,
use_tutel: bool = False,
enable_expert_tensor_parallelism: bool = False):
enable_expert_tensor_parallelism: bool = False) -> None:

super(MoE, self).__init__()

Expand Down Expand Up @@ -77,12 +78,12 @@ def __init__(self,
if self.use_residual:
self.mlp = expert
# coefficient is used for weighted sum of the output of expert and mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)
self.coefficient = nn.Linear(hidden_size, 2)

def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
def set_deepspeed_parallelism(self, use_data_before_expert_parallel_: bool = False) -> None:
self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)

def _create_process_groups(self, use_data_before_expert_parallel_=False):
def _create_process_groups(self, use_data_before_expert_parallel_: bool = False) -> None:
# Create process group for a layer if needed
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
Expand All @@ -98,7 +99,9 @@ def _create_process_groups(self, use_data_before_expert_parallel_=False):
# Set the group handle for the MOELayer (deepspeed_moe) object
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))

def forward(self, hidden_states, used_token=None):
def forward(self,
hidden_states: torch.Tensor,
used_token: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" MoE forward
Arguments:
Expand All @@ -112,15 +115,15 @@ def forward(self, hidden_states, used_token=None):
* l_aux (Tensor): gate loss value
* exp_counts (int): expert count
* exp_counts (Tensor): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
if self.use_residual:
# Residual MoE
output_mlp = self.mlp(hidden_states)
if type(output_mlp) is tuple:
if isinstance(output_mlp, tuple):
output_mlp = output_mlp[0] # Ignore the bias term for now
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1)
coef = F.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
29 changes: 12 additions & 17 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, Union, cast

import torch
Expand Down Expand Up @@ -68,10 +69,9 @@ def split_params_grads_into_shared_and_expert_params(
return shared_grads, expert_grads


def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any],
...],
List[Dict[str, Any]]],
max_group_size: int = 178956971) -> List[Dict[str, Any]]:
def split_params_into_different_moe_groups_for_optimizer(
param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]],
max_group_size: Union[int, float] = 178956971) -> List[Dict[str, Any]]:
"""Split parameters into different MoE groups for optimizer
Args:
Expand All @@ -97,18 +97,15 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dic
data_parallel_group_names.add(param.group_name)

# Create the param MoE groups, leave param assign to next step
group_moe: Dict[str, Dict[str, Dict[str, Any]]] = {}
group_moe: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
for param_group in param_groups:
group_moe[param_group['name']] = {}
for key in data_parallel_group_names:
group_moe[param_group['name']][key] = {}
group_moe[param_group['name']][key]['name'] = key
group_moe[param_group['name']][key]['moe'] = True

for ori_key in param_group.keys():
if ori_key != 'name':
group_moe[param_group['name']][key][ori_key] = ([]
if ori_key == 'params' else param_group[ori_key])
group_moe[param_group['name']][key] = {
**param_group,
'name': key,
'moe': True,
'params': [],
}

# Assign param
for param_group in param_groups:
Expand Down Expand Up @@ -142,9 +139,7 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Union[Dic
all_groups.append(cur_group)

for group in all_groups:
new_dict = dict(param_group)
new_dict['params'] = group
param_groups.append(new_dict)
param_groups.append({**param_group, 'params': group})
else:
for moe_group in group_moe.values():
for param_group in moe_group.values():
Expand Down
Loading

0 comments on commit 0961d5f

Please sign in to comment.