Skip to content

Commit

Permalink
Add tests for SetFitABSA as well
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Nov 24, 2023
1 parent 44daad4 commit 6f06204
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/span/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from setfit import AbsaModel
from setfit.span.aspect_extractor import AspectExtractor
from setfit.span.modeling import AspectModel, PolarityModel
from tests.test_modeling import torch_cuda_available


def test_loading():
Expand Down Expand Up @@ -84,3 +85,12 @@ def test_to(absa_model: AbsaModel) -> None:
assert absa_model.device.type == "cpu"
assert absa_model.aspect_model.device.type == "cpu"
assert absa_model.polarity_model.device.type == "cpu"


@torch_cuda_available
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_load_model_on_device(device):
model = AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", device=device)
assert model.device.type == device
assert model.polarity_model.device.type == device
assert model.aspect_model.device.type == device

0 comments on commit 6f06204

Please sign in to comment.