-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_graph_data.py
167 lines (153 loc) · 7.11 KB
/
create_graph_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
####
# This file is created due to the heavyness of the computations for creating a large dataset using graphs
####
import json
import torch
import time
from transformers import BertTokenizer, BertModel
from graph_utils import load_json
import argparse
parser = argparse.ArgumentParser(description='Extract graph data from triplets.')
parser.add_argument('--dataset', default='ucm', required=True, help='name of the dataset of which you want to create the triplets.')
def RSICD_tripl2graph(triplets_path, model, tokenizer):
'''
Function that extract DGLGraph data from triplets and puts them into files
Args:
triplets_path (str): path of the triplet's file
model (torch.nn.Module): model to extract the embeddings from tokens
tokenizer (torch.nn.Module): tokenizer to tokenize the captions
Return:
None
'''
all_triplets = load_json(triplets_path)
for split in all_triplets.keys():
if str(split) != 'discarded_images':
split_time = time.time()
triplets = all_triplets[split]
node_feats = {}
num_nodes = {}
src_ids = {}
dst_ids = {}
# Here to check what happen when split is not passed
for id in triplets:
f_tripl = []
tmp_dict = {}
tmp_id = 0
tmp_src_ids = []
tmp_dst_ids = []
tmp_node_feats = []
# Extract features from triplets
for _, tripl in enumerate(triplets[id]):
encoded_input = tokenizer(tripl, return_tensors='pt', add_special_tokens=False, padding=True)
output = model(**encoded_input)
f_tripl.append(output.pooler_output)
if tripl[0] not in list(tmp_dict.keys()):
tmp_dict[tripl[0]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[0]))
if tripl[1] not in list(tmp_dict.keys()):
tmp_dict[tripl[1]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[1]))
if tripl[2] not in list(tmp_dict.keys()):
tmp_dict[tripl[2]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[2]))
# Create source and destination lists
tmp_src_ids.append(tmp_dict[tripl[0]])
tmp_dst_ids.append(tmp_dict[tripl[1]])
tmp_src_ids.append(tmp_dict[tripl[1]])
tmp_dst_ids.append(tmp_dict[tripl[2]])
src_ids[id] = tmp_src_ids
dst_ids[id] = tmp_dst_ids
node_feats[id] = torch.Tensor(tmp_node_feats).numpy().tolist()
num_nodes[id] = len(tmp_node_feats)
# Write onto files
with open('src_ids_' + str(split) + '.json', 'w') as f:
json.dump(src_ids, f)
with open('dst_ids_' + str(split) + '.json', 'w') as f:
json.dump(dst_ids, f)
with open('node_feats_' + str(split) + '.json', 'w') as f:
json.dump(node_feats, f)
with open('num_nodes_' + str(split) + '.json', 'w') as f:
json.dump(num_nodes, f)
print("{} split done! Total time: {}".format(str(split), (time.time()-split_time)))
# This is a temporary version due to the changes introduced with the RSICD dataset
def UCM_tripl2graph(triplets_path, model, tokenizer, dataset):
'''
Function that extract DGLGraph data from triplets and puts them into files
Args:
triplets_path (str): path of the triplet's file
model (torch.nn.Module): model to extract the embeddings from tokens
tokenizer (torch.nn.Module): tokenizer to tokenize the captions
Return:
None
'''
try:
triplets = load_json(triplets_path+'/'+'triplets'+dataset+'.json')
except:
print("No triplets file for {} split".format(dataset))
exit(0)
splits = ['train', 'val', 'test']
for split in splits:
split_time = time.time()
# caption_tripl, discarded_ids = triplets['tripl'], triplets['discarded_ids']
caption_tripl = triplets[split]
node_feats = {}
num_nodes = {}
src_ids = {}
dst_ids = {}
# Here to check what happen when split is not passed
for id in caption_tripl:
tmp_dict = {}
tmp_id = 0
tmp_src_ids = []
tmp_dst_ids = []
tmp_node_feats = []
# Extract features from triplets
for _, tripl in enumerate(caption_tripl[id]):
encoded_input = tokenizer(tripl, return_tensors='pt', add_special_tokens=False, padding=True)
output = model(**encoded_input)
if tripl[0] not in list(tmp_dict.keys()):
tmp_dict[tripl[0]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[0]))
if tripl[1] not in list(tmp_dict.keys()):
tmp_dict[tripl[1]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[1]))
if tripl[2] not in list(tmp_dict.keys()):
tmp_dict[tripl[2]]=tmp_id
tmp_id+=1
tmp_node_feats.append(list(output.pooler_output[2]))
# Create source and destination lists
tmp_src_ids.append(tmp_dict[tripl[0]])
tmp_dst_ids.append(tmp_dict[tripl[1]])
tmp_src_ids.append(tmp_dict[tripl[1]])
tmp_dst_ids.append(tmp_dict[tripl[2]])
src_ids[id] = tmp_src_ids
dst_ids[id] = tmp_dst_ids
node_feats[id] = torch.Tensor(tmp_node_feats).numpy().tolist()
num_nodes[id] = len(tmp_node_feats)
# Write onto files
with open('src_ids_' + str(split) + '.json', 'w') as f:
json.dump(src_ids, f)
with open('dst_ids_' + str(split) + '.json', 'w') as f:
json.dump(dst_ids, f)
with open('node_feats_' + str(split) + '.json', 'w') as f:
json.dump(node_feats, f)
with open('num_nodes_' + str(split) + '.json', 'w') as f:
json.dump(num_nodes, f)
print("{} split done! Total time: {}".format(str(split), (time.time()-split_time)))
if __name__ == '__main__':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")
total_time = time.time()
args = parser.parse_args()
if args.dataset == 'rsicd':
triplets_path = 'dataset/RSICD_dataset/triplets_rsicd.json'
RSICD_tripl2graph(triplets_path, model, tokenizer)
if args.dataset == 'ucm':
triplets_path = 'dataset/UCM_dataset'
UCM_tripl2graph(triplets_path, model, tokenizer, '_ucm')
print("Done everything! Total time: {}".format((time.time()-total_time)))