From c18b205bc0b9e18dcd1ec3d61572f9b96284d19c Mon Sep 17 00:00:00 2001 From: hw <45089338+MorningForest@users.noreply.github.com> Date: Tue, 9 Nov 2021 16:14:05 +0800 Subject: [PATCH] update construct_graph (#393) --- fastNLP/io/file_utils.py | 5 + fastNLP/io/pipe/__init__.py | 20 ++- fastNLP/io/pipe/construct_graph.py | 268 +++++++++++++++++++++++++++++ 3 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 fastNLP/io/pipe/construct_graph.py diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index c7893f38..bbf3de1e 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -103,6 +103,11 @@ "yelp-review-polarity": "yelp_review_polarity.tar.gz", "sst-2": "SST-2.zip", "sst": "SST.zip", + 'mr': 'mr.zip', + "R8": "R8.zip", + "R52": "R52.zip", + "20ng": "20ng.zip", + "ohsumed": "ohsumed.zip", # Classification, Chinese "chn-senti-corp": "chn_senti_corp.zip", diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 94784515..35965ca3 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -23,15 +23,15 @@ "ChnSentiCorpPipe", "THUCNewsPipe", "WeiboSenti100kPipe", - "MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader", - + "MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe", + "Conll2003NERPipe", "OntoNotesNERPipe", "MsraNERPipe", "WeiboNERPipe", "PeopleDailyPipe", "Conll2003Pipe", - + "MatchingBertPipe", "RTEBertPipe", "SNLIBertPipe", @@ -53,14 +53,20 @@ "RenamePipe", "GranularizePipe", "MachingTruncatePipe", - + "CoReferencePipe", - "CMRC2018BertPipe" + "CMRC2018BertPipe", + + "R52PmiGraphPipe", + "R8PmiGraphPipe", + "OhsumedPmiGraphPipe", + "NG20PmiGraphPipe", + "MRPmiGraphPipe" ] from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ - WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader + WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe from .conll import Conll2003Pipe from .coreference import CoReferencePipe @@ -70,3 +76,5 @@ LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe from .pipe import Pipe from .qa import CMRC2018BertPipe + +from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py new file mode 100644 index 00000000..d597da9d --- /dev/null +++ b/fastNLP/io/pipe/construct_graph.py @@ -0,0 +1,268 @@ + +__all__ =[ + 'MRPmiGraphPipe', + 'R8PmiGraphPipe', + 'R52PmiGraphPipe', + 'OhsumedPmiGraphPipe', + 'NG20PmiGraphPipe' +] +try: + import networkx as nx + from sklearn.feature_extraction.text import CountVectorizer + from sklearn.feature_extraction.text import TfidfTransformer + from sklearn.pipeline import Pipeline +except: + pass +from collections import defaultdict +import itertools +import math +from tqdm import tqdm +import numpy as np + +from ..data_bundle import DataBundle +from ...core.const import Const +from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader + + +def _get_windows(content_lst: list, window_size:int): + r""" + 滑动窗口处理文本,获取词频和共现词语的词频 + :param content_lst: + :param window_size: + :return: 词频,共现词频,窗口化后文本段的数量 + """ + word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数 + word_pair_count = defaultdict(int) # w(i, j) + windows_len = 0 + for words in tqdm(content_lst, desc="Split by window"): + windows = list() + + if isinstance(words, str): + words = words.split() + length = len(words) + + if length <= window_size: + windows.append(words) + else: + for j in range(length - window_size + 1): + window = words[j: j + window_size] + windows.append(list(set(window))) + + for window in windows: + for word in window: + word_window_freq[word] += 1 + + for word_pair in itertools.combinations(window, 2): + word_pair_count[word_pair] += 1 + + windows_len += len(windows) + return word_window_freq, word_pair_count, windows_len + +def _cal_pmi(W_ij, W, word_freq_i, word_freq_j): + r""" + params: w_ij:为词语i,j的共现词频 + w:文本数量 + word_freq_i: 词语i的词频 + word_freq_j: 词语j的词频 + return: 词语i,j的tfidf值 + """ + p_i = word_freq_i / W + p_j = word_freq_j / W + p_i_j = W_ij / W + pmi = math.log(p_i_j / (p_i * p_j)) + + return pmi + +def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold): + r""" + params: windows_len: 文本段数量 + word_pair_count: 词共现频率字典 + word_window_freq: 词频率字典 + threshold: 阈值 + return 词语pmi的list列表,其中元素为[word1, word2, pmi] + """ + word_pmi_lst = list() + for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"): + word_freq_1 = word_window_freq[word_pair[0]] + word_freq_2 = word_window_freq[word_pair[1]] + + pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2) + if pmi <= threshold: + continue + word_pmi_lst.append([word_pair[0], word_pair[1], pmi]) + return word_pmi_lst + +class GraphBuilderBase: + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + self.graph = nx.Graph() + self.word2id = dict() + self.graph_type = graph_type + self.window_size = widow_size + self.doc_node_num = 0 + self.tr_doc_index = None + self.te_doc_index = None + self.dev_doc_index = None + self.doc = None + self.threshold = threshold + + def _get_doc_edge(self, data_bundle: DataBundle): + r''' + 对输入的DataBundle进行处理,然后生成文档-单词的tfidf值 + :param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果'] + : return 返回带有具有tfidf边文档-单词稀疏矩阵 + ''' + tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD)) + val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD)) + te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD)) + doc = tr_doc + val_doc + te_doc + self.doc = doc + self.tr_doc_index = [ind for ind in range(len(tr_doc))] + self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))] + self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))] + text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)), + ('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))]) + + tfidf_vec = text_tfidf.fit_transform(doc) + self.doc_node_num = tfidf_vec.shape[0] + vocab_lst = text_tfidf['count'].get_feature_names() + for ind, word in enumerate(vocab_lst): + self.word2id[word] = ind + for ind, row in enumerate(tfidf_vec): + for col_index, value in zip(row.indices, row.data): + self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value) + return nx.to_scipy_sparse_matrix(self.graph) + + def _get_word_edge(self): + word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size) + pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold) + for edge_item in pmi_edge_lst: + word_indx1 = self.doc_node_num + self.word2id[edge_item[0]] + word_indx2 = self.doc_node_num + self.word2id[edge_item[1]] + if word_indx1 == word_indx2: + continue + self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2]) + + def build_graph(self, data_bundle: DataBundle): + r""" + 对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。 + + :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 + :return: + """ + raise NotImplementedError + + def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + + :param paths: + :return: scipy_sparse_matrix + """ + raise NotImplementedError + + +class MRPmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r''' + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + ''' + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = MRLoader().load(path) + return self.build_graph(data_bundle) + +class R8PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r''' + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + ''' + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = R8Loader().load(path) + return self.build_graph(data_bundle) + +class R52PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r''' + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index. + ''' + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = R52Loader().load(path) + return self.build_graph(data_bundle) + +class OhsumedPmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r''' + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + ''' + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = OhsumedLoader().load(path) + return self.build_graph(data_bundle) + + +class NG20PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r''' + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + ''' + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + r''' + param: path->数据集的路径. + return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + ''' + data_bundle = NG20Loader().load(path) + return self.build_graph(data_bundle)