forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHeCo.py
348 lines (286 loc) · 12.5 KB
/
HeCo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data.utils import load_graphs
from dgl.sampling import sample_neighbors
from dgl.nn.pytorch import GATConv, GraphConv
from openhgnn.models import BaseModel, register_model
from ..utils.utils import extract_metapaths
def init_drop(dropout):
if dropout > 0:
return nn.Dropout(dropout)
else:
return lambda x: x
@register_model('HeCo')
class HeCo(BaseModel):
r"""
**Title:** Self-supervised Heterogeneous Graph Neural Network with Co-contrastive Learning
**Authors:** Xiao Wang, Nian Liu, Hui Han, Chuan Shi
HeCo was introduced in `[paper] <http://shichuan.org/doc/112.pdf>`_
and parameters are defined as follows:
Parameters
----------
meta_paths : dict
Extract metapaths from graph
network_schema : dict
Directed edges from other types to target type
category : string
The category of the nodes to be classificated
hidden_size : int
Hidden units size
feat_drop : float
Dropout rate for projected feature
attn_drop : float
Dropout rate for attentions used in two view guided encoders
sample_rate : dict
The nuber of neighbors of each type sampled for network schema view
tau : float
Temperature parameter used for contrastive loss
lam : float
Balance parameter for two contrastive losses
"""
@classmethod
def build_model_from_args(cls, args, hg):
if args.meta_paths_dict is None:
meta_paths_dict = extract_metapaths(args.category, hg.canonical_etypes)
else:
meta_paths_dict = args.meta_paths_dict
schema = []
for etype in hg.canonical_etypes:
if etype[2] == args.category:
schema.append(etype)
return cls(meta_paths_dict=meta_paths_dict, network_schema=schema, category=args.category,
hidden_size=args.hidden_dim, feat_drop=args.feat_drop,
attn_drop=args.attn_drop, sample_rate=args.sample_rate, tau=args.tau, lam=args.lam)
def __init__(self, meta_paths_dict, network_schema, category, hidden_size, feat_drop, attn_drop
, sample_rate, tau, lam):
super(HeCo, self).__init__()
self.category = category # target node type
self.feat_drop = init_drop(feat_drop)
self.attn_drop = attn_drop
self.mp = Mp_encoder(meta_paths_dict, hidden_size, self.attn_drop)
self.sc = Sc_encoder(network_schema, hidden_size, self.attn_drop, sample_rate, self.category)
self.contrast = Contrast(hidden_size, tau, lam)
def forward(self, g, h_dict, pos):
r"""
This is the forward part of model HeCo.
Parameters
----------
g : DGLGraph
A DGLGraph
h_dict: dict
Projected features after linear projection
pos: matrix
A matrix to indicate the postives for each node
Returns
-------
loss : float
The optimize objective
Note
-----------
Pos matrix is pre-defined by users. The relative tool is given in original code.
"""
new_h = {}
for key, value in h_dict.items():
new_h[key] = F.elu(self.feat_drop(value))
z_mp = self.mp(g, new_h[self.category])
z_sc = self.sc(g, new_h)
loss = self.contrast(z_mp, z_sc, pos)
return loss
def get_embeds(self, g, h_dict):
r"""
This is to get final embeddings of target nodes
"""
z_mp = F.elu(h_dict[self.category])
z_mp = self.mp(g, z_mp)
return z_mp.detach()
class SelfAttention(nn.Module):
def __init__(self, hidden_dim, attn_drop, txt):
r"""
This part is used to calculate type-level attention and semantic-level attention, and utilize them to generate :math:`z^{sc}` and :math:`z^{mp}`.
.. math::
w_{n}&=\frac{1}{|V|}\sum\limits_{i\in V} \textbf{a}^\top \cdot \tanh\left(\textbf{W}h_i^{n}+\textbf{b}\right) \\
\beta_{n}&=\frac{\exp\left(w_{n}\right)}{\sum_{i=1}^M\exp\left(w_{i}\right)} \\
z &= \sum_{n=1}^M \beta_{n}\cdot h^{n}
Parameters
----------
txt : str
A str to identify view, MP or SC
Returns
-------
z : matrix
The fused embedding matrix
"""
super(SelfAttention, self).__init__()
self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True)
nn.init.xavier_normal_(self.fc.weight, gain=1.414)
self.tanh = nn.Tanh()
self.att = nn.Parameter(torch.empty(size=(1, hidden_dim)), requires_grad=True)
nn.init.xavier_normal_(self.att.data, gain=1.414)
self.softmax = nn.Softmax(dim=0)
self.attn_drop = init_drop(attn_drop)
self.txt = txt
def forward(self, embeds):
beta = []
attn_curr = self.attn_drop(self.att)
for embed in embeds:
sp = self.tanh(self.fc(embed)).mean(dim=0)
beta.append(attn_curr.matmul(sp.t()))
beta = torch.cat(beta, dim=-1).view(-1)
beta = self.softmax(beta)
print(self.txt, beta.data.cpu().numpy()) # semantic attention
z = 0
for i in range(len(embeds)):
z += embeds[i] * beta[i]
return z
class Mp_encoder(nn.Module):
def __init__(self, meta_paths_dict, hidden_size, attn_drop):
r"""
This part is to encode meta-path view.
Returns
-------
z_mp : matrix
The embedding matrix under meta-path view.
"""
super(Mp_encoder, self).__init__()
# One GCN layer for each meta path based adjacency matrix
self.act = nn.PReLU()
self.gcn_layers = nn.ModuleDict()
for mp in meta_paths_dict:
one_layer = GraphConv(hidden_size, hidden_size, activation=self.act, allow_zero_in_degree=True)
one_layer.reset_parameters()
self.gcn_layers[mp] = one_layer
self.meta_paths_dict = meta_paths_dict
self._cached_graph = None
self._cached_coalesced_graph = {}
self.semantic_attention = SelfAttention(hidden_size, attn_drop, "mp")
def forward(self, g, h):
semantic_embeddings = []
if self._cached_graph is None or self._cached_graph is not g:
self._cached_graph = g
self._cached_coalesced_graph.clear()
for mp, meta_path in self.meta_paths_dict.items():
self._cached_coalesced_graph[mp] = dgl.metapath_reachable_graph(
g, meta_path)
for mp, meta_path in self.meta_paths_dict.items():
new_g = self._cached_coalesced_graph[mp]
one = self.gcn_layers[mp](new_g, h)
semantic_embeddings.append(one) # node level attention
z_mp = self.semantic_attention(semantic_embeddings)
return z_mp
class Sc_encoder(nn.Module):
def __init__(self, network_schema, hidden_size, attn_drop, sample_rate, category):
r"""
This part is to encode network schema view.
Returns
-------
z_mp : matrix
The embedding matrix under network schema view.
Note
-----------
There is a different sampling strategy between original code and this code. In original code, the authors implement sampling without replacement if the number of neighbors exceeds a threshold,
and with replacement if not. In this version, we simply use the API dgl.sampling.sample_neighbors to implement this operation, and set replacement as True.
"""
super(Sc_encoder, self).__init__()
self.gat_layers = nn.ModuleList()
for i in range(len(network_schema)):
one_layer = GATConv((hidden_size, hidden_size), hidden_size, num_heads=1, attn_drop=attn_drop,
allow_zero_in_degree=True)
one_layer.reset_parameters()
self.gat_layers.append(one_layer)
self.network_schema = list(tuple(ns) for ns in network_schema)
self._cached_graph = None
self._cached_coalesced_graph = {}
self.inter = SelfAttention(hidden_size, attn_drop, "sc")
self.sample_rate = sample_rate
self.category = category
def forward(self, g, h):
intra_embeddings = []
for i, network_schema in enumerate(self.network_schema):
src_type = network_schema[0]
one_graph = g[network_schema]
cate_num = torch.arange(0, g.num_nodes(self.category)).to(g.device)
sub_graph = sample_neighbors(one_graph, {self.category: cate_num},
{network_schema[1]: self.sample_rate[src_type]}, replace=True)
one = self.gat_layers[i](sub_graph, (h[src_type], h[self.category]))
one = one.squeeze(1)
intra_embeddings.append(one)
z_sc = self.inter(intra_embeddings)
return z_sc
class Contrast(nn.Module):
def __init__(self, hidden_dim, tau, lam):
r"""
This part is used to calculate the contrastive loss.
Returns
-------
contra_loss : float
The calculated loss
"""
super(Contrast, self).__init__()
self.proj = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.tau = tau
self.lam = lam
for model in self.proj:
if isinstance(model, nn.Linear):
nn.init.xavier_normal_(model.weight, gain=1.414)
def sim(self, z1, z2):
r"""
This part is used to calculate the cosine similarity of each pair of nodes from different views.
"""
z1_norm = torch.norm(z1, dim=-1, keepdim=True)
z2_norm = torch.norm(z2, dim=-1, keepdim=True)
dot_numerator = torch.mm(z1, z2.t())
dot_denominator = torch.mm(z1_norm, z2_norm.t())
sim_matrix = torch.exp(dot_numerator / dot_denominator / self.tau)
return sim_matrix
def forward(self, z_mp, z_sc, pos):
r"""
This is the forward part of contrast part.
We firstly project the embeddings under two views into the space where contrastive loss is calculated. Then, we calculate the contrastive loss with projected embeddings in a cross-view way.
.. math::
\mathcal{L}_i^{sc}=-\log\frac{\sum_{j\in\mathbb{P}_i}exp\left(sim\left(z_i^{sc}\_proj,z_j^{mp}\_proj\right)/\tau\right)}{\sum_{k\in\{\mathbb{P}_i\bigcup\mathbb{N}_i\}}exp\left(sim\left(z_i^{sc}\_proj,z_k^{mp}\_proj\right)/\tau\right)}
where we show the contrastive loss :math:`\mathcal{L}_i^{sc}` under network schema view, and :math:`\mathbb{P}_i` and :math:`\mathbb{N}_i` are positives and negatives for node :math:`i`.
In a similar way, we can get the contrastive loss :math:`\mathcal{L}_i^{mp}` under meta-path view. Finally, we utilize combination parameter :math:`\lambda` to add this two losses.
Note
-----------
In implementation, each row of 'matrix_mp2sc' means the similarity with exponential between one node in meta-path view and all nodes in network schema view. Then, we conduct normalization for this row,
and pick the results where the pair of nodes are positives. Finally, we sum these results for each row, and give a log to get the final loss.
"""
z_proj_mp = self.proj(z_mp)
z_proj_sc = self.proj(z_sc)
matrix_mp2sc = self.sim(z_proj_mp, z_proj_sc)
matrix_sc2mp = matrix_mp2sc.t()
matrix_mp2sc = matrix_mp2sc / (torch.sum(matrix_mp2sc, dim=1).view(-1, 1) + 1e-8)
lori_mp = -torch.log(matrix_mp2sc.mul(pos.to_dense()).sum(dim=-1)).mean()
matrix_sc2mp = matrix_sc2mp / (torch.sum(matrix_sc2mp, dim=1).view(-1, 1) + 1e-8)
lori_sc = -torch.log(matrix_sc2mp.mul(pos.to_dense()).sum(dim=-1)).mean()
contra_loss = self.lam * lori_mp + (1 - self.lam) * lori_sc
return contra_loss
'''logreg'''
class LogReg(nn.Module):
r"""
Parameters
----------
ft_in : int
Size of hid_units
nb_class : int
The number of category's types
"""
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(ft_in, nb_classes)
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, seq):
ret = self.fc(seq)
return ret