Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[new] Add a trick for StaticEmbedding #317

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions fastNLP/embeddings/static_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,8 @@ def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pa
else:
dim = len(parts) - 1
f.seek(0)
matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word)
if vocab.padding:
matrix[vocab.padding_idx] = torch.zeros(dim)
if vocab.unknown:
matrix[vocab.unknown_idx] = torch.zeros(dim)
found_count = 0
pre_train_matrix = {} # index是word,value是pre train vector
# 首先加载所有预训练词向量
found_unknown = False
for idx, line in enumerate(f, start_idx):
try:
Expand All @@ -278,19 +274,41 @@ def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pa
elif word == unknown and vocab.unknown is not None:
word = vocab.unknown
found_unknown = True
if word in vocab:
index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
if self.only_norm_found_vector:
matrix[index] = matrix[index] / np.linalg.norm(matrix[index])
found_count += 1
pre_train_matrix[word] = torch.from_numpy(
np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
if self.only_norm_found_vector:
pre_train_matrix[word] = pre_train_matrix[word] / np.linalg.norm(pre_train_matrix[word])
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
logger.error("Error occurred at the {} line.".format(idx))
raise e

matrix = {} # index是word在vocab中的index,value是vector或None(如果在pretrain中没有找到该word)
if vocab.padding:
matrix[vocab.padding_idx] = torch.zeros(dim)
if vocab.unknown:
matrix[vocab.unknown_idx] = torch.zeros(dim)
found_count = 0
# 遍历vocab,查看vocab的word是否存在于预训练词向量中,若原始word不存在,对word分别进行小写化、大写化、
# 首字母大写化处理后,判断处理后的word是否存在,从而增大vocab匹配到词向量的比例。
for word, index_in_vocab in vocab:
if word in pre_train_matrix:
matrix[index_in_vocab] = pre_train_matrix[word]
found_count += 1
elif word.lower() in pre_train_matrix:
matrix[index_in_vocab] = pre_train_matrix[word.lower()]
found_count += 1
elif word.upper() in pre_train_matrix:
matrix[index_in_vocab] = pre_train_matrix[word.upper()]
found_count += 1
elif word.capitalize() in pre_train_matrix:
matrix[index_in_vocab] = pre_train_matrix[word.capitalize()]
found_count += 1

logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))

if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
Expand Down