Skip to content

Commit

Permalink
Fix pandas groupby -> apply warning (#555)
Browse files Browse the repository at this point in the history
* Fix pandas groupby -> apply warning

* Fix model card test pattern

We sample slightly different samples, so the training set metrics are obviously also different
  • Loading branch information
tomaarsen authored Sep 18, 2024
1 parent 6c8d739 commit edee867
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/setfit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,8 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i
shuffled_dataset = dataset.shuffle(seed=seed)

df = shuffled_dataset.to_pandas()
df = df.groupby(label_column)

# sample num_samples, or at least as much as possible
df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed))
# Sample (at most) num_samples examples per class
df = df.groupby(label_column).head(n=num_samples)
df = df.reset_index(drop=True)

all_samples = Dataset.from_pandas(df, features=dataset.features)
Expand Down
8 changes: 4 additions & 4 deletions tests/model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
### Model Labels
\| Label\s+\| Examples\s+\|
\|:-+\|:-+\|
\| negative\s+\| [^\|]+ \|
\| positive\s+\| [^\|]+ \|
\| negative\s+\| [^\|]+ \|
## Evaluation
Expand Down Expand Up @@ -97,9 +97,9 @@
## Training Details
### Training Set Metrics
\| Training set \| Min \| Median \| Max \|
\|:-------------\|:----\|:--------\|:----\|
\| Word count \| 2 \| 11.4375 \| 33 \|
\| Training set \| Min \| Median \| Max \|
\|:-------------\|:----\|:-------\|:----\|
\| Word count \| 3 \| 7.875 \| 18 \|
\| Label \| Training Sample Count \|
\|:---------\|:----------------------\|
Expand Down

0 comments on commit edee867

Please sign in to comment.