-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
156 lines (139 loc) · 8.16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
from transformers import AutoModel
from modules import Instruction, EntityInit, GATLayer, NSMLayer, GraphEncType
class QAModel(nn.Module):
def __init__(self, word_size, word_dim, hidden_dim, question_dropout, linear_dropout, num_step, pretrained_emb,
entity_size, entity_dim, relation_size, relation_dim, pretrained_relation, direction, graph_encoder_type,
gat_head_dim, gat_head_size, gat_dropout, gat_skip, gat_bias, attn_key, attn_value,
pretrained_model_name, hugging_face_cache):
super(QAModel, self).__init__()
assert direction in ('all', 'inward', 'outward')
self.num_step = num_step
self.direction = direction
#
# Question Encoder
#
self.pretrained_model_name = pretrained_model_name
if pretrained_model_name is not None:
self.pretrained_model = AutoModel.from_pretrained(pretrained_model_name, cache_dir=hugging_face_cache)
self.pretrained_model.requires_grad_(False)
else:
if pretrained_emb is None:
self.word_embedding = nn.Embedding(word_size, word_dim, padding_idx=0)
else:
self.word_embedding = nn.Embedding.from_pretrained(pretrained_emb, padding_idx=0, freeze=False)
self.instruction_generator = Instruction(word_dim, hidden_dim, question_dropout, linear_dropout, num_step,
pretrained_model_name is not None)
if entity_size > 0:
self.entity_embedding = nn.Embedding(entity_size+1, entity_dim, padding_idx=entity_size)
self.ent_linear = nn.Linear(entity_dim, hidden_dim * (2 if direction == 'all' else 1))
else:
self.entity_embedding = None
#
# Relation Embedding & Entity Encoder
#
if pretrained_relation is None:
self.relation_embedding = nn.Embedding(relation_size, relation_dim)
else:
self.relation_embedding = nn.Embedding.from_pretrained(pretrained_relation, freeze=False)
self.rel_norm = nn.LayerNorm(relation_dim)
self.entity_encoder = EntityInit(relation_dim, hidden_dim, direction)
#
# Graph Encoder
#
self.graph_encoder_type = graph_encoder_type
num_dir = 2 if direction == 'all' else 1
layers = []
if graph_encoder_type == GraphEncType.GAT.name:
num_in_features = [hidden_dim * num_dir] + [gat_head_dim * gat_head_size * num_dir for _ in range(num_step-1)]
for i in range(num_step):
layers.append(GATLayer(
num_in_features[i], gat_head_dim, gat_head_size, relation_dim, hidden_dim,
concat=True, activation=nn.ELU(),
dropout_prob=gat_dropout, add_skip_connection=gat_skip, bias=gat_bias, direction=direction
))
self.entity_proj = nn.Linear(gat_head_dim * gat_head_size * num_dir, hidden_dim)
elif graph_encoder_type == GraphEncType.NSM.name:
for i in range(num_step):
layers.append(NSMLayer(
hidden_dim * num_dir, gat_head_dim, gat_head_size, hidden_dim, relation_dim, concat=True,
dropout=gat_dropout, direction=direction, skip=gat_skip, attn_key=attn_key, attn_value=attn_value
))
self.entity_proj = nn.Linear(hidden_dim * num_dir, hidden_dim)
elif graph_encoder_type == GraphEncType.MIX.name:
num_in_features = [hidden_dim * num_dir] + [gat_head_dim * gat_head_size * num_dir for _ in range(num_step-1)]
for i in range(num_step):
layers.append(GATLayer(
num_in_features[i], gat_head_dim, gat_head_size, relation_dim, hidden_dim,
concat=True, activation=nn.ELU(),
dropout_prob=gat_dropout, add_skip_connection=gat_skip, bias=gat_bias, direction=direction
))
layers.append(NSMLayer(
gat_head_dim * gat_head_size * num_dir, gat_head_dim, gat_head_size, hidden_dim, relation_dim,
concat=True, dropout=gat_dropout, direction=direction, skip=gat_skip
))
self.entity_proj = nn.Linear(gat_head_dim * gat_head_size * num_dir, hidden_dim)
else:
raise ValueError("Unknown Graph Encoder Type: " + graph_encoder_type)
self.layers = nn.ModuleList(layers)
def forward(self, batch):
question, question_mask, topic_label, candidate_entity, entity_mask, subgraph = batch
batch_ids, batch_relations, edge_index = subgraph
batch_size, max_local_entity = topic_label.shape
if self.pretrained_model_name is not None:
if 't5' in self.pretrained_model_name:
question = self.pretrained_model(
input_ids=question, attention_mask=question_mask,
decoder_input_ids=question, decoder_attention_mask=question_mask
).encoder_last_hidden_state
else:
question = self.pretrained_model(input_ids=question, attention_mask=question_mask).last_hidden_state
instructions, question, _ = self.instruction_generator(question, question_mask)
else:
# batch size, max seq len, word dim
question = self.word_embedding(question)
# [ batch size, hidden dim ]
# batch size, 1, hidden dim
# [ batch size, max seq len ]
instructions, question, attentions = self.instruction_generator(question, question_mask)
# print("Question: %s" % question[0, 0, :5].tolist())
# fact size, relation dim
fact_relations = self.rel_norm(self.relation_embedding(batch_relations))
# print("Relation Embedding: %s" % fact_relations[0, :5].tolist())
# batch size * max local entity, hidden dim * num dir
if self.entity_embedding is None:
entity_emb = self.entity_encoder(fact_relations, edge_index, batch_size*max_local_entity)
else:
entity_emb = self.ent_linear(self.entity_embedding(candidate_entity).view(batch_size*max_local_entity, -1))
# print("Entity Embedding: %s" % entity_emb[0, :5].tolist())
for i in range(self.num_step):
# print("instruction: %s" % instructions[i][0, :5].tolist())
if self.graph_encoder_type == GraphEncType.GAT.name:
entity_emb = self.layers[i](
entity_emb, edge_index, fact_relations, instructions[i], batch_ids, max_local_entity
)
# print("Step %d entity embedding: %s" % (i+1, entity_emb[0, :5].tolist()))
elif self.graph_encoder_type == GraphEncType.NSM.name:
entity_emb, topic_label = self.layers[i](
entity_emb, fact_relations, instructions[i], edge_index, batch_ids, topic_label, entity_mask
)
# print("Step %d entity embedding: %s" % (i+1, entity_emb[0, :5].tolist()))
elif self.graph_encoder_type == GraphEncType.MIX.name:
entity_emb = self.layers[i*2](
entity_emb, edge_index, fact_relations, instructions[i], batch_ids, max_local_entity
)
# print("Step %d gat entity embedding: %s" % (i+1, entity_emb[0, :5].tolist()))
entity_emb, topic_label = self.layers[i*2+1](
entity_emb, fact_relations, instructions[i], edge_index, batch_ids, topic_label, entity_mask
)
# print("Step %d nsm entity embedding: %s" % (i+1, entity_emb[0, :5].tolist()))
else:
ValueError("Unknown Graph Encoder Type: " + self.graph_encoder_type)
# batch size, max local entity, hidden dim
entity_emb = self.entity_proj(entity_emb.view(batch_size, max_local_entity, -1))
# print("Project entity emb: %s" % entity_emb[0, 0, :5].tolist())
# batch size, max local entity
predict_scores = entity_mask * question.matmul(entity_emb.transpose(1, 2)).squeeze(1) + (1 - entity_mask) * -1e20
# print("Scores: %s" % predict_scores[0, :5].tolist())
return predict_scores