-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In-framework deployment NeMo 2.0 nemo_export.py test #11749
Changes from 8 commits
9d3dfed
be30d89
d47658a
5eb0804
f7c5b89
7979802
c3f9cc9
3e5fb00
633e973
398b0e7
c0aae6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,14 +15,15 @@ | |
import logging | ||
from enum import IntEnum, auto | ||
from pathlib import Path | ||
from typing import List | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed | ||
import wrapt | ||
from lightning.pytorch.trainer.trainer import Trainer | ||
from megatron.core.inference.common_inference_params import CommonInferenceParams | ||
from pytorch_lightning.trainer.trainer import Trainer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
|
||
from megatron.core.inference.inference_request import InferenceRequest | ||
|
||
import nemo.lightning as nl | ||
from nemo.collections.llm import inference | ||
|
@@ -107,7 +108,7 @@ class MegatronLLMDeploy: | |
|
||
@staticmethod | ||
def get_deployable( | ||
nemo_checkpoint_filepath: str = None, | ||
nemo_checkpoint_filepath: str, | ||
num_devices: int = 1, | ||
num_nodes: int = 1, | ||
tensor_model_parallel_size: int = 1, | ||
|
@@ -178,6 +179,39 @@ def __init__( | |
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, | ||
) | ||
|
||
def generate( | ||
self, | ||
prompts: List[str], | ||
max_batch_size: int = 4, | ||
inference_params: Optional[CommonInferenceParams] = None, | ||
random_seed: Optional[int] = None, | ||
) -> List[InferenceRequest]: | ||
""" | ||
Generates text based on the provided input prompts. | ||
|
||
Args: | ||
prompts (List[str]): A list of input strings. | ||
max_batch_size (int): The maximum batch size used for inference. | ||
inference_params (Optional[CommonInferenceParams]): Parameters for controlling the inference process. | ||
random_seed (Optional[int]): A random seed for reproducibility. | ||
|
||
Returns: | ||
List[InferenceRequest]: A list containing the generated results. | ||
""" | ||
# TODO: This function doesn't account for parallelism settings currently | ||
|
||
inference_params = inference_params or CommonInferenceParams() | ||
|
||
results = inference.generate( | ||
model=self.inference_wrapped_model, | ||
tokenizer=self.mcore_tokenizer, | ||
prompts=prompts, | ||
max_batch_size=max_batch_size, | ||
random_seed=random_seed, | ||
inference_params=inference_params, | ||
) | ||
return list(results) | ||
|
||
@property | ||
def get_triton_input(self): | ||
inputs = ( | ||
|
@@ -222,14 +256,7 @@ def triton_infer_fn(self, **inputs: np.ndarray): | |
return_log_probs=log_probs, | ||
) | ||
|
||
results = inference.generate( | ||
model=self.inference_wrapped_model, | ||
tokenizer=self.mcore_tokenizer, | ||
prompts=prompts, | ||
max_batch_size=max_batch_size, | ||
random_seed=random_seed, | ||
inference_params=inference_params, | ||
) | ||
results = self.generate(prompts, max_batch_size, inference_params, random_seed) | ||
|
||
output_texts = [r.generated_text if text_only else r for r in results] | ||
output_infer = {"sentences": cast_output(output_texts, np.bytes_)} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,7 +141,7 @@ def query_llm( | |
"object": "text_completion", | ||
"created": int(time.time()), | ||
"model": self.model_name, | ||
"choices": [{"text": str(sentences)}], | ||
"choices": [{"text": sentences}], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Projecting NumPy array to string makes it hard to recover it later for further processing (note missing comma
|
||
} | ||
if log_probs_output is not None: | ||
openai_response["log_probs"] = log_probs_output | ||
|
@@ -297,7 +297,7 @@ def query_llm( | |
"object": "text_completion", | ||
"created": int(time.time()), | ||
"model": self.model_name, | ||
"choices": [{"text": str(sentences)}], | ||
"choices": [{"text": sentences}], | ||
} | ||
if output_generation_logits: | ||
openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import os | ||
|
||
""" | ||
Create a dataset with five Lambada test examples for functional testing. Each line | ||
contains a dictionary with a "text_before_last_word" and "last_word" keys. | ||
""" | ||
|
||
|
||
def create_sample_lambada(output_file: str, overwrite: bool = False): | ||
"""Create JSON file with a few Lambada examples.""" | ||
if os.path.isfile(output_file) and not overwrite: | ||
print(f"File {output_file} exists and overwrite flag is not set so exiting.") | ||
return | ||
|
||
texts = [ | ||
{ | ||
"text_before_last_word": "In my palm is a clear stone , and inside it is a small ivory statuette . A guardian angel .\n\n\" Figured if you re going to be out at night getting hit by cars , you might as well have some backup .\"\n\n I look at him , feeling stunned . Like this is some sort of sign . But as I stare at Harlin , his mouth curved in a confident grin , I don t care about", | ||
"last_word": "signs", | ||
}, | ||
{ | ||
"text_before_last_word": "Give me a minute to change and I ll meet you at the docks .\" She d forced those words through her teeth .\n\n\" No need to change . We won t be that long .\"\n\n Shane gripped her arm and started leading her to the dock .\n\n\" I can make it there on my own ,", | ||
"last_word": "Shane", | ||
}, | ||
{ | ||
"text_before_last_word": "\" Only one source I know of that would be likely to cough up enough money to finance a phony sleep research facility and pay people big bucks to solve crimes in their dreams ,\" Farrell concluded dryly .\n\n\" What can I say ?\" Ellis unfolded his arms and widened his hands . \" Your tax dollars at work .\"\n\n Before Farrell could respond , Leila s voice rose from inside the house .\n\n\" No insurance ?\" she wailed . \" What do you mean you don t have any", | ||
"last_word": "insurance", | ||
}, | ||
{ | ||
"text_before_last_word": "Helen s heart broke a little in the face of Miss Mabel s selfless courage . She thought that because she was old , her life was of less value than the others . For all Helen knew , Miss Mabel had a lot more years to live than she did . \" Not going to happen ,\" replied", | ||
"last_word": "Helen", | ||
}, | ||
{ | ||
"text_before_last_word": "Preston had been the last person to wear those chains , and I knew what I d see and feel if they were slipped onto my skin the Reaper s unending hatred of me . I d felt enough of that emotion already in the amphitheater . I didn t want to feel anymore .\n\n\" Don t put those on me ,\" I whispered . \" Please .\"\n\n Sergei looked at me , surprised by my low , raspy please , but he put down the", | ||
"last_word": "chains", | ||
}, | ||
] | ||
|
||
print(f"Writing {len(texts)} line(s) to {output_file}...") | ||
os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) | ||
with open(output_file, mode="w", encoding="utf-8") as f: | ||
json.dump(texts, f) | ||
print("OK.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("Create a sample from Lambada test dataset.") | ||
parser.add_argument("--output_file", required=True, help="Output file name") | ||
parser.add_argument("--overwrite", action="store_true", help="Overwrite file if it exists") | ||
args = parser.parse_args() | ||
create_sample_lambada(args.output_file, args.overwrite) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to embed this into setup.py? What reasons may speak against it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I can offer sth like 398b0e7