Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE] add expert tensor parallelism support for NeMo2.0 MoE #11880

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def setup_mcore_distributed_parallel(self):
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,7 @@ def setup_mcore_distributed_parallel(self):
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

try:
from megatron.core.parallel_state import (
get_data_modulo_expert_parallel_group,
get_data_parallel_group,
get_data_parallel_world_size,
get_expert_data_parallel_group,
)
from megatron.core.tensor_parallel import copy_tensor_model_parallel_attributes

Expand Down Expand Up @@ -74,7 +74,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf):

def _get_grad_data_group(is_expert_group):
if is_expert_group:
data_group = get_data_modulo_expert_parallel_group()
data_group = get_expert_data_parallel_group()
else:
data_group = get_data_parallel_group(with_context_parallel=True)
return data_group
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def init_parallel_ranks(
local_rank=init_local_rank,
tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
expert_model_parallel_size=parallel_config.expert_model_parallel_size,
expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
context_parallel_size=parallel_config.context_parallel_size,
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def initialize_model_parallel_for_nemo(
local_rank,
tensor_model_parallel_size=1,
expert_model_parallel_size=1,
expert_tensor_parallel_size=None,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
Expand Down Expand Up @@ -126,6 +127,7 @@ def initialize_model_parallel_for_nemo(
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
app_state.expert_tensor_parallel_size = expert_tensor_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
Expand All @@ -144,6 +146,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
expert_tensor_parallel_size_=expert_tensor_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
parallelizing layer norms and dropout sequentially. Defaults to False.
expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension.
Defaults to 1.
expert_tensor_parallel_size (Optional[int]): Sets MoE Experts tensor parallelism size. Defaults to None.
moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False.
data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None.
parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None.
Expand Down
1 change: 1 addition & 0 deletions tests/lightning/test_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_init_parallel_ranks() -> None:
mock_parallel_config.virtual_pipeline_model_parallel_size = 4
mock_parallel_config.context_parallel_size = 2
mock_parallel_config.expert_model_parallel_size = 2
mock_parallel_config.expert_tensor_parallel_size = None
mock_parallel_config.encoder_tensor_model_parallel_size = 0
mock_parallel_config.encoder_pipeline_model_parallel_size = 0
mock_parallel_config.tp_comm_overlap = False
Expand Down
Loading