diff --git a/tests/span/test_modeling.py b/tests/span/test_modeling.py index 0bc3ccb8..81675aba 100644 --- a/tests/span/test_modeling.py +++ b/tests/span/test_modeling.py @@ -7,6 +7,7 @@ from setfit import AbsaModel from setfit.span.aspect_extractor import AspectExtractor from setfit.span.modeling import AspectModel, PolarityModel +from tests.test_modeling import torch_cuda_available def test_loading(): @@ -84,3 +85,12 @@ def test_to(absa_model: AbsaModel) -> None: assert absa_model.device.type == "cpu" assert absa_model.aspect_model.device.type == "cpu" assert absa_model.polarity_model.device.type == "cpu" + + +@torch_cuda_available +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_load_model_on_device(device): + model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device) + assert model.device.type == device + assert model.polarity_model.device.type == device + assert model.aspect_model.device.type == device