Skip to content

Commit

Permalink
Added write_h5 function (#70)
Browse files Browse the repository at this point in the history
Facilitates dumping of database to h5 format. See documentation string
for more details.
  • Loading branch information
bnebgen-LANL authored Apr 30, 2024
1 parent cefc093 commit 8d5fd8a
Showing 1 changed file with 77 additions and 2 deletions.
79 changes: 77 additions & 2 deletions hippynn/databases/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import warnings
import numpy as np
import torch
import importlib.util
from pathlib import Path

from .restarter import NoRestart
from ..tools import arrdict_len, device_fallback
Expand Down Expand Up @@ -282,8 +284,81 @@ def remove_high_property(self,key,perAtom,species_key=None,cut=None,std_factor=1
#This does nothing with ndim=1
trimArr = np.sum(failArr,axis=tuple(range(1,ndim)))==0
self.trim_all_arrays(trimArr)



if importlib.util.find_spec("pyanitools") is not None:
def write_h5(self,split=None,h5path=None,species_key='species',overwrite=False,return_dictionary=False):
"""
Writes database as ANI-style h5 file
:param split: str or None; selects data split to save. If None, contents of arr_dict are used.
:param species_key: str; the key that designates atomic species. Used for determine number of atoms. Assumed to be [N_structures,N_atom]. Default: 'species'
:param overwrite: boolean; enables over-writing of h5 file.
:param return_dictionary: boolean; return dictionary style database for writing.
:return: dataloader containing relevant data
"""
import pyanitools as pyt
if split in self.splits:
database = self.splits[split]
elif split is None:
database = self.arr_dict
else:
raise Exception(f"Unknown split name: {split:s}")
dataDict = {}
if (h5path is not None) :
if Path(h5path).exists():
if overwrite:
Path(h5path).unlink()
else:
raise Exception(f"h5path {h5path:s} exists.")
print("Saving h5 file: " + h5path)
dpack = pyt.datapacker(h5path)
else:
dpack = None
totalNumber = database[species_key].shape[0]
atomDim = database[species_key].shape[1]
isAtomKey={}
#determine which keys have second element N atoms
for curK in database.keys():
#Lazy if evaluation
if (len(database[curK].shape)>1) and (database[curK].shape[1] == atomDim):
isAtomKey[curK] = True
else:
isAtomKey[curK] = False
del(isAtomKey[species_key])
for sysI,sysV in enumerate(database[species_key]):
# We can append the system data to an existing set of system data
molkey = hash(np.array(sysV).tobytes())
molnAtom = np.count_nonzero(sysV)
if molkey in dataDict:
if (database[species_key][sysI,:molnAtom].shape != dataDict[molkey][species_key].shape) or not((database[species_key][sysI,:molnAtom]==dataDict[molkey][species_key]).all()):
raise Exception("Error. Hash not unique. You should never see this.")
for curK in isAtomKey.keys():
if isAtomKey[curK]:
dataDict[molkey][curK].append(database[curK][sysI,:molnAtom])
else:
dataDict[molkey][curK].append(database[curK][sysI])
else:
dataDict[molkey] = {}
for curK in isAtomKey.keys():
if isAtomKey[curK]:
dataDict[molkey][curK] = [database[curK][sysI,:molnAtom]]
else:
dataDict[molkey][curK] = [database[curK][sysI]]
dataDict[molkey][species_key] = database[species_key][sysI,:molnAtom]
for sysV in dataDict.keys():
for curK in isAtomKey.keys():
dataDict[sysV][curK] = np.array(dataDict[sysV][curK])
if np.issubdtype(dataDict[sysV][curK].dtype,np.unicode_):
dataDict[sysV][curK] = [el.encode('utf-8') for el in list(dataDict[sysV][curK])]
dataDict[sysV][curK] = np.array(dataDict[sysV][curK])

if dpack is not None:
for key in dataDict:
dpack.store_data(str(key),**dataDict[key])
dpack.cleanup()
if (h5path is None) or return_dictionary:
return(dataDict)

def compute_index_mask(indices, index_pool):
if not np.all(np.isin(indices, index_pool)):
Expand Down

0 comments on commit 8d5fd8a

Please sign in to comment.