-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
33 lines (25 loc) · 1.01 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#%%
import pickle
from torch_geometric.data import Dataset, Batch
import numpy as np
class GraphSetDataset(Dataset):
def __init__(self, data_path):
super().__init__()
with open(data_path, 'rb') as f:
self.data_list = pickle.load(f)
targets = np.array([self.data_list[i][-1] for i in range(len(self.data_list))])
self.target_mean = np.mean(targets)
self.target_std = np.std(targets)
def len(self):
return len(self.data_list)
def get(self, idx):
sample = self.data_list[idx]
normalized_target = [float((sample[-1] - self.target_mean) / self.target_std)]
normalized_sample = sample[:-1] + (normalized_target,)
return normalized_sample
def get_orig(self, target):
return target * self.target_std + self.target_mean
def graph_set_collate(batch):
graph_lists, ys = zip(*batch)
batched_graph_sets = [Batch.from_data_list(g_list) for g_list in graph_lists]
return batched_graph_sets, ys