diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 3c6946f129a7..f2bb9d36b377 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -1177,7 +1177,7 @@ def triton_infer_fn(self, **inputs: np.ndarray): lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8") infer_input["lora_uids"] = lora_uids[0].tolist() if "output_generation_logits" in inputs: - generation_logits_available = inputs["output_generation_logits"] + generation_logits_available = inputs["output_generation_logits"][0][0] infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0] if generation_logits_available: