Skip to content

Commit

Permalink
Merge pull request #445 from tomaarsen/feat/load_on_device
Browse files Browse the repository at this point in the history
Allow 'device' on SetFitModel.from_pretrained()
  • Loading branch information
tomaarsen authored Nov 24, 2023
2 parents 93c52dd + 6f06204 commit c41b7c3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,11 +632,12 @@ def _from_pretrained(
multi_target_strategy: Optional[str] = None,
use_differentiable_head: bool = False,
normalize_embeddings: bool = False,
device: Optional[Union[torch.device, str]] = None,
**model_kwargs,
) -> "SetFitModel":
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token)
target_device = model_body._target_device
model_body.to(target_device) # put `model_body` on the target device
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token, device=device)
device = model_body._target_device
model_body.to(device) # put `model_body` on the target device

if os.path.isdir(model_id):
if MODEL_HEAD_NAME in os.listdir(model_id):
Expand Down Expand Up @@ -671,7 +672,7 @@ def _from_pretrained(
if model_head_file is not None:
model_head = joblib.load(model_head_file)
if isinstance(model_head, torch.nn.Module):
model_head.to(target_device)
model_head.to(device)
else:
head_params = model_kwargs.pop("head_params", {})
if use_differentiable_head:
Expand All @@ -689,7 +690,7 @@ def _from_pretrained(
# - follow the `model_body`, put `model_head` on the target device
base_head_params = {
"in_features": model_body.get_sentence_embedding_dimension(),
"device": target_device,
"device": device,
"multitarget": use_multitarget,
}
model_head = SetFitHead(**{**head_params, **base_head_params})
Expand Down
10 changes: 10 additions & 0 deletions tests/span/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from setfit import AbsaModel
from setfit.span.aspect_extractor import AspectExtractor
from setfit.span.modeling import AspectModel, PolarityModel
from tests.test_modeling import torch_cuda_available


def test_loading():
Expand Down Expand Up @@ -84,3 +85,12 @@ def test_to(absa_model: AbsaModel) -> None:
assert absa_model.device.type == "cpu"
assert absa_model.aspect_model.device.type == "cpu"
assert absa_model.polarity_model.device.type == "cpu"


@torch_cuda_available
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_load_model_on_device(device):
model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device)
assert model.device.type == device
assert model.polarity_model.device.type == device
assert model.aspect_model.device.type == device
10 changes: 10 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,13 @@ def test_to_sentence_transformer_device_reset(use_differentiable_head):

model.model_body.encode("This is a test sample to encode")
assert model.model_body.device == torch.device("cpu")


@torch_cuda_available
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_load_model_on_device(device):
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device)
assert model.device.type == device
assert model.model_body.device.type == device

model.model_body.encode("This is a test sample to encode")

0 comments on commit c41b7c3

Please sign in to comment.