diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index 84718f99262f..9016ec054f84 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -812,7 +812,7 @@ def setup_mcore_distributed_parallel(self): ddp_config, model_chunk, data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(), # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0), diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index e530a40d8aaa..43fd971ca117 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -1882,7 +1882,7 @@ def setup_mcore_distributed_parallel(self): ddp_config, model_chunk, data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(), # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0), diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 0f549443b772..9723a876e58a 100755 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -30,9 +30,9 @@ try: from megatron.core.parallel_state import ( - get_data_modulo_expert_parallel_group, get_data_parallel_group, get_data_parallel_world_size, + get_expert_data_parallel_group, ) from megatron.core.tensor_parallel import copy_tensor_model_parallel_attributes @@ -74,7 +74,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf): def _get_grad_data_group(is_expert_group): if is_expert_group: - data_group = get_data_modulo_expert_parallel_group() + data_group = get_expert_data_parallel_group() else: data_group = get_data_parallel_group(with_context_parallel=True) return data_group diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 30090896ac8e..f3d735c604f7 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -81,6 +81,7 @@ def init_parallel_ranks( local_rank=init_local_rank, tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, expert_model_parallel_size=parallel_config.expert_model_parallel_size, + expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size, pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, context_parallel_size=parallel_config.context_parallel_size, diff --git a/nemo/lightning/megatron_init.py b/nemo/lightning/megatron_init.py index 9e163f349206..5a618d1a278b 100644 --- a/nemo/lightning/megatron_init.py +++ b/nemo/lightning/megatron_init.py @@ -91,6 +91,7 @@ def initialize_model_parallel_for_nemo( local_rank, tensor_model_parallel_size=1, expert_model_parallel_size=1, + expert_tensor_parallel_size=None, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, @@ -126,6 +127,7 @@ def initialize_model_parallel_for_nemo( app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size app_state.use_fp8 = use_fp8 app_state.init_mpi_proc_group = init_mpi_proc_group + app_state.expert_tensor_parallel_size = expert_tensor_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, @@ -144,6 +146,7 @@ def initialize_model_parallel_for_nemo( pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, expert_model_parallel_size_=expert_model_parallel_size, + expert_tensor_parallel_size_=expert_tensor_parallel_size, encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size, encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size, use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index ad15a0677f6c..21c6f892bfae 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -126,6 +126,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): parallelizing layer norms and dropout sequentially. Defaults to False. expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension. Defaults to 1. + expert_tensor_parallel_size (Optional[int]): Sets MoE Experts tensor parallelism size. Defaults to None. moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False. data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None. parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None. diff --git a/tests/lightning/test_strategy_lib.py b/tests/lightning/test_strategy_lib.py index 017c325842d4..197fc8d4982c 100644 --- a/tests/lightning/test_strategy_lib.py +++ b/tests/lightning/test_strategy_lib.py @@ -78,6 +78,7 @@ def test_init_parallel_ranks() -> None: mock_parallel_config.virtual_pipeline_model_parallel_size = 4 mock_parallel_config.context_parallel_size = 2 mock_parallel_config.expert_model_parallel_size = 2 + mock_parallel_config.expert_tensor_parallel_size = None mock_parallel_config.encoder_tensor_model_parallel_size = 0 mock_parallel_config.encoder_pipeline_model_parallel_size = 0 mock_parallel_config.tp_comm_overlap = False