Skip to content

Commit

Permalink
Merge pull request #6 from unitaryai/lightweight_models
Browse files Browse the repository at this point in the history
Added lightweight models
  • Loading branch information
laurahanu authored Dec 16, 2020
2 parents 1da4222 + d43924b commit 035f92b
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 7 deletions.
42 changes: 42 additions & 0 deletions configs/Toxic_comment_classification_ALBERT.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"name": "Jigsaw_ALBERT",
"n_gpu": 1,
"batch_size": 10,
"accumulate_grad_batches": 3,
"loss": "binary_cross_entropy",
"arch": {
"type": "ALBERT",
"args": {
"num_classes": 6,
"model_type": "albert-base-v2",
"model_name": "AlbertForSequenceClassification",
"tokenizer_name": "AlbertTokenizer"
}
},
"dataset": {
"type": "JigsawDataOriginal",
"args": {
"train_csv_file": "jigsaw_data/jigsaw-toxic-comment-classification-challenge/train.csv",
"test_csv_file": "jigsaw_data/jigsaw-toxic-comment-classification-challenge/val.csv",
"val_fraction": null,
"create_val_set": false,
"add_test_labels": false,
"classes": [
"toxic",
"severe_toxic",
"obscene",
"threat",
"insult",
"identity_hate"
]
}
},
"optimizer": {
"type": "Adam",
"args": {
"lr": 3e-5,
"weight_decay": 3e-6,
"amsgrad": true
}
}
}
55 changes: 55 additions & 0 deletions configs/Unintended_bias_toxic_comment_classification_Albert.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"name": "Jigsaw_ALBERT_bias",
"n_gpu": 1,
"batch_size": 10,
"accumulate_grad_batches": 3,
"num_main_classes": 7,
"loss_weight": 0.75,
"arch": {
"type": "ALBERT",
"args": {
"num_classes": 16,
"model_type": "albert-base-v2",
"model_name": "AlbertForSequenceClassification",
"tokenizer_name": "AlbertTokenizer"
}
},
"dataset": {
"type": "JigsawDataBias",
"args": {
"train_csv_file": "jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/train.csv",
"test_csv_file": "jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/test_public_expanded.csv",
"val_fraction": null,
"create_val_set": false,
"loss_weight": 0.75,
"classes": [
"toxicity",
"severe_toxicity",
"obscene",
"identity_attack",
"insult",
"threat",
"sexual_explicit"
],
"identity_classes": [
"male",
"female",
"homosexual_gay_or_lesbian",
"christian",
"jewish",
"muslim",
"black",
"white",
"psychiatric_or_mental_illness"
]
}
},
"optimizer": {
"type": "Adam",
"args": {
"lr": 3e-5,
"weight_decay": 3e-6,
"amsgrad": true
}
}
}
17 changes: 14 additions & 3 deletions detoxify/detoxify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"original": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_original-c1212f89.ckpt",
"unbiased": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_bias-4e693588.ckpt",
"multilingual": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_multilingual-bbddc277.ckpt",
"original-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/original-albert-0e1d6498.ckpt",
"unbiased-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/unbiased-albert-c8519128.ckpt"
}

PRETRAINED_MODEL = None
Expand Down Expand Up @@ -58,13 +60,13 @@ class Detoxify:
Easily predict if a comment or list of comments is toxic.
Can initialize 3 different model types from model type or checkpoint path:
- original:
BERT model trained on data from the Jigsaw Toxic Comment
model trained on data from the Jigsaw Toxic Comment
Classification Challenge
- unbiased:
RoBERTa model trained on data from the Jigsaw Unintended Bias in
model trained on data from the Jigsaw Unintended Bias in
Toxicity Classification Challenge
- multilingual:
XLM-RoBertA model trained on data from the Jigsaw Multilingual
model trained on data from the Jigsaw Multilingual
Toxic Comment Classification Challenge
Args:
model_type(str): model type to be loaded, can be either original,
Expand All @@ -82,6 +84,7 @@ def __init__(self, model_type="original", checkpoint=PRETRAINED_MODEL):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)


@torch.no_grad()
def predict(self, text):
self.model.eval()
Expand All @@ -104,9 +107,17 @@ def toxic_bert():
return load_model("original")


def toxic_albert():
return load_model("original-small")


def unbiased_toxic_roberta():
return load_model("unbiased")


def unbiased_albert():
return load_model("unbiased-small")


def multilingual_toxic_xlm_r():
return load_model("multilingual")
4 changes: 4 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
from detoxify import unbiased_toxic_roberta

from detoxify import multilingual_toxic_xlm_r

from detoxify import toxic_albert

from detoxify import unbiased_albert
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pandas >= 1.1.2
scikit-learn >= 0.23.2
datasets >= 1.0.2
tqdm == 4.41.0
sentencepiece >= 0.1.94
sentencepiece >= 0.1.94
3 changes: 0 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ def cli_main():
type=str,
help="number of workers used in the data loader (default: 10)",
)
parser.add_argument(
"-g", "--n_gpu", default=None, type=int, help="if given, override the num"
)
parser.add_argument(
"-e", "--n_epochs", default=100, type=int, help="if given, override the num"
)
Expand Down

0 comments on commit 035f92b

Please sign in to comment.