-
Folder:
/dart/
-
Datasets:
dart
(Google drive link) -
Teacher model
- Init:
bart-large
- Fine-tune:
scripts/run_finetune_teacher.sh
(Google drive link)
- Init:
-
Student init model:
- Pre-train init student model:
scripts/run_pretrain_student_distill.sh
(Google drive link)
- Pre-train init student model:
-
Generate pseudo-target with Teacher model (useful for Seqkd, JS, TVD)
scripts/run_teacher_label_all.sh
(Google drive link), replacetrain.json
-
Run KD methods (
/scripts/
)- Seqkd:
run_seqkd.sh
- ENGINE:
run_engine.sh
- RKL:
run_rkl.sh
- KL:
run_kl_sample.sh
- JS:
run_js.sh
- TVD:
run_tvd_symm.sh
- Seqkd:
-
Decode (
/scripts/
) (Google drive link)run_eval.sh
-
Eval
cd /evaluation/
(Google drive link)sh ./run_eval_on_dart.sh
(need to modify the$OUTPUT_FILE
and download bert-base-uncased model)
-
Calculate coverage loss (PPL of teacher)
- run
python3 dart/run_calc_ppl.py --reference_path [teacher_output_path] --input_path [input_path] --model_name [student_model_path] --save_path /tmp/
- run
-
Calculate likelihood loss (PPL of student)
- run
python3 dart/run_calc_ppl.py --model_name [teacher_model_path] --input_path [input_path] --reference_path [student_output_path] --save_path /tmp/
- run
- Folder:
/summa/
- Dataset:
xsum
(wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz
) - Teacher model:
https://huggingface.co/facebook/bart-large-xsum
- Student init model:
- Pre-train init student model:
run_pretrain_student_distill.sh
(Google drive link)
- Pre-train init student model:
- Generate pseudo-target with Teacher model (useful for Seqkd, JS, TVD)
run_teacher_label.sh
(Google drive link), replacetrain.target
- Run KD methods (
/scripts/
)- Seqkd:
run_seqkd.sh
- ENGINE:
run_engine.sh
- RKL:
run_rkl.sh
- KL:
run_kl.sh
- JS:
run_js.sh
- TVD:
run_tvd_symm.sh
- Seqkd:
- Decode and Evaluate
- run
eval.sh
- run
- Calculate coverage loss (PPL of teacher)
- run
python3 summa/run_calc_ppl.py --reference_path [teacher_output_path] --input_path [input_path] --model_name [student_model_path] --save_path /tmp/
- run
- Calculate likelihood loss (PPL of student)
- run
python3 summa/run_calc_ppl.py --model_name [teacher_model_path] --input_path [input_path] --reference_path [student_output_path] --save_path /tmp/
- run
- Folder:
/t5mt/
- Dataset:
wmt_en_ro_100k
(Google drive link) - Teacher model
- Init:
t5-base
- Fine-tune:
scripts_sm/run_finetune_teacher.sh
- Init:
- Student init model:
- Pre-train init student model:
scripts_sm/run_pretrain_student_distill.sh
(Google drive link)
- Pre-train init student model:
- Generate pseudo-target with Teacher model (useful for Seqkd, JS, TVD)
scripts_sm/run_teacher_label.sh
(Google drive link), replacetrain.target
- Run KD methods (
/scripts/
)- Seqkd:
run_seqkd.sh
- ENGINE:
run_engine.sh
- RKL:
run_rkl.sh
- KL:
run_kl.sh
- JS:
run_js.sh
- TVD:
run_tvd_symm.sh
- Seqkd:
- Decode and eval
sh scripts_sm/run_eval.sh
- Calculate coverage loss (PPL of teacher)
- run
python3 t5mt/run_calc_ppl.py --reference_path [teacher_output_path] --input_path [input_path] --model_name [student_model_path] --save_path /tmp/
- run
- Calculate likelihood loss (PPL of student)
- run
python3 t5mt/run_calc_ppl.py --model_name [teacher_model_path] --input_path [input_path] --reference_path [student_output_path] --save_path /tmp/
- run
- Folder:
/chat/
- Dataset:
Commonsense-Dialogues
(Google drive link) - Teacher model
- Init:
microsoft/DialoGPT-medium
- Fine-tune:
scripts/run_finetune_teacher.sh
(Google drive link)
- Init:
- Student init model:
- Pre-train init student model:
scripts/run_pretrain_student_distill.sh
(Google drive link)
- Pre-train init student model:
- Generate pseudo-target with Teacher model (useful for Seqkd, JS, TVD) (Google drive link), replace
train.target
scripts/run_teacher_label.sh
- Run KD methods (
/scripts/
)- Seqkd:
run_seqkd.sh
- ENGINE:
run_engine.sh
- RKL:
run_rkl.sh
- KL:
run_kl.sh
- JS:
run_js.sh
- TVD:
run_tvd_symm.sh
- Seqkd:
- Decode and eval
- sh
scripts/run_eval.sh
(need to download bert-base-uncased model)
- sh
- Calculate coverage loss (PPL of teacher)
- run
python3 chat/run_calc_ppl.py --reference_path [teacher_output_path] --input_path [input_path] --model_name [student_model_path] --save_path /tmp/
- run
- Calculate likelihood loss (PPL of student)
- run
python3 chat/run_calc_ppl.py --model_name [teacher_model_path] --input_path [input_path] --reference_path [student_output_path] --save_path /tmp/
- run
- The methods in our codebase are mainly implemented with PyTorch and Huggingface's Transformers libararies
- The pre-distillation part is based on the method proposed in Shleifer & Rush (2020) and their implementation
- We use Maluuba's nlg-eval to measure the BLEU score for the dialogue task.