From c2529b113a9b7afa20f4233004d6563309b48aad Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 2 Jan 2025 16:01:37 +0000 Subject: [PATCH] refactor(trainer): adapt to new compute_loss signature --- optimum/neuron/trainers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 6d9f25348..6f5f04afb 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -401,14 +401,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry) return inputs - def compute_loss(self, model, inputs, return_outputs: bool = False): + def compute_loss(self, model, inputs, num_items_in_batch): from neuronx_distributed.pipeline import NxDPPModel if isinstance(model, NxDPPModel): inputs = self._prepare_inputs(inputs) loss = model.run_train(**inputs) else: - loss = super().compute_loss(model, inputs, return_outputs=return_outputs) + loss = super().compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) return loss def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):