Skip to content

Commit

Permalink
etp
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Jan 20, 2025
1 parent 0075ed0 commit 6dd3e81
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
2 changes: 2 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
encoder_tensor_model_parallel_size=app_state.encoder_tensor_model_parallel_size,
context_parallel_size=app_state.context_parallel_size,
expert_model_parallel_size=app_state.expert_model_parallel_size,
expert_tensor_parallel_size=app_state.expert_tensor_parallel_size,
)

# assert that fake tp and pp rank match after model parallel init
assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank()
assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank()
assert app_state.expert_tensor_parallel_rank == parallel_state.get_expert_tensor_parallel_rank()

app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group()
app_state.data_parallel_group = parallel_state.get_data_parallel_group()
Expand Down
9 changes: 9 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def initialize_model_parallel_for_nemo(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.expert_model_parallel_rank,
app_state.expert_tensor_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
Expand Down Expand Up @@ -482,6 +483,13 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

# ETP
expert_tensor_parallel_rank = 0
if expert_tensor_parallel_size_ is not None and expert_tensor_parallel_size_ > 1:
for ranks in generator_wrapper('tp-ep', is_expert=True):
if rank in ranks:
expert_tensor_parallel_rank = list(ranks).index(rank)

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
all_pipeline_model_parallel_group_ranks = []
Expand Down Expand Up @@ -520,6 +528,7 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
tensor_model_parallel_rank,
pipeline_model_parallel_rank,
expert_model_parallel_rank,
expert_tensor_parallel_rank,
model_parallel_size,
data_parallel_size,
pipeline_model_parallel_split_rank_,
Expand Down
6 changes: 6 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ParallelismConfig:
encoder_tensor_model_parallel_size: int = 0
encoder_pipeline_model_parallel_size: int = 0
use_te_rng_tracker: bool = False
expert_tensor_parallel_size: int = None


class MegatronStrategy(DDPStrategy, io.IOMixin):
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
sequence_parallel: bool = False,
expert_model_parallel_size: int = 1,
moe_extended_tp: bool = False,
expert_tensor_parallel_size: int = None,
encoder_tensor_model_parallel_size: Optional[int] = 0,
encoder_pipeline_model_parallel_size: Optional[int] = 0,
data_sampler: Optional["DataSampler"] = None,
Expand Down Expand Up @@ -237,6 +239,7 @@ def __init__(
)
self.context_parallel_size = context_parallel_size
self.expert_model_parallel_size = expert_model_parallel_size
self.expert_tensor_parallel_size = expert_tensor_parallel_size
self.moe_extended_tp = moe_extended_tp
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.sequence_parallel = sequence_parallel
Expand Down Expand Up @@ -302,6 +305,8 @@ def connect(self, model: pl.LightningModule) -> None:
assert not 'is_hf_model' in model.__dict__, "Cannot use HFAutoModelForCausalLM with MegatronParallel"

dtype_config = getattr(self._precision_plugin, "dtype_config", None)
if self.pipeline_dtype is None and dtype_config:
self.pipeline_dtype = dtype_config.pipeline_dtype

_maybe_mcore_config = _strategy_lib.set_model_parallel_attributes(model, self.parallelism)
if _maybe_mcore_config:
Expand Down Expand Up @@ -899,6 +904,7 @@ def parallelism(self) -> ParallelismConfig:
context_parallel_size=self.context_parallel_size,
sequence_parallel=self.sequence_parallel,
expert_model_parallel_size=self.expert_model_parallel_size,
expert_tensor_parallel_size=self.expert_tensor_parallel_size,
moe_extended_tp=self.moe_extended_tp,
encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size,
Expand Down
40 changes: 37 additions & 3 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self):
self._global_rank = None
self._tensor_model_parallel_rank = None
self._expert_model_parallel_rank = None
self._expert_tensor_parallel_rank = None
self._pipeline_model_parallel_rank = None
self._data_parallel_rank = None

Expand All @@ -48,6 +49,7 @@ def __init__(self):
self._tensor_model_parallel_size = None
self._tensor_model_parallel_group = None
self._expert_model_parallel_size = None
self._expert_tensor_parallel_size = None
self._pipeline_model_parallel_size = None
self._virtual_pipeline_model_parallel_size = None
self._encoder_tensor_model_parallel_size = None
Expand Down Expand Up @@ -181,12 +183,44 @@ def expert_model_parallel_size(self):

@expert_model_parallel_size.setter
def expert_model_parallel_size(self, size):
"""Property sets the number of GPUs in each expert parallel group.
Args:
size (int): Number of GPUs in each expert parallel group.
"""Property returns the number of GPUs in each expert parallel group.
Returns:
Number of GPUs in each expert parallel group.
"""
self._expert_model_parallel_size = size

@property
def expert_tensor_parallel_size(self):
"""Property returns the number of GPUs in each expert tensor parallel group.
Returns:
Number of GPUs in each expert tensor parallel group.
"""
return self._expert_tensor_parallel_size

@expert_tensor_parallel_size.setter
def expert_tensor_parallel_size(self, size):
"""Property sets the number of GPUs in each expert tensor parallel group.
Args:
size (int): Number of GPUs in each tensor expert parallel group.
"""
self._expert_tensor_parallel_size = size

@property
def expert_tensor_parallel_rank(self):
"""Property returns the expert tensor model parallel rank.
Returns:
Tensor model parallel rank.
"""
return self._expert_tensor_parallel_rank

@expert_tensor_parallel_rank.setter
def expert_tensor_parallel_rank(self, rank):
"""Property sets the expert tensor model parallel rank.
Args:
rank (int): Tensor model parallel rank.
"""
self._expert_tensor_parallel_rank = rank

@property
def pipeline_model_parallel_size(self):
"""Property returns the number of GPUs in each model parallel group.
Expand Down

0 comments on commit 6dd3e81

Please sign in to comment.