forked from FAIR-Chem/fairchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathase_datasets.py
497 lines (375 loc) · 18.1 KB
/
ase_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import bisect
import copy
import functools
import glob
import logging
import os
import warnings
from pathlib import Path
from abc import ABC, abstractmethod
import ase
import numpy as np
from torch import tensor
from torch.utils.data import Dataset
from tqdm import tqdm
from ocpmodels.common.registry import registry
from ocpmodels.datasets.lmdb_database import LMDBDatabase
from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata
from ocpmodels.preprocessing import AtomsToGraphs
def apply_one_tags(atoms, skip_if_nonzero=True, skip_always=False):
"""
This function will apply tags of 1 to an ASE atoms object.
It is used as an atoms_transform in the datasets contained in this file.
Certain models will treat atoms differently depending on their tags.
For example, GemNet-OC by default will only compute triplet and quadruplet interactions
for atoms with non-zero tags. This model throws an error if there are no tagged atoms.
For this reason, the default behavior is to tag atoms in structures with no tags.
args:
skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms
skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled
without needing to pass a callable (which is currently difficult to do with main.py)
"""
if skip_always:
return atoms
if np.all(atoms.get_tags() == 0) or not skip_if_nonzero:
atoms.set_tags(np.ones(len(atoms)))
return atoms
class AseAtomsDataset(Dataset, ABC):
"""
This is an abstract Dataset that includes helpful utilities for turning
ASE atoms objects into OCP-usable data objects. This should not be instantiated directly
as get_atoms_object and load_dataset_get_ids are not implemented in this base class.
Derived classes must add at least two things:
self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object
self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads
of the dataset and importantly must return a list of all possible identifiers that can be passed into
self.get_atoms_object(id)
Identifiers need not be any particular type.
"""
def __init__(self, config, transform=None, atoms_transform=apply_one_tags):
self.config = config
a2g_args = config.get("a2g_args", {})
# Make sure we always include PBC info in the resulting atoms objects
a2g_args["r_pbc"] = True
self.a2g = AtomsToGraphs(**a2g_args)
self.transform = transform
self.atoms_transform = atoms_transform
if self.config.get("keep_in_memory", False):
self.__getitem__ = functools.cache(self.__getitem__)
# Derived classes should extend this functionality to also create self.ids,
# a list of identifiers that can be passed to get_atoms_object()
self.ids = self.load_dataset_get_ids(config)
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
# Handle slicing
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self.ids)))]
# Get atoms object via derived class method
atoms = self.get_atoms_object(self.ids[idx])
# Transform atoms object
if self.atoms_transform is not None:
atoms = self.atoms_transform(
atoms, **self.config.get("atoms_transform_args", {})
)
if "sid" in atoms.info:
sid = atoms.info["sid"]
else:
sid = tensor([idx])
# Convert to data object
data_object = self.a2g.convert(atoms, sid)
data_object.pbc = tensor(atoms.pbc)
# Transform data object
if self.transform is not None:
data_object = self.transform(
data_object, **self.config.get("transform_args", {})
)
return data_object
@abstractmethod
def get_atoms_object(self, identifier):
# This function should return an ASE atoms object.
raise NotImplementedError(
"Returns an ASE atoms object. Derived classes should implement this function."
)
@abstractmethod
def load_dataset_get_ids(self, config):
# This function should return a list of ids that can be used to index into the database
raise NotImplementedError(
"Every ASE dataset needs to declare a function to load the dataset and return a list of ids."
)
def close_db(self):
# This method is sometimes called by a trainer
pass
def guess_target_metadata(self, num_samples=100):
metadata = {}
if num_samples < len(self):
metadata["targets"] = guess_property_metadata(
[
self.get_atoms_object(self.ids[idx])
for idx in np.random.choice(
len(self), size=(num_samples,), replace=False
)
]
)
else:
metadata["targets"] = guess_property_metadata(
[
self.get_atoms_object(self.ids[idx])
for idx in range(len(self))
]
)
return metadata
def get_metadata(self):
return self.guess_target_metadata()
@registry.register_dataset("ase_read")
class AseReadDataset(AseAtomsDataset):
"""
This Dataset uses ase.io.read to load data from a directory on disk.
This is intended for small-scale testing and demonstrations of OCP.
Larger datasets are better served by the efficiency of other dataset types
such as LMDB.
For a full list of ASE-readable filetypes, see
https://wiki.fysik.dtu.dk/ase/ase/io/io.html
args:
config (dict):
src (str): The source folder that contains your ASE-readable files
pattern (str): Filepath matching each file you want to read
ex. "*/POSCAR", "*.cif", "*.xyz"
search recursively with two wildcards: "**/POSCAR" or "**/*.cif"
a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs()
default options will work for most users
If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
In that case, energy/forces must be in the files you read (ex. OUTCAR)
ase_read_args (dict): Keyword arguments for ase.io.read()
keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need
to iterate over a dataset many times (e.g. training for many epochs).
Not recommended for large datasets.
atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable
transform_args (dict): Additional keyword arguments for the transform callable
atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.
transform (callable, optional): Additional preprocessing function for the Data object
"""
def load_dataset_get_ids(self, config):
self.ase_read_args = config.get("ase_read_args", {})
if ":" in self.ase_read_args.get("index", ""):
raise NotImplementedError(
"To read multiple structures from a single file, please use AseReadMultiStructureDataset."
)
self.path = Path(config["src"])
if self.path.is_file():
raise Exception("The specified src is not a directory")
return list(self.path.glob(f'{config["pattern"]}'))
def get_atoms_object(self, identifier):
try:
atoms = ase.io.read(identifier, **self.ase_read_args)
except Exception as err:
warnings.warn(f"{err} occured for: {identifier}")
raise err
return atoms
@registry.register_dataset("ase_read_multi")
class AseReadMultiStructureDataset(AseAtomsDataset):
"""
This Dataset can read multiple structures from each file using ase.io.read.
The disadvantage is that all files must be read at startup.
This is a significant cost for large datasets.
This is intended for small-scale testing and demonstrations of OCP.
Larger datasets are better served by the efficiency of other dataset types
such as LMDB.
For a full list of ASE-readable filetypes, see
https://wiki.fysik.dtu.dk/ase/ase/io/io.html
args:
config (dict):
src (str): The source folder that contains your ASE-readable files
pattern (str): Filepath matching each file you want to read
ex. "*.traj", "*.xyz"
search recursively with two wildcards: "**/POSCAR" or "**/*.cif"
index_file (str): Filepath to an indexing file, which contains each filename
and the number of structures contained in each file. For instance:
/path/to/relaxation1.traj 200
/path/to/relaxation2.traj 150
This will overrule the src and pattern that you specify!
a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs()
default options will work for most users
If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
In that case, energy/forces must be in the files you read (ex. OUTCAR)
ase_read_args (dict): Keyword arguments for ase.io.read()
keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need
to iterate over a dataset many times (e.g. training for many epochs).
Not recommended for large datasets.
use_tqdm (bool): Use TQDM progress bar when initializing dataset
atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable
transform_args (dict): Additional keyword arguments for the transform callable
atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.
transform (callable, optional): Additional preprocessing function for the Data object
"""
def load_dataset_get_ids(self, config):
self.ase_read_args = config.get("ase_read_args", {})
if not hasattr(self.ase_read_args, "index"):
self.ase_read_args["index"] = ":"
if config.get("index_file", None) is not None:
f = open(config["index_file"], "r")
index = f.readlines()
ids = []
for line in index:
filename = line.split(" ")[0]
for i in range(int(line.split(" ")[1])):
ids.append(f"{filename} {i}")
return ids
self.path = Path(config["src"])
if self.path.is_file():
raise Exception("The specified src is not a directory")
filenames = list(self.path.glob(f'{config["pattern"]}'))
ids = []
if config.get("use_tqdm", True):
filenames = tqdm(filenames)
for filename in filenames:
try:
structures = ase.io.read(filename, **self.ase_read_args)
except Exception as err:
warnings.warn(f"{err} occured for: {filename}")
else:
for i, structure in enumerate(structures):
ids.append(f"{filename} {i}")
return ids
def get_atoms_object(self, identifier):
try:
atoms = ase.io.read(
"".join(identifier.split(" ")[:-1]), **self.ase_read_args
)[int(identifier.split(" ")[-1])]
except Exception as err:
warnings.warn(f"{err} occured for: {identifier}")
raise err
return atoms
def get_metadata(self):
return {}
class dummy_list(list):
def __init__(self, max):
self.max = max
return
def __len__(self):
return self.max
def __getitem__(self, idx):
# Handle slicing
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(self.max))]
# Cast idx as int since it could be a tensor index
idx = int(idx)
# Handle negative indices (referenced from end)
if idx < 0:
idx += self.max
if 0 <= idx < self.max:
return idx
else:
raise IndexError
@registry.register_dataset("ase_db")
class AseDBDataset(AseAtomsDataset):
"""
This Dataset connects to an ASE Database, allowing the storage of atoms objects
with a variety of backends including JSON, SQLite, and database server options.
For more information, see:
https://databases.fysik.dtu.dk/ase/ase/db/db.html
args:
config (dict):
src (str): Either
- the path an ASE DB,
- the connection address of an ASE DB,
- a folder with multiple ASE DBs,
- a glob string to use to find ASE DBs, or
- a list of ASE db paths/addresses.
If a folder, every file will be attempted as an ASE DB, and warnings
are raised for any files that can't connect cleanly
Note that for large datasets, ID loading can be slow and there can be many
ids, so it's advised to make loading the id list as easy as possible. There is not
an obvious way to get a full list of ids from most ASE dbs besides simply looping
through the entire dataset. See the AseLMDBDataset which was written with this usecase
in mind.
connect_args (dict): Keyword arguments for ase.db.connect()
select_args (dict): Keyword arguments for ase.db.select()
You can use this to query/filter your database
a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs()
default options will work for most users
If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
In that case, energy/forces must be in the database
keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need
to iterate over a dataset many times (e.g. training for many epochs).
Not recommended for large datasets.
atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable
transform_args (dict): Additional keyword arguments for the transform callable
atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.
transform (callable, optional): Additional preprocessing function for the Data object
"""
def load_dataset_get_ids(self, config):
if isinstance(config["src"], list):
filepaths = config["src"]
elif os.path.isfile(config["src"]):
filepaths = [config["src"]]
elif os.path.isdir(config["src"]):
filepaths = glob.glob(f'{config["src"]}/*')
else:
filepaths = glob.glob(config["src"])
self.dbs = []
for path in filepaths:
try:
self.dbs.append(
self.connect_db(path, config.get("connect_args", {}))
)
except ValueError:
logging.warning(
f"Tried to connect to {path} but it's not an ASE database!"
)
self.select_args = config.get("select_args", {})
# In order to get all of the unique IDs using the default ASE db interface
# we have to load all the data and check ids using a select. This is extremely
# inefficient for large dataset. If the db we're using already presents a list of
# ids and there is no query, we can just use that list instead and save ourselves
# a lot of time!
self.db_ids = []
for db in self.dbs:
if hasattr(db, "ids") and self.select_args == {}:
self.db_ids.append(db.ids)
else:
self.db_ids.append(
[row.id for row in db.select(**self.select_args)]
)
idlens = [len(ids) for ids in self.db_ids]
self._idlen_cumulative = np.cumsum(idlens).tolist()
return dummy_list(sum(idlens))
def get_atoms_object(self, idx):
# Figure out which db this should be indexed from.
db_idx = bisect.bisect(self._idlen_cumulative, idx)
# Extract index of element within that db
el_idx = idx
if db_idx != 0:
el_idx = idx - self._idlen_cumulative[db_idx - 1]
assert el_idx >= 0
atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx])
atoms = atoms_row.toatoms()
if isinstance(atoms_row.data, dict):
atoms.info.update(atoms_row.data)
return atoms
def connect_db(self, address, connect_args={}):
db_type = connect_args.get("type", "extract_from_name")
if db_type == "lmdb" or (
db_type == "extract_from_name" and address.split(".")[-1] == "lmdb"
):
return LMDBDatabase(address, readonly=True, **connect_args)
else:
return ase.db.connect(address, **connect_args)
def close_db(self):
for db in self.dbs:
if hasattr(db, "close"):
db.close()
def get_metadata(self):
logging.warning(
"You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!"
)
if self.dbs[0].metadata == {}:
return self.guess_target_metadata()
else:
return copy.deepcopy(self.dbs[0].metadata)