Skip to content

Commit

Permalink
Merge branch 'main' into athitten/fix_output_generation_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
athitten authored Jan 10, 2025
2 parents 25241e0 + 1ab22d1 commit 5f01584
Show file tree
Hide file tree
Showing 16 changed files with 254 additions and 54 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4871,6 +4871,35 @@ jobs:
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine
L2_NeMo_2_Export_In_Framework:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Export_In_Framework') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/test_hf_import.py \
--hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf \
--output_path /tmp/nemo2_ckpt
python tests/setup/data/create_sample_lambada.py \
--output_file /tmp/lambada.json
python tests/export/nemo_export.py \
--model_name test \
--model_type llama \
--checkpoint_dir /tmp/nemo2_ckpt \
--min_tps 1 \
--in_framework True \
--test_deployment True \
--run_accuracy True \
--test_data_path /tmp/lambada.json \
--accuracy_threshold 0.0 \
--debug
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_ckpt /tmp/lambada.json
L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5068,6 +5097,7 @@ jobs:
- L2_Megatron_GPT_Reranker
- L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact
- L2_NeMo_2_PTQ_Llama2_FP8
- L2_NeMo_2_Export_In_Framework
- L2_NeMo_2_jit_callback
- L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING
- L2_HF_Transformer_SFT_FSDP2_2gpu
Expand Down
7 changes: 1 addition & 6 deletions nemo/collections/common/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,9 @@ def extend_instance(obj, mixin):
) # mixin needs to go first for our forward() logic to work


def apply_rope_scaling(freqs):
def apply_rope_scaling(freqs, scale_factor=8, low_freq_factor=1, high_freq_factor=4, old_context_len=8192):
# Apply scaling for RoPE frequencies
logger.info("apply rope scaling ...")
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class MixtralConfig8x22B(MixtralConfig):
hidden_size: int = 6144
num_attention_heads: int = 48
ffn_hidden_size: int = 16384
max_position_embeddings: int = 4096
seq_length: int = 4096
max_position_embeddings: int = 65536
seq_length: int = 65536


