diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index a78f836c..195476df 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -89,21 +89,20 @@ def __init__( def generate_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if _label == label: + if (_label == label) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)): self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - else: + elif (_label != label) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: + elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: break def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if any(np.logical_and(_label, label)): - # logical_and checks if labels are both set for each class + if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - else: + elif any(np.logical_xor(_label, label)) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: + elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: break def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]: