-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathnn_utils.py
91 lines (86 loc) · 3.58 KB
/
nn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
def get_batch_input(dataset, bs, args, idx=None,all=False):
if all :
batch_input = dataset
else:
batch_input = dataset[idx*bs:(idx+1)*bs]
batch_data = pd.DataFrame.from_dict(batch_input)
if args.model == 'RGAT':
target_fields = [ 'wids', 'tids', 'dep_tag_ids', 'y', 'pw']
elif args.model in ['ASGCN','KGNN']:
if args.gcn ==1:
target_fields = [ 'wids', 'tids', 'y', 'pw', 'adj', 'mask',"know_adj","know_mask","wids_know"]
elif args.gcn == 2:
target_fields = [ 'wids', 'tids', 'y', 'pw', 'adj', 'mask',"graph_embedding"]
else:
target_fields = [ 'wids', 'tids', 'y', 'pw', 'adj', 'mask']
else:
target_fields = [ 'wids', 'tids', 'y', 'pw']
batch_input_var = []
for key in target_fields:
data = list(batch_data[key].values)
if key in ['pw', "graph_embedding"]:
batch_input_var.append(np.array(data, dtype='float32'))
else:
try:
batch_input_var.append(np.array(data, dtype='int32'))
except ValueError:
print(batch_data[key].values)
return batch_input_var
def get_batch_input_inference(dataset, bs, args, idx=None,all=False):
if all :
batch_input = dataset
else:
batch_input = dataset[idx*bs:(idx+1)*bs]
batch_data = pd.DataFrame.from_dict(batch_input)
if args.model == 'RGAT':
target_fields = [ 'wids', 'tids', 'dep_tag_ids', 'y', 'pw']
elif args.model in ['ASGCN','KGNN']:
if args.gcn ==1:
target_fields = [ 'words','twords','wids', 'tids', 'y', 'pw', 'adj', 'mask',"know_adj","know_mask","wids_know"]
elif args.gcn == 2:
target_fields = [ 'words','twords','wids', 'tids', 'y', 'pw', 'adj', 'mask',"graph_embedding"]
else:
target_fields = [ 'words','twords','wids', 'tids', 'y', 'pw', 'adj', 'mask']
else:
target_fields = [ 'wids', 'tids', 'y', 'pw']
batch_input_var = []
for key in target_fields:
data = list(batch_data[key].values)
if key in ['pw', "graph_embedding"]:
batch_input_var.append(np.array(data, dtype='float32'))
elif key in ['words','twords']:
batch_input_var.append(data)
else:
try:
batch_input_var.append(np.array(data, dtype='int32'))
except ValueError:
print(batch_data[key].values)
return batch_input_var
def get_batch_input_bert(dataset, bs, args, idx=None,all=False):
if all :
batch_input = dataset
else:
batch_input = dataset[idx*bs:(idx+1)*bs]
batch_data = pd.DataFrame.from_dict(batch_input)
if args.model == 'RGAT':
target_fields = [ 'bert_token', 'bert_token_aspect', 'dep_tag_ids', 'y', 'pw']
elif args.model in ['ASGCN','KGNN']:
target_fields = [ 'bert_token', 'bert_token_aspect', 'y', 'pw', 'adj', 'mask']
elif args.model =="BERT_vanilla":
target_fields = ['concat_token', 'y']
else:
target_fields = ['bert_token', 'bert_token_aspect', 'y', 'pw']
batch_input_var = []
for key in target_fields:
data = list(batch_data[key].values)
if key in ['pw']:
batch_input_var.append(np.array(data, dtype='float32'))
else:
try:
batch_input_var.append(np.array(data, dtype='int32'))
except ValueError:
print(batch_data[key].values)
return batch_input_var