Skip to content

Commit

Permalink
Merge pull request #265 from tomaarsen/refactor_v2
Browse files Browse the repository at this point in the history
Refactor to introduce `Trainer` & `TrainingArguments`, add SetFit ABSA
  • Loading branch information
tomaarsen authored Nov 10, 2023
2 parents a893809 + d85f0d9 commit b636cd7
Show file tree
Hide file tree
Showing 37 changed files with 3,665 additions and 1,110 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
- v*-pre
pull_request:
branches:
- main
- v*-pre
workflow_dispatch:

jobs:

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ on:
branches:
- main
- v*-release
- v*-pre
pull_request:
branches:
- main
- v*-pre
workflow_dispatch:

jobs:

Expand Down Expand Up @@ -40,6 +43,8 @@ jobs:
run: |
python -m pip install --no-cache-dir --upgrade pip
python -m pip install --no-cache-dir ${{ matrix.requirements }}
python -m spacy download en_core_web_lg
python -m spacy download en_core_web_sm
if: steps.restore-cache.outputs.cache-hit != 'true'

- name: Install the checked-out setfit
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,7 @@ scripts/tfew/run_tmux.sh
# macOS
.DS_Store
.vscode/settings.json

# Common SetFit Trainer logging folders
wandb
runs/
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/setfit/span/model_card_template.md
120 changes: 56 additions & 64 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ The examples below provide a quick overview on the various features supported in
`setfit` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes:

