From 765c8af6eacc88ec9c984ef99cf99c1c7d1a376d Mon Sep 17 00:00:00 2001 From: "Oscar Martinez, MS" <82097051+OscarRunsCode@users.noreply.github.com> Date: Thu, 5 Sep 2024 14:37:21 -0400 Subject: [PATCH 1/5] Check if max pairs limit reached in `generate_pairs` and `generate_multilabel_pairs` --- src/setfit/sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index a78f836c..195476df 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -89,21 +89,20 @@ def __init__( def generate_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if _label == label: + if (_label == label) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)): self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - else: + elif (_label != label) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: + elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: break def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if any(np.logical_and(_label, label)): - # logical_and checks if labels are both set for each class + if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - else: + elif any(np.logical_xor(_label, label)) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: + elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: break def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]: From 496cf0d1b66f5b02f4fcf6985665553adb10804b Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Sep 2024 13:20:38 +0200 Subject: [PATCH 2/5] Resolve merge mistake --- src/setfit/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 195476df..97442870 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -98,7 +98,8 @@ def generate_pairs(self) -> None: def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: + if any(np.logical_and(_label, label)) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)): + # logical_and checks if labels are both set for each class self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) elif any(np.logical_xor(_label, label)) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) From 8ace7c648e249d0bb0c9c9358913bf719ebeec2e Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Sep 2024 13:21:10 +0200 Subject: [PATCH 3/5] Run 'make style' --- src/setfit/sampler.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 97442870..48df2e47 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -93,17 +93,29 @@ def generate_pairs(self) -> None: self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) elif (_label != label) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: + elif ( + self.max_pairs != -1 + and len(self.pos_pairs) >= self.max_pairs + and len(self.neg_pairs) >= self.max_pairs + ): break def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if any(np.logical_and(_label, label)) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)): + if any(np.logical_and(_label, label)) and ( + (self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs) + ): # logical_and checks if labels are both set for each class self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - elif any(np.logical_xor(_label, label)) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): + elif any(np.logical_xor(_label, label)) and ( + (self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs) + ): self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - elif self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs and len(self.neg_pairs) >= self.max_pairs: + elif ( + self.max_pairs != -1 + and len(self.pos_pairs) >= self.max_pairs + and len(self.neg_pairs) >= self.max_pairs + ): break def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]: From e4d806b81e831c466ce39d48dfd8a00e8136f7bf Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Sep 2024 13:58:53 +0200 Subject: [PATCH 4/5] Rewrite to be simpler/clearer; should be equivalent --- src/setfit/sampler.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 48df2e47..1f7727c9 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -89,33 +89,33 @@ def __init__( def generate_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if (_label == label) and ((self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs)): - self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - elif (_label != label) and ((self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs)): + is_positive = _label == label + is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs + is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs + + if is_positive: + if not is_positive_full: + self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) + elif not is_negative_full: self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - elif ( - self.max_pairs != -1 - and len(self.pos_pairs) >= self.max_pairs - and len(self.neg_pairs) >= self.max_pairs - ): + + if is_positive_full and is_negative_full: break def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - if any(np.logical_and(_label, label)) and ( - (self.max_pairs == -1) or (len(self.pos_pairs) < self.max_pairs) - ): - # logical_and checks if labels are both set for each class - self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - elif any(np.logical_xor(_label, label)) and ( - (self.max_pairs == -1) or (len(self.neg_pairs) < self.max_pairs) - ): + # logical_and checks if labels are both set for each class + is_positive = any(np.logical_and(_label, label)) + is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs + is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs + + if is_positive: + if not is_positive_full: + self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) + elif not is_negative_full: self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - elif ( - self.max_pairs != -1 - and len(self.pos_pairs) >= self.max_pairs - and len(self.neg_pairs) >= self.max_pairs - ): + + if is_positive_full and is_negative_full: break def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]: From 41d2906401e4ae98fbbc3001c4f33e5965727799 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 18 Sep 2024 14:05:19 +0200 Subject: [PATCH 5/5] Update model card patterns with new max_pairs --- tests/model_card_pattern.py | 4 ++-- tests/span/aspect_model_card_pattern.py | 4 ++-- tests/span/polarity_model_card_pattern.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/model_card_pattern.py b/tests/model_card_pattern.py index a77cdabd..59e290ec 100644 --- a/tests/model_card_pattern.py +++ b/tests/model_card_pattern.py @@ -125,8 +125,8 @@ - load_best_model_at_end: False ### Training Results -\| Epoch \| Step \| Training Loss \| Validation Loss \| -\|:------:\|:----:\|:-------------:\|:---------------:\| +\| Epoch \| Step \| Training Loss \| Validation Loss \| +\|:-----:\|:----:\|:-------------:\|:---------------:\| (\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+ ### Environmental Impact Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\. diff --git a/tests/span/aspect_model_card_pattern.py b/tests/span/aspect_model_card_pattern.py index 2f33d252..7cf7393b 100644 --- a/tests/span/aspect_model_card_pattern.py +++ b/tests/span/aspect_model_card_pattern.py @@ -137,8 +137,8 @@ - load_best_model_at_end: False ### Training Results -\| Epoch \| Step \| Training Loss \| Validation Loss \| -\|:------:\|:----:\|:-------------:\|:---------------:\| +\| Epoch \| Step \| Training Loss \| Validation Loss \| +\|:-----:\|:----:\|:-------------:\|:---------------:\| (\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+ ### Environmental Impact Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\. diff --git a/tests/span/polarity_model_card_pattern.py b/tests/span/polarity_model_card_pattern.py index 4895956d..e9eabf67 100644 --- a/tests/span/polarity_model_card_pattern.py +++ b/tests/span/polarity_model_card_pattern.py @@ -137,8 +137,8 @@ - load_best_model_at_end: False ### Training Results -\| Epoch \| Step \| Training Loss \| Validation Loss \| -\|:------:\|:----:\|:-------------:\|:---------------:\| +\| Epoch \| Step \| Training Loss \| Validation Loss \| +\|:-----:\|:----:\|:-------------:\|:---------------:\| (\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+ ### Environmental Impact Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\.