-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmodel.py
51 lines (43 loc) · 2.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
@file : model.py
@author : xiaolu
@email : [email protected]
@time : 2022-01-07
"""
import torch
from torch import nn
from transformers import BertConfig, BertModel
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.config = BertConfig.from_pretrained('./mengzi_pretrain/config.json')
self.bert = BertModel.from_pretrained('./mengzi_pretrain/pytorch_model.bin', config=self.config)
def forward(self, input_ids, attention_mask, encoder_type='fist-last-avg'):
'''
:param input_ids:
:param attention_mask:
:param encoder_type: encoder_type: "first-last-avg", "last-avg", "cls", "pooler(cls + dense)"
:return:
'''
output = self.bert(input_ids, attention_mask, output_hidden_states=True)
if encoder_type == 'fist-last-avg':
# 第一层和最后一层的隐层取出 然后经过平均池化
first = output.hidden_states[1] # hidden_states列表有13个hidden_state,第一个其实是embeddings,第二个元素才是第一层的hidden_state
last = output.hidden_states[-1]
seq_length = first.size(1) # 序列长度
first_avg = torch.avg_pool1d(first.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # batch, hid_size
last_avg = torch.avg_pool1d(last.transpose(1, 2), kernel_size=seq_length).squeeze(-1) # batch, hid_size
final_encoding = torch.avg_pool1d(torch.cat([first_avg.unsqueeze(1), last_avg.unsqueeze(1)], dim=1).transpose(1, 2), kernel_size=2).squeeze(-1)
return final_encoding
if encoder_type == 'last-avg':
sequence_output = output.last_hidden_state # (batch_size, max_len, hidden_size)
seq_length = sequence_output.size(1)
final_encoding = torch.avg_pool1d(sequence_output.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
return final_encoding
if encoder_type == "cls":
sequence_output = output.last_hidden_state
cls = sequence_output[:, 0] # [b,d]
return cls
if encoder_type == "pooler":
pooler_output = output.pooler_output # [b,d]
return pooler_output