Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump mcore #11740

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,7 @@ jobs:
with:
RUNNER: self-hosted-azure-gpus-2-h100
SCRIPT: |
CUDA_DEVICE_MAX_CONNECTIONS=1 NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=1 python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
CUDA_DEVICE_MAX_CONNECTIONS=1 python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.max_epochs=9999 \
Expand Down Expand Up @@ -2965,6 +2965,7 @@ jobs:
+model.tp_comm_overlap_ag=False \
+model.tp_comm_overlap_rs=False \
+model.tp_comm_overlap_disable_qkv=True \
+model.attention_backend="unfused" \
model.peft.peft_scheme="lora" \
model.peft.lora_tuning.adapter_dim=16 \
model.peft.lora_tuning.alpha=32 \
Expand Down Expand Up @@ -4267,7 +4268,7 @@ jobs:
with:
RUNNER: self-hosted-azure
SCRIPT: |
NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python3 tests/collections/llm/megatron_mixtral_pretraining.py \
python3 tests/collections/llm/megatron_mixtral_pretraining.py \
--experiment-dir=/tmp/mixtral_pretrain_results \
--data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.21.0
ARG MCORE_TAG=bd677bfb13ac2f19deaa927adc6da6f9201d66aa
ARG MCORE_TAG=076972e37420b5325c5fe06e7131be7d96f05b53

ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
Expand Down
3 changes: 1 addition & 2 deletions docs/source/nlp/information_retrieval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ Then you can fine-tune the sentence-BERT model using the following script:
VALIDATION_DATASET_PATH= # Path to validation dataset
SAVE_DIR= # where the checkpoint and logs are saved
mkdir -p $SAVE_DIR
export NVTE_FLASH_ATTN=0
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
export NVTE_FUSED_ATTN=0

python NeMo/examples/nlp/information_retrieval/megatron_bert_embedding_finetuning.py \
--config-path=${CONFIG_PATH} \
Expand All @@ -87,6 +85,7 @@ Then you can fine-tune the sentence-BERT model using the following script:
model.post_process=False \
model.global_batch_size=8 \ # should be NUM_DEVICES * model.micro_batch_size
model.micro_batch_size=8 \
model.attention_backend="unfused" \
model.optim.lr=0.000005 \
model.optim.sched.min_lr=0.00000001 \
model.optim.sched.warmup_steps=100 \
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/diffusion/scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
export WANDB_PROJECT=xxx
export WANDB_RUN_ID=xxx
export WANDB_RESUME=allow
export NVTE_FUSED_ATTN=0
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/model/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from megatron.core import parallel_state
from megatron.core.transformer.enums import AttnBackend
from torch import nn

from nemo.collections.llm.fn.activation import openai_gelu
Expand Down Expand Up @@ -53,6 +54,8 @@ class GemmaConfig(GPTConfig):
# Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script
# The present implementation is more in line with the official implementation
layernorm_zero_centered_gamma: bool = True
# Disable cuDNN attention since TE 1.8 does not support head dim > 128
attention_backend: AttnBackend = AttnBackend.flash


@dataclass
Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/llm/recipes/gemma_2b.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def model() -> run.Config[pl.LightningModule]:
>>> model_config = model()
>>> print(model_config)
"""
# Disable cuDNN attention since TE 1.8 does not support head dim > 128
os.environ['NVTE_FUSED_ATTN'] = "0"
return run.Config(GemmaModel, config=run.Config(GemmaConfig2B))


Expand Down
4 changes: 0 additions & 4 deletions nemo/collections/llm/recipes/gemma_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def model() -> run.Config[pl.LightningModule]:
>>> model_config = model()
>>> print(model_config)
"""
# Disable cuDNN attention since TE 1.8 does not support head dim > 128
os.environ['NVTE_FUSED_ATTN'] = "0"
return run.Config(GemmaModel, config=run.Config(GemmaConfig7B))


