diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d2668081..1618d1bd 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,15 +7,20 @@ New Features: ------------- - Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation. +- Exporting a database to NPZ or H5 format after preprocessing is now just a function call away. +- SNAPjson format can now support an optional number of comment lines. - Added Batch optimizer features in order to optimize geometries in parallel on the GPU. Algorithms include FIRE and BFGS. Improvements: ------------- +- Eliminated dependency on pyanitools for loading ANI-style H5 datasets. Bug Fixes: ---------- +- Fixed bug where custom kernels were not launching properly on non-default GPUs + 0.0.3 ======= diff --git a/docs/source/user_guide/databases.rst b/docs/source/user_guide/databases.rst index 09836397..4448033b 100644 --- a/docs/source/user_guide/databases.rst +++ b/docs/source/user_guide/databases.rst @@ -24,6 +24,13 @@ Note that input of bond variables for periodic systems can be ill-defined if there are multiple bonds between the same pairs of atoms. This is not yet supported. +A note on *cell* variables. The shape of a cell variable should be specified as (n_atoms,3,3). +There are two common conventions for the cell matrix itself; we use the convention that the basis index +comes first, and the cartesian index comes second. That is similar to `ase`, +the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. If you experience +massive difficulties fitting to periodic boundary conditions, you may check the transposed version +of your cell data, or compute the RDF. + ASE Objects Database handling ---------------------------------------------------------- diff --git a/hippynn/databases/SNAPJson.py b/hippynn/databases/SNAPJson.py index c9cb208b..92d17a58 100644 --- a/hippynn/databases/SNAPJson.py +++ b/hippynn/databases/SNAPJson.py @@ -25,7 +25,7 @@ def __init__( transpose_cell=True, allow_unfound=False, quiet=False, - comments=1, + n_comments=1, **kwargs, ): @@ -35,7 +35,7 @@ def __init__( self.targets = targets self.transpose_cell = transpose_cell self.depth = depth - self.comments = comments + self.n_comments = n_comments arr_dict = self.load_arrays(quiet=quiet, allow_unfound=allow_unfound) super().__init__(arr_dict, inputs, targets, *args, **kwargs, allow_unfound=allow_unfound, quiet=quiet) @@ -48,6 +48,7 @@ def __init__( transpose_cell=transpose_cell, files=files, allow_unfound=allow_unfound, + n_comments=n_comments, **kwargs, quiet=quiet, ) @@ -98,7 +99,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): def extract_snap_file(self, file): with open(file, "rt") as jf: - for i in range(self.comments): + for i in range(self.n_comments): comment = jf.readline() content = jf.read() parsed = json.loads(content) diff --git a/hippynn/databases/_ani_reader.py b/hippynn/databases/_ani_reader.py new file mode 100644 index 00000000..104db0f1 --- /dev/null +++ b/hippynn/databases/_ani_reader.py @@ -0,0 +1,142 @@ +""" +Based on pyanitools.py written by Roman Zubatyuk and Justin S. Smith: +https://github.com/atomistic-ml/ani-al/blob/master/readers/lib/pyanitools.py +""" + +import os +import numpy as np +import h5py + + +class DataPacker: + def __init__(self, store_file, mode='w-', compression_lib='gzip', compression_level=6, driver=None): + """ + Wrapper to store arrays within HFD5 file + """ + self.store = h5py.File(store_file, mode=mode, driver=driver) + self.compression = compression_lib + self.compression_opts = compression_level + + def store_data(self, store_location, **kwargs): + """ + Put arrays to store + """ + group = self.store.create_group(store_location) + + for name, data in kwargs.items(): + if isinstance(data, list): + if len(data) != 0: + if type(data[0]) is np.str_ or type(data[0]) is str: + data = [a.encode('utf8') for a in data] + + group.create_dataset(name, data=data, compression=self.compression, compression_opts=self.compression_opts) + + def cleanup(self): + """ + Wrapper to close HDF5 file + """ + self.store.close() + + def __del__(self): + if self.store is not None: + self.cleanup() + + +class AniDataLoader(object): + def __init__(self, store_file, driver=None): + """ + Constructor + """ + if not os.path.exists(store_file): + store_file = os.path.realpath(store_file) + self.store = None + raise FileNotFoundError(f'File not found: {store_file}') + self.store = h5py.File(store_file, driver=driver) + + def h5py_dataset_iterator(self, g, prefix=''): + """ + Group recursive iterator (iterate through all groups in all branches and return datasets in dicts) + """ + + for key, item in g.items(): + + path = f'{prefix}/{key}' + + first_subkey = list(item.keys())[0] + first_subitem = item[first_subkey] + + if isinstance(first_subitem, h5py.Dataset): + # If dataset, yield the data from it. + data = self.populate_data_dict({'path': path}, item) + yield data + else: + # If not a dataset, assume it's a group and iterate from that. + yield from self.h5py_dataset_iterator(item, path) + + def __iter__(self): + """ + Default class iterator (iterate through all data) + """ + for data in self.h5py_dataset_iterator(self.store): + yield data + + def get_group_list(self): + """ + Returns a list of all groups in the file + """ + return [g for g in self.store.values()] + + def iter_group(self, g): + """ + Allows interation through the data in a given group + """ + for data in self.h5py_dataset_iterator(g): + yield data + + def get_data(self, path, prefix=''): + """ + Returns the requested dataset + """ + item = self.store[path] + data = self.populate_data_dict({'path': f'{prefix}/{path}'}, item) + + return data + + @staticmethod + def populate_data_dict(data, group): + for key, value in group.items(): + + if not isinstance(value, h5py.Group): + dataset = np.asarray(value[()]) + + # decode bytes objects to ascii strings. + if isinstance(dataset, np.ndarray): + if dataset.size != 0: + if type(dataset[0]) is np.bytes_: + dataset = [a.decode('ascii') for a in dataset] + + data.update({key: dataset}) + + return data + + def group_size(self): + """ + Returns the number of groups + """ + return len(self.get_group_list()) + + def size(self): + count = 0 + for g in self.store.values(): + count = count + len(g.items()) + return count + + def cleanup(self): + """ + Close the HDF5 file + """ + self.store.close() + + def __del__(self): + if self.store is not None: + self.cleanup() diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index 3be163e1..19acbb52 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -4,14 +4,14 @@ import warnings import numpy as np import torch -import importlib.util from pathlib import Path from .restarter import NoRestart -from ..tools import arrdict_len, device_fallback +from ..tools import arrdict_len, device_fallback, unsqueeze_multiple from torch.utils.data import DataLoader, TensorDataset, Subset +_AUTO_SPLIT_PREFIX = "split_mask_" class Database: """ @@ -29,19 +29,24 @@ def __init__( num_workers=0, pin_memory=True, allow_unfound=False, + auto_split=False, + device=None, quiet=False, ): """ :param arr_dict: dictionary mapping strings to numpy arrays :param inputs: list of strings for input db_names :param targets: list of strings for output db_namees - :param seed: int, for random splitting + :param seed: int, for random splitting, or "mask" for pre-split. + Can also be existing numpy.random.RandomState. + Can also be tuple from numpy.random.RandomState.get_state() :param test_size: fraction of data to use in test split :param valid_size: fraction of data to use in train split :param num_workers: passed to pytorch dataloaders :param pin_memory: passed to pytorch dataloaders :param allow_unfound: If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None. + :param auto_split: If true, look for keys like "split_*" to make initial splits from. See write_npz() method. :param quiet: If True, print little or nothing while loading. """ @@ -54,6 +59,7 @@ def __init__( self.splitting_completed = False self.num_workers = num_workers self.pin_memory = pin_memory + self.auto_split = auto_split if not quiet: print(f"All arrays:") @@ -90,14 +96,33 @@ def __init__( print("Database: Using pre-specified data indices.") self.splits = {} - self.random_state = np.random.RandomState(seed=seed) + + if isinstance(seed, np.random.RandomState): + self.random_state = seed + elif isinstance(seed, tuple): + self.random_state = np.random.RandomState() + self.random_state.set_state(seed) + else: + self.random_state = np.random.RandomState(seed=seed) + + if self.auto_split: + if test_size is not None or valid_size is not None: + warnings.warn(f"Auto split was set but test and valid size was also set." + f" Ignoring supplied test and validation sizes ({test_size} and {valid_size}.") + self.make_automatic_splits() if test_size is not None or valid_size is not None: if test_size is None or valid_size is None: - raise ValueError("Both test and valid size must be set for auto-splitting") + raise ValueError("Both test and valid size must be set for auto-splitting based on fractions") else: self.make_trainvalidtest_split(test_size=test_size, valid_size=valid_size) + if device is not None: + if not self.splitting_completed: + raise ValueError("Device cannot be set in constructor unless automatic split provided.") + else: + self.send_to_device(device) + def __len__(self): return arrdict_len(self.arr_dict) @@ -110,7 +135,13 @@ def var_list(self): return self.inputs + self.targets def send_to_device(self, device=None): - "Send a database to a device" + """ + Move the database to an accelerator device if possible. + In some circumstances this can accelerate training. + + :param device: device to move to, if None, try to auto-detect. + :return: + """ if len(self.splits) == 0: raise RuntimeError("Arrays must be split before sending database to device.") if device is None: @@ -167,30 +198,187 @@ def make_trainvalidtest_split(self, test_size, valid_size): self.split_the_rest("train") def make_explicit_split(self, evaluation_mode, split_indices): + """ + + :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_indices: the indices of the items for the split + :return: + """ if self.splitting_completed: - raise RuntimeError("Database already split!") + raise RuntimeError("Database splitting already complete!") if len(split_indices) == 0: raise ValueError("Cannot make split of size 0.") # Compute which indices are not being split off. index_mask = compute_index_mask(split_indices, self.arr_dict["indices"]) - complement_mask = ~index_mask + # Precompute the actual integer indices, because indexing with a boolean mask + # requires doing this, and we have to index with a boolean several times. + where_index = np.where(index_mask) + where_complement = np.where(complement_mask) + # Split off data, and keep the rest. - self.splits[evaluation_mode] = {k: torch.from_numpy(self.arr_dict[k][index_mask]) for k in self.arr_dict} - self.splits[evaluation_mode]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) + self.splits[evaluation_mode] = {k: torch.from_numpy(self.arr_dict[k][where_index]) for k in self.arr_dict} + if "split_indices" not in self.splits[evaluation_mode]: + if not self.quiet: + print(f"Adding split indices for split: {evaluation_mode}") + self.splits[evaluation_mode]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) for k, v in self.arr_dict.items(): - self.arr_dict[k] = v[complement_mask] + self.arr_dict[k] = v[where_complement] if not self.quiet: print(f"Arrays for split: {evaluation_mode}") prettyprint_arrays(self.splits[evaluation_mode]) + if arrdict_len(self.arr_dict) == 0: + if not self.quiet: + print("Database: Splitting complete.") + self.splitting_completed = True + return + + def make_explicit_split_bool(self, evaluation_mode, split_mask): + """ + + :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_mask: a boolean array for where to split + :return: + """ + if split_mask.dtype != np.bool_: + if not np.isin(split_mask, [0, 1]).all(): + raise ValueError(f"Mask function contains invalid values. Values found: {np.unique(split_mask)}") + else: + split_mask = split_mask.astype(np.bool_) + + indices = self.arr_dict['indices'][split_mask] + self.make_explicit_split(evaluation_mode, indices) + return + def split_the_rest(self, evaluation_mode): self.make_explicit_split(evaluation_mode, self.arr_dict["indices"]) self.splitting_completed = True + return + + def add_split_masks(self, dict_to_add_to=None, split_prefix=None): + """ + Add split masks to the dataset. This function is used internally before writing databases. + + When using the dict_to_add_to parameter, this function writes numpy arrays. + When adding to self.splits, this function writes tensors. + :param dict_to_add_to: where to put the split masks. Default to self.splits. + :param split_prefix: prefix for mask names + :return: + """ + + if not self.splitting_completed: + raise ValueError("Can't add split masks until splitting is complete.") + + if split_prefix is None: + split_prefix = _AUTO_SPLIT_PREFIX + + if dict_to_add_to is None: + dict_to_add_to = self.splits + write_tensor = True + else: + write_tensor = False + + for s in self.splits.keys(): + mask_name = split_prefix + s + for sprime, split in self.splits.items(): + + if sprime == s: + mask = np.ones_like(split['indices'], dtype=np.bool_) + else: + mask = np.zeros_like(split['indices'], dtype=np.bool_) + + if write_tensor: + mask = torch.as_tensor(mask) + + if mask_name in split: + # Check that the mask is correct and in the np_dict + old_mask = dict_to_add_to[sprime][mask_name] + if (old_mask != mask).all(): + raise ValueError(f"Mask in database did not match existing split structure: {mask_name} ") + else: + # if not present, write it. + dict_to_add_to[sprime][mask_name] = mask + + def make_automatic_splits(self, split_prefix=None, dry_run=False): + """ + Split the database automatically. Since the user specifies this routine, + it fails pretty strictly. + + :param split_prefix: None, use default. + If otherwise, use this prefix to determine what arrays are masks. + :param dry_run: Only validate that existing split masks are correct; don't perform splitting. + :return: + """ + + if split_prefix is None: + split_prefix = _AUTO_SPLIT_PREFIX + if not self.quiet: + print("Attempting automatically splitting.") + # Find mask-like variables + mask_vars = set() + + # Here we validate existing masks. + # We want to make sure that if someone did it manually there was not a mistake. + for k, arr in self.arr_dict.items(): + if k.startswith(split_prefix): + if arr.ndim != 1: + raise ValueError(f"Split mask for '{k}' has too many dimensions. Shape: {arr.shape=}") + if arr.dtype == np.dtype('bool'): + mask_vars.add(k) + elif arr.dtype is np.int and arr.ndim == 1: + if np.isin(arr, [0, 1]).all(): + mask_vars.add(k) + else: + arr_values = np.unique(arr) + raise ValueError(f"Integer masks for split contain invalid values: {arr_values}") + else: + raise ValueError(f"Failed on split {k} Split arrays must be 1-d boolean or (0,1)-valued integer arrays.") + + if not len(mask_vars): + raise ValueError("No split mask detected.") + + masks = {k[len(split_prefix):]: self.arr_dict[k].astype(bool) for k in mask_vars} + + if not self.quiet: + print("Auto-detected splits:", list(masks.keys())) + + # Check masks are all the same length. + lengths = set(x.shape[0] for x in masks.values()) + if len(lengths) == 0: + raise ValueError("No split masks found.") + elif len(lengths) != 1: + raise ValueError(f"Mask arrays must all be the same size, got sizes: {lengths}") + n_sys = list(lengths)[0] + + # Check that masks define a complete split + mask_counts = np.zeros(n_sys, dtype=int) + for k, arr in masks.items(): + mask_counts += arr.astype(int) + if not (mask_counts == 1).all(): + set_of_counts = set(mask_counts) + raise ValueError(f" Auto-splitting requires unique split for each item." + + f" Items with the following split counts were detected: {set_of_counts}") + + if dry_run: + return + + masks = {k: self.arr_dict['indices'][m] for k, m in masks.items()} + for k, m in masks.items(): + self.make_explicit_split(k, m) + + if not self.quiet: + print("Finished automatic splitting.") + + assert arrdict_len(self.arr_dict) == 0, "Not all items were successfully auto-split." + + self.splitting_completed = True + + return def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample=False): """ @@ -240,125 +428,274 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample ) return generator - - def trim_all_arrays(self,index): + + def _array_stat_helper(self, key, species_key, atomwise, norm_per_atom, norm_axis): + + prop = self.arr_dict[key] + + if norm_axis: + prop = np.linalg.norm(prop, axis=norm_axis) + + if atomwise: + if norm_per_atom: + raise ValueError("norm_per_atom and atom_var cannot both be True!") + if species_key is None: + raise RuntimeError("species_key must be given to trim a atomwise quantity") + + real_atoms = self.arr_dict[species_key] > 0 + stat_prop = prop[real_atoms] + else: + stat_prop = prop + + if norm_per_atom: + if species_key is None: + raise RuntimeError("species_key must be given to trim an atom-normalized quantity") + + n_atoms = (self.arr_dict[species_key] > 0).sum(axis=1) + # Transposes broadcast the result rightwards instead of leftwards. + # numpy transpose on higher-order arrays reverses all dimensions. + prop = (prop.T/n_atoms).T + stat_prop = (stat_prop.T/n_atoms).T + + mean = stat_prop.mean() + std = stat_prop.std() + if np.isnan(mean) or np.isnan(std): + warnings.warn(f"Array statistics, {mean=},{std=} contain NaN.", stacklevel=3) + + return prop, mean, std + + + def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=None, cut=None, std_factor=10, norm_axis=None): """ - To be used in conjuction with remove_high_property + For removing outliers from a dataset. Use with caution; do not inadvertently remove outliers from benchmarks! + + The parameters cut and std_factor can be set to `None` to avoid their steps. + the per_atom and atom_var properties are exclusive; they cannot both be true. + + :param key: The property key in the dataset to check for high values + :param atomwise: True if the property is defined per atom in axis 1, otherwise property is treated as whole-system value + :param norm_per_atom: True if the property should be normalized by atom counts + :param species_key: Which array represents the atom presence; required if per_atom is True + :param cut: If values > mu + cut, the system is removed. The step done first. + :param std_factor: If (value-mu)/std > std_fact, the system is trimmed. This step done second. + :param norm_axis: if not None, the property array is normed on the axis. Useful for vector properties like force. + :return: """ - for key in self.arr_dict: - self.arr_dict[key] = self.arr_dict[key][index] - - def remove_high_property(self,key,perAtom,species_key=None,cut=None,std_factor=10): + print(f"Cutting on variable: {key}") + if cut is not None: + prop, mean, std = self._array_stat_helper(key, species_key, atomwise, norm_per_atom, norm_axis) + + large_property_mask = np.abs(prop - mean) > cut + # Scan over all non-batch indices. + non_batch_axes = tuple(range(1, prop.ndim)) + drop_mask = np.sum(large_property_mask, axis=non_batch_axes) > 0 + indices = self.arr_dict["indices"][drop_mask] + if drop_mask.any(): + print(f"Removed {drop_mask.astype(int).sum()} outlier systems in variable {key} due to static cut.") + self.make_explicit_split(f"failed_cut_{key}", indices) + + if std_factor is not None: + prop, mean, std = self._array_stat_helper(key, species_key, atomwise, norm_per_atom, norm_axis) + large_property_mask = np.abs(prop - mean)/std > std_factor + # Scan over all non-batch indices. + non_batch_axes = tuple(range(1, prop.ndim)) + drop_mask = np.sum(large_property_mask, axis=non_batch_axes) > 0 + indices = self.arr_dict["indices"][drop_mask] + if drop_mask.any(): + print(f"Removed {drop_mask.astype(int).sum()} outlier systems in variable {key} due to std. factor.") + self.make_explicit_split(f"failed_std_fac_{key}", indices) + + def write_h5(self, split=None, h5path=None, species_key='species', overwrite=False): + + try: + from .h5_pyanitools import write_h5 as write_h5_function + except ImportError as ie: + raise ImportError("Writing h5 versions of databases not available.") from ie + + return write_h5_function(self, split=split, file=h5path, species_key=species_key, overwrite=overwrite) + + def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False): """ - This function removes outlier data from the dataset - Must be called before splitting - "key": the property key in the dataset to check for high values - "perAtom": True if the property is defined per atom in axis 1, otherwise property is treated as full system - "std_factor": systems with values larger than this multiplier time the standard deviation of all data will be reomved. None to skip this step - "cut_factor": systems with values larger than this number are reomved. None to skip this step. This step is done first. + :param file: str, Path, or file object compatible with np.save + :param record_split_masks: + :param overwrite: Whether to accept an existing path. Only used if fname is str or path. + :param split_prefix: optionally change the prefix for the masks computed by the splits. + :param return_only: if True, ignore the file string and just return the resulting dictionary of numpy arrays. + :return: """ - if perAtom: - if species_key==None: - raise RuntimeError("species_key must be defined to trim a per atom quantity") - atom_ind = self.arr_dict[species_key] > 0 - ndim = len(self.arr_dict[key].shape) - if cut!=None: - if perAtom: - Kmean = np.mean(self.arr_dict[key][atom_ind]) - else: - Kmean = np.mean(self.arr_dict[key]) - failArr = np.abs(self.arr_dict[key]-Kmean)>cut - #This does nothing with ndim=1 - trimArr = np.sum(failArr,axis=tuple(range(1,ndim)))==0 - self.trim_all_arrays(trimArr) - - if std_factor!=None: - if perAtom: - atom_ind = self.arr_dict[species_key] > 0 - Kmean = np.mean(self.arr_dict[key][atom_ind]) - std_cut = np.std(self.arr_dict[key][atom_ind]) * std_factor - else: - Kmean = np.mean(self.arr_dict[key]) - std_cut = np.std(self.arr_dict[key]) * std_factor - failArr = np.abs(self.arr_dict[key]-Kmean)>std_cut - #This does nothing with ndim=1 - trimArr = np.sum(failArr,axis=tuple(range(1,ndim)))==0 - self.trim_all_arrays(trimArr) - - if importlib.util.find_spec("pyanitools") is not None: - def write_h5(self,split=None,h5path=None,species_key='species',overwrite=False,return_dictionary=False): - """ - Writes database as ANI-style h5 file - - :param split: str or None; selects data split to save. If None, contents of arr_dict are used. - :param species_key: str; the key that designates atomic species. Used for determine number of atoms. Assumed to be [N_structures,N_atom]. Default: 'species' - :param overwrite: boolean; enables over-writing of h5 file. - :param return_dictionary: boolean; return dictionary style database for writing. - :return: dataloader containing relevant data - - """ - import pyanitools as pyt - if split in self.splits: - database = self.splits[split] - elif split is None: - database = self.arr_dict - else: - raise Exception(f"Unknown split name: {split:s}") - dataDict = {} - if (h5path is not None) : - if Path(h5path).exists(): - if overwrite: - Path(h5path).unlink() + if split_prefix is None: + split_prefix = _AUTO_SPLIT_PREFIX + if not self.splitting_completed: + raise ValueError("Cannot write an incompletely split database to npz file.\n" + + "You can split the rest using `database.split_the_rest('other_data')`\n" + + "to put the remaining data into a new split named 'other_data'") + + # get combined dictionary of arrays. + np_dict = {sname: + {arr_name: array.to('cpu').numpy() for arr_name, array in split.items()} + for sname, split in self.splits.items()} + + # insert split masks if requested. + + if record_split_masks: + self.add_split_masks(dict_to_add_to=np_dict, split_prefix=split_prefix) + + + # Stack numpy arrays: + arr_dict = {} + a_split = list(np_dict.values())[0] + keys = a_split.keys() + + for k in list(keys): + list_of_arrays = [split_dict[k] for split_dict in np_dict.values()] + arr_dict[k] = np.concatenate(list_of_arrays, axis=0) + + # Put results where requested. + if return_only: + return arr_dict + + if isinstance(file, str): + file = Path(file) + + if isinstance(file, Path): + if file.exists() and not overwrite: + raise FileExistsError(f"File exists: {file}") + + np.savez_compressed(file, **arr_dict) + + return arr_dict + + def sort_by_index(self, index_name='indices'): + """ + Sort arrays in each split of the database by an index key. + The default is 'indices', also possible is 'split_indices', or any other variable name in the database. + + :param index_name: + :return: None + """ + for sname, split in self.splits.items(): + ind = split[index_name] + ind_order = torch.argsort(ind) + # Modify dictionary in-place. + for k, v in split.items(): + split[k] = v[ind_order] + + def trim_by_species(self, species_key: str, keep_splits_same_size: bool =True): + """ + Remove any excess padding in a database. + :param species_key: what array to use to mark atom presence. + :param keep_splits_same_size: true: trim by the minimum amount across splits, false: trim by the maximum amount for each split. + :return: + """ + if not self.splitting_completed: + raise ValueError("Cannot trim arrays until splitting has been completed.") + + split_max_max_atom_size = {} + for k, split in self.splits.items(): + species_array = split[species_key] + max_atoms = (species_array != 0).sum(axis=1) + max_max_atoms = max_atoms.max().item() + split_max_max_atom_size[k] = max_max_atoms + del max_atoms, species_array, max_max_atoms # Marking unneeded. + + if keep_splits_same_size: + # find the longest of the split sizes + max_max_max_atoms = max(split_max_max_atom_size.values()) + # store that back into the dictionary + split_max_max_atom_size = {k: max_max_max_atoms for k, v in split_max_max_atom_size.items()} + del max_max_max_atoms # Marking unneeded. + + for k, split in self.splits.items(): + species_array = split[species_key] + orig_atom_size = species_array.shape[1] + max_max_atoms = split_max_max_atom_size[k] + order = torch.argsort(species_array, dim=1, descending=True, stable=True) + + assert max_max_atoms > 7, "Max atoms bigger than 7 required for automatic atom dimension detection." + + for key, arr in split.items(): + + # determine where to broadcast sorting indices for this array. + non_species_non_batch_axes = [] + for dim, length in enumerate(arr.shape[1:], start=1): + if length == orig_atom_size: + pass else: - raise Exception(f"h5path {h5path:s} exists.") - print("Saving h5 file: " + h5path) - dpack = pyt.datapacker(h5path) - else: - dpack = None - totalNumber = database[species_key].shape[0] - atomDim = database[species_key].shape[1] - isAtomKey={} - #determine which keys have second element N atoms - for curK in database.keys(): - #Lazy if evaluation - if (len(database[curK].shape)>1) and (database[curK].shape[1] == atomDim): - isAtomKey[curK] = True - else: - isAtomKey[curK] = False - del(isAtomKey[species_key]) - for sysI,sysV in enumerate(database[species_key]): - # We can append the system data to an existing set of system data - molkey = hash(np.array(sysV).tobytes()) - molnAtom = np.count_nonzero(sysV) - if molkey in dataDict: - if (database[species_key][sysI,:molnAtom].shape != dataDict[molkey][species_key].shape) or not((database[species_key][sysI,:molnAtom]==dataDict[molkey][species_key]).all()): - raise Exception("Error. Hash not unique. You should never see this.") - for curK in isAtomKey.keys(): - if isAtomKey[curK]: - dataDict[molkey][curK].append(database[curK][sysI,:molnAtom]) - else: - dataDict[molkey][curK].append(database[curK][sysI]) - else: - dataDict[molkey] = {} - for curK in isAtomKey.keys(): - if isAtomKey[curK]: - dataDict[molkey][curK] = [database[curK][sysI,:molnAtom]] - else: - dataDict[molkey][curK] = [database[curK][sysI]] - dataDict[molkey][species_key] = database[species_key][sysI,:molnAtom] - for sysV in dataDict.keys(): - for curK in isAtomKey.keys(): - dataDict[sysV][curK] = np.array(dataDict[sysV][curK]) - if np.issubdtype(dataDict[sysV][curK].dtype,np.unicode_): - dataDict[sysV][curK] = [el.encode('utf-8') for el in list(dataDict[sysV][curK])] - dataDict[sysV][curK] = np.array(dataDict[sysV][curK]) - - if dpack is not None: - for key in dataDict: - dpack.store_data(str(key),**dataDict[key]) - dpack.cleanup() - if (h5path is None) or return_dictionary: - return(dataDict) + non_species_non_batch_axes.append(dim) + + for dim, length in enumerate(arr.shape[1:], start=1): + if dim in non_species_non_batch_axes: + continue + + unsq_dims = tuple(x for x in non_species_non_batch_axes if x != dim) + this_order = unsqueeze_multiple(order, unsq_dims) + arr, this_order = torch.broadcast_tensors(arr, this_order) + arr = torch.take_along_dim(arr, this_order, dim) + arr = torch.narrow_copy(arr, dim, 0, max_max_atoms) + if not self.quiet: + print(f"Resorting {key} along axis {dim}. {arr.shape=},{this_order.shape=}") + + split[key] = arr + # end loop over arrays + # end loop over splits + + return + + def get_device(self): + if not self.splitting_completed: + raise ValueError("Device should not be changed before splitting is complete.") + + devices = set(a.device for s, split in self.splits.items() for k, a in split.items()) + if len(devices) != 1: + raise ValueError(f"Devices for tensors are not uniform, got: {devices}") + + device = devices.pop() + return device + + def make_database_cache(self, file="./hippynn_db_cache.npz", overwrite=False, **override_kwargs): + """ + Cache the database as-is, and re-open it. + + Useful for creating an easy restart script if the storage space is available. + + :param file: where to store the database + :param overwrite: whether to overwrite an existing cache file with this name. + :param override_kwargs: passed to NPZDictionary instead of the current database settings. + :return: The new database created from the cache. + """ + from .ondisk import NPZDatabase + + # first prepare arguments + arguments = dict( + file=file, + inputs=self.inputs, + targets=self.targets, + seed=self.random_state.get_state(), + test_size=None, # using auto_split + valid_size=None, # using_auto_split + num_workers=self.num_workers, + pin_memory=self.pin_memory, + allow_unfound=True, # We may have extra arrays; reproduce them. + auto_split=True, # Inherit splitting from this db. + device=self.get_device(), + quiet=self.quiet, + ) + + if override_kwargs: + if not self.quiet: + print("Overriding arguments to database cache:", override_kwargs) + arguments.update(override_kwargs) + + # now write cache + if not self.quiet: + print("Writing Cached database to", file) + + self.write_npz(file=file, + record_split_masks=True, # allows inheriting of splits from this db. + overwrite=overwrite, + return_only=False) + # now reload cached file. + return NPZDatabase(**arguments) def compute_index_mask(indices, index_pool): if not np.all(np.isin(indices, index_pool)): @@ -379,7 +716,7 @@ def prettyprint_arrays(arr_dict): Pretty-print array dictionary :return: None """ - column_format = "| {:<18} | {:<18} | {:<40} |" + column_format = "| {:<30} | {:<18} | {:<28} |" ncols = len(column_format.format("", "", "")) def printrow(*args): diff --git a/hippynn/databases/h5_pyanitools.py b/hippynn/databases/h5_pyanitools.py index bbe71e01..41c1d2cd 100644 --- a/hippynn/databases/h5_pyanitools.py +++ b/hippynn/databases/h5_pyanitools.py @@ -4,20 +4,22 @@ Note: You will need `pyanitools.py` to be importable to import this module. """ -import pyanitools - import os +import collections +from pathlib import Path + +import h5py # If dependency not available then just fail here. import numpy as np import torch -from ..tools import progress_bar, np_of_torchdefaultdtype -from ..tools import pad_np_array_to_length_with_zeros as pad_atoms +from ase.data import atomic_numbers from . import Database from .restarter import Restartable -from ase.data import atomic_numbers +from ..tools import progress_bar, np_of_torchdefaultdtype +from ._ani_reader import AniDataLoader, DataPacker numpy_map_elements = np.vectorize(atomic_numbers.__getitem__) @@ -29,7 +31,8 @@ class PyAniMethods: def extract_full_file(self, file, species_key="species"): n_atoms_max = 0 batches = [] - x = pyanitools.anidataloader(file) + x = AniDataLoader(file, driver=self.driver) # Engine=core reads the entire file at once. + sys_counter = collections.Counter() for c in progress_bar(x, desc="Data Groups", unit="group", total=x.group_size()): batch_dict = {} @@ -46,7 +49,7 @@ def extract_full_file(self, file, species_key="species"): # Special logic for species if k == species_key: - # Groups have the same species, broad-cast out the batch axis + # Groups have the same species, broadcast out the batch axis v = np.expand_dims(v, 0) # If given as strings, map to atomic elements if (not isinstance(v.dtype, type)) and issubclass(v.dtype.type, np.str_): @@ -54,13 +57,15 @@ def extract_full_file(self, file, species_key="species"): n_atoms_max = max(n_atoms_max, v.shape[1]) + sys_counter[k] += v.shape[0] batch_dict[k] = v batches.append(batch_dict) - return batches, n_atoms_max + sys_count = max(sys_counter.values()) # some variables are batch-wise, but most of these should be the same + return batches, n_atoms_max, sys_count - def determine_key_structure(self, batch_list, species_key="species"): + def determine_key_structure(self, batch_list, sys_count, n_atoms_max, species_key="species"): """Determine what arrays to pad""" batch = batch_list[0] n_atoms = batch[species_key].shape[1] @@ -79,6 +84,7 @@ def determine_key_structure(self, batch_list, species_key="species"): # dict of which axes need to be padded; # pad if the array size is equal to the number of atoms along a given axis padding_scheme = {k: [] for k in batch.keys()} + shape_scheme = {} bsize = 0 bkey = None for k, v in batch.items(): @@ -93,46 +99,56 @@ def determine_key_structure(self, batch_list, species_key="species"): if this_bsize > bsize: bsize = this_bsize bkey = k + shape_scheme[k] = list(v.shape) + for axis in padding_scheme[k]: + shape_scheme[k][axis] = n_atoms_max + shape_scheme[k][0] = sys_count padding_scheme['sys_number'] = [] - return padding_scheme, bkey + return padding_scheme, shape_scheme, bkey + + def process_batches(self, batches, n_atoms_max, sys_count, species_key="species"): + + # Get padding abd shape info and batch size key + padding_scheme, shape_scheme, size_key =\ + self.determine_key_structure(batches, sys_count, n_atoms_max, species_key=species_key) - def process_batches(self, batches, n_atoms_max, species_key="species"): + # add system numbers to the final arrays + shape_scheme['sys_number'] = [sys_count, ] + batches[0]['sys_number'] = np.asarray([0], dtype=np.int64) - # Get padding info and batch size key - padding_scheme, size_key = self.determine_key_structure(batches, species_key=species_key) + arr_dict = {} + for k, shape in shape_scheme.items(): + dtype = batches[0][k].dtype + arr_dict[k] = np.zeros(shape, dtype=dtype) - # Pad the arrays - padded_batches = [] + sys_start = 0 for i, b in enumerate(progress_bar(batches, desc="Processing Batches", unit="batch")): - pb = {} + # Get batch metadata + n_sys = b[size_key].shape[0] b['sys_number'] = np.asarray([i], dtype=np.int64) - for k, v in b.items(): - bsize = len(b[size_key]) - # Expand species array to fit batch size - if k == species_key: + sys_end = sys_start + n_sys + # n_atoms_batch = b[species_key].shape[1] # don't need this! - v = np.repeat(v, bsize, axis=0) + for k, arr in b.items(): - # Perform padding as needed - for axis in padding_scheme[k]: - v = pad_atoms(v, n_atoms_max, axis=axis) - if 0 not in padding_scheme[k] and v.shape[0] == 1: - v = np.broadcast_to(v, (bsize, *v.shape[1:])) + if k == species_key: + arr = np.repeat(arr, n_sys, axis=0) - pb[k] = v + # set up slicing for non-batch axes + where = tuple(slice(0, s) for s in arr.shape[1:]) + # add batch slicing + where = (slice(sys_start, sys_end), *where) - padded_batches.append(pb) + # store array! + arr_dict[k][where] = arr - arr_dict = {} + sys_start += n_sys - for k in b.keys(): - try: - arr_dict[k] = np.concatenate([pb[k] for pb in padded_batches]) - except ValueError as ve: - print("Error occured:",ve) - print("Skipping key:",k) - continue + if sys_start != sys_count: + # Just in case someone tries to change this code later, + # Here is a consistency check. + raise RuntimeError(f"Number of systems was inconsistent: {sys_start} vs. {sys_count}") return arr_dict @@ -157,32 +173,45 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): class PyAniFileDB(Database, PyAniMethods, Restartable): - def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, **kwargs): + def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, driver='core', **kwargs): + """ + + :param file: + :param inputs: + :param targets: + :param args: + :param allow_unfound: + :param species_key: + :param quiet: + :param driver: h5 file driver. + :param kwargs: + """ self.file = file self.inputs = inputs self.targets = targets self.species_key = species_key + self.driver = driver arr_dict = self.load_arrays(quiet=quiet, allow_unfound=allow_unfound) super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) self.restarter = self.make_restarter( - file, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound, + file, inputs, targets, *args, **kwargs, driver=driver, quiet=quiet, allow_unfound=allow_unfound, species_key=species_key, ) def load_arrays(self, allow_unfound=False, quiet=False): if not quiet: print("Loading arrays from", self.file) - batches, n_atoms_max = self.extract_full_file(self.file,species_key=self.species_key) - arr_dict = self.process_batches(batches, n_atoms_max,species_key=self.species_key) + batches, n_atoms_max, sys_count = self.extract_full_file(self.file, species_key=self.species_key) + arr_dict = self.process_batches(batches, n_atoms_max, sys_count, species_key=self.species_key) arr_dict = self.filter_arrays(arr_dict, quiet=quiet, allow_unfound=allow_unfound) return arr_dict class PyAniDirectoryDB(Database, PyAniMethods, Restartable): - def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False,species_key="species", + def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False, species_key="species", quiet=False,**kwargs): self.directory = directory @@ -220,6 +249,103 @@ def load_arrays(self, allow_unfound=False, quiet=False): n_atoms_max = max(max_atoms_list) batches = [item for fb in data for item in fb] - arr_dict = self.process_batches(batches, n_atoms_max,species_key=self.species_key) + arr_dict = self.process_batches(batches, n_atoms_max, species_key=self.species_key) arr_dict = self.filter_arrays(arr_dict, quiet=quiet, allow_unfound=allow_unfound) return arr_dict + + +def write_h5(database: Database, split: str = None, file: Path = None, species_key: str = 'species', overwrite=False): + """ + :param database: database to get + :param split: str, None, or True; selects data split to save. + If None, contents of arr_dict are used. + If True, save all splits and save split masks as well. + :param file: where to save the database. + :param species_key: the key used for system contents (padding and chemical formulas) + :param overwrite: boolean; enables over-writing of h5 file. + :return: dictionary of ANI-style systems. + """ + + if split is True: + database = database.write_npz("", record_split_masks=True, return_only=True) + print("writenpz", database.keys()) + elif split in database.splits: + database = database.splits[split] + database = {k: v.to('cpu').numpy() for k,v in database.items()} + elif split is None: + database = database.arr_dict + else: + raise Exception(f"Unknown split variable supplied (must be True, None, or str): {split:s}") + + if file is not None: + if Path(file).exists(): + if overwrite: + print("Overwriting h5 file:", file) + Path(file).unlink() + else: + raise FileExistsError(f"h5path {file:s} exists.") + print("Saving h5 file:", file) + packer = DataPacker(file) + else: + packer = None + + db_species = database[species_key] + total_systems = db_species.shape[0] + n_atoms_max = db_species.shape[1] + + # determine which keys have second shape of N atoms + is_atom_var = { + k: (len(k_arr.shape) > 1) and (k_arr.shape[1] == n_atoms_max) for k, k_arr in database.items() + } + del (is_atom_var[species_key]) # species handled separately + + # Create the data dictionary + # Maps hashes of system chemical formulas to dictionaries of system information. + data = {} + for i, db_mol_species in enumerate(db_species): + # We can append the system data to an existing set of system data + mol_n_atom = np.count_nonzero(db_mol_species) + if np.count_nonzero(db_mol_species[mol_n_atom:]) > 0: + raise ValueError(f"Malformed species row with non-standard padding: {db_mol_species}") + db_mol_species = db_mol_species[:mol_n_atom] + mhash = hash(np.array(db_mol_species).tobytes()) # molecule hash + + if mhash not in data: + # need to make a new mol entry in the data for this chemical formula + # the mol dictionary maps a data key to the array for that mol. + # The species key has one value, but the other keys can store batch of values. + mol = {species_key: db_mol_species} + for k, k_is_atom_based in is_atom_var.items(): + db_arr = database[k] + store_arr = db_arr[i, :mol_n_atom] if k_is_atom_based else db_arr[i] + mol[k] = [store_arr] + data[mhash] = mol + else: + # If there is already an entry for this chemical formula, append it to the current one + mol = data[mhash] + mol_species = mol[species_key] + # First sanity check that the mhash hash we are using is unique, or else BAD. + if (db_mol_species.shape != mol_species.shape) or not (db_mol_species == mol_species).all(): + raise ValueError("Error. Hash not unique. You should never see this.") + + # Now append the system to the set of systems with this chemical formula. + for k, k_is_atom_based in is_atom_var.items(): + db_arr = database[k] + store_arr = db_arr[i, :mol_n_atom] if k_is_atom_based else db_arr[i] + mol[k].append(store_arr) + + # post-process atom variables into arrays and handle strings. + for mhash, mol in data.items(): + for k in is_atom_var.keys(): + mol[k] = np.asarray(mol[k]) + + if np.issubdtype(mol[k].dtype, np.unicode_): + mol[k] = [el.encode('utf-8') for el in list(mol[k])] + mol[k] = np.array(mol[k]) + # Store data + if packer is not None: + for key in data: + packer.store_data(str(key), **data[key]) + packer.cleanup() + + return data diff --git a/hippynn/experiment/metric_tracker.py b/hippynn/experiment/metric_tracker.py index 989042c7..f43426e6 100644 --- a/hippynn/experiment/metric_tracker.py +++ b/hippynn/experiment/metric_tracker.py @@ -30,7 +30,7 @@ class MetricTracker: """ - def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train", "valid", "test")): + def __init__(self, metric_names, stopping_key, quiet=False): """ :param metric_names: @@ -48,7 +48,7 @@ def __init__(self, metric_names, stopping_key, quiet=False, split_names=("train" self.n_metrics = len(metric_names) # State variables - self.best_metric_values = {split: {mtype: float("inf") for mtype in self.metric_names} for split in split_names} + self.best_metric_values = {} self.other_metric_values = {} self.best_model = None self.epoch_times = [] @@ -78,7 +78,18 @@ def register_metrics(self, metric_info, when): better_metrics = {k: {} for k in self.best_metric_values} for split_type, typevals in metric_info.items(): for mname, mval in typevals.items(): - better = self.best_metric_values[split_type][mname] > mval + try: + old_best = self.best_metric_values[split_type][mname] + better = old_best > mval + del old_best # marking not needed. + except KeyError: + if split_type not in self.best_metric_values: + # Haven't seen this split before! + print("ADDING ",split_type) + self.best_metric_values[split_type] = {} + better_metrics[split_type] = {} + better = True # old best was not found! + if better: self.best_metric_values[split_type][mname] = mval better_metrics[split_type][mname] = better @@ -97,13 +108,17 @@ def register_metrics(self, metric_info, when): return better_metrics, better_model, stopping_key_metric - def evaluation_print(self, evaluation_dict): - if self.quiet: + def evaluation_print(self, evaluation_dict, quiet=None): + if quiet is None: + quiet = self.quiet + if quiet: return table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width) - def evaluation_print_better(self, evaluation_dict, better_dict): - if self.quiet: + def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None): + if quiet is None: + quiet = self.quiet + if quiet: return table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width) if self.stopping_key: @@ -119,14 +134,14 @@ def plot_over_time(self): # Driver for printing evaluation table results, with * for better entries. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, ncs): +def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns): """ Print metric results as a table, add a '*' character for metrics in better_dict. :param evaluation_dict: dict[eval type]->dict[metric]->value :param better_dict: dict[eval type]->dict[metric]->bool :param metric_names: Names - :param ncs: Number of columns for name fields. + :param n_columns: Number of columns for name fields. :return: None """ type_names = evaluation_dict.keys() @@ -139,8 +154,8 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc n_types = len(type_names) - header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names) - rowstring = "{:<" + str(ncs) + "}: " + " {}{:>10.5g}" * n_types + header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) + rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types print(header) print("-" * len(header)) @@ -151,13 +166,13 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, nc # Driver for printing evaluation table results. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print(evaluation_dict, metric_names, ncs): +def table_evaluation_print(evaluation_dict, metric_names, n_columns): """ Print metric results as a table. :param evaluation_dict: dict[eval type]->dict[metric]->value :param metric_names: Names - :param ncs: Number of columns for name fields. + :param n_columns: Number of columns for name fields. :return: None """ @@ -166,8 +181,8 @@ def table_evaluation_print(evaluation_dict, metric_names, ncs): n_types = len(type_names) - header = " " * (ncs + 2) + "".join("{:>14}".format(tn) for tn in type_names) - rowstring = "{:<" + str(ncs) + "}: " + " {:>10.5g}" * n_types + header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) + rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types print(header) print("-" * len(header)) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 8d3b37cb..f6aee191 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -207,8 +207,8 @@ def setup_training( model, loss, evaluator, optimizer, setup_params.device or tools.device_fallback() ) - metrics = MetricTracker(evaluator.loss_names, stopping_key=controller.stopping_key) - + metrics = MetricTracker(evaluator.loss_names, + stopping_key=controller.stopping_key) return training_modules, controller, metrics @@ -353,17 +353,21 @@ def test_model(database, evaluator, batch_size, when, metric_tracker=None): if metric_tracker is None: metric_tracker = MetricTracker(evaluator.loss_names, stopping_key=None) - metric_tracker.quiet = False + + # A little dance to make sure train, valid, test always come first, when present. + basic_splits = ["train", "valid", "test"] + basic_splits = [s for s in basic_splits if s in database.splits] + splits = basic_splits + [s for s in database.splits if s not in basic_splits] + evaluation_data = collections.OrderedDict( ( - ("train", database.make_generator("train", "eval", batch_size)), - ("valid", database.make_generator("valid", "eval", batch_size)), - ("test", database.make_generator("test", "eval", batch_size)), + (key, database.make_generator(key, "eval", batch_size)) + for key in splits ) ) evaluation_metrics = {k: evaluator.evaluate(gen, eval_type=k, when=when) for k, gen in evaluation_data.items()} metric_tracker.register_metrics(evaluation_metrics, when=when) - metric_tracker.evaluation_print(evaluation_metrics) + metric_tracker.evaluation_print(evaluation_metrics, quiet=False) return metric_tracker diff --git a/hippynn/tools.py b/hippynn/tools.py index 680a9a99..d2c78133 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -151,11 +151,25 @@ def pad_np_array_to_length_with_zeros(array, length, axis=0): pad_width[axis][1] = m return np.pad(array, pad_width, mode="constant") - +def unsqueeze_multiple(tensor, dims: tuple): + """ + Adds unsqueezing dimensions dimensions + :param tensor: + :param dims: + :return: + """ + if len(dims)==0: + return tensor + dims = tuple(sorted(dims)) + while dims: + d, *rest = dims + tensor = tensor.unsqueeze(d) + dims = tuple(d+1 for d in rest) + return tensor def np_of_torchdefaultdtype(): return torch.ones(1, dtype=torch.get_default_dtype()).numpy().dtype -def is_equal_state_dict(d1, d2): +def is_equal_state_dict(d1, d2, raise_where=False): """ Checks if two pytorch state dictionaries are equal. Calls itself recursively if the value for a parameter is a dictionary. @@ -163,25 +177,41 @@ def is_equal_state_dict(d1, d2): :param d1: :param d2: + :param raise_where: if not equal, use an assertion to fail. :return: """ if set(d1.keys()) != set(d2.keys()): + if raise_where: + raise AssertionError(f"State dictionaries not equal keys: {set(d1.keys())=}, {d2.keys()=}") # They have different sets of keys. return False for k in d1: v1 = d1[k] v2 = d2[k] if type(v1) != type(v2): + if raise_where: + raise AssertionError(f"State dictionaries not equal at key {k}; {v1} != {v2})") return False if isinstance(v1, torch.Tensor): if torch.equal(v1, v2): continue else: + + if raise_where: + if v1.shape!=v2.shape: + raise AssertionError(f"State dictionaries not equal at key {k}" + + f" due to shapes: {v1.shape=},{v2.shape=}") + where_not_equal = torch.where(torch.ne(v1,v2)) + raise AssertionError(f"State dictionaries not equal at key {k}" + + f" at locations {where_not_equal};" + + f" {v1[where_not_equal]} != {v2[where_not_equal]})") return False elif isinstance(v1, dict): # call recursive: - return is_equal_state_dict(v1, v2) + return is_equal_state_dict(v1, v2, raise_where=raise_where) elif v1 != v2: + if raise_where: + raise AssertionError(f"State dictionaries not equal at key {k}; {v1} != {v2})") return False return True diff --git a/tests/dataset_writing.py b/tests/dataset_writing.py new file mode 100644 index 00000000..6bc67148 --- /dev/null +++ b/tests/dataset_writing.py @@ -0,0 +1,112 @@ +import glob +import torch + +torch.set_default_dtype(torch.float64) + +from hippynn.databases.h5_pyanitools import PyAniFileDB +from hippynn.databases import NPZDatabase +from hippynn.tools import is_equal_state_dict + + +# compare if databases are equal, split by split +def eqsplit(db1, db2, raise_error=True): + return is_equal_state_dict(db1.splits, db2.splits, raise_where=raise_error) + + +if __name__ == "__main__": + CLEANUP = True # delete datasets afterwards + # Example dataset + location = "../../datasets/new_qm9_clean.npz" + seed = 1 + num_workers = 0 + db_info = {} + db1 = NPZDatabase( + file=location, + seed=seed, + num_workers=num_workers, + allow_unfound=True, + **db_info, + inputs=None, + targets=None, + ) + + # test remove_high_property + db1.remove_high_property("E", species_key="Z", atomwise=False, norm_per_atom=True, std_factor=5) + + # throw stuff away + db1.make_random_split("random stuff", 0.99) + del db1.splits["random stuff"] # remove something at random + + new_ani_file = "TEST_clean_ani1x.h5" + new_npz_file = "TEST_clean_ani1x.npz" + + # Divide up the dataset in a bunch of splits. + db1.make_random_split("first", 0.5) + db1.make_random_split("second", 0.2) + db1.make_random_split("third", 3) # integer + db1.split_the_rest("remaining") + # write an npz version and reload it. + db1.write_npz(file=new_npz_file, record_split_masks=True, overwrite=True) + db3 = NPZDatabase(file=new_npz_file, seed=seed, num_workers=num_workers, allow_unfound=True, inputs=None, targets=None, auto_split=True, **db_info) + + # write an h5 version and reload it. + db1.write_h5(split=True, h5path=new_ani_file, species_key="Z", overwrite=True) + db2 = PyAniFileDB( + file=new_ani_file, + species_key="Z", + seed=seed, + num_workers=num_workers, + allow_unfound=True, + **db_info, + inputs=None, + targets=None, + auto_split=True, + ) + new_ani_filetwo = "TEST_clean_ani1x_2.h5" + # trim this dataset earlier than others. + db2.trim_by_species("Z") + # write and load new database. + db2.write_h5(split=True, h5path=new_ani_filetwo, species_key="Z", overwrite=True) + db4 = PyAniFileDB( + file=new_ani_filetwo, + species_key="Z", + seed=seed, + num_workers=num_workers, + allow_unfound=True, + **db_info, + inputs=None, + targets=None, + auto_split=True, + ) + + for i, d in enumerate([db1, db2, db3, db4]): + print("sorting", i) + d.sort_by_index() + print("trimming", i) + d.trim_by_species("Z", keep_splits_same_size=True) # can do either way if disable caching test. + + # "sys_number" comes from h5 format databases, but not present in others. + for d in [db2, db4]: + for s in d.splits: + del d.splits[s]["sys_number"] + + db1.add_split_masks() # this first didn't ever get split masks! add them now. + + print("NPZ Equality:", eqsplit(db1, db3)) + print("H5 Equality:", eqsplit(db2, db4)) + print("NPZ-H5 Equality1:", eqsplit(db1, db2)) + print("NPZ-H5 Equality2:", eqsplit(db2, db3)) + + # test caching routine. + db2p = db2.make_database_cache(file="TEST_CACHE_FROMH5.npz", overwrite=True, quiet=True) + print("H5 to cache Equality:", eqsplit(db2, db2p)) + db3p = db3.make_database_cache(file="TEST_CACHE_FROMNPZ.npz", overwrite=True, quiet=True) + print("NPZ to cache Equality:", eqsplit(db3, db3p)) + + for ext in ["npz", "h5"]: + + for file in glob.glob(f"./TEST*.{ext}"): + print(file) + import os + + os.remove(file)