-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvocabulary.py
124 lines (104 loc) · 4.16 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import json
import logging
import os.path
import warnings
from tqdm import tqdm
import spacy
from typing import List, Set, Iterable
import numpy as np
# 词表类
# 用于将文本转换为序列,以及将序列转换为文本
class OOVWarning(Warning):
pass
class EndLabelNotFoundWarning(Warning):
pass
class Vocabulary:
pad = 0
start = 1
end = 2
unk = 3
def __init__(self, json_filepath):
with open(json_filepath, 'rb') as fp:
self._content: dict = json.load(fp) # str -> int
self.inv = {v: k for k, v in self._content.items()} # int -> str
self.spec_words: Set[int] = {self.pad, self.start, self.end, self.unk} # 特殊字符
self.size = len(self._content)
self.nlp = spacy.load("en_core_web_sm")
@staticmethod
def build(raw_documents: Iterable[str], save_path):
"""
Vocabulary Factory
:param raw_documents:
:param save_path:
:return:
"""
if os.path.exists(save_path):
warnings.warn('vocabulary file already exists, this operation is going to overwrite it')
content = {'<pad>': 0, '<start>': 1, '<end>': 2, '<unk>': 3}
nlp = spacy.load("en_core_web_sm")
for raw_document in tqdm(raw_documents, desc='building vocabulary'):
for token in nlp(raw_document):
if token.text.lower() not in content:
content[token.text.lower()] = len(content)
with open(save_path, 'w') as fw:
json.dump(content, fw)
return Vocabulary(save_path)
def __len__(self):
return len(self._content)
def __getitem__(self, word):
if word in self._content:
return self._content[word]
else:
warnings.warn(f'词表之外的词汇{word},被编码为 <unk>', OOVWarning)
return self.unk
def get_word2vec(self, cache_path='word2vec.npy'):
"""获取词表的词向量嵌入, spacy提供的支持,固定提供96维词嵌入"""
try:
return np.load(cache_path)
except FileNotFoundError:
logging.log(0, 'building word2vec matrix...')
word2vecs = []
for idx, word in self.inv.items():
word2vecs.append(
self.nlp(word)[0].tensor if idx not in self.spec_words else np.random.randn(96)) # nlp提供96维嵌入
np.save(cache_path, np.array(word2vecs))
return np.array(word2vecs)
def split(self, string: str):
return [token.text.lower() for token in self.nlp(string)]
@staticmethod
def post_process(string: str):
s1 = string.replace(' .', '.').replace(' ,', ',').replace(' -', '-').replace('- ', '-') # 标点间隔0
s2 = '. '.join([s[0].upper() + s[1:] for s in s1.split('. ') if s])
return s2
def pad_sequence(self, indices, fixed_length):
return (indices + [self.pad] * (fixed_length - len(indices)))[:fixed_length]
def encode(self, sentence) -> List[int]:
"""
将文本编码为序列
:param sentence:
:return:
"""
indices = [self.start] + [self[token.text.lower()] for token in self.nlp(sentence)] + [self.end]
return indices
def decode(self, sequence: List[int]) -> str:
"""
将序列解码为文本,跳过特殊字符,遇到终止符结束
:param sequence:
:return:
"""
words = []
for idx in sequence:
if idx == self.end:
break
if idx not in self.inv:
raise KeyError(f'索引{idx}超出了词表的范围,词表长度:{self.__len__()}')
elif idx not in self.spec_words: # 跳过特殊字符
words.append(self.inv[idx])
else:
warnings.warn('没有在序列中发现终止符<end>', EndLabelNotFoundWarning)
return self.post_process(' '.join(words))
if __name__ == "__main__":
train_labels_path = 'data/deepfashion-multimodal/train_captions.json'
with open(train_labels_path, 'rb') as fp:
train_labels: dict = json.load(fp)
vocab = Vocabulary.build(train_labels.values(), 'vocabulary/vocab.json')