class MixtralModel(GPTModel):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def pretrain_recipe(
trainer=trainer(
num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, callbacks=[run.Config(TimingCallback)]
),
data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1),
data=run.Config(MockDataModule, seq_length=65536, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,13 @@ def mcore_model_customize(cfg, model):
if cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage():
extend_instance(model.embedding, EmbeddingScalingMixin)
if cfg.get("scale_positional_embedding", False):
model.rotary_pos_emb.inv_freq = apply_rope_scaling(model.rotary_pos_emb.inv_freq)
model.rotary_pos_emb.inv_freq = apply_rope_scaling(
model.rotary_pos_emb.inv_freq,
scale_factor=cfg.get('scale_factor', 8),
low_freq_factor=cfg.get('low_freq_factor', 1),
high_freq_factor=cfg.get('high_freq_factor', 4),
old_context_len=cfg.get('old_context_len', 8192),
)
if cfg.get("mcore_customization_config", {}).get("final_logit_softcapping", 0):
from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_modules import Gemma2OutputLayer

Expand Down
91 changes: 71 additions & 20 deletions nemo/deploy/nlp/megatronllm_deployable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import logging
from enum import IntEnum, auto
from pathlib import Path
from typing import List
from typing import List, Optional

import numpy as np
import torch
import torch.distributed
import wrapt
from lightning.pytorch.trainer.trainer import Trainer
from megatron.core.inference.common_inference_params import CommonInferenceParams
from pytorch_lightning.trainer.trainer import Trainer
from megatron.core.inference.inference_request import InferenceRequest

import nemo.lightning as nl
from nemo.collections.llm import inference
Expand Down Expand Up @@ -94,7 +95,7 @@ def GetNumpyDtype(pyvalue):


class ServerSync(IntEnum):
"""Enum for synchronization messages using torch.distributed"""
"""Enum for synchronization messages using torch.distributed."""

WAIT = auto()
SIGNAL = auto()
Expand All @@ -104,17 +105,35 @@ def to_long_tensor(self):


class MegatronLLMDeploy:
"""
A factory class for creating deployable instances of Megatron LLM models.
This class provides a method to get the appropriate deployable instance
based on the version of the NeMo checkpoint model used.
"""

@staticmethod
def get_deployable(
nemo_checkpoint_filepath: str = None,
nemo_checkpoint_filepath: str,
num_devices: int = 1,
num_nodes: int = 1,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
context_parallel_size: int = 1,
):

"""
Returns the appropriate deployable instance for the given NeMo checkpoint.
Args:
nemo_checkpoint_filepath (str): Path to the .nemo checkpoint file.
num_devices (int): Number of devices to use for deployment.
num_nodes (int): Number of nodes to use for deployment.
tensor_model_parallel_size (int): Size of the tensor model parallelism.
pipeline_model_parallel_size (int): Size of the pipeline model parallelism.
context_parallel_size (int): Size of the context parallelism.
Returns:
ITritonDeployable: An instance of a deployable class compatible with Triton inference server.
"""
if nemo_checkpoint_version(nemo_checkpoint_filepath) == NEMO2:
return MegatronLLMDeployableNemo2(
nemo_checkpoint_filepath=nemo_checkpoint_filepath,
Expand Down Expand Up @@ -178,6 +197,39 @@ def __init__(
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
)

def generate(
self,
prompts: List[str],
max_batch_size: int = 4,
inference_params: Optional[CommonInferenceParams] = None,
random_seed: Optional[int] = None,
) -> List[InferenceRequest]:
"""
Generates text based on the provided input prompts.
Args:
prompts (List[str]): A list of input strings.
max_batch_size (int): The maximum batch size used for inference.
inference_params (Optional[CommonInferenceParams]): Parameters for controlling the inference process.
random_seed (Optional[int]): A random seed for reproducibility.
Returns:
List[InferenceRequest]: A list containing the generated results.
"""
# TODO: This function doesn't account for parallelism settings currently

inference_params = inference_params or CommonInferenceParams()

results = inference.generate(
model=self.inference_wrapped_model,
tokenizer=self.mcore_tokenizer,
prompts=prompts,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
)
return list(results)

@property
def get_triton_input(self):
inputs = (
Expand Down Expand Up @@ -222,14 +274,7 @@ def triton_infer_fn(self, **inputs: np.ndarray):
return_log_probs=log_probs,
)

results = inference.generate(
model=self.inference_wrapped_model,
tokenizer=self.mcore_tokenizer,
prompts=prompts,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
)
results = self.generate(prompts, max_batch_size, inference_params, random_seed)

output_texts = [r.generated_text if text_only else r for r in results]
output_infer = {"sentences": cast_output(output_texts, np.bytes_)}
Expand Down Expand Up @@ -263,11 +308,14 @@ def __init__(
raise IMPORT_ERROR
if nemo_checkpoint_filepath is None and existing_model is None:
raise ValueError(
"MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None"
"MegatronLLMDeployable requires either a .nemo checkpoint filepath "
"or an existing MegatronGPTModel, but both provided were None."
)
if num_devices > 1:
LOGGER.warning(
"Creating a MegatronLLMDeployable with num_devices>1 will assume running with a PyTorch Lightning DDP-variant strategy, which will run the main script once per device. Make sure any user code is compatible with multiple executions!"
"Creating a MegatronLLMDeployable with num_devices > 1 will assume running with "
"a PyTorch Lightning DDP-variant strategy, which will run the main script once per device. "
"Make sure any user code is compatible with multiple executions!"
)

# if both existing_model and nemo_checkpoint_filepath are provided, existing_model will take precedence
Expand All @@ -292,14 +340,16 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices:
# transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled
if not custom_config.transformer_engine:
LOGGER.warning(
"MegatronLLMDeployable expects model config transformer_engine=True, but this model has it =False. "
"Overriding it to =True, but this may break certain checkpoints converted on older Nemo versions. "
"MegatronLLMDeployable expects model config transformer_engine=True, but this model uses False. "
"Overriding it to True, but this may break certain checkpoints converted on older Nemo versions. "
"If your model breaks, please try re-converting the checkpoint on the current Nemo version."
)
custom_config.transformer_engine = True
# using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination
# using multi-gpu for tensor parallelism directly for now,
# could do pipeline parallel instead or a combination
custom_config.tensor_model_parallel_size = num_devices
# had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py
# had to override these to make Nemotron3-22B work,
# see sample_sequence_batch() in text_generation_utils.py
custom_config.activations_checkpoint_granularity = None
custom_config.activations_checkpoint_method = None
# Models trained with TE < 1.10 and loaded with TE >= 1.10 require
Expand Down Expand Up @@ -398,7 +448,8 @@ def generate(self, inputs: List[str], length_params: LengthParam, sampling_param
distributed_rank = torch.distributed.get_rank()
if distributed_rank != 0:
raise ValueError(
f"Triton inference function should not be called on a thread with torch.distributed rank != 0, but this thread is rank {distributed_rank}"
"Triton inference function should not be called on a thread with "
f"torch.distributed rank != 0, but this thread is rank {distributed_rank}."
)
signal_value = ServerSync.SIGNAL.to_long_tensor()
torch.distributed.broadcast(signal_value, 0)
Expand Down
6 changes: 3 additions & 3 deletions nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import time
from abc import ABC, abstractmethod
from abc import ABC

import numpy as np

Expand Down Expand Up @@ -141,7 +141,7 @@ def query_llm(
"object": "text_completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [{"text": str(sentences)}],
"choices": [{"text": sentences}],
}
if log_probs_output is not None:
openai_response["log_probs"] = log_probs_output
Expand Down Expand Up @@ -297,7 +297,7 @@ def query_llm(
"object": "text_completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [{"text": str(sentences)}],
"choices": [{"text": sentences}],
}
if output_generation_logits:
openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"]
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class NsysPlugin(run.Plugin):
end_step: int
ranks: Optional[list[int]] = None
nsys_trace: Optional[list[str]] = None
gen_shape: bool = False

def setup(self, task: run.Partial | run.Script, executor: run.Executor):
if isinstance(task, run.Partial):
Expand All @@ -163,6 +164,7 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
start_step=self.start_step,
end_step=self.end_step,
ranks=self.ranks or [0],
gen_shape=self.gen_shape,
)
callbacks: list[run.Config[Callback]] = [nsys_callback] # type: ignore
_merge_callbacks(task, callbacks=callbacks)
Expand Down
6 changes: 6 additions & 0 deletions requirements/requirements_deploy.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi
nvidia-pytriton
pydantic-settings
tensorstore==0.1.45
uvicorn
zarr
2 changes: 2 additions & 0 deletions requirements/requirements_infer.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This is a copy of requirements_deploy.txt for a seamless rename 'infer' -> 'deploy'.
# TODO: Remove this file once it is not used in container build anywhere.
fastapi
nvidia-pytriton
pydantic-settings
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def req_file(filename, folder="requirements"):
'slu': req_file("requirements_slu.txt"),
'multimodal': req_file("requirements_multimodal.txt"),
'audio': req_file("requirements_audio.txt"),
'deploy': req_file("requirements_deploy.txt"),
}


Expand Down Expand Up @@ -257,7 +258,7 @@ def finalize_options(self):
extras_require=extras_require,
# Add in any packaged data.
include_package_data=True,
exclude=['tools', 'tests', 'nemo.deploy', 'nemo.export'],
exclude=['tools', 'tests'],
package_data={'': ['*.tsv', '*.txt', '*.far', '*.fst', '*.cpp', 'Makefile']},
zip_safe=False,
# PyPI package information.
Expand Down
4 changes: 2 additions & 2 deletions tests/collections/llm/gpt/model/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,5 @@ def test_mixtral_config_8x22b():
assert config.hidden_size == 6144
assert config.num_attention_heads == 48
assert config.ffn_hidden_size == 16384
assert config.max_position_embeddings == 4096
assert config.seq_length == 4096
assert config.max_position_embeddings == 65536
assert config.seq_length == 65536
12 changes: 10 additions & 2 deletions tests/collections/llm/recipes/test_mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def test_pretrain_recipe(self, recipe_module):
assert recipe.trainer.__fn_or_cls__ == Trainer
assert isinstance(recipe.data, run.Config)
assert recipe.data.__fn_or_cls__ == MockDataModule
assert recipe.data.seq_length == 4096
assert isinstance(recipe.model.config, run.Config)
if recipe.model.config.__fn_or_cls__ == MixtralConfig8x22B:
assert recipe.data.seq_length == 65536
else:
assert recipe.data.seq_length == 4096
assert recipe.data.global_batch_size == 512
assert recipe.data.micro_batch_size == 1

Expand Down Expand Up @@ -118,8 +122,12 @@ def test_trainer_parallelism_options(self, recipe_module):
def test_model_config_parameters(self, recipe_module):
model_config = recipe_module.model()
mixtral_config = model_config.config
assert isinstance(mixtral_config, run.Config)
assert mixtral_config.num_layers == 56
assert mixtral_config.hidden_size == 6144
assert mixtral_config.num_attention_heads == 48
assert mixtral_config.seq_length == 4096
if mixtral_config.__fn_or_cls__ == MixtralConfig8x22B:
assert mixtral_config.seq_length == 65536
else:
assert mixtral_config.seq_length == 4096
assert mixtral_config.num_moe_experts == 8
Loading

0 comments on commit 5f01584

Please sign in to comment.