Skip to content

Commit

Permalink
[MoE] add expert tensor parallelism support for NeMo2.0 MoE (#11880)
Browse files Browse the repository at this point in the history
* add support for expert tensor parallelism

Signed-off-by: gdeng <[email protected]>

* remove function that's going to be deprecated

Signed-off-by: gdeng <[email protected]>

* format

Signed-off-by: gdeng <[email protected]>

* Apply isort and black reformatting

Signed-off-by: gdengk <[email protected]>

* add missing initilization

Signed-off-by: gdeng <[email protected]>

* fix unit test

Signed-off-by: gdeng <[email protected]>

---------

Signed-off-by: gdeng <[email protected]>
Signed-off-by: gdengk <[email protected]>
Co-authored-by: gdengk <[email protected]>
Signed-off-by: Abhinav Garg <[email protected]>
  • Loading branch information
2 people authored and abhinavg4 committed Jan 30, 2025
1 parent 35b4c44 commit 7aedf22
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def setup_mcore_distributed_parallel(self):
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,7 @@ def setup_mcore_distributed_parallel(self):
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

try:
from megatron.core.parallel_state import (
get_data_modulo_expert_parallel_group,
get_data_parallel_group,
get_data_parallel_world_size,
get_expert_data_parallel_group,
)
from megatron.core.tensor_parallel import copy_tensor_model_parallel_attributes

Expand Down Expand Up @@ -74,7 +74,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf):

def _get_grad_data_group(is_expert_group):
if is_expert_group:
data_group = get_data_modulo_expert_parallel_group()
data_group = get_expert_data_parallel_group()
else:
data_group = get_data_parallel_group(with_context_parallel=True)
return data_group
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def init_parallel_ranks(
local_rank=init_local_rank,
tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
expert_model_parallel_size=parallel_config.expert_model_parallel_size,
expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
context_parallel_size=parallel_config.context_parallel_size,
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def initialize_model_parallel_for_nemo(
local_rank,
tensor_model_parallel_size=1,
expert_model_parallel_size=1,
expert_tensor_parallel_size=None,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
Expand Down Expand Up @@ -126,6 +127,7 @@ def initialize_model_parallel_for_nemo(
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
app_state.expert_tensor_parallel_size = expert_tensor_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
Expand All @@ -144,6 +146,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
expert_tensor_parallel_size_=expert_tensor_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
parallelizing layer norms and dropout sequentially. Defaults to False.
expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension.
Defaults to 1.
expert_tensor_parallel_size (Optional[int]): Sets MoE Experts tensor parallelism size. Defaults to None.
moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False.
data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None.
parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None.
Expand Down
1 change: 1 addition & 0 deletions tests/lightning/test_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_init_parallel_ranks() -> None:
mock_parallel_config.virtual_pipeline_model_parallel_size = 4
mock_parallel_config.context_parallel_size = 2
mock_parallel_config.expert_model_parallel_size = 2
mock_parallel_config.expert_tensor_parallel_size = None
mock_parallel_config.encoder_tensor_model_parallel_size = 0
mock_parallel_config.encoder_pipeline_model_parallel_size = 0
mock_parallel_config.tp_comm_overlap = False
Expand Down

0 comments on commit 7aedf22

Please sign in to comment.