-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_dataset.py
232 lines (193 loc) · 8.77 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import math
from multiprocessing import Pool, cpu_count
import pandas as pd
from pathlib import Path
import pickle
from backend.utils.molecule_formats import *
import numpy as np
"""
current dataset version, raise when altering dataset content
keep in sync with storage_handler dataset_version
"""
_version = 5
def smiles_to_fingerprints(smiles, sizes, radius=2):
result_dict = dict()
for size in sizes:
result_dict[str(size)] = smiles_to_fingerprint(smiles, size, radius)
if not result_dict[str(size)]:
return None
return result_dict
def smiles_list_to_fingerprint_input(smiles_list, sizes, radius=2):
result = list()
for smiles in smiles_list:
result.append((smiles, sizes, radius))
return result
def make_data_label_pairs(data, label):
result_list = list()
for x in data:
result_list.append({str(label): x})
return result_list
def create_dataset(path: str,
max_size: int,
data_offset: int,
labels: list,
smiles_fingerprint_sizes: list,
smiles_fingerprint_radius: int):
"""
Creates a new Dataset with a given csv file path, a given size, starting at a certain point, with specific labels
and fingerprint sizes.
:param path: Path to the .csv file, from this file
:param max_size: How many entries the dataset should have at most (sometimes it has fewer actual entries)
:param data_offset: At what point to start taking data from the .csv file (ex: 5 means starting at the 6th entry)
:param labels: List of strings of labels included in the dataset
:param smiles_fingerprint_sizes: Array of integers (usually powers of 2)
:param smiles_fingerprint_radius: numeric value, usually left at 2
:return: A list of dictionaries containing inputs and outputs. Pickle & save or add a descriptor
"""
csv = pd.read_csv(path)
data = csv[min(data_offset, len(csv)):min(max_size + data_offset, len(csv))]
data_smiles = data["SMILES"].tolist()
print(f'creating set with {min(max_size, len(csv))} entries, starting at entry {data_offset}')
print(f'loaded set with labels: {labels} of {list(csv.columns)}')
if not labels or not set(labels).issubset(set(csv.columns)):
labels = list(csv.columns)
raise 'Error with label selection'
num_workers = max(cpu_count() - 2, 1)
print(f'using {num_workers} threads')
with Pool(num_workers) as p:
fingerprints_input = smiles_list_to_fingerprint_input(data_smiles, smiles_fingerprint_sizes,
smiles_fingerprint_radius)
fingerprints = p.starmap(smiles_to_fingerprints, fingerprints_input)
data_mol_graphs = p.map(smiles_to_mol_graph, data_smiles)
print('converted smiles to input')
label_result_list = []
for label in labels:
label_result_list.append(make_data_label_pairs(data=data[label].tolist(), label=label))
# creates list of dictionaries with label: data pairs for data ex [{'homo': 0.0552, 'lumo': 15.2},...]
print('created output-label-pairs')
y_list = []
for pairing in zip(*label_result_list):
new_pairing = dict()
for entry in pairing:
new_pairing |= entry
y_list.append(new_pairing)
print('created output dictionary')
# creates dictionary with labels for all different input types
# Current input types are SMILES fingerprints & mol_graphs (v3)
x_list = []
for fingerprint, mol_graph in zip(fingerprints, data_mol_graphs):
if mol_graph[0] is not None and fingerprint:
x_list.append({'fingerprints': fingerprint, 'mol_graph': mol_graph})
else:
x_list.append(None)
print('created input dictionary')
# pairs each entry of our input with an entry of our output lists
data_zip = zip(x_list, y_list)
dataset = list()
for input_dict, output_dict in data_zip:
if input_dict:
dataset.append({'x': input_dict, 'y': output_dict})
print('done')
return dataset
def add_dataset_descriptor(dataset, name, histograms, parameters):
"""
Takes a Dataset and adds it to a dictionary containing various fields
:param dataset: The dataset. Presumably created using 'create_dataset'
:param name: The name of the Dataset
:param histograms: histogram of the dataset
:param parameters: parameters that were used to create the dataset
:return: A dictionary with the fields 'name', 'size', 'labels', 'dataset'
"""
size = len(dataset)
labels = list(dataset[0].get('y').keys())
print(f'adding descriptor with size {size}, labels {labels}')
return {'name': name, 'size': size, 'labels': labels, 'dataset': dataset,
'version': _version, 'histograms': histograms, 'parameters': parameters}
def create_complete_dataset(path, max_size, data_offset, smiles_fingerprint_sizes, smiles_fingerprint_radius, labels,
name):
"""
Creates dataset, histograms and descriptor according to given parameters
:param path: path string to csv file
:param max_size: int, maximum size of dataset
:param data_offset: offset from which to start taking data from the .csv file
:param smiles_fingerprint_sizes: Array of integers (usually powers of 2)
:param smiles_fingerprint_radius: numeric value, usually left at 2
:param labels: List of strings of labels included in the dataset
:param name: string name of dataset
:return: the dataset descriptor
"""
raw_dataset = create_dataset(path,
max_size,
data_offset,
labels,
smiles_fingerprint_sizes,
smiles_fingerprint_radius)
histograms = create_histograms(raw_dataset, labels)
return add_dataset_descriptor(raw_dataset,
name,
histograms,
[path, max_size, data_offset, smiles_fingerprint_sizes, smiles_fingerprint_radius,
labels, name])
def update_dataset(path):
"""
if necessary,
updates the referenced dataset to the current version by creating it anew with create_complete_dataset and
writes changes to the pkl file
:param path: string path to dataset pickle (pkl) file
:return: the latest version of the dataset
"""
path = (Path.cwd() / path)
with path.open('rb') as old_set_file_read:
old_set = pickle.load(old_set_file_read)
if old_set.get('version') == _version:
return old_set
try:
new_set = create_complete_dataset(*old_set.get('parameters'))
with path.open('wb') as old_set_file_write:
pickle.dump(new_set, old_set_file_write)
return old_set
except ValueError:
print('Dataset too old to automatically upgrade')
def create_histograms(dataset, labels):
"""
Creates a histogram of the dataset for each label
Its granularity can be customized according to required degree of detail
:param dataset: dataset to create a histogram of
:param labels: for which to create histograms
:return: dictionary containing histogram (dictionary containing lists for buckets and interval edges) for each label
"""
histograms = dict()
# prep: separate data by label
columns_by_label = dict()
for label in labels:
columns_by_label[label] = list()
for entry in dataset:
for [label, value] in entry['y'].items():
columns_by_label[label].append(value)
# create histogram for each label
for [label, data] in columns_by_label.items():
# bucket count is currently an arbitrary value, change corresponding to degree of detail required
hist, bin_edges = np.histogram(data, math.floor(len(data) / 100))
histograms[label] = dict({
'buckets': hist.tolist(),
'bin_edges': bin_edges.tolist()
})
return histograms
# HOW TO USE:
# Look at examples below
if __name__ == '__main__':
'''
Example for creating a new dataset:
new_set = create_complete_dataset(path="../../storage/csv_data/solubility.csv",
max_size=100,
data_offset=0,
smiles_fingerprint_sizes=[128, 512, 1024],
smiles_fingerprint_radius=2,
labels=['Solubility'],
name='Medium Solubility Set')
file = (Path.cwd() / 'output.pkl').open('wb')
pickle.dump(new_set, file)
file.close()
Example for updating dataset:
updated_set = update_dataset('../../storage/data/solubility.pkl')
'''