diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 6d9f25348..6f5f04afb 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -401,14 +401,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry) return inputs - def compute_loss(self, model, inputs, return_outputs: bool = False): + def compute_loss(self, model, inputs, num_items_in_batch): from neuronx_distributed.pipeline import NxDPPModel if isinstance(model, NxDPPModel): inputs = self._prepare_inputs(inputs) loss = model.run_train(**inputs) else: - loss = super().compute_loss(model, inputs, return_outputs=return_outputs) + loss = super().compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) return loss def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):