Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-mcore fsdp2 strategy #11525

Merged
merged 135 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
3338e06
Add fsdp2 strategy
BoxiangW Dec 9, 2024
4c9b5df
Apply isort and black reformatting
BoxiangW Dec 9, 2024
11a4637
Add imports
BoxiangW Dec 10, 2024
5971cf4
Apply isort and black reformatting
BoxiangW Dec 10, 2024
3533b89
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Dec 10, 2024
7c30f82
Add init import
BoxiangW Dec 21, 2024
ef11d67
Apply isort and black reformatting
BoxiangW Dec 21, 2024
3d50f08
Fix mixtral export for NeMo 2.0 (#11532)
Laplasjan107 Dec 10, 2024
6ea60de
Make HFDatasetDataModule a datasets.load_dataset wrapper (#11500)
akoumpa Dec 11, 2024
b81df9e
ci: Bump release workflow (#11544)
ko3n1g Dec 11, 2024
dfbb87f
ci: Use SHA for cut-off (#11545)
ko3n1g Dec 11, 2024
1fe0310
link to mcore documentation (#11538)
ashors1 Dec 11, 2024
02c2cdf
ci: Adjust inputs for code-freeze workflow (#11550)
ko3n1g Dec 11, 2024
c37570a
ci: Bump release freeze (#11551)
ko3n1g Dec 11, 2024
37ee432
Ko3n1g/ci/commit sha for cutoff (#11553)
ko3n1g Dec 11, 2024
fc39d24
ci: Bump code-freeze workflow (#11554)
ko3n1g Dec 11, 2024
2aff616
ci: Bump code freeze workflow (#11557)
ko3n1g Dec 11, 2024
726e50a
Fix deploy conflicts in llm.api (#11367)
hemildesai Dec 11, 2024
f68208e
perf summary docs link (#11262)
malay-nagda Dec 11, 2024
867dd0c
Add vlm nemo run scripts (#11394)
yaoyu-33 Dec 11, 2024
3867529
Add from_dict to HFDatasetDataModule (#11559)
akoumpa Dec 12, 2024
8a44926
Prevent llama3.1 from using Linear interpolation (#11548)
suiyoubi Dec 12, 2024
c1bb950
[TTS] Add audio and mel codec HF models to docs (#11526)
rlangman Dec 12, 2024
4017f42
Update for NEST release (#11537)
stevehuang52 Dec 12, 2024
4b714b3
Merging SpeechLLM development branch (#11462)
pzelasko Dec 12, 2024
9c264b7
Sync validation metrics for ASRModel (#11533)
pzelasko Dec 12, 2024
4972cc3
NeMo 2.0 In-framework deployment support (#11523)
oyilmaz-nvidia Dec 12, 2024
73181b4
Add SFT/PEFT HF tests (#11519)
akoumpa Dec 12, 2024
e32ded1
Fix typo: LocalNonpersitentObject -> LocalNonpersistentObject (#11546)
ananthsub Dec 12, 2024
81729e3
Adding documentation for packed dataset preparation with context para…
tomlifu Dec 12, 2024
729d2ee
have micro_batch_size and global_batch_size as class attributes in mo…
yashaswikarnati Dec 12, 2024
bd6d6ff
Revert "Fix the names of two sets of weight and bias in mcore_to_nemo…
ashors1 Dec 13, 2024
820b3ec
add huggingface-based tokenizer support for mixtral HF -> .nemo (#11572)
dimapihtar Dec 13, 2024
ad1282e
Github Actions tests for Llava Next and modify pretrain recipe to hav…
yashaswikarnati Dec 13, 2024
f11585e
Fix SingleDeviceStrategy support in Nsys callback (#11574)
akoumpa Dec 13, 2024
fd4c302
remove dialogue scripts and docs (#11577)
dimapihtar Dec 13, 2024
fc809b8
add JitTransform (#11131)
akoumpa Dec 13, 2024
36ee6f3
NeMo 2.0 documentation upgrade (#11235)
dimapihtar Dec 13, 2024
a3c377c
Remove auto-import of lhotse when importing nemo.collections.common.d…
pzelasko Dec 13, 2024
9d3d36e
Fix example configs (#11571)
BoxiangW Dec 13, 2024
b8f9b0b
fix (#11575)
ko3n1g Dec 13, 2024
7bc32cd
NIM supporting changes for nemo.export for NeMo 2.0 (#11488)
janekl Dec 14, 2024
929d643
AED greedy confidence estimation (#11573)
GNroy Dec 14, 2024
82544df
gemma fix (#11587)
suiyoubi Dec 14, 2024
3483208
Update T5 DataModule regarding Pretrain/Finetune validate (#11584)
huvunvidia Dec 14, 2024
6f084dd
fix llama3 (#11580)
suiyoubi Dec 15, 2024
a3eb280
Add Hf nemorun tests (#11566)
HuiyingLi Dec 16, 2024
b98160c
[🤖]: Howdy folks, let's bump NeMo-Toolkit to `2.2.0rc0` ! (#11555)
github-actions[bot] Dec 16, 2024
711176a
Pass the number of experts to modelopt layer spec (#11607)
janekl Dec 16, 2024
7210212
Adding changes to asr documentation (#11397)
Ssofja Dec 16, 2024
86393b5
Support Cosmos tokenizer TensorRT inference (#11472)
meatybobby Dec 16, 2024
5298ce3
Neva updates to latest mcore and some fixes (#11565)
yaoyu-33 Dec 16, 2024
08bf53c
add nemo2-sft-peft to readme (#11613)
HuiyingLi Dec 16, 2024
75bc074
Set Minitron width pruning batch size 1 (#11603)
kevalmorabia97 Dec 16, 2024
69d84cc
Disable CP for running Inference using megatron_gpt_eval (#11547)
suiyoubi Dec 16, 2024
0b14618
ci: Add `no-fail-fast` mode (#11608)
ko3n1g Dec 16, 2024
b975aaa
Chat dataset support (#11423)
cuichenx Dec 16, 2024
aa7a4e1
Sortformer Diarizer 4spk v1 model PR Part 2: Unit-tests for Sortforme…
tango4j Dec 16, 2024
4256580
2x more memory efficient Graph-based RNN-T (#11169)
artbataev Dec 17, 2024
2b6100d
Use explicit subpaths in io for exporting a checkpoint (#11352)
hemildesai Dec 17, 2024
1da9632
Remove triton requirement (#11627)
thomasdhc Dec 17, 2024
2eb897b
ci: Remove comment if no changes required anymore (#11624)
ko3n1g Dec 17, 2024
e4afd2d
Jit with peft (#11586)
akoumpa Dec 17, 2024
44689b1
NeMo-UX: add Hf's AutoModelForImageTextToText (#11321)
akoumpa Dec 17, 2024
ff568cd
ci: Bump release workflow (#11635)
ko3n1g Dec 17, 2024
b9457db
Add fix docstring for speech commands (#11638)
titu1994 Dec 18, 2024
53b8eb4
Fixing Multi_Task_Adapters.ipynb by replacing canary2 with canary_cus…
weiqingw4ng Dec 18, 2024
8829106
fixed config name in online augmentation tutorial (#11628)
nasretdinovr Dec 18, 2024
faa04ed
fix default nodes (#11632)
suiyoubi Dec 18, 2024
276c075
add renormalize_blend_weights param (#11647)
dimapihtar Dec 18, 2024
368ed62
Sortformer Diarizer 4spk v1 model PR Part 3: Speaker Diarization Mixi…
tango4j Dec 18, 2024
2b3b158
Fix peft inference (#11568)
cuichenx Dec 18, 2024
45f2a4c
Add fix docstring for speech commands (#11659)
titu1994 Dec 19, 2024
90f6fb7
update nemo container version for notebooks (#11651)
HuiyingLi Dec 19, 2024
18448b9
Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT …
suiyoubi Dec 19, 2024
ccb3dc1
Add vlm generation function (#11063)
meatybobby Dec 19, 2024
db0a2d0
ci: Small pylint fix (#11667)
ko3n1g Dec 19, 2024
53c64ed
Add slimpajama example (#10671)
hemildesai Dec 19, 2024
5cebc85
Remove NeMo 1 docs (#11670)
cuichenx Dec 19, 2024
c872739
Change peft merged model to bf16 (#11663)
HuiyingLi Dec 19, 2024
72634f3
Add Minitron depth pruning (layer dropping) to megatron_gpt_prune.py …
kevalmorabia97 Dec 19, 2024
6dced1d
add documentation for checkpoint averaging (#11594)
dimapihtar Dec 20, 2024
8078cd4
Downgrading the 'datasets' package from 3.0.0 to 2.21.0 for Multilang…
weiqingw4ng Dec 20, 2024
054bd46
Utilities to detect and drop deprecated arguments from NeMo 2.0 check…
janekl Dec 20, 2024
fc54cee
NIM supporting changes for nemo.export for NeMo 2.0 (part II) (#11669)
janekl Dec 20, 2024
3c9c3f6
Add check for symlink in _safe_extract (#11611)
athitten Dec 20, 2024
86c0f1a
Fix baichuan export (#11640)
cuichenx Dec 20, 2024
3a8e75d
Rename multimodal data module - EnergonMultiModalDataModule (#11654)
yashaswikarnati Dec 20, 2024
e0b14e7
ci: Bump release workflow (#11686)
ko3n1g Dec 20, 2024
6e3cccf
Fixing the device assignment issues during inference (test_batch) in …
tango4j Dec 20, 2024
b4aecaf
add timestamp support (#11591)
kevinhu-nv Dec 20, 2024
099bc80
Make LinearAdapter a nn.Linear child to maintain ckpt structure (#11642)
akoumpa Dec 20, 2024
da10109
Add multi images support for mllama generate (#11672)
meatybobby Dec 20, 2024
597b387
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Dec 23, 2024
7682cf2
Fix merge
BoxiangW Dec 23, 2024
0e694f7
Fix non-mcore fsdp2
BoxiangW Dec 24, 2024
e50be32
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Dec 30, 2024
a2a67ef
Add run code
BoxiangW Dec 30, 2024
446c888
Apply isort and black reformatting
BoxiangW Dec 30, 2024
4782c7e
Add hooks
BoxiangW Dec 31, 2024
d53ab15
Apply isort and black reformatting
BoxiangW Dec 31, 2024
a1f08a2
Add test
BoxiangW Jan 2, 2025
837035f
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Jan 2, 2025
ae1104e
Apply isort and black reformatting
BoxiangW Jan 2, 2025
a31f367
Add fsdp2 support on hf_auto_model_for_causal_lm
BoxiangW Jan 3, 2025
b5014b8
REvert model hooks for fsdp2 sharding
BoxiangW Jan 3, 2025
1d9257a
REvert some test changes
BoxiangW Jan 3, 2025
27ffb82
REvert tesst
BoxiangW Jan 3, 2025
346554d
Add line
BoxiangW Jan 3, 2025
7cd8186
Fix test
BoxiangW Jan 3, 2025
0a10868
fix test
BoxiangW Jan 3, 2025
03110a9
Fix test
BoxiangW Jan 3, 2025
bb52f17
Apply isort and black reformatting
BoxiangW Jan 3, 2025
648d830
Add CI test
BoxiangW Jan 3, 2025
752ff1a
Revert test change
BoxiangW Jan 3, 2025
6650a1f
Fix test
BoxiangW Jan 3, 2025
6edd7fc
Fix test
BoxiangW Jan 3, 2025
6f4cb64
Apply isort and black reformatting
BoxiangW Jan 3, 2025
7e18cfe
Add check for parallel
BoxiangW Jan 4, 2025
70b9f7c
Move function into nl.strategy
BoxiangW Jan 6, 2025
94d1f90
Apply isort and black reformatting
BoxiangW Jan 6, 2025
1e18b9e
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Jan 6, 2025
c175ca0
Add tests
BoxiangW Jan 6, 2025
adeeffd
Remove test
BoxiangW Jan 6, 2025
6dc57c4
Apply isort and black reformatting
BoxiangW Jan 6, 2025
9c9b972
Add test
BoxiangW Jan 7, 2025
ff2c54c
Add fsdp2 ci test with memory check
BoxiangW Jan 7, 2025
db85277
Apply isort and black reformatting
BoxiangW Jan 7, 2025
ca5ffe4
Fix import
BoxiangW Jan 7, 2025
42f4ee8
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Jan 7, 2025
1430226
Merge branch 'main' into boxiangw/non-mcore-fsdp2
BoxiangW Jan 7, 2025
b78205a
fix test
BoxiangW Jan 7, 2025
1e069c3
Apply isort and black reformatting
BoxiangW Jan 7, 2025
81aeb50
Add copyright
BoxiangW Jan 7, 2025
d8e7247
include test list
akoumpa Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3700,6 +3700,17 @@ jobs:
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp
AFTER_SCRIPT: |
rm -rf nemo_experiments

L2_HF_Transformer_SFT_FSDP2_2gpu:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_FSDP2_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_fsdp2.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2
AFTER_SCRIPT: |
rm -rf nemo_experiments

L2_HF_Transformer_PT_2gpu:
needs: [ cicd-test-container-setup ]
Expand All @@ -3722,6 +3733,17 @@ jobs:
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_nemorun.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp
AFTER_SCRIPT: |
rm -rf nemo_experiments

L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_nemorun_fsdp2.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2
AFTER_SCRIPT: |
rm -rf nemo_experiments

L2_HF_Transformer_PT_2gpu_nemorun:
needs: [ cicd-test-container-setup ]
Expand Down Expand Up @@ -5047,6 +5069,8 @@ jobs:
- L2_NeMo_2_PTQ_Llama2_FP8
- L2_NeMo_2_jit_callback
- L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING
- L2_HF_Transformer_SFT_FSDP2_2gpu
- L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm import fn
from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
from nemo.utils import logging


Expand Down Expand Up @@ -91,6 +92,10 @@ def configure_model(self):
config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code
)

# Apply FSDP2 and TP to the model
if self.device_mesh is not None:
fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh)

if self.model_accelerator is not None:
self.model_accelerator(self.model)

Expand All @@ -99,7 +104,7 @@ def configure_model(self):
def forward(self, batch):
return self.model(**batch)

def training_step(self, batch):
def training_step(self, batch, batch_idx=None):
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)

Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, lr_scheduler
from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy
from nemo.lightning.pytorch.strategies import FSDP2Strategy, FSDPStrategy, MegatronStrategy
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop
from nemo.lightning.resume import AutoResume
Expand Down Expand Up @@ -60,6 +60,7 @@ def _is_slurm_interactive_mode():
"MegatronMixedPrecision",
"MegatronOptimizerModule",
"FSDPStrategy",
"FSDP2Strategy",
"RestoreConfig",
"lr_scheduler",
"NeMoLogger",
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.lightning.pytorch.strategies.fsdp2_strategy import FSDP2Strategy
from nemo.lightning.pytorch.strategies.fsdp_strategy import FSDPStrategy
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy


__all__ = [
"FSDPStrategy",
"FSDP2Strategy",
"MegatronStrategy",
]
276 changes: 276 additions & 0 deletions nemo/lightning/pytorch/strategies/fsdp2_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union

import lightning.pytorch as pl
import torch
from lightning.fabric.plugins import CheckpointIO
from lightning.fabric.strategies.fsdp import _get_sharded_state_dict_context
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy as PLModelParallelStrategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.distributed.checkpoint.state_dict import ( # get_state_dict,
StateDictOptions,
get_optimizer_state_dict,
set_state_dict,
)
from torch.utils.data import DataLoader
from typing_extensions import override

from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import (
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
init_model_parallel,
mcore_to_pyt_sharded_state_dict,
pyt_to_mcore_state_dict,
setup_data_sampler,
setup_parallel_ranks,
)


class FSDP2Strategy(PLModelParallelStrategy, io.IOMixin):
"""Megatron plugin for Pytorch Lightning.

This strategy implements FSDP 2 using PyTorch's native FSDP 2 methods. Comparing with
MegatronStrategy, FSDP2Strategy is designed to be more lightweight, with minimal
modifications over Lightning's ModelParallelStrategy which supports FSDP2 + TP
parallelization but preserves necessary features to be compatible with nemo and mcore.
By default, this strategy wraps FSDP2 per TransformerLayer.

Note:
This strategy is designed to work with NVIDIA's Megatron-LM framework and requires
specific model implementations that are compatible with Megatron's parallelism techniques.
Note:
Due to the different optimizer structure (FSDP2 only uses torch native optimizers),
MegatronStrategy cannot resume training from checkpoints saved by FSDP2Strategy, and vice
versa. However, the model weights structure is made compatible, so switching strategy is
possible if users only need the weights not the optimizer states. (E.g. run pretrain with
megatron 4D parallelism and run SFT with FSDP2.)
"""

def __init__(
self,
data_parallel_size: Union[Literal["auto"], int] = "auto",
tensor_parallel_size: Union[Literal["auto"], int] = "auto",
ckpt_load_optimizer: bool = True,
ckpt_save_optimizer: bool = True,
data_sampler=None,
**kwargs,
):
super().__init__(data_parallel_size=data_parallel_size, tensor_parallel_size=tensor_parallel_size, **kwargs)

self.data_sampler = data_sampler
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer

@override
def setup_environment(self) -> None:
setup_parallel_ranks(self)
super().setup_environment()
init_model_parallel(self.model)

@override
def setup(self, trainer: pl.Trainer) -> None:
self.trainer = trainer
setup_data_sampler(self.trainer)
fix_progress_bar(trainer)
super().setup(trainer)

def _get_loss_reduction(self, step_type: str):
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(self.lightning_module, fn_name):
return getattr(self.lightning_module, fn_name)
return None

def _step_proxy(self, step_type, batch, batch_idx=None):
method_name = f"{step_type}_step"
if self.model != self.lightning_module:
loss = self._forward_redirection(self.model, self.lightning_module, method_name, batch, batch_idx)
else:
loss = getattr(self.lightning_module, method_name)(batch, batch_idx)

_loss_reduction = self._get_loss_reduction(step_type)
if _loss_reduction:
return _loss_reduction.forward(batch, loss)
return loss, {'avg': loss}

@override
def training_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.train_step_context():
loss, reduced = self._step_proxy("training", batch, batch_idx)

self.lightning_module.log(
'global_step',
self.trainer.global_step,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

self.lightning_module.log(
'step',
self.trainer.global_step,
)
self.lightning_module.log(
'reduced_train_loss', reduced['avg'], prog_bar=True, rank_zero_only=True, batch_size=1
)

# returns unreduced loss for backward
return loss

@override
def validation_step(self, batch, batch_idx=None) -> Any:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.val_step_context():
loss, reduced = self._step_proxy("validation", batch, batch_idx)
self.lightning_module.log('val_loss', reduced['avg'], rank_zero_only=True, batch_size=1)
return loss

@override
def test_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.test_step_context():
loss, reduced = self._step_proxy("test", batch, batch_idx)
self.lightning_module.log('test_loss', reduced['avg'], rank_zero_only=True, batch_size=1)

return loss

@override
def predict_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.predict_step_context():
loss, reduced = self._step_proxy("predict", batch, batch_idx)
return reduced

@override
def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
if self.data_sampler:
return self.data_sampler.transform_dataloader(dataloader)

return dataloader

@property
@override
def checkpoint_io(self) -> CheckpointIO:
if not self._checkpoint_io:
self._checkpoint_io = create_checkpoint_io()

return self._checkpoint_io

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
self._checkpoint_io = io

@property
def current_epoch_step(self) -> int:
"""
Get the value of step within an epoch.
"""
return max(
self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed,
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed,
)

@override
def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
# Taken from MegatronStrategy
ckpt = ckpt_to_dir(filepath)
if self.is_global_zero:
if os.path.islink(ckpt):
os.unlink(ckpt)
else:
shutil.rmtree(ckpt)

@override
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
"""Converts PyT checkpoints to MCore format and save using MCore dist ckpt library."""
checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(
checkpoint.pop("state_dict"), device_mesh=self.device_mesh
)
checkpoint["state_dict"] = OrderedDict([])

if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING:
# Clear the optimizer states. This handles the case where ckpt_save_optimizer=False
# Ideally, the optimizer state dicts should not be generated in this case
checkpoint["optimizer_states"] = {}

## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
## the checkpoint will contain only model weights. Optimizer states will be omitted.
if self.ckpt_save_optimizer:
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(
checkpoint['optimizer']['state'], prefix="optimizer.state.", device_mesh=self.device_mesh
)

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)

@override
def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for FSDP models.
Different from MegatronStrategy, both model and optimizer states are restore within
this method.

The logic here is slightly more complicated:
1. Obtain PyT state dicts (sharded & unflattened) for model and optim -> torch::ShardedTensor
2. Convert to MCore state dicts -> mcore::ShardedTensor
3. Load from checkpoint using MCore dist ckpt API -> torch::Tensor
4. Convert to PyT state dicts (sharded & unflattened) -> torch::ShardedTensor
5. Load into model and optim using PyT dist ckpt API
6. Return the loaded checkpoint for lightning to load other metadata
"""
path = Path(self.broadcast(checkpoint_path))
torch.cuda.empty_cache()

# TODO: the elegant way to load both state dicts. Need pytorch 2.3.1
# msd, osd = get_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True))
sharded_state_dict = {}
with _get_sharded_state_dict_context(self.model):
msd = self.model.state_dict()
pyt_to_mcore_state_dict(msd, device_mesh=self.device_mesh)
sharded_state_dict["sharded_state_dict"] = msd

if self.ckpt_load_optimizer and self.trainer.state.fn == TrainerFn.FITTING:
osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True))
pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.", device_mesh=self.device_mesh)
sharded_state_dict["optimizer"] = osd

checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict)
mcore_to_pyt_sharded_state_dict(checkpoint['sharded_state_dict'], msd)

if self.ckpt_load_optimizer and self.trainer.state.fn == TrainerFn.FITTING:
mcore_to_pyt_sharded_state_dict(checkpoint['optimizer']['state'], osd['state'])

set_state_dict(
self.model,
self.optimizers if self.ckpt_load_optimizer else [],
model_state_dict=checkpoint['sharded_state_dict'],
optim_state_dict=checkpoint['optimizer'] if self.ckpt_load_optimizer else None,
)

return checkpoint
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
across GPU ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to
improve performance by reducing the pipeline bubble. Defaults to None.
microbatch_group_size_per_vp_stageOptional[int]: the number of micro-batches that are executed
microbatch_group_size_per_vp_stage (Optional[int]): the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward). Defaults to None and convert
to pipeline_parallel_size. which specifies a depth-first schedule.
context_parallel_size (int): Splits network input along sequence dimension across GPU ranks.
Expand Down
Loading
Loading