Expand Down Expand Up @@ -173,8 +171,6 @@ def pretrain_recipe(
For more details on pre-training LLMs with NeMo, see the pre-training
guide in the `examples/llm/pretrain/` directory.
"""
# Disable cuDNN attention since TE 1.8 does not support head dim > 128
os.environ['NVTE_FUSED_ATTN'] = "0"

return run.Partial(
fn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
try:
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
Expand Down Expand Up @@ -537,6 +538,9 @@ def build_transformer_config(self) -> TransformerConfig:

tp_only_amax_red = self.cfg.get('tp_only_amax_red', False)

attention_backend = self.cfg.get('attention_backend', "auto")
attention_backend = AttnBackend[attention_backend]

# any configs that are not in the nemo model config will be added here
config_mapping = {
'apply_query_key_layer_scaling': apply_query_key_layer_scaling,
Expand All @@ -561,6 +565,7 @@ def build_transformer_config(self) -> TransformerConfig:
'rotary_interleaved': rotary_interleaved,
'deallocate_pipeline_outputs': True,
'tp_only_amax_red': tp_only_amax_red,
'attention_backend': attention_backend,
}

# populate the transformer config dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from megatron.core.models.retro.utils import get_config_path as get_retro_config_path
from megatron.core.models.retro.utils import get_gpt_data_dir as get_retro_data_dir
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
Expand Down Expand Up @@ -431,6 +432,8 @@ def build_retro_config(self) -> RetroConfig:

te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("1.3"):
if HAVE_MEGATRON_CORE:
retro_config.attention_backend = AttnBackend.unfused
try:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/vlm/mllama/model/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def sharded_state_dict(
layer_prefix = f'{prefix}layers.'
num_layers = self.config.num_layers
for layer in self.layers:
offset = layer._get_layer_offset()
offset = layer._get_layer_offset(layer.config)
global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f'{layer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long
sharded_prefix = layer_prefix
Expand All @@ -403,7 +403,7 @@ def sharded_state_dict(
for xlayer in self.xattn_layers:
if isinstance(xlayer, DummyCrossAttentionTransformerLayer):
continue
offset = xlayer._get_layer_offset()
offset = xlayer._get_layer_offset(xlayer.config)
global_layer_offset = xlayer.layer_number - 1
state_dict_prefix = f'{xlayer_prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock # pylint: disable=line-too-long
sharded_prefix = f'{xlayer_prefix}{global_layer_offset}.'
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def load_checkpoint(
if getattr(path, "base_model_path", None):
## PEFT Resume, FIRST TIME
self.adapter_ckpt_path = Path(str(path))
adapter_ckpt = self.checkpoint_io.load_checkpoint(path) # Loads only metadata
adapter_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict={}) # Loads only metadata
# path is adapter path to restore the training metadata, but switch to loading base model here.
path = self.model_ckpt_path = path.base_model_path
elif adapter_meta_path.exists():
Expand Down
3 changes: 2 additions & 1 deletion scripts/checkpoint_converters/convert_bert_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def convert(args):
nemo_config.model = adjust_nemo_config(nemo_config.model, hf_model.config.to_dict(), mcore_bert=args.mcore)

nemo_config.trainer["precision"] = args.precision
# Bert doesn't support FLASH_ATTN
nemo_config.model["attention_backend"] = "fused"
trainer = MegatronTrainerBuilder(nemo_config).create_trainer()
model = MegatronBertModel(nemo_config.model, trainer)

Expand Down Expand Up @@ -288,6 +290,5 @@ def convert(args):


if __name__ == '__main__':
os.environ['NVTE_FLASH_ATTN'] = '0' # Bert doesn't support FLASH_ATTN
args = get_args()
convert(args)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import init_method_normal, scaled_init_method_normal

from nemo.collections.llm import MixtralConfig8x7B, MixtralModel, PreTrainingDataModule
Expand Down Expand Up @@ -102,6 +103,7 @@ def main(args):
bias_dropout_fusion=True,
apply_rope_fusion=True,
distribute_saved_activations=False,
attention_backend=AttnBackend.unfused,
)

data = PreTrainingDataModule(
Expand Down
4 changes: 2 additions & 2 deletions tests/collections/llm/bitexact/mixtral/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ MCORE_OUTPUT_PATH="/tmp/bex_mixtral_mcore_output/"
NEMO_OUTPUT_PATH="/tmp/bex_mixtral_nemo_output/"

# Run Mcore
CUDA_DEVICE_MAX_CONNECTIONS=1 CUDA_LAUNCH_BLOCKING=1 TORCH_COMPILE_DISABLE=1 NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 \
CUDA_DEVICE_MAX_CONNECTIONS=1 CUDA_LAUNCH_BLOCKING=1 TORCH_COMPILE_DISABLE=1 \
torchrun --nproc-per-node 1 --nnodes 1 /workspace/Megatron-LM/pretrain_gpt.py \
--apply-layernorm-1p --rotary-percent 1.0 --rotary-base 1000000 \
--no-position-embedding --position-embedding-type rope \
Expand All @@ -30,7 +30,7 @@ torchrun --nproc-per-node 1 --nnodes 1 /workspace/Megatron-LM/pretrain_gpt.py \
--split 99,1,0 --log-interval 10 --save-interval 20000 --eval-interval 1000 --eval-iters 32 \
--save "$MCORE_OUTPUT_PATH" \
--log-num-zeros-in-grad --distributed-timeout-minutes 6000 --moe-router-topk 1 --num-experts 2 \
--moe-router-pre-softmax --expert-model-parallel-size 1 --eval-iters=0
--moe-router-pre-softmax --expert-model-parallel-size 1 --eval-iters=0 --attention-backend unfused

# Run NeMo
CUDA_LAUNCH_BLOCKING=1 TORCH_COMPILE_DISABLE=1 NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 \
Expand Down
5 changes: 5 additions & 0 deletions tests/collections/llm/gpt/model/test_model_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch

torch.set_grad_enabled(False)
Expand Down Expand Up @@ -95,5 +97,8 @@ def import_from_hf(config_name, hf_path):

if __name__ == '__main__':
for config_name, hf_id in config_name_to_hf_id.items():
for env_var in ['NVTE_FLASH_ATTN', 'NVTE_FUSED_ATTN', 'NVTE_UNFUSED_ATTN']:
if env_var in os.environ:
del os.environ[env_var]
src = f'hf:///home/TestData/nemo2_ckpt/{config_name}'
import_from_hf(config_name, src)
1 change: 0 additions & 1 deletion tests/collections/llm/hf/peft_nemorun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecut
"NCCL_NVLS_ENABLE": "0",
"NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
"NVTE_ASYNC_AMAX_REDUCTION": "1",
"NVTE_FUSED_ATTN": "0",
}

executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)
Expand Down
1 change: 0 additions & 1 deletion tests/collections/llm/hf/sft_nemorun.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecut
"NCCL_NVLS_ENABLE": "0",
"NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
"NVTE_ASYNC_AMAX_REDUCTION": "1",
"NVTE_FUSED_ATTN": "0",
}

executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)
Expand Down
2 changes: 2 additions & 0 deletions tests/collections/llm/megatron_mixtral_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from megatron.core.transformer.enums import AttnBackend

from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule
from nemo.collections.llm.api import train
Expand Down Expand Up @@ -117,6 +118,7 @@ def main(args):
bf16=True,
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
attention_backend=AttnBackend.unfused,
)
mixtral_config.overlap_param_gather_with_optimizer_step = True

Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path
import shutil
import tarfile
Expand Down Expand Up @@ -122,6 +123,19 @@ def reset_singletons():
Singleton._Singleton__instances = {}


@pytest.fixture(autouse=True)
def reset_env_vars():
# Store the original environment variables before the test
original_env = dict(os.environ)

# Run the test
yield

# After the test, restore the original environment
os.environ.clear()
os.environ.update(original_env)


@pytest.fixture(scope="session")
def test_data_dir():
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ def test_log_dir_overrides(self, monkeypatch, tmp_path):
assert Path(tmp_path).exists()
assert Path(tmp_path / "test_no_name" / "default" / "957").exists()

monkeypatch.delenv(NEMO_ENV_VARNAME_VERSION)
monkeypatch.delenv(NEMO_ENV_VARNAME_VERSION, raising=False)
# Checks that use_datetime_version False toggle works
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
log_dir = exp_manager(test_trainer, {"exp_dir": str(tmp_path / "test_no_name"), "use_datetime_version": False})
assert log_dir.resolve() == (tmp_path / "test_no_name" / "default" / "version_0").resolve()
assert Path(tmp_path).exists()
assert Path(tmp_path / "test_no_name" / "default" / "version_0").exists()

monkeypatch.delenv(NEMO_ENV_VARNAME_VERSION)
monkeypatch.delenv(NEMO_ENV_VARNAME_VERSION, raising=False)
# Checks that use_datetime_version False toggle works and version increments
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
log_dir = exp_manager(test_trainer, {"exp_dir": str(tmp_path / "test_no_name"), "use_datetime_version": False})
Expand Down
10 changes: 5 additions & 5 deletions tests/lightning/test_nemo_resume_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional

import pytest


def set_env():
os.environ['NVTE_FLASH_ATTN'] = '0'
os.environ['NVTE_FUSED_ATTN'] = '0'
os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '0'


Expand All @@ -28,6 +27,7 @@ def set_env():
import pytest
import torch
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.enums import AttnBackend

import nemo.lightning as nl
from nemo.collections import llm
Expand Down Expand Up @@ -68,7 +68,8 @@ def load_dcp(ckpt_dir, torch_tensor=True):
return state_dict


def compare_ckpts(a, b, path=[]):
def compare_ckpts(a, b, path: Optional[List[str]] = None):
path = path if path is not None else []
if isinstance(a, dict):
assert isinstance(b, dict)
assert set(a.keys()) == set(b.keys())
Expand Down Expand Up @@ -125,6 +126,7 @@ def setup_model_optim(log_dir, n_steps, tokenizer, gbs=2, mbs=1):
make_vocab_size_divisible_by=128,
normalization='RMSNorm',
masked_softmax_fusion=False,
attention_backend=AttnBackend.local,
)

model = llm.GPTModel(gpt_config, tokenizer=tokenizer)
Expand Down Expand Up @@ -269,8 +271,6 @@ def train(n_steps, resume):
trainer._teardown()

set_env()
assert os.environ['NVTE_FLASH_ATTN'] == '0'
assert os.environ['NVTE_FUSED_ATTN'] == '0'
assert os.environ['NVTE_APPLY_QK_LAYER_SCALING'] == '0'

# Train for 40 steps
Expand Down
2 changes: 0 additions & 2 deletions tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@
" \"NCCL_NVLS_ENABLE\": \"0\",\n",
" \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n",
" \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n",
" \"NVTE_FUSED_ATTN\": \"0\",\n",
" }\n",
"\n",
" executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n",
Expand Down Expand Up @@ -457,7 +456,6 @@
" \"NCCL_NVLS_ENABLE\": \"0\",\n",
" \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n",
" \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n",
" \"NVTE_FUSED_ATTN\": \"0\",\n",
" }\n",
"\n",
" executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n",
Expand Down
Loading
Loading