Skip to content

Commit

Permalink
Check if max pairs limit reached in generate_pairs and `generate_mu…
Browse files Browse the repository at this point in the history
…ltilabel_pairs`
  • Loading branch information
OscarRunsCode authored and tomaarsen committed Sep 18, 2024
1 parent 6d4010d commit 765c8af
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,20 @@ def __init__(

def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if _label == label:
if (_label == label) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)):
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
elif (_label != label) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)):
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs:
break

def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if any(np.logical_and(_label, label)):
# logical_and checks if labels are both set for each class
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
elif any(np.logical_xor(_label, label)) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)):
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs:
break

def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]:
Expand Down

0 comments on commit 765c8af

Please sign in to comment.