diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index e1d85809..ec0d7a3c 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -632,11 +632,12 @@ def _from_pretrained( multi_target_strategy: Optional[str] = None, use_differentiable_head: bool = False, normalize_embeddings: bool = False, + device: Optional[Union[torch.device, str]] = None, **model_kwargs, ) -> "SetFitModel": - model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token) - target_device = model_body._target_device - model_body.to(target_device) # put `model_body` on the target device + model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token, device=device) + device = model_body._target_device + model_body.to(device) # put `model_body` on the target device if os.path.isdir(model_id): if MODEL_HEAD_NAME in os.listdir(model_id): @@ -671,7 +672,7 @@ def _from_pretrained( if model_head_file is not None: model_head = joblib.load(model_head_file) if isinstance(model_head, torch.nn.Module): - model_head.to(target_device) + model_head.to(device) else: head_params = model_kwargs.pop("head_params", {}) if use_differentiable_head: @@ -689,7 +690,7 @@ def _from_pretrained( # - follow the `model_body`, put `model_head` on the target device base_head_params = { "in_features": model_body.get_sentence_embedding_dimension(), - "device": target_device, + "device": device, "multitarget": use_multitarget, } model_head = SetFitHead(**{**head_params, **base_head_params}) diff --git a/tests/span/test_modeling.py b/tests/span/test_modeling.py index 0bc3ccb8..81675aba 100644 --- a/tests/span/test_modeling.py +++ b/tests/span/test_modeling.py @@ -7,6 +7,7 @@ from setfit import AbsaModel from setfit.span.aspect_extractor import AspectExtractor from setfit.span.modeling import AspectModel, PolarityModel +from tests.test_modeling import torch_cuda_available def test_loading(): @@ -84,3 +85,12 @@ def test_to(absa_model: AbsaModel) -> None: assert absa_model.device.type == "cpu" assert absa_model.aspect_model.device.type == "cpu" assert absa_model.polarity_model.device.type == "cpu" + + +@torch_cuda_available +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_load_model_on_device(device): + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device) + assert model.device.type == device + assert model.polarity_model.device.type == device + assert model.aspect_model.device.type == device diff --git a/tests/test_modeling.py b/tests/test_modeling.py index a5e279f6..71c683a1 100644 --- a/tests/test_modeling.py +++ b/tests/test_modeling.py @@ -246,3 +246,13 @@ def test_to_sentence_transformer_device_reset(use_differentiable_head): model.model_body.encode("This is a test sample to encode") assert model.model_body.device == torch.device("cpu") + + +@torch_cuda_available +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_load_model_on_device(device): + model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device) + assert model.device.type == device + assert model.model_body.device.type == device + + model.model_body.encode("This is a test sample to encode")