Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm performance scripts #11736

Merged
merged 37 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
687f58f
finetuning llama3 8b
malay-nagda Jan 2, 2025
95f1809
llama3 70b
malay-nagda Jan 7, 2025
c5c42cc
Apply isort and black reformatting
malay-nagda Jan 7, 2025
4cde0fc
peft and slurm functional
malay-nagda Jan 8, 2025
372d376
Apply isort and black reformatting
malay-nagda Jan 8, 2025
423d7b2
Merge branch 'main' into malay/llama3_finetuning
malay-nagda Jan 8, 2025
92c381e
formatting and cleanup
malay-nagda Jan 9, 2025
1bc19a1
Apply isort and black reformatting
malay-nagda Jan 9, 2025
916d220
Merge branch 'main' into malay/llama3_finetuning
malay-nagda Jan 9, 2025
cf3bc02
405b lora + more cleanup
malay-nagda Jan 9, 2025
fcbe667
Apply isort and black reformatting
malay-nagda Jan 9, 2025
46c89bd
no tp comm, import ckpt, data filename
malay-nagda Jan 10, 2025
4df3bda
Apply isort and black reformatting
malay-nagda Jan 10, 2025
4a59854
renamed files
malay-nagda Jan 10, 2025
c6d3c82
tp comm
malay-nagda Jan 10, 2025
c6c044a
mpi tp comm
malay-nagda Jan 12, 2025
bc65fe4
nemotron recipes
malay-nagda Jan 12, 2025
d108b37
Apply isort and black reformatting
malay-nagda Jan 12, 2025
ec7f073
Merge branch 'main' into malay/llama3_finetuning
malay-nagda Jan 12, 2025
d45209b
formatting, cleanup & nemotron tokenizer
malay-nagda Jan 13, 2025
7ddc1ac
Apply isort and black reformatting
malay-nagda Jan 13, 2025
35e60ed
supported tokenizers
malay-nagda Jan 13, 2025
afb329a
cleanup
malay-nagda Jan 13, 2025
5984bb6
Apply isort and black reformatting
malay-nagda Jan 13, 2025
c8ffcc1
tp and pp related cfgs
malay-nagda Jan 13, 2025
bb72366
formatting
malay-nagda Jan 13, 2025
6022e64
340b fused attn
malay-nagda Jan 13, 2025
dc8819c
Apply isort and black reformatting
malay-nagda Jan 13, 2025
dc1e8ee
conditional nccl_pp_comm_chunksize
malay-nagda Jan 13, 2025
10145fe
null tokenizer
malay-nagda Jan 13, 2025
467ba30
logs msgs
malay-nagda Jan 15, 2025
c0d291b
logs msgs
malay-nagda Jan 15, 2025
a91071c
temp mem mesaurement
malay-nagda Jan 15, 2025
c1a25db
mem usage, 8b tp4
malay-nagda Jan 18, 2025
c80f80a
nccl backend
malay-nagda Jan 22, 2025
35f4a92
Merge branch 'main' into malay/llama3_finetuning
erhoo82 Jan 24, 2025
50008ad
Merge branch 'main' into malay/llama3_finetuning
erhoo82 Jan 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions scripts/llm/performance/finetuning_llama3_70b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from typing import Optional

import nemo_run as run
from nemo_run.config import NEMORUN_HOME
from utils import (
get_comm_overlap_callback_idx,
hf_tokenizer,
import_ckpt_experiment,
isfile_train_pack_metadata,
parse_cli_args,
slurm_executor,
)

from nemo.collections.llm.recipes.llama3_70b import finetune_recipe, model
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
from nemo.utils import logging

NUM_NODES = 1
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 32
TP_SIZE = 2
PP_SIZE = 4
CP_SIZE = 1
VP_SIZE = 20
MAX_STEPS = 100

HF_MODEL_URI = "meta-llama/Meta-Llama-3-70B"


def llama3_70b_performance_recipe(
finetuning_scheme: str,
compute_dtype: str,
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
vp_size: Optional[int],
max_steps: int,
):
"""
llama3 70b pre-train recipe aimed at achieving best possible performance.

NOTE: Use fp8 precision training with caution. It might not give desirable results.
"""
finetuning_scheme = "none" if finetuning_scheme == "sft" else finetuning_scheme
recipe = finetune_recipe(peft_scheme=finetuning_scheme, performance_mode=True)

