Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(trainer): adapt to new compute_loss signature
Browse files Browse the repository at this point in the history
dacorvo committed Jan 2, 2025
1 parent fa0d0b9 commit c2529b1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
@@ -401,14 +401,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry)
return inputs

def compute_loss(self, model, inputs, return_outputs: bool = False):
def compute_loss(self, model, inputs, num_items_in_batch):
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
inputs = self._prepare_inputs(inputs)
loss = model.run_train(**inputs)
else:
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)
loss = super().compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
return loss

def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):

0 comments on commit c2529b1

Please sign in to comment.