Skip to content

Commit

Permalink
Add database features (#86)
Browse files Browse the repository at this point in the history
* add database features

* add evaluation of extra splits to test_model routine

* add test for dataset caching
  • Loading branch information
lubbersnick authored Aug 14, 2024
1 parent cf16e9f commit 9a1a662
Show file tree
Hide file tree
Showing 10 changed files with 975 additions and 196 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ New Features:
-------------

- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation.
- Exporting a database to NPZ or H5 format after preprocessing is now just a function call away.
- SNAPjson format can now support an optional number of comment lines.
- Added Batch optimizer features in order to optimize geometries in parallel on the GPU. Algorithms include FIRE and BFGS.

Improvements:
-------------

- Eliminated dependency on pyanitools for loading ANI-style H5 datasets.

Bug Fixes:
----------

- Fixed bug where custom kernels were not launching properly on non-default GPUs

0.0.3
=======

Expand Down
7 changes: 7 additions & 0 deletions docs/source/user_guide/databases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Note that input of bond variables for periodic systems can be ill-defined
if there are multiple bonds between the same pairs of atoms. This is not yet
supported.

A note on *cell* variables. The shape of a cell variable should be specified as (n_atoms,3,3).
There are two common conventions for the cell matrix itself; we use the convention that the basis index
comes first, and the cartesian index comes second. That is similar to `ase`,
the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. If you experience
massive difficulties fitting to periodic boundary conditions, you may check the transposed version
of your cell data, or compute the RDF.


ASE Objects Database handling
----------------------------------------------------------
Expand Down
7 changes: 4 additions & 3 deletions hippynn/databases/SNAPJson.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
transpose_cell=True,
allow_unfound=False,
quiet=False,
comments=1,
n_comments=1,
**kwargs,
):

Expand All @@ -35,7 +35,7 @@ def __init__(
self.targets = targets
self.transpose_cell = transpose_cell
self.depth = depth
self.comments = comments
self.n_comments = n_comments
arr_dict = self.load_arrays(quiet=quiet, allow_unfound=allow_unfound)

super().__init__(arr_dict, inputs, targets, *args, **kwargs, allow_unfound=allow_unfound, quiet=quiet)
Expand All @@ -48,6 +48,7 @@ def __init__(
transpose_cell=transpose_cell,
files=files,
allow_unfound=allow_unfound,
n_comments=n_comments,
**kwargs,
quiet=quiet,
)
Expand Down Expand Up @@ -98,7 +99,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False):

def extract_snap_file(self, file):
with open(file, "rt") as jf:
for i in range(self.comments):
for i in range(self.n_comments):
comment = jf.readline()
content = jf.read()
parsed = json.loads(content)
Expand Down
142 changes: 142 additions & 0 deletions hippynn/databases/_ani_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Based on pyanitools.py written by Roman Zubatyuk and Justin S. Smith:
https://github.com/atomistic-ml/ani-al/blob/master/readers/lib/pyanitools.py
"""

import os
import numpy as np
import h5py


class DataPacker:
def __init__(self, store_file, mode='w-', compression_lib='gzip', compression_level=6, driver=None):
"""
Wrapper to store arrays within HFD5 file
"""
self.store = h5py.File(store_file, mode=mode, driver=driver)
self.compression = compression_lib
self.compression_opts = compression_level

def store_data(self, store_location, **kwargs):
"""
Put arrays to store
"""
group = self.store.create_group(store_location)

for name, data in kwargs.items():
if isinstance(data, list):
if len(data) != 0:
if type(data[0]) is np.str_ or type(data[0]) is str:
data = [a.encode('utf8') for a in data]

group.create_dataset(name, data=data, compression=self.compression, compression_opts=self.compression_opts)

def cleanup(self):
"""
Wrapper to close HDF5 file
"""
self.store.close()

def __del__(self):
if self.store is not None:
self.cleanup()


class AniDataLoader(object):
def __init__(self, store_file, driver=None):
"""
Constructor
"""
if not os.path.exists(store_file):
store_file = os.path.realpath(store_file)
self.store = None
raise FileNotFoundError(f'File not found: {store_file}')
self.store = h5py.File(store_file, driver=driver)

def h5py_dataset_iterator(self, g, prefix=''):
"""
Group recursive iterator (iterate through all groups in all branches and return datasets in dicts)
"""

for key, item in g.items():

path = f'{prefix}/{key}'

first_subkey = list(item.keys())[0]
first_subitem = item[first_subkey]

if isinstance(first_subitem, h5py.Dataset):
# If dataset, yield the data from it.
data = self.populate_data_dict({'path': path}, item)
yield data
else:
# If not a dataset, assume it's a group and iterate from that.
yield from self.h5py_dataset_iterator(item, path)

def __iter__(self):
"""
Default class iterator (iterate through all data)
"""
for data in self.h5py_dataset_iterator(self.store):
yield data

def get_group_list(self):
"""
Returns a list of all groups in the file
"""
return [g for g in self.store.values()]

def iter_group(self, g):
"""
Allows interation through the data in a given group
"""
for data in self.h5py_dataset_iterator(g):
yield data

def get_data(self, path, prefix=''):
"""
Returns the requested dataset
"""
item = self.store[path]
data = self.populate_data_dict({'path': f'{prefix}/{path}'}, item)

return data

@staticmethod
def populate_data_dict(data, group):
for key, value in group.items():

if not isinstance(value, h5py.Group):
dataset = np.asarray(value[()])

# decode bytes objects to ascii strings.
if isinstance(dataset, np.ndarray):
if dataset.size != 0:
if type(dataset[0]) is np.bytes_:
dataset = [a.decode('ascii') for a in dataset]

data.update({key: dataset})

return data

def group_size(self):
"""
Returns the number of groups
"""
return len(self.get_group_list())

def size(self):
count = 0
for g in self.store.values():
count = count + len(g.items())
return count

def cleanup(self):
"""
Close the HDF5 file
"""
self.store.close()

def __del__(self):
if self.store is not None:
self.cleanup()
Loading

0 comments on commit 9a1a662

Please sign in to comment.