# data module configs
recipe.data.micro_batch_size = mbs
recipe.data.global_batch_size = gbs
recipe.data.tokenizer = hf_tokenizer(HF_MODEL_URI)
if not isfile_train_pack_metadata(HF_MODEL_URI, recipe.data):
recipe.data.force_redownload = True

recipe.trainer.max_steps = max_steps
recipe.trainer.num_nodes = num_nodes
recipe.trainer.devices = num_gpus_per_node

# parallelism configs
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
recipe.trainer.strategy.context_parallel_size = cp_size
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
if tp_size > 1:
recipe.trainer.strategy.sequence_parallel = True
else:
recipe.trainer.strategy.sequence_parallel = False

comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)

# compute dtype configs
if compute_dtype.lower() == "fp8":
recipe.trainer.plugins = bf16_with_fp8_mixed()
recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype

# callback configs
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

# Misc. for overall faster experiment runtime
recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval = max_steps
recipe.trainer.log_every_n_steps = 1

return recipe


if __name__ == "__main__":
args = parse_cli_args().parse_args()
if args.log_dir != NEMORUN_HOME:
logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
sys.exit(1)
if args.nemo_home and args.nemo_home != os.getenv("NEMO_HOME"):
logging.error(f"Run `export NEMO_HOME={args.nemo_home}` in your shell environment and rerun this script.")
sys.exit(1)

exp_name = "_".join(
[
args.finetuning.lower(),
f"llama3_70b",
args.compute_dtype,
f"{NUM_NODES}nodes",
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
]
)

executor = slurm_executor(
args.account,
args.partition,
args.log_dir,
NUM_NODES,
NUM_GPUS_PER_NODE,
args.time_limit,
args.container_image,
custom_mounts=[],
custom_env_vars={
"NVTE_FUSED_ATTN": "0",
"NVTE_FLASH_ATTN": "1",
},
hf_token=args.hf_token,
nemo_home=args.nemo_home,
)

recipe = llama3_70b_performance_recipe(
args.finetuning.lower(),
args.compute_dtype,
NUM_NODES,
NUM_GPUS_PER_NODE,
MICRO_BATCH_SIZE,
GLOBAL_BATCH_SIZE,
TP_SIZE,
PP_SIZE,
CP_SIZE,
VP_SIZE,
MAX_STEPS,
)

if not args.tensorboard: # tensorboard adds performance overhead.
recipe.log.tensorboard = None
recipe.trainer.logger = False
else:
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
recipe.log.log_dir = "/nemo_run/lightning_logs"

plugins = [PerfEnvPlugin(enable_vboost=True)]
if args.enable_profiling:
plugins.append(NsysPlugin(start_step=5, end_step=6))

with run.Experiment(exp_name) as exp:
exp.add(*import_ckpt_experiment(NUM_NODES, executor, model(), source=f"hf://{HF_MODEL_URI}"))
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=plugins,
)

if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()
187 changes: 187 additions & 0 deletions scripts/llm/performance/finetuning_llama3_8b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from typing import Optional

import nemo_run as run
from nemo_run.config import NEMORUN_HOME
from utils import (
get_comm_overlap_callback_idx,
hf_tokenizer,
import_ckpt_experiment,
isfile_train_pack_metadata,
parse_cli_args,
slurm_executor,
)

from nemo.collections.llm.recipes.llama3_8b import finetune_recipe, model
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed
from nemo.lightning.run.plugins import NsysPlugin, PerfEnvPlugin
from nemo.utils import logging

NUM_NODES = 1
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 32
TP_SIZE = 1
PP_SIZE = 1
CP_SIZE = 1
VP_SIZE = None
MAX_STEPS = 100

HF_MODEL_URI = "meta-llama/Meta-Llama-3-8B"


def llama3_8b_performance_recipe(
finetuning_scheme: str,
compute_dtype: str,
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
vp_size: Optional[int],
max_steps: int,
):
"""
llama3 8b pre-train recipe aimed at achieving best possible performance.

NOTE: Use fp8 precision training with caution. It might not give desirable results.
"""
finetuning_scheme = "none" if finetuning_scheme == "sft" else finetuning_scheme
recipe = finetune_recipe(peft_scheme=finetuning_scheme, performance_mode=True)

