Skip to content

Commit

Permalink
v0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Su Dan committed Apr 11, 2020
1 parent b0f4f52 commit 0e622c2
Show file tree
Hide file tree
Showing 44 changed files with 7,740 additions and 0 deletions.
5 changes: 5 additions & 0 deletions build/lib/covidSumm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .abstractive_bart_model import *
from .abstractive_config import *
from .abstractive_model import *
from .abstractive_utils import *
from .abstractive_api import *
14 changes: 14 additions & 0 deletions build/lib/covidSumm/abstractive_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .abstractive_utils import abstractive_api


def abstractive_api_uni_para(query):
return abstractive_api(query, 'unilm_para')

def abstractive_api_bart_para(query):
return abstractive_api(query, 'bart_para')

def abstractive_api_uni_article(query):
return abstractive_api(query, 'unilm_article')

def abstractive_api_bart_article(query):
return abstractive_api(query, 'bart_article')
86 changes: 86 additions & 0 deletions build/lib/covidSumm/abstractive_bart_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import json
from .abstractive_utils import get_ir_result, result_to_json, get_qa_result
from fairseq.models.bart import BARTModel


class Bart_model(object):
def __init__(self, model_path):
# self.model = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
self.model = BARTModel.from_pretrained(model_path, checkpoint_file='model.pt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
self.model.eval()
self.model.half()
self.count = 1
self.bsz = 2
self.summary_list = []
self.slines = []

def bart_generate_summary(self, paragraphs_list):
self.summary_list = []
self.slines = []
for i in range(len(paragraphs_list)):
self.sline = paragraphs_list[i]['src'].strip()
self.slines.append(self.sline.strip())
if self.count % self.bsz == 0:
with torch.no_grad():
hypotheses_batch = self.model.sample(self.slines, beam=4, lenpen=2.0, max_len_b=520, min_len=55, no_repeat_ngram_size=3)

for hypothesis in hypotheses_batch:
self.summary_list.append(hypothesis)
self.slines = []
self.count += 1

if self.slines != []:
hypotheses_batch = self.model.sample(self.slines, beam=4, lenpen=2.0, max_len_b=520, min_len=55, no_repeat_ngram_size=3)
for hypothesis in hypotheses_batch:
self.summary_list.append(hypothesis)
return self.summary_list


def bart_generate_summary_list(list_of_paragraphs_list, bart_model):
count = bart_model.count
bsz = bart_model.bsz

list_of_summary_list = []

for paragraphs_list in list_of_paragraphs_list:
summary_list = bart_model.bart_generate_summary(paragraphs_list)
summary_result = ""
for item in summary_list:
summary_result += item.replace("\n", ' ')

list_of_summary_list.append(summary_result)

return list_of_summary_list



def get_bart_answer_summary(query, bart_model):
paragraphs_list = get_qa_result(query, topk = 3)
answer_summary_list = bart_model.bart_generate_summary(paragraphs_list)
answer_summary_result = ""
for item in answer_summary_list:
answer_summary_result += item.replace('\n', ' ')

answer_summary_json = {}
answer_summary_json['summary'] = answer_summary_result
answer_summary_json['question'] = query
return answer_summary_json


def get_bart_article_summary(query, bart_model, topk = 3):
article_list, meta_info_list = get_ir_result(query, topk)
summary_list = bart_generate_summary_list(article_list, bart_model)
summary_list_json = []
with open('summary_bart.output', 'w') as fout:
for i in range(len(summary_list)):
json_summary = {}
json_summary = result_to_json(meta_info_list[i], summary_list[i])
summary_list_json.append(json_summary)
json.dump(json_summary, fout)
fout.write('\n')

return summary_list_json

114 changes: 114 additions & 0 deletions build/lib/covidSumm/abstractive_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import argparse
import os

from transformers import BertTokenizer
from .s2s_ft.tokenization_unilm import UnilmTokenizer


TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'unilm': UnilmTokenizer,
}

import easydict

def set_config():

args = easydict.EasyDict({
"model_type": 'unilm',
"tokenizer_name": 'unilm1.2-base-uncased',
"config_path": None,
"config_path": None,
"max_seq_length": 512,
"fp16": True,
"split": "validation",
"seed": 123,
"do_lower_case": True,
"batch_size": 1,
"beam_size":5,
"length_penalty": 0,
"forbid_duplicate_ngrams": True,
"forbid_ignore_word": '.',
"min_len": 50,
"ngram_size":3,
"mode": 's2s',
"max_tgt_length": 48,
"cache_dir": None,
"pos_shift": False,
"need_score_traces": False,
"model_path": ""
})

return args



def set_config1():

parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--model_type", default='unilm', type=str,
help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys()))
parser.add_argument("--model_path", default='./checkpoint/ckpt-32000', type=str,
help="Path to the model checkpoint.")
parser.add_argument("--config_path", default=None, type=str,
help="Path to config.json for the model.")

# tokenizer_name
parser.add_argument("--tokenizer_name", default='unilm1.2-base-uncased', type=str,
help="tokenizer name")
parser.add_argument("--max_seq_length", default=512, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")

# decoding parameters
parser.add_argument('--fp16', default=True, type=bool,
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--amp', action='store_true',
help="Whether to use amp for fp16")
# parser.add_argument("--input_file", type=str, help="Input file")
parser.add_argument('--subset', type=int, default=0,
help="Decode a subset of the input dataset.")
parser.add_argument("--output_file", type=str, help="output file")
parser.add_argument("--split", type=str, default="validation",
help="Data split (train/val/test).")
parser.add_argument('--tokenized_input', action='store_true',
help="Whether the input is tokenized.")
parser.add_argument('--seed', type=int, default=123,
help="random seed for initialization")
parser.add_argument("--do_lower_case", default=True, type=bool,
help="Set this flag if you are using an uncased model.")
parser.add_argument('--batch_size', type=int, default=1,
help="Batch size for decoding.")
parser.add_argument('--beam_size', type=int, default=5,
help="Beam size for searching")
parser.add_argument('--length_penalty', type=float, default=0,
help="Length penalty for beam search")

parser.add_argument('--forbid_duplicate_ngrams', type=bool, default=True)
parser.add_argument('--forbid_ignore_word', type=str, default='.',
help="Forbid the word during forbid_duplicate_ngrams")
parser.add_argument("--min_len", default=50, type=int)
parser.add_argument('--need_score_traces', action='store_true')
parser.add_argument('--ngram_size', type=int, default=3)
parser.add_argument('--mode', default="s2s",
choices=["s2s", "l2r", "both"])
parser.add_argument('--max_tgt_length', type=int, default=48,
help="maximum length of target sequence")
parser.add_argument('--s2s_special_token', action='store_true',
help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
parser.add_argument('--s2s_add_segment', action='store_true',
help="Additional segmental for the encoder of S2S.")
parser.add_argument('--s2s_share_segment', action='store_true',
help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).")
parser.add_argument('--pos_shift', action='store_true',
help="Using position shift for fine-tuning.")
parser.add_argument("--cache_dir", default=None, type=str,
help="Where do you want to store the pre-trained models downloaded from s3")


args = parser.parse_args()
# args, unknown = parser.parse_known_args()

return args
Loading

0 comments on commit 0e622c2

Please sign in to comment.