Skip to content

Commit

Permalink
Firm up max_pairs limit in generate_pairs and `generate_multilabe…
Browse files Browse the repository at this point in the history
…l_pairs` (#549)

* Check if max pairs limit reached in `generate_pairs` and `generate_multilabel_pairs`

* Resolve merge mistake

* Run 'make style'

* Rewrite to be simpler/clearer; should be equivalent

* Update model card patterns with new max_pairs

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
OscarRunsCode and tomaarsen authored Sep 18, 2024
1 parent edee867 commit ac0e8bc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
30 changes: 21 additions & 9 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,33 @@ def __init__(

def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if _label == label:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
is_positive = _label == label
is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs
is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs

if is_positive:
if not is_positive_full:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
elif not is_negative_full:
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:

if is_positive_full and is_negative_full:
break

def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if any(np.logical_and(_label, label)):
# logical_and checks if labels are both set for each class
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
# logical_and checks if labels are both set for each class
is_positive = any(np.logical_and(_label, label))
is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs
is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs

if is_positive:
if not is_positive_full:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
elif not is_negative_full:
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:

if is_positive_full and is_negative_full:
break

def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]:
Expand Down
4 changes: 2 additions & 2 deletions tests/model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@
- load_best_model_at_end: False
### Training Results
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:------:\|:----:\|:-------------:\|:---------------:\|
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:-----:\|:----:\|:-------------:\|:---------------:\|
(\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+
### Environmental Impact
Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\.
Expand Down
4 changes: 2 additions & 2 deletions tests/span/aspect_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@
- load_best_model_at_end: False
### Training Results
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:------:\|:----:\|:-------------:\|:---------------:\|
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:-----:\|:----:\|:-------------:\|:---------------:\|
(\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+
### Environmental Impact
Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\.
Expand Down
4 changes: 2 additions & 2 deletions tests/span/polarity_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@
- load_best_model_at_end: False
### Training Results
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:------:\|:----:\|:-------------:\|:---------------:\|
\| Epoch \| Step \| Training Loss \| Validation Loss \|
\|:-----:\|:----:\|:-------------:\|:---------------:\|
(\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+
### Environmental Impact
Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\.
Expand Down

0 comments on commit ac0e8bc

Please sign in to comment.