# data module configs
recipe.data.micro_batch_size = mbs
recipe.data.global_batch_size = gbs
recipe.data.tokenizer = hf_tokenizer(HF_MODEL_URI)
if not isfile_train_pack_metadata(HF_MODEL_URI, recipe.data):
recipe.data.force_redownload = True

recipe.trainer.max_steps = max_steps
recipe.trainer.num_nodes = num_nodes
recipe.trainer.devices = num_gpus_per_node

# parallelism configs
recipe.trainer.strategy.tensor_model_parallel_size = tp_size
recipe.trainer.strategy.pipeline_model_parallel_size = pp_size
recipe.trainer.strategy.context_parallel_size = cp_size
recipe.trainer.strategy.virtual_pipeline_model_parallel_size = vp_size
if tp_size > 1:
recipe.trainer.strategy.sequence_parallel = True
else:
recipe.trainer.strategy.sequence_parallel = False

comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)

# compute dtype configs
if compute_dtype.lower() == "fp8":
recipe.trainer.plugins = bf16_with_fp8_mixed()
recipe.trainer.plugins.grad_reduce_in_fp32 = False # bf16 grad dtype

# callback configs
dp_size = (num_nodes * num_gpus_per_node) / (tp_size * pp_size * cp_size)
if dp_size > 1 and pp_size > 1 and vp_size and vp_size > 1:
if comm_overlap_callback_idx >= 0:
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather_with_optimizer_step = True

# Misc. for overall faster experiment runtime
recipe.log.ckpt = None
recipe.trainer.enable_checkpointing = False
recipe.trainer.val_check_interval = max_steps
recipe.trainer.log_every_n_steps = 1

return recipe


if __name__ == "__main__":
args = parse_cli_args().parse_args()
if args.log_dir != NEMORUN_HOME:
logging.error(f"Run `export NEMORUN_HOME={args.log_dir}` in your shell environment and rerun this script.")
sys.exit(1)
if args.nemo_home and args.nemo_home != os.getenv("NEMO_HOME"):
logging.error(f"Run `export NEMO_HOME={args.nemo_home}` in your shell environment and rerun this script.")
sys.exit(1)

exp_name = "_".join(
[
args.finetuning.lower(),
f"llama3_8b",
args.compute_dtype,
f"{NUM_NODES}nodes",
f"tp{TP_SIZE}_pp{PP_SIZE}_cp{CP_SIZE}_vp{VP_SIZE}",
f"{MICRO_BATCH_SIZE}mbs_{GLOBAL_BATCH_SIZE}gbs",
]
)

executor = slurm_executor(
args.account,
args.partition,
args.log_dir,
NUM_NODES,
NUM_GPUS_PER_NODE,
args.time_limit,
args.container_image,
custom_mounts=[],
custom_env_vars={
"NVTE_FUSED_ATTN": "0",
"NVTE_FLASH_ATTN": "1",
},
hf_token=args.hf_token,
nemo_home=args.nemo_home,
)

recipe = llama3_8b_performance_recipe(
args.finetuning.lower(),
args.compute_dtype,
NUM_NODES,
NUM_GPUS_PER_NODE,
MICRO_BATCH_SIZE,
GLOBAL_BATCH_SIZE,
TP_SIZE,
PP_SIZE,
CP_SIZE,
VP_SIZE,
MAX_STEPS,
)

if not args.tensorboard: # tensorboard adds performance overhead.
recipe.log.tensorboard = None
recipe.trainer.logger = False
else:
# default path is NOT intuitive- `<log_dir>/code/nemo_experiments/tb_logs/default/<tfevents_file>`
# following line ensures file is at- `<log_dir>/lightning_logs/tb_logs/default/<tfevents_file>`
recipe.log.log_dir = "/nemo_run/lightning_logs"

plugins = [PerfEnvPlugin(enable_vboost=True)]
if args.enable_profiling:
plugins.append(NsysPlugin(start_step=5, end_step=6))

with run.Experiment(exp_name) as exp:
exp.add(*import_ckpt_experiment(NUM_NODES, executor, model(), source=f"hf://{HF_MODEL_URI}"))
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=plugins,
)

if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()
Loading
Loading