Skip to content

Commit

Permalink
Only use self.processing_class with transformers v4.46.0+
Browse files Browse the repository at this point in the history
and use *args, **kwargs in log; seems safer in case of future changes.
  • Loading branch information
tomaarsen committed Jan 10, 2025
1 parent cb4e803 commit 8c53ce8
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from sentence_transformers.model_card import ModelCardCallback as STModelCardCallback
from sentence_transformers.training_args import BatchSamplers
from sklearn.preprocessing import LabelEncoder
from packaging.version import parse as parse_version
from torch import nn
from transformers.integrations import CodeCarbonCallback
from transformers.trainer_callback import IntervalStrategy, TrainerCallback
from transformers.trainer_utils import HPSearchBackend, default_compute_objective, number_of_arguments, set_seed
from transformers.utils.import_utils import is_in_notebook
from transformers import __version__ as transformers_version

from setfit.model_card import ModelCardCallback

Expand Down Expand Up @@ -72,7 +74,7 @@ def overwritten_call_event(self, event, args, state, control, **kwargs):
model=self.setfit_model,
st_model=self.model,
st_args=args,
tokenizer=self.processing_class,
tokenizer=self.processing_class if parse_version(transformers_version) >= parse_version("4.46.0") else self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
Expand Down Expand Up @@ -156,9 +158,9 @@ def _set_logs_prefix(self, logs_prefix: str) -> None:
"""
self.logs_prefix = logs_prefix

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
logs = {f"{self.logs_prefix}_{k}" if k == "loss" else k: v for k, v in logs.items()}
return super().log(logs, start_time)
return super().log(logs, *args, **kwargs)

def evaluate(
self,
Expand Down

0 comments on commit 8c53ce8

Please sign in to comment.