Skip to content

Commit

Permalink
Merge pull request #463 from tomaarsen/update/sbert_2.3_support
Browse files Browse the repository at this point in the history
Prepare SetFit for upcoming 2.3.0 release of SentenceTransformers
  • Loading branch information
tomaarsen authored Jan 11, 2024
2 parents 3e3d828 + f3f8666 commit f387387
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"evaluate>=0.3.0",
"huggingface_hub>=0.13.0",
"scikit-learn",
"packaging",
]
ABSA_REQUIRE = ["spacy"]
QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
Expand Down
52 changes: 46 additions & 6 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from sentence_transformers import SentenceTransformer, models
from packaging.version import Version, parse
from sentence_transformers import SentenceTransformer
from sentence_transformers import __version__ as sentence_transformers_version
from sentence_transformers import models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
Expand Down Expand Up @@ -215,6 +218,7 @@ class SetFitModel(PyTorchModelHubMixin):
normalize_embeddings: bool = False
labels: Optional[List[str]] = None
model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData)
sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False)

attributes_to_save: Set[str] = field(
init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"}
Expand Down Expand Up @@ -605,6 +609,9 @@ def device(self) -> torch.device:
Returns:
torch.device: The device that the model is on.
"""
# SentenceTransformers.device is reliable from 2.3.0 onwards
if parse(sentence_transformers_version) >= Version("2.3.0"):
return self.model_body.device
return self.model_body._target_device

def to(self, device: Union[str, torch.device]) -> "SetFitModel":
Expand All @@ -622,9 +629,10 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
Returns:
SetFitModel: Returns the original model, but now on the desired device.
"""
# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
# the body location
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
# Note that we must also set _target_device with sentence-transformers <2.3.0,
# or any SentenceTransformer.fit() call will reset the body location
if parse(sentence_transformers_version) < Version("2.3.0"):
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
self.model_body = self.model_body.to(device)

if self.has_differentiable_head:
Expand Down Expand Up @@ -701,10 +709,37 @@ def _from_pretrained(
multi_target_strategy: Optional[str] = None,
use_differentiable_head: bool = False,
device: Optional[Union[torch.device, str]] = None,
trust_remote_code: bool = False,
**model_kwargs,
) -> "SetFitModel":
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token, device=device)
device = model_body._target_device
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"use_auth_token": token,
"device": device,
"trust_remote_code": trust_remote_code,
}
if parse(sentence_transformers_version) >= Version("2.3.0"):
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"token": token,
"device": device,
"trust_remote_code": trust_remote_code,
}
else:
if trust_remote_code:
raise ValueError(
"The `trust_remote_code` argument is only supported for `sentence-transformers` >= 2.3.0."
)
sentence_transformers_kwargs = {
"cache_folder": cache_dir,
"use_auth_token": token,
"device": device,
}
model_body = SentenceTransformer(model_id, **sentence_transformers_kwargs)
if parse(sentence_transformers_version) >= Version("2.3.0"):
device = model_body.device
else:
device = model_body._target_device
model_body.to(device) # put `model_body` on the target device

# Try to load a SetFit config file
Expand Down Expand Up @@ -827,6 +862,7 @@ def _from_pretrained(
model_head=model_head,
multi_target_strategy=multi_target_strategy,
model_card_data=model_card_data,
sentence_transformers_kwargs=sentence_transformers_kwargs,
**model_kwargs,
)

Expand All @@ -851,6 +887,10 @@ def _from_pretrained(
Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
device (`Union[torch.device, str]`, *optional*):
The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`.
trust_remote_code (`bool`, defaults to `False`): Whether or not to allow for custom Sentence Transformers
models defined on the Hub in their own modeling files. This option should only be set to True for
repositories you trust and in which you have read the code, as it will execute code present on
the Hub on your local machine. Defaults to False.
Example::
Expand Down
18 changes: 10 additions & 8 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,8 @@ def _train_sentence_transformer(
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()

model_body.to(model_body._target_device)
loss_func.to(model_body._target_device)
model_body.to(self.model.device)
loss_func.to(self.model.device)

# Use smart batching
train_dataloader.collate_fn = model_body.smart_batching_collate
Expand Down Expand Up @@ -623,8 +623,8 @@ def _train_sentence_transformer(
data = next(data_iterator)

features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
labels = labels.to(self.model.device)
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))

if args.use_amp:
with autocast():
Expand Down Expand Up @@ -671,10 +671,12 @@ def _train_sentence_transformer(
step_to_load = dir_name[5:]
logger.info(f"Loading best SentenceTransformer model from step {step_to_load}.")
self.model.model_card_data.set_best_model_step(int(step_to_load))
sentence_transformer_kwargs = self.model.sentence_transformers_kwargs
sentence_transformer_kwargs["device"] = self.model.device
self.model.model_body = SentenceTransformer(
self.state.best_model_checkpoint, device=model_body._target_device
self.state.best_model_checkpoint, **sentence_transformer_kwargs
)
self.model.model_body.to(model_body._target_device)
self.model.model_body.to(self.model.device)

# Ensure logging the speed metrics
num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps
Expand Down Expand Up @@ -734,8 +736,8 @@ def _evaluate_with_loss(
tqdm(iter(eval_dataloader), total=eval_steps, leave=False, disable=not args.show_progress_bar), start=1
):
features, labels = data
labels = labels.to(model_body._target_device)
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
labels = labels.to(self.model.device)
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))

if args.use_amp:
with autocast():
Expand Down

0 comments on commit f387387

Please sign in to comment.