Skip to content

Commit

Permalink
Fix SetFitModel: not a dataclass, not a PyTorchModelHubMixin (#505)
Browse files Browse the repository at this point in the history
* Fix SetFitModel: not a dataclass, not a PyTorchModelHubMixin

* adapt SpanSetFitModel and PolarityModel

* super

* fix tests ?

* fix tests

* Apply formatting

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
Wauplin and tomaarsen authored Sep 12, 2024
1 parent 6167a03 commit 3904e53
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 35 deletions.
40 changes: 23 additions & 17 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Literal, Optional, Set, Tuple, Union

import joblib
import numpy as np
import requests
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from packaging.version import Version, parse
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -189,8 +188,7 @@ def __repr__(self) -> str:
return "SetFitHead({})".format(self.get_config_dict())


@dataclass
class SetFitModel(PyTorchModelHubMixin):
class SetFitModel(ModelHubMixin):
"""A SetFit model with integration to the [Hugging Face Hub](https://huggingface.co).
Example::
Expand All @@ -205,19 +203,27 @@ class SetFitModel(PyTorchModelHubMixin):
['positive', 'negative', 'negative']
"""

model_body: Optional[SentenceTransformer] = None
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None
multi_target_strategy: Optional[str] = None
normalize_embeddings: bool = False
labels: Optional[List[str]] = None
model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData)
sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False)

attributes_to_save: Set[str] = field(
init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"}
)

def __post_init__(self):
def __init__(
self,
model_body: Optional[SentenceTransformer] = None,
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
multi_target_strategy: Optional[str] = None,
normalize_embeddings: bool = False,
labels: Optional[List[str]] = None,
model_card_data: Optional[SetFitModelCardData] = None,
sentence_transformers_kwargs: Optional[Dict] = None,
**kwargs,
) -> None:
super(SetFitModel, self).__init__()
self.model_body = model_body
self.model_head = model_head
self.multi_target_strategy = multi_target_strategy
self.normalize_embeddings = normalize_embeddings
self.labels = labels
self.model_card_data = model_card_data or SetFitModelCardData()
self.sentence_transformers_kwargs = sentence_transformers_kwargs or {}

self.attributes_to_save: Set[str] = {"normalize_embeddings", "labels"}
self.model_card_data.register_model(self)

@property
Expand Down
30 changes: 16 additions & 14 deletions src/setfit/span/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import tempfile
import types
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from datasets import Dataset
Expand All @@ -25,18 +25,19 @@
logger = logging.get_logger(__name__)


@dataclass
class SpanSetFitModel(SetFitModel):
spacy_model: str = "en_core_web_lg"
span_context: int = 0

attributes_to_save: Set[str] = field(
init=False,
repr=False,
default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"},
)
def __init__(
self,
spacy_model: str = "en_core_web_lg",
span_context: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.spacy_model = spacy_model
self.span_context = span_context
self.attributes_to_save = {"normalize_embeddings", "labels", "span_context", "spacy_model"}

def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> Iterable[str]:
for doc, aspects in zip(docs, aspects_list):
for aspect_slice in aspects:
aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
Expand Down Expand Up @@ -137,9 +138,10 @@ def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[b
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel)


@dataclass
class PolarityModel(SpanSetFitModel):
span_context: int = 3
def __init__(self, span_context: int = 3, **kwargs):
super().__init__(**kwargs)
self.span_context = span_context


PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel)
Expand Down
4 changes: 2 additions & 2 deletions tests/span/aspect_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- sentence-transformers
- text-classification
- generated_from_setfit_trainer
base_model: sentence-transformers/paraphrase-albert-small-v2
metrics:
- accuracy
widget:
Expand All @@ -31,8 +32,7 @@
ram_total_size: [\d\.]+
hours_used: [\d\.]+
( hardware_used: .+
)?base_model: sentence-transformers/paraphrase-albert-small-v2
model-index:
)?model-index:
- name: SetFit Aspect Model with sentence-transformers\/paraphrase-albert-small-v2
results:
- task:
Expand Down
4 changes: 2 additions & 2 deletions tests/span/polarity_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- sentence-transformers
- text-classification
- generated_from_setfit_trainer
base_model: sentence-transformers/paraphrase-albert-small-v2
metrics:
- accuracy
widget:
Expand All @@ -31,8 +32,7 @@
ram_total_size: [\d\.]+
hours_used: [\d\.]+
( hardware_used: .+
)?base_model: sentence-transformers/paraphrase-albert-small-v2
model-index:
)?model-index:
- name: SetFit Polarity Model with sentence-transformers\/paraphrase-albert-small-v2
results:
- task:
Expand Down

0 comments on commit 3904e53

Please sign in to comment.