* `SetFitModel`: a wrapper that combines a pretrained body from `sentence_transformers` and a classification head from either [`scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) or [`SetFitHead`](https://github.com/huggingface/setfit/blob/main/src/setfit/modeling.py) (a differentiable head built upon `PyTorch` with similar APIs to `sentence_transformers`).
* `SetFitTrainer`: a helper class that wraps the fine-tuning process of SetFit.
* `Trainer`: a helper class that wraps the fine-tuning process of SetFit.

Here is an end-to-end example using a classification head from `scikit-learn`:


```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset


# Load a dataset from the Hugging Face Hub
Expand All @@ -61,17 +59,19 @@ eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
args = TrainingArguments(
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1 # The number of epochs to use for contrastive learning
)

trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
Expand All @@ -81,7 +81,7 @@ metrics = trainer.evaluate()
# Push model to the Hub
trainer.push_to_hub("my-awesome-setfit-model")

# Download from Hub and run inference
# Download from Hub
model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
Expand All @@ -92,9 +92,7 @@ Here is an end-to-end example using `SetFitHead`:

```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset


# Load a dataset from the Hugging Face Hub
Expand All @@ -103,6 +101,7 @@ dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]
num_classes = 2

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
Expand All @@ -111,36 +110,26 @@ model = SetFitModel.from_pretrained(
head_params={"out_features": num_classes},
)

# Create trainer
trainer = SetFitTrainer(
args = TrainingArguments(
body_learning_rate=2e-5,
head_learning_rate=1e-2,
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=(1, 25), # For finetuning the embeddings and training the classifier, respectively
l2_weight=0.0,
end_to_end=False, # Don't train the classifier end-to-end, i.e. only train the head
)

trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
trainer.freeze() # Freeze the head
trainer.train() # Train only the body

# Unfreeze the head and freeze the body -> head-only training
trainer.unfreeze(keep_body_frozen=True)
# or
# Unfreeze the head and unfreeze the body -> end-to-end training
trainer.unfreeze(keep_body_frozen=False)

trainer.train(
num_epochs=25, # The number of epochs to train the head or the whole model (body and head)
batch_size=16,
body_learning_rate=1e-5, # The body's learning rate
learning_rate=1e-2, # The head's learning rate
l2_weight=0.0, # Weight decay on **both** the body and head. If `None`, will use 0.01.
)
trainer.train()
metrics = trainer.evaluate()

# Push model to the Hub
Expand Down Expand Up @@ -175,7 +164,7 @@ This will initialise a multilabel classification head from `sklearn` - the follo
* `multi-output`: uses a `MultiOutputClassifier` head.
* `classifier-chain`: uses a `ClassifierChain` head.

From here, you can instantiate a `SetFitTrainer` using the same example above, and train it as usual.
From here, you can instantiate a `Trainer` using the same example above, and train it as usual.

#### Example using the differentiable `SetFitHead`:

Expand All @@ -196,7 +185,6 @@ model = SetFitModel.from_pretrained(
SetFit can also be applied to scenarios where no labels are available. To do so, create a synthetic dataset of training examples:

```python
from datasets import Dataset
from setfit import get_templated_dataset

candidate_labels = ["negative", "positive"]
Expand All @@ -206,22 +194,22 @@ train_dataset = get_templated_dataset(candidate_labels=candidate_labels, sample_
This will create examples of the form `"This sentence is {}"`, where the `{}` is filled in with one of the candidate labels. From here you can train a SetFit model as usual:

```python
from setfit import SetFitModel, SetFitTrainer
from setfit import SetFitModel, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
trainer = SetFitTrainer(
trainer = Trainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
```

We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with Bart), while being 5x faster to generate predictions with.
We find this approach typically outperforms the [zero-shot pipeline](https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) in 🤗 Transformers (based on MNLI with BART), while being 5x faster to generate predictions with.


### Running hyperparameter search

`SetFitTrainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:
`Trainer` provides a `hyperparameter_search()` method that you can use to find good hyperparameters for your data. To use this feature, first install the `optuna` backend:

```bash
python -m pip install setfit[optuna]
Expand Down Expand Up @@ -267,23 +255,23 @@ def hp_space(trial): # Training parameters

**Note:** In practice, we found `num_iterations` to be the most important hyperparameter for the contrastive learning process.

The next step is to instantiate a `SetFitTrainer` and call `hyperparameter_search()`:
The next step is to instantiate a `Trainer` and call `hyperparameter_search()`:

```python
from datasets import Dataset
from setfit import SetFitTrainer
from setfit import Trainer

dataset = Dataset.from_dict(
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)
{"text_new": ["a", "b", "c"], "label_new": [0, 1, 2], "extra_column": ["d", "e", "f"]}
)

trainer = SetFitTrainer(
trainer = Trainer(
train_dataset=dataset,
eval_dataset=dataset,
model_init=model_init,
column_mapping={"text_new": "text", "label_new": "label"},
)
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=20)
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5)
```

Finally, you can apply the hyperparameters you found to the trainer, and lock in the optimal model, before training for
Expand All @@ -300,9 +288,8 @@ If you have access to unlabeled data, you can use knowledge distillation to comp

```python
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, DistillationSetFitTrainer, sample_dataset
from setfit import SetFitModel, Trainer, DistillationTrainer, sample_dataset
from setfit.training_args import TrainingArguments

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
Expand All @@ -320,34 +307,37 @@ teacher_model = SetFitModel.from_pretrained(
)

# Create trainer for teacher model
teacher_trainer = SetFitTrainer(
teacher_trainer = Trainer(
model=teacher_model,
train_dataset=train_dataset_teacher,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()

# Load small student model
student_model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
batch_size=16,
num_iterations=20,
num_epochs=1
)

# Create trainer for knowledge distillation
student_trainer = DistillationSetFitTrainer(
student_trainer = DistillationTrainer(
teacher_model=teacher_model,
train_dataset=train_dataset_student,
student_model=student_model,
args=args,
train_dataset=train_dataset_student,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20,
num_epochs=1,
)

# Train student with knowledge distillation
student_trainer.train()
student_metrics = student_trainer.evaluate()
```


Expand Down Expand Up @@ -403,13 +393,15 @@ make style && make quality

## Citation

```@misc{https://doi.org/10.48550/arxiv.2209.11055,
```
@misc{https://doi.org/10.48550/arxiv.2209.11055,
doi = {10.48550/ARXIV.2209.11055},
url = {https://arxiv.org/abs/2209.11055},
author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Efficient Few-Shot Learning Without Prompts},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}}
copyright = {Creative Commons Attribution 4.0 International}
}
```
4 changes: 4 additions & 0 deletions docs/source/en/api/main.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
# SetFitHead

[[autodoc]] SetFitHead

# AbsaModel

[[autodoc]] AbsaModel
12 changes: 8 additions & 4 deletions docs/source/en/api/trainer.mdx
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

# SetFitTrainer
# Trainer

[[autodoc]] SetFitTrainer
[[autodoc]] Trainer

# DistillationSetFitTrainer
# DistillationTrainer

[[autodoc]] DistillationSetFitTrainer
[[autodoc]] DistillationTrainer

# AbsaTrainer

[[autodoc]] AbsaTrainer
Loading

0 comments on commit b636cd7

Please sign in to comment.