Skip to content

Latest commit

 

History

History
82 lines (69 loc) · 4.94 KB

README.md

File metadata and controls

82 lines (69 loc) · 4.94 KB

CGDIALOG

This repository contains the dataset and the pytorch implementations of the models from the paper Less is More: Mitigate Spurious Correlations for Open-Domain Dialogue Response Generation Models by Causal Discovery.

Dataset

We leverage two public dialogue corpora (ESConv and MSC) to construct a corpus annotated with direct causes of responses called CGDIALOG. The original annotated dataset can be found in datasets/CGDIALOG.

Statistics of the CGDIALOG

Number of Items ESConv MSC Total
Dialogues 80 80 160
History-response paris 694 800 922
Utterances 2301 3807 6108
Utterances containing direct causes 1347 1525 2872
Average token length of direct causes 24.01 (std=16.61) 22.22 (std=13.79) 23.05 (std=15.20)
The proportion of direct causes in original utterances 0.86 (std=0.22) 0.72 (std=0.27) 0.79 (std=0.26)

number of direct causes distance to response

Data Format

The dataset format is like the following.

{
"HITId": "3RIHDBQ1OL0HF1CNXY41R5Y9Q7OHMC",
"WorkerId": "ARH3NPT7GUFQ6",
"history": [ # dialogue history
    "seeker: Hi!",
    "supporter: Hello, how are you doing today?",
    "seeker: Not so good. I have conspiracy theorist as a friend who is now mad at me because I told her to pull up her mask while talking to me.",
    "seeker: We have been friends for 13 years",
    "seeker: I am hurt and confused that she still thinks this is a game.",
    "seeker: She thought Corona was fake until someone we know caught it.",
    "seeker: It is like she is mad she was wrong and is taking it out and lashing out at those who have been trying to persuade her the whole time...",
    "seeker: what do you think?",
    ""
],
"response": "It sounds like you care a lot about your friend and others. How old is your friend?",
"entities": [ # direct causes of responses that are annotated by workers.
    "Not so good. I have conspiracy theorist as a friend who is now mad at me because I told her to pull up her mask while talking to me.",
    "I am hurt and confused that she still thinks this is a game.",
    "It is like she is mad she was wrong and is taking it out and lashing out at those who have been trying to persuade her the whole time..."
]
}

Setup:

The code is based on PyTorch and HuggingFace transformers.

cd CGDIALOG
conda create --prefix env/ python=3.6
conda activate env/
pip install -r requirements.txt 

Train Models

python run_dialogue_generation_no_trainer.py --train_file datasets/CGDIALOG/ESConv_causal_generator_train.csv --model_name_or_path models/blenderbot_400M_distill/ --output_dir models/ESConv_causal_generator_model_new
python run_dialogue_generation_no_trainer.py --train_file datasets/CGDIALOG/msc_causal_generator_train.csv --model_name_or_path models/blenderbot_400M_distill/ --output_dir models/msc_causal_generator_model_new

Generate ESConv responses

python test_causal_response_generator.py --validation_file datasets/ESConv/test_dataset.json --model_name_or_path models/ESConv_causal_generator_model/ --tokenizer_name models/blenderbot_400M_distill/ --twoCondition_tc_model_name_or_path models/ESConv_classifier/ --tc_tokenizer_name models/roberta_base/ --output_dir outputs/ESConv_causal_generator_test_result

Generate MSC responses

python test_causal_response_generator.py --validation_file datasets/msc/msc_dialogue/session_4/test.json --model_name_or_path models/msc_causal_generator_model/ --tokenizer_name models/blenderbot_400M_distill/ --twoCondition_tc_model_name_or_path models/msc_classifier/ --tc_tokenizer_name models/roberta_base/ --output_dir outputs/msc_causal_generator_test_result

Select final response

python
from select_best_response_ourModel import select_highScore_response
select_highScore_response(ourModel_file="outputs/ESConv_causal_generator_test_result", save_file="outputs/ESConv_test_result_in_testset_highestScore.json")
select_highScore_response(ourModel_file="outputs/msc_causal_generator_test_result", save_file="outputs/msc_test_result_in_testset_highestScore.json")