From bb38683e1563eec4f96bc38c60f2315cddb532ee Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 29 Mar 2024 09:44:38 -0700 Subject: [PATCH 01/37] Added in download_from_url function to remote that checks md5 checksum. --- modelforge/tests/test_remote.py | 44 ++++++++++++++++++++++++++ modelforge/utils/remote.py | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/modelforge/tests/test_remote.py b/modelforge/tests/test_remote.py index f8653db4..8e314b56 100644 --- a/modelforge/tests/test_remote.py +++ b/modelforge/tests/test_remote.py @@ -26,6 +26,50 @@ def test_is_url(): ) +def test_download_from_url(prep_temp_dir): + url = "https://raw.githubusercontent.com/choderalab/modelforge/e3e65e15e23ccc55d03dd7abb4b9add7a7dd15c3/modelforge/modelforge.py" + checksum = "66ec18ca5db3df5791ff1ffc584363a8" + # Download the file + download_from_url( + url, + md5_checksum=checksum, + output_path=str(prep_temp_dir), + output_filename="modelforge.py", + force_download=True, + ) + + file_name_path = str(prep_temp_dir) + "/modelforge.py" + assert os.path.isfile(file_name_path) + + # create a dummy document to test the case where + # the checksum doesn't match so it will redownload + with open(file_name_path, "w") as f: + f.write("dummy document") + + # This will force a download because the checksum doesn't match + download_from_url( + url, + md5_checksum=checksum, + output_path=str(prep_temp_dir), + output_filename="modelforge.py", + force_download=False, + ) + + file_name_path = str(prep_temp_dir) + "/modelforge.py" + assert os.path.isfile(file_name_path) + + # let us change the expected checksum to cause a failure + with pytest.raises(Exception): + url = "https://choderalab.com/modelforge.py" + download_from_url( + url, + md5_checksum="checksum_garbage", + output_path=str(prep_temp_dir), + output_filename="modelforge.py", + force_download=True, + ) + + def test_download_from_figshare(prep_temp_dir): url = "https://figshare.com/ndownloader/files/22247589" name = download_from_figshare( diff --git a/modelforge/utils/remote.py b/modelforge/utils/remote.py index 6252c84e..7a68c337 100644 --- a/modelforge/utils/remote.py +++ b/modelforge/utils/remote.py @@ -93,6 +93,61 @@ def calculate_md5_checksum(file_name: str, file_path: str) -> str: return file_hash.hexdigest() +def download_from_url( + url: str, + md5_checksum: str, + output_path: str, + output_filename: str, + force_download=False, +) -> str: + + import requests + import os + from tqdm import tqdm + + chunk_size = 512 + + if os.path.isfile(f"{output_path}/{output_filename}"): + calculated_checksum = calculate_md5_checksum( + file_name=output_filename, file_path=output_path + ) + if calculated_checksum != md5_checksum: + force_download = True + logger.debug( + f"Checksum {calculated_checksum} of existing file {output_filename} does not match expected checksum {md5_checksum}, re-downloading." + ) + + if not os.path.isfile(f"{output_path}/{output_filename}") or force_download: + logger.debug( + f"Downloading datafile from {url} to {output_path}/{output_filename}." + ) + + r = requests.get(url, stream=True) + + os.makedirs(output_path, exist_ok=True) + + with open(f"{output_path}/{output_filename}", "wb") as fd: + for chunk in tqdm( + r.iter_content(chunk_size=chunk_size), + ascii=True, + desc="downloading", + ): + fd.write(chunk) + calculated_checksum = calculate_md5_checksum( + file_name=output_filename, file_path=output_path + ) + if calculated_checksum != md5_checksum: + raise Exception( + f"Checksum of downloaded file {calculated_checksum} does not match expected checksum {md5_checksum}." + ) + + else: # if the file exists and we don't set force_download to True, just use the cached version + logger.debug(f"Datafile {output_filename} already exists in {output_path}.") + logger.debug( + "Using previously downloaded file; set force_download=True to re-download." + ) + + # Figshare helper functions def download_from_figshare( url: str, md5_checksum: str, output_path: str, force_download=False From 142a53ea48f99a568cc15523acad16c5c7eb5609 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 4 Apr 2024 07:56:45 -0700 Subject: [PATCH 02/37] Checksum checking is in place for each step. Dataset prep logic updated. --- modelforge/dataset/dataset.py | 117 ++++++++++++++++++++++++++++------ modelforge/dataset/qm9.py | 80 +++++++++++++++++++---- scripts/dataset_curation.py | 56 ++++++++++------ scripts/training.py | 2 +- 4 files changed, 204 insertions(+), 51 deletions(-) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 48714562..9ae8ceb8 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -226,7 +226,7 @@ class HDF5Dataset: """ def __init__( - self, raw_data_file: str, processed_data_file: str, local_cache_dir: str + self, # raw_data_file: str, processed_data_file: str, local_cache_dir: str ): """ Initializes the HDF5Dataset with paths to raw and processed data files. @@ -241,11 +241,25 @@ def __init__( Directory to store temporary processing files. """ - self.raw_data_file = raw_data_file - self.processed_data_file = processed_data_file + # self.raw_data_file = raw_data_file + # self.processed_data_file = processed_data_file self.hdf5data: Optional[Dict[str, List[np.ndarray]]] = None self.numpy_data: Optional[np.ndarray] = None - self.local_cache_dir = local_cache_dir + # self.local_cache_dir = local_cache_dir + + def _ungzip_hdf5(self) -> None: + """ + Unzips an HDF5.gz file. + + Examples + ------- + """ + import gzip + import shutil + + with gzip.open(self.raw_data_file, "rb") as gz_file: + with open(self.unzipped_data_file, "wb") as out_file: + shutil.copyfileobj(gz_file, out_file) def _from_hdf5(self) -> None: """ @@ -257,24 +271,27 @@ def _from_hdf5(self) -> None: >>> processed_data = hdf5_data._from_hdf5() """ - import gzip from collections import OrderedDict import h5py import tqdm - import shutil - - log.debug(f"Processing and extracting data from {self.raw_data_file}") # this will create an unzipped file which we can then load in # this is substantially faster than passing gz_file directly to h5py.File() # by avoiding data chunking issues. - temp_hdf5_file = f"{self.local_cache_dir}/temp_unzipped.hdf5" - with gzip.open(self.raw_data_file, "rb") as gz_file: - with open(temp_hdf5_file, "wb") as out_file: - shutil.copyfileobj(gz_file, out_file) + temp_hdf5_file = f"{self.local_cache_dir}/{self.unzipped_data_file}" log.debug("Reading in and processing hdf5 file ...") + from modelforge.utils.remote import calculate_md5_checksum + + # add in a check to make sure the checksum matches before loading the file + # this also appears in the dataset factory, but having this also here is good if used as a standalone function + checksum = calculate_md5_checksum(self.unzipped_data_file, self.local_cache_dir) + + if checksum != self.md5_unzipped_checksum: + raise ValueError( + f"Checksum mismatch for unzipped data file {temp_hdf5_file}. Found {checksum}, Expected {self.md5_unzipped_checksum}" + ) with h5py.File(temp_hdf5_file, "r") as hf: # create dicts to store data for each format type @@ -408,6 +425,15 @@ def _from_file_cache(self) -> None: >>> hdf5_data = HDF5Dataset("raw_data.hdf5", "processed_data.npz") >>> processed_data = hdf5_data._from_file_cache() """ + from modelforge.utils.remote import calculate_md5_checksum + + checksum = calculate_md5_checksum( + self.processed_data_file, self.local_cache_dir + ) + if checksum != self.md5_processed_checksum: + raise ValueError( + f"Checksum mismatch for processed data file {self.processed_data_file}.Found {checksum}, expected {self.md5_processed_checksum}" + ) log.debug(f"Loading processed data from {self.processed_data_file}") self.numpy_data = np.load(self.processed_data_file) @@ -466,15 +492,68 @@ def _load_or_process_data( The HDF5 dataset instance to use. """ - # if not cached, download and process - if not os.path.exists(data.processed_data_file): - if not os.path.exists(data.raw_data_file): - data._download() - # load from hdf5 and process + import os + from modelforge.utils.remote import calculate_md5_checksum + + # if the dataset was initialize with force_download, we will skip all other checking and just download and process + if data.force_download: + data._download() + data._ungzip_hdf5() data._from_hdf5() - # save to cache data._to_file_cache() - # load from cache + else: + file_loaded = False + if os.path.exists(data.processed_data_file): + checksum = calculate_md5_checksum( + data.processed_data_file, data.local_cache_dir + ) + if checksum != data.md5_processed_checksum: + + log.warning( + f"Checksum mismatch for processed data file {data.processed_data_file}. Re-processing." + ) + log.debug( + f"Checksum mismatch, found {checksum}, expected {data.md5_processed_checksum}" + ) + else: + data._from_file_cache() + file_loaded = True + + if os.path.exists(data.unzipped_data_file) and not file_loaded: + checksum = calculate_md5_checksum( + data.unzipped_data_file, data.local_cache_dir + ) + if checksum != data.md5_unzipped_checksum: + log.warning( + f"Checksum mismatch for unzipped data file {data.unzipped_data_file}. Re-processing." + ) + log.debug( + f"Checksum mismatch, found {checksum}, expected {data.md5_unzipped_checksum}" + ) + + # the download function automatically checks if the file is already downloaded and if the checksum matches + # if either are false, it will redownload the file + data._download() + data._ungzip_hdf5() + checksum = calculate_md5_checksum( + data.unzipped_data_file, data.local_cache_dir + ) + if checksum != data.md5_unzipped_checksum: + raise ValueError( + f"Checksum mismatch for unzipped data file {data.unzipped_data_file}. Expected {data.md5_unzipped_checksum}, found {checksum}. Please check the raw gzipped file or try running with force_download." + ) + + data._from_hdf5() + data._to_file_cache() + else: + data._from_hdf5() + data._to_file_cache() + else: + data._download() + data._ungzip_hdf5() + data._from_hdf5() + data._to_file_cache() + data._from_file_cache() @staticmethod diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index afdecb62..05327119 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -56,6 +56,7 @@ def __init__( dataset_name: str = "QM9", for_unit_testing: bool = False, local_cache_dir: str = ".", + force_download: bool = False, overwrite: bool = False, ) -> None: """ @@ -69,6 +70,8 @@ def __init__( If set to True, a subset of the dataset is used for unit testing purposes; by default False. local_cache_dir: str, optional Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. Examples -------- @@ -76,6 +79,8 @@ def __init__( >>> test_data = QM9Dataset(for_unit_testing=True) # Testing subset """ + self.force_download = force_download + _default_properties_of_interest = [ "geometry", "atomic_numbers", @@ -87,17 +92,29 @@ def __init__( if for_unit_testing: dataset_name = f"{dataset_name}_subset" - super().__init__( - f"{local_cache_dir}/{dataset_name}_cache.hdf5.gz", - f"{local_cache_dir}/{dataset_name}_processed.npz", - local_cache_dir=local_cache_dir, - ) + super().__init__() + # super().__init__( + # f"{local_cache_dir}/{dataset_name}_cache.hdf5.gz", + # f"{local_cache_dir}/{dataset_name}_processed.npz", + # local_cache_dir=local_cache_dir, + # ) self.dataset_name = dataset_name self.for_unit_testing = for_unit_testing - self.test_url = ( - "https://github.com/wiederm/gm9/raw/main/qm9_dataset_n100.hdf5.gz" - ) - self.full_url = "https://github.com/wiederm/gm9/raw/main/qm9.hdf5.gz" + self.local_cache_dir = local_cache_dir + + # self.test_url = ( + # "https://github.com/wiederm/gm9/raw/main/qm9_dataset_n100.hdf5.gz" + # ) + # to ensure we have the same checksum each time, we need to copy the permalink to the file + # otherwise the checksum will change each time we download the file from the same link because of how github works + # self.test_url = "https://github.com/wiederm/gm9/blob/264af75e41e6e296f400d9a1019f082b21d5bc36/qm9_dataset_n100.hdf5.gz" + self.test_url = "https://www.dropbox.com/scl/fi/9jeselknatcw9xi0qp940/qm9_dataset_n100.hdf5.gz?rlkey=50of7gn2s12i65c6j06r73c97&dl=1" + # self.test_url = "https://drive.google.com/file/d/1yE8krZo3MMI84unZH0_H01C4m5RkXJ8d/view?usp=sharing" + # self.full_url = "https://github.com/wiederm/gm9/raw/main/qm9.hdf5.gz" + + # to ensure we have the same checksum each time, we need to copy the permalink to the file + self.full_url = "https://www.dropbox.com/scl/fi/4wu7zlpuuixttp0u741rv/qm9_dataset.hdf5.gz?rlkey=nszkqt2t4kmghih5mt4ssppvo&dl=1" + # self.full_url = "https://github.com/wiederm/gm9/blob/264af75e41e6e296f400d9a1019f082b21d5bc36/qm9.hdf5.gz" self._ase = { "H": -1313.4668615546, "C": -99366.70745535441, @@ -105,6 +122,34 @@ def __init__( "O": -197082.0671774158, "F": -261811.54555874597, } + from loguru import logger + + if self.for_unit_testing: + + self.url = self.test_url + self.md5_raw_checksum = "af3afda5c3265c9c096935ab060f537a" + self.raw_data_file = "qm9_dataset_n100.hdf5.gz" + + # define the name and checksum of the unzipped file + self.unzipped_data_file = "qm9_dataset_n100.hdf5" + self.md5_unzipped_checksum = "77df0e1df7a5ec5629be52181e82a7d7" + + self.processed_data_file = "qm9_dataset_n100_processed.npz" + self.md5_processed_checksum = "9d671b54f7b9d454db9a3dd7f4ef2020" + logger.info("Using test dataset") + + else: + self.url = self.full_url + self.md5_raw_checksum = "d172127848de114bd9cc47da2bc72566" + self.raw_data_file = "qm9_dataset.hdf5.gz" + + self.unzipped_data_file = "qm9_dataset.hdf5" + self.md5_unzipped_checksum = "0b22dc048f3361875889f832527438db" + + self.processed_data_file = "qm9_dataset_processed.npz" + self.md5_processed_checksum = "62d98cf38bcf02966e1fa2d9e44b3fa0" + + logger.info("Using full dataset") @property def atomic_self_energies(self): @@ -178,7 +223,16 @@ def _download(self) -> None: >>> data.download() # Downloads the dataset from Google Drive """ - from modelforge.dataset.utils import _download_from_url - - url = self.test_url if self.for_unit_testing else self.full_url - _download_from_url(url, self.raw_data_file) + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.md5_raw_checksum, + output_path=self.local_cache_dir, + output_filename=self.raw_data_file, + force_download=self.force_download, + ) + # from modelforge.dataset.utils import _download_from_url + # + # url = self.test_url if self.for_unit_testing else self.full_url + # _download_from_url(url, self.raw_data_file) diff --git a/scripts/dataset_curation.py b/scripts/dataset_curation.py index fd953341..d3ad9d1c 100644 --- a/scripts/dataset_curation.py +++ b/scripts/dataset_curation.py @@ -183,6 +183,7 @@ def QM9( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + unit_testing_max_records=None, ): """ This fetches and process the QM9 dataset into a curated hdf5 file. @@ -218,7 +219,13 @@ def QM9( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - qm9.process(force_download=force_download) + if unit_testing_max_records is None: + qm9.process(force_download=force_download) + else: + qm9.process( + force_download=force_download, + unit_testing_max_records=unit_testing_max_records, + ) def ANI1x( @@ -320,28 +327,41 @@ def ANI2x( # define the local path prefix local_prefix = "/Users/cri/Documents/Projects-msk/datasets" - -# we will save all the files to a central location output_file_dir = f"{local_prefix}/hdf5_files" -# SPICE 2 dataset -local_cache_dir = f"{local_prefix}/spice2_dataset" -hdf5_file_name = "spice_2_dataset.hdf5" - -SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) - -# SPICE 1.1.4 OpenFF dataset -local_cache_dir = f"{local_prefix}/spice_openff_dataset" -hdf5_file_name = "spice_114_openff_dataset.hdf5" - -SPICE_114_OpenFF(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# # QM9 dataset +local_cache_dir = f"{local_prefix}/qm9_dataset" +hdf5_file_name = "qm9_dataset.hdf5" +QM9( + hdf5_file_name, + output_file_dir, + local_cache_dir, + force_download=False, + unit_testing_max_records=100, +) -# SPICE 1.1.4 dataset -local_cache_dir = f"{local_prefix}/spice_114_dataset" -hdf5_file_name = "spice_114_dataset.hdf5" -SPICE_114(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# we will save all the files to a central location +# output_file_dir = f"{local_prefix}/hdf5_files" +# +# # SPICE 2 dataset +# local_cache_dir = f"{local_prefix}/spice2_dataset" +# hdf5_file_name = "spice_2_dataset.hdf5" +# +# SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# +# # SPICE 1.1.4 OpenFF dataset +# local_cache_dir = f"{local_prefix}/spice_openff_dataset" +# hdf5_file_name = "spice_114_openff_dataset.hdf5" +# +# SPICE_114_OpenFF(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# +# # SPICE 1.1.4 dataset +# local_cache_dir = f"{local_prefix}/spice_114_dataset" +# hdf5_file_name = "spice_114_dataset.hdf5" # +# SPICE_114(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# # # # QM9 dataset # local_cache_dir = f"{local_prefix}/qm9_dataset" # hdf5_file_name = "qm9_dataset.hdf5" diff --git a/scripts/training.py b/scripts/training.py index cc0f9eba..d7a9fa38 100644 --- a/scripts/training.py +++ b/scripts/training.py @@ -10,7 +10,7 @@ logger = TensorBoardLogger("tb_logs", name="training") # Set up dataset -data = QM9Dataset() +data = QM9Dataset(force_download=True, for_unit_testing=True) dataset = TorchDataModule( data, batch_size=512, splitting_strategy=FirstComeFirstServeSplittingStrategy() ) From 430121b8952e21a7e6b15b6c4e67825c0cd4bf05 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 4 Apr 2024 11:43:05 -0700 Subject: [PATCH 03/37] ensuring local_cache_dir is being used in dataset.py --- modelforge/dataset/dataset.py | 69 ++++++++++++++++-------- modelforge/dataset/qm9.py | 99 +++++++++++++++++++++++------------ 2 files changed, 113 insertions(+), 55 deletions(-) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 9ae8ceb8..71c6ffee 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -226,26 +226,41 @@ class HDF5Dataset: """ def __init__( - self, # raw_data_file: str, processed_data_file: str, local_cache_dir: str + self, + url: str, + gz_data_file: Dict[str, str], + hdf5_data_file: Dict[str, str], + processed_data_file: Dict[str, str], + local_cache_dir: str, + force_download: bool = False, ): """ Initializes the HDF5Dataset with paths to raw and processed data files. Parameters ---------- - raw_data_file : str - Path to the raw HDF5 data file. - processed_data_file : str - Path to the processed data file. + url : str + URL of the hdf5.gz data file. + gz_data_file : Dict[str, str] + Name of the gzipped data file (name) and checksum (md5). + hdf5_data_file : Dict[str, str] + Name of the hdf5 data file (name) and checksum (md5). + processed_data_file : Dict[str, str] + Name of the processed npz data file (name) and checksum (md5). local_cache_dir : str - Directory to store temporary processing files. + Directory to store the files. + force_download : bool, optional + If set to True, the data will be downloaded even if it already exists. Default is False. """ + self.url = url + self.gz_data_file = gz_data_file + self.hdf5_data_file = hdf5_data_file + self.processed_data_file = processed_data_file + self.local_cache_dir = local_cache_dir + self.force_download = force_download - # self.raw_data_file = raw_data_file - # self.processed_data_file = processed_data_file self.hdf5data: Optional[Dict[str, List[np.ndarray]]] = None self.numpy_data: Optional[np.ndarray] = None - # self.local_cache_dir = local_cache_dir def _ungzip_hdf5(self) -> None: """ @@ -257,8 +272,12 @@ def _ungzip_hdf5(self) -> None: import gzip import shutil - with gzip.open(self.raw_data_file, "rb") as gz_file: - with open(self.unzipped_data_file, "wb") as out_file: + with gzip.open( + f"{self.local_cache_dir}/{self.gz_data_file['name']}", "rb" + ) as gz_file: + with open( + f"{self.local_cache_dir}/{self.hdf5_data_file['name']}", "wb" + ) as out_file: shutil.copyfileobj(gz_file, out_file) def _from_hdf5(self) -> None: @@ -279,18 +298,20 @@ def _from_hdf5(self) -> None: # this is substantially faster than passing gz_file directly to h5py.File() # by avoiding data chunking issues. - temp_hdf5_file = f"{self.local_cache_dir}/{self.unzipped_data_file}" + temp_hdf5_file = f"{self.local_cache_dir}/{self.hdf5_data_file['name']}" log.debug("Reading in and processing hdf5 file ...") from modelforge.utils.remote import calculate_md5_checksum # add in a check to make sure the checksum matches before loading the file # this also appears in the dataset factory, but having this also here is good if used as a standalone function - checksum = calculate_md5_checksum(self.unzipped_data_file, self.local_cache_dir) + checksum = calculate_md5_checksum( + self.hdf5_data_file["name"], self.local_cache_dir + ) - if checksum != self.md5_unzipped_checksum: + if checksum != self.hdf5_data_file["md5"]: raise ValueError( - f"Checksum mismatch for unzipped data file {temp_hdf5_file}. Found {checksum}, Expected {self.md5_unzipped_checksum}" + f"Checksum mismatch for unzipped data file {temp_hdf5_file}. Found {checksum}, Expected {self.hdf5_data_file['md5']}" ) with h5py.File(temp_hdf5_file, "r") as hf: @@ -428,14 +449,16 @@ def _from_file_cache(self) -> None: from modelforge.utils.remote import calculate_md5_checksum checksum = calculate_md5_checksum( - self.processed_data_file, self.local_cache_dir + self.processed_data_file["name"], self.local_cache_dir ) - if checksum != self.md5_processed_checksum: + if checksum != self.processed_data_file["md5"]: raise ValueError( - f"Checksum mismatch for processed data file {self.processed_data_file}.Found {checksum}, expected {self.md5_processed_checksum}" + f"Checksum mismatch for processed data file {self.processed_data_file}.Found {checksum}, expected {self.processed_data_file['md5']}" ) - log.debug(f"Loading processed data from {self.processed_data_file}") - self.numpy_data = np.load(self.processed_data_file) + log.debug(f"Loading processed data from {self.processed_data_file['name']}") + self.numpy_data = np.load( + f"{self.local_cache_dir}/{self.processed_data_file['name']}" + ) def _to_file_cache( self, @@ -450,10 +473,12 @@ def _to_file_cache( >>> hdf5_data = HDF5Dataset("raw_data.hdf5", "processed_data.npz") >>> hdf5_data._to_file_cache() """ - log.debug(f"Writing data cache to {self.processed_data_file}") + log.debug( + f"Writing npz file to {self.local_cache_dir}/{self.processed_data_file['name']}" + ) np.savez( - self.processed_data_file, + f"{self.local_cache_dir}/{self.processed_data_file['name']}", atomic_subsystem_counts=self.atomic_subsystem_counts, n_confs=self.n_confs, **self.hdf5data, diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index 05327119..1e38e493 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -79,8 +79,6 @@ def __init__( >>> test_data = QM9Dataset(for_unit_testing=True) # Testing subset """ - self.force_download = force_download - _default_properties_of_interest = [ "geometry", "atomic_numbers", @@ -92,7 +90,6 @@ def __init__( if for_unit_testing: dataset_name = f"{dataset_name}_subset" - super().__init__() # super().__init__( # f"{local_cache_dir}/{dataset_name}_cache.hdf5.gz", # f"{local_cache_dir}/{dataset_name}_processed.npz", @@ -100,21 +97,12 @@ def __init__( # ) self.dataset_name = dataset_name self.for_unit_testing = for_unit_testing - self.local_cache_dir = local_cache_dir + # self.local_cache_dir = local_cache_dir - # self.test_url = ( - # "https://github.com/wiederm/gm9/raw/main/qm9_dataset_n100.hdf5.gz" - # ) - # to ensure we have the same checksum each time, we need to copy the permalink to the file - # otherwise the checksum will change each time we download the file from the same link because of how github works - # self.test_url = "https://github.com/wiederm/gm9/blob/264af75e41e6e296f400d9a1019f082b21d5bc36/qm9_dataset_n100.hdf5.gz" + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download self.test_url = "https://www.dropbox.com/scl/fi/9jeselknatcw9xi0qp940/qm9_dataset_n100.hdf5.gz?rlkey=50of7gn2s12i65c6j06r73c97&dl=1" - # self.test_url = "https://drive.google.com/file/d/1yE8krZo3MMI84unZH0_H01C4m5RkXJ8d/view?usp=sharing" - # self.full_url = "https://github.com/wiederm/gm9/raw/main/qm9.hdf5.gz" - # to ensure we have the same checksum each time, we need to copy the permalink to the file self.full_url = "https://www.dropbox.com/scl/fi/4wu7zlpuuixttp0u741rv/qm9_dataset.hdf5.gz?rlkey=nszkqt2t4kmghih5mt4ssppvo&dl=1" - # self.full_url = "https://github.com/wiederm/gm9/blob/264af75e41e6e296f400d9a1019f082b21d5bc36/qm9.hdf5.gz" self._ase = { "H": -1313.4668615546, "C": -99366.70745535441, @@ -124,33 +112,76 @@ def __init__( } from loguru import logger + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + if self.for_unit_testing: - self.url = self.test_url - self.md5_raw_checksum = "af3afda5c3265c9c096935ab060f537a" - self.raw_data_file = "qm9_dataset_n100.hdf5.gz" + url = self.test_url + gz_data_file = { + "name": "qm9_dataset_n100.hdf5.gz", + "md5": "af3afda5c3265c9c096935ab060f537a", + } + hdf5_data_file = { + "name": "qm9_dataset_n100.hdf5", + "md5": "77df0e1df7a5ec5629be52181e82a7d7", + } + processed_data_file = { + "name": "qm9_dataset_n100_processed.npz", + "md5": "9d671b54f7b9d454db9a3dd7f4ef2020", + } + + # self.md5_raw_checksum = "af3afda5c3265c9c096935ab060f537a" + # self.raw_data_file = "qm9_dataset_n100.hdf5.gz" # define the name and checksum of the unzipped file - self.unzipped_data_file = "qm9_dataset_n100.hdf5" - self.md5_unzipped_checksum = "77df0e1df7a5ec5629be52181e82a7d7" - self.processed_data_file = "qm9_dataset_n100_processed.npz" - self.md5_processed_checksum = "9d671b54f7b9d454db9a3dd7f4ef2020" - logger.info("Using test dataset") + # self.unzipped_data_file = "qm9_dataset_n100.hdf5" + # self.md5_unzipped_checksum = "77df0e1df7a5ec5629be52181e82a7d7" - else: - self.url = self.full_url - self.md5_raw_checksum = "d172127848de114bd9cc47da2bc72566" - self.raw_data_file = "qm9_dataset.hdf5.gz" + # self.processed_data_file = "qm9_dataset_n100_processed.npz" + # self.md5_processed_checksum = "9d671b54f7b9d454db9a3dd7f4ef2020" - self.unzipped_data_file = "qm9_dataset.hdf5" - self.md5_unzipped_checksum = "0b22dc048f3361875889f832527438db" + logger.info("Using test dataset") - self.processed_data_file = "qm9_dataset_processed.npz" - self.md5_processed_checksum = "62d98cf38bcf02966e1fa2d9e44b3fa0" + else: + url = self.full_url + gz_data_file = { + "name": "qm9_dataset.hdf5.gz", + "md5": "d172127848de114bd9cc47da2bc72566", + } + + hdf5_data_file = { + "name": "qm9_dataset.hdf5", + "md5": "0b22dc048f3361875889f832527438db", + } + + processed_data_file = { + "name": "qm9_dataset_processed.npz", + "md5": "62d98cf38bcf02966e1fa2d9e44b3fa0", + } + + # self.md5_raw_checksum = "d172127848de114bd9cc47da2bc72566" + # self.raw_data_file = "qm9_dataset.hdf5.gz" + # + # self.unzipped_data_file = "qm9_dataset.hdf5" + # self.md5_unzipped_checksum = "0b22dc048f3361875889f832527438db" + # + # self.processed_data_file = "qm9_dataset_processed.npz" + # self.md5_processed_checksum = "62d98cf38bcf02966e1fa2d9e44b3fa0" logger.info("Using full dataset") + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + ) + @property def atomic_self_energies(self): from modelforge.potential.utils import AtomicSelfEnergies @@ -215,7 +246,7 @@ def properties_of_interest(self, properties_of_interest: List[str]) -> None: def _download(self) -> None: """ - Download the hdf5 file containing the data from Google Drive. + Download the hdf5 file containing the data from Dropbox. Examples -------- @@ -223,13 +254,15 @@ def _download(self) -> None: >>> data.download() # Downloads the dataset from Google Drive """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class from modelforge.utils.remote import download_from_url download_from_url( url=self.url, - md5_checksum=self.md5_raw_checksum, + md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, - output_filename=self.raw_data_file, + output_filename=self.gz_data_file["name"], force_download=self.force_download, ) # from modelforge.dataset.utils import _download_from_url From 446afd0f123028128ecca55a10fe500b36442df4 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 4 Apr 2024 16:19:10 -0700 Subject: [PATCH 04/37] Adding tests for caching --- modelforge/dataset/dataset.py | 157 +++++++++++++++++-------------- modelforge/tests/test_dataset.py | 155 +++++++++++++++++++++++++++--- modelforge/tests/test_remote.py | 3 +- 3 files changed, 226 insertions(+), 89 deletions(-) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 71c6ffee..0ce604f9 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -280,6 +280,39 @@ def _ungzip_hdf5(self) -> None: ) as out_file: shutil.copyfileobj(gz_file, out_file) + def _file_validation(self, file_name: str, file_path: str, checksum: str) -> bool: + """ + Validates if the file exists, and if the calculated checksum matches the expected checksum. + + Parameters + ---------- + file_name : str + Name of the file to validate. + file_path : str + Path to the file. + checksum : str + Expected checksum of the file. + + Returns + ------- + bool + True if the file exists and the checksum matches, False otherwise. + """ + full_file_path = f"{file_path}/{file_name}" + if not os.path.exists(full_file_path): + log.debug(f"File {full_file_path} does not exist.") + return False + else: + from modelforge.utils.remote import calculate_md5_checksum + + calculated_checksum = calculate_md5_checksum(file_name, file_path) + if calculated_checksum != checksum: + log.warning( + f"Checksum mismatch for file {file_path}/{file_name}. Expected {calculated_checksum}, found {checksum}." + ) + return False + return True + def _from_hdf5(self) -> None: """ Processes and extracts data from an hdf5 file. @@ -300,16 +333,18 @@ def _from_hdf5(self) -> None: temp_hdf5_file = f"{self.local_cache_dir}/{self.hdf5_data_file['name']}" - log.debug("Reading in and processing hdf5 file ...") - from modelforge.utils.remote import calculate_md5_checksum - - # add in a check to make sure the checksum matches before loading the file - # this also appears in the dataset factory, but having this also here is good if used as a standalone function - checksum = calculate_md5_checksum( - self.hdf5_data_file["name"], self.local_cache_dir - ) + if self._file_validation( + self.hdf5_data_file["name"], + self.local_cache_dir, + self.hdf5_data_file["md5"], + ): + log.debug(f"Loading unzipped hdf5 file from {temp_hdf5_file}") + else: + from modelforge.utils.remote import calculate_md5_checksum - if checksum != self.hdf5_data_file["md5"]: + checksum = calculate_md5_checksum( + self.hdf5_data_file["name"], self.local_cache_dir + ) raise ValueError( f"Checksum mismatch for unzipped data file {temp_hdf5_file}. Found {checksum}, Expected {self.hdf5_data_file['md5']}" ) @@ -446,15 +481,23 @@ def _from_file_cache(self) -> None: >>> hdf5_data = HDF5Dataset("raw_data.hdf5", "processed_data.npz") >>> processed_data = hdf5_data._from_file_cache() """ - from modelforge.utils.remote import calculate_md5_checksum + if self._file_validation( + self.processed_data_file["name"], + self.local_cache_dir, + self.processed_data_file["md5"], + ): + log.debug(f"Loading processed data from {self.processed_data_file['name']}") - checksum = calculate_md5_checksum( - self.processed_data_file["name"], self.local_cache_dir - ) - if checksum != self.processed_data_file["md5"]: + else: + from modelforge.utils.remote import calculate_md5_checksum + + checksum = calculate_md5_checksum( + self.processed_data_file["name"], self.local_cache_dir + ) raise ValueError( f"Checksum mismatch for processed data file {self.processed_data_file}.Found {checksum}, expected {self.processed_data_file['md5']}" ) + log.debug(f"Loading processed data from {self.processed_data_file['name']}") self.numpy_data = np.load( f"{self.local_cache_dir}/{self.processed_data_file['name']}" @@ -517,69 +560,37 @@ def _load_or_process_data( The HDF5 dataset instance to use. """ - import os - from modelforge.utils.remote import calculate_md5_checksum - - # if the dataset was initialize with force_download, we will skip all other checking and just download and process - if data.force_download: + # check to see if we can load from the npz file. This also validates the checksum + if ( + data._file_validation( + data.processed_data_file["name"], + data.local_cache_dir, + data.processed_data_file["md5"], + ) + and not data.force_download + ): + data._from_file_cache() + # check to see if the hdf5 file exists and the checksum matches + elif ( + data._file_validation( + data.hdf5_data_file["name"], + data.local_cache_dir, + data.hdf5_data_file["md5"], + ) + and not data.force_download + ): + data._from_hdf5() + data._to_file_cache() + data._from_file_cache() + # if the npz or hdf5 files don't exist/match checksums, call download + # download will check if the gz file exists and matches the checksum + # or will use force_download. + else: data._download() data._ungzip_hdf5() data._from_hdf5() data._to_file_cache() - else: - file_loaded = False - if os.path.exists(data.processed_data_file): - checksum = calculate_md5_checksum( - data.processed_data_file, data.local_cache_dir - ) - if checksum != data.md5_processed_checksum: - - log.warning( - f"Checksum mismatch for processed data file {data.processed_data_file}. Re-processing." - ) - log.debug( - f"Checksum mismatch, found {checksum}, expected {data.md5_processed_checksum}" - ) - else: - data._from_file_cache() - file_loaded = True - - if os.path.exists(data.unzipped_data_file) and not file_loaded: - checksum = calculate_md5_checksum( - data.unzipped_data_file, data.local_cache_dir - ) - if checksum != data.md5_unzipped_checksum: - log.warning( - f"Checksum mismatch for unzipped data file {data.unzipped_data_file}. Re-processing." - ) - log.debug( - f"Checksum mismatch, found {checksum}, expected {data.md5_unzipped_checksum}" - ) - - # the download function automatically checks if the file is already downloaded and if the checksum matches - # if either are false, it will redownload the file - data._download() - data._ungzip_hdf5() - checksum = calculate_md5_checksum( - data.unzipped_data_file, data.local_cache_dir - ) - if checksum != data.md5_unzipped_checksum: - raise ValueError( - f"Checksum mismatch for unzipped data file {data.unzipped_data_file}. Expected {data.md5_unzipped_checksum}, found {checksum}. Please check the raw gzipped file or try running with force_download." - ) - - data._from_hdf5() - data._to_file_cache() - else: - data._from_hdf5() - data._to_file_cache() - else: - data._download() - data._ungzip_hdf5() - data._from_hdf5() - data._to_file_cache() - - data._from_file_cache() + data._from_file_cache() @staticmethod def create_dataset( diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index a4e3400f..602add8c 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -12,6 +12,12 @@ from ..utils import PropertyNames +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("dataset_test") + return fn + + @pytest.fixture( autouse=True, ) @@ -155,34 +161,151 @@ def test_different_properties_of_interest(dataset): @pytest.mark.parametrize("dataset", DATASETS) -def test_file_existence_after_initialization(dataset): +def test_file_existence_after_initialization(dataset, prep_temp_dir): """Test if files are created after dataset initialization.""" + + local_cache_dir = str(prep_temp_dir) + factory = DatasetFactory() - data = dataset(for_unit_testing=True) + data = dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) - assert not os.path.exists(data.raw_data_file) - assert not os.path.exists(data.processed_data_file) + assert not os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + assert not os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + assert not os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") dataset = factory.create_dataset(data) - assert os.path.exists(data.raw_data_file) - assert os.path.exists(data.processed_data_file) + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + + +def test_caching(prep_temp_dir): + local_cache_dir = str(prep_temp_dir) + local_cache_dir = local_cache_dir + "/data_test" + from modelforge.dataset.qm9 import QM9Dataset + + data = QM9Dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) + + # first test that no file exists + assert not os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + # the _file_validation method also checks the path in addition to the checksum + assert ( + data._file_validation( + data.gz_data_file["name"], local_cache_dir, data.gz_data_file["md5"] + ) + == False + ) + + data._download() + # check that the file exists + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + # check that the file is there and has the right checksum + assert ( + data._file_validation( + data.gz_data_file["name"], local_cache_dir, data.gz_data_file["md5"] + ) + == True + ) + + # give a random checksum to see this is false + assert ( + data._file_validation( + data.gz_data_file["name"], local_cache_dir, "madeupcheckusm" + ) + == False + ) + # make sure that if we run again we don't fail + data._download() + # remove the file and check that it is downloaded again + os.remove(f"{local_cache_dir}/{data.gz_data_file['name']}") + data._download() + + # check that the file is unzipped + data._ungzip_hdf5() + assert os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + assert ( + data._file_validation( + data.hdf5_data_file["name"], local_cache_dir, data.hdf5_data_file["md5"] + ) + == True + ) + data._from_hdf5() + + data._to_file_cache() + + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + assert ( + data._file_validation( + data.processed_data_file["name"], + local_cache_dir, + data.processed_data_file["md5"], + ) + == True + ) + + data._from_file_cache() @pytest.mark.parametrize("dataset", DATASETS) -def test_different_scenarios_of_file_availability(dataset): +def test_different_scenarios_of_file_availability(dataset, prep_temp_dir): """Test the behavior when raw and processed dataset files are removed.""" + + local_cache_dir = str(prep_temp_dir) + "/test_diff_secnarios" + factory = DatasetFactory() - data = dataset(for_unit_testing=True) + data = dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) + # this will download the .gz, the .hdf5 and the .npz files factory.create_dataset(data) - os.remove(data.raw_data_file) + # first check if we remote the npz file, rerunning it will regenerated it + os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") factory.create_dataset(data) - os.remove(data.processed_data_file) + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + + # now remove the npz and hdf5 files, rerunning will generate it + + os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") + os.remove(f"{local_cache_dir}/{data.hdf5_data_file['name']}") factory.create_dataset(data) - assert os.path.exists(data.raw_data_file) - assert os.path.exists(data.processed_data_file) + + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + + # now remove the gz file; rerunning should NOT download, it will use the npz + os.remove(f"{local_cache_dir}/{data.gz_data_file['name']}") + + factory.create_dataset(data) + assert not os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + + # now let us remove the hdf5 file, it should use the npz file + os.remove(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + factory.create_dataset(data) + assert not os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + + # now if we remove the npz, it will redownload the gz file and unzip it, then process it + os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") + factory.create_dataset(data) + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + + # now we will remove the gz file, and set force_download to True + # this should now regenerate the gz file, even though others are present + + data = dataset( + for_unit_testing=True, local_cache_dir=local_cache_dir, force_download=True + ) + factory.create_dataset(data) + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.hdf5_data_file['name']}") + assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + + # now we will remove the gz file and run it again + os.remove(f"{local_cache_dir}/{data.gz_data_file['name']}") + factory.create_dataset(data) + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") @pytest.mark.parametrize("dataset", DATASETS) @@ -284,13 +407,15 @@ def test_file_cache_methods(dataset): @pytest.mark.parametrize("dataset", DATASETS) -def test_dataset_downloader(dataset): +def test_dataset_downloader(dataset, prep_temp_dir): """ Test the DatasetDownloader functionality. """ - data = dataset(for_unit_testing=True) + local_cache_dir = str(prep_temp_dir) + + data = dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) data._download() - assert os.path.exists(data.raw_data_file) + assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") @pytest.mark.parametrize("dataset", DATASETS) diff --git a/modelforge/tests/test_remote.py b/modelforge/tests/test_remote.py index 8e314b56..e6aaf4bd 100644 --- a/modelforge/tests/test_remote.py +++ b/modelforge/tests/test_remote.py @@ -59,8 +59,9 @@ def test_download_from_url(prep_temp_dir): assert os.path.isfile(file_name_path) # let us change the expected checksum to cause a failure + # this will see this as not matching and will redownload, + # but since the new file doesn't match it will raise an exception with pytest.raises(Exception): - url = "https://choderalab.com/modelforge.py" download_from_url( url, md5_checksum="checksum_garbage", From a92511cba0e8ed430ef27c0ef7c60599f831c9eb Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 5 Apr 2024 01:02:04 -0700 Subject: [PATCH 05/37] preliminary ani2x dataloader. --- modelforge/curation/ani2x_curation.py | 2 +- modelforge/dataset/ani2x.py | 232 ++++++++++++++++++++++++++ modelforge/dataset/dataset.py | 27 ++- modelforge/dataset/qm9.py | 37 +--- modelforge/dataset/utils.py | 22 --- modelforge/utils/prop.py | 5 +- 6 files changed, 260 insertions(+), 65 deletions(-) create mode 100644 modelforge/dataset/ani2x.py diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index 6db4d65a..d12cff50 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -9,7 +9,7 @@ class ANI2xCuration(DatasetCuration): Routines to fetch and process the ANI-2x dataset into a curated hdf5 file. The ANI-2x data set includes properties for small organic molecules that contain - H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for 200,000 + H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for nearly 200,000 molecules. This will fetch data generated with the wB97X/631Gd level of theory used in the original ANI-2x paper, calculated using Gaussian 09 diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py new file mode 100644 index 00000000..587d539e --- /dev/null +++ b/modelforge/dataset/ani2x.py @@ -0,0 +1,232 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class ANI2xDataset(HDF5Dataset): + """ + Data class for handling ANI2x data. + + This class provides utilities for processing the ANI2x dataset stored in the modelforge HDF5 format. + + The ANI-2x data set includes properties for small organic molecules that contain + H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for nearly 200,000 molecules. + This will fetch data generated with the wB97X/631Gd level of theory used in the original ANI-2x paper, + calculated using Gaussian 09. See ani2x_curation.py for more details on the dataset curation. + + Citation: Devereux, C, Zubatyuk, R., Smith, J. et al. + "Extending the applicability of the ANI deep learning molecular potential to sulfur and halogens." + Journal of Chemical Theory and Computation 16.7 (2020): 4192-4202. + https://doi.org/10.1021/acs.jctc.0c00121 + + DOI for dataset: 10.5281/zenodo.10108941 + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "ANI2x". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", R="geometry", E="energies", F="forces" + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "energies", + "forces", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "ANI2x", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + overwrite: bool = False, + ) -> None: + """ + Initialize the QM9Data class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "ANI2x". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + + Examples + -------- + >>> data = ANI2xDataset() # Default dataset + >>> test_data = ANI2xDataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "energies", + "forces", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + # self._ase = { + # "H": -1313.4668615546, + # "C": -99366.70745535441, + # "N": -143309.9379722722, + # "O": -197082.0671774158, + # "F": -261811.54555874597, + # } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/okv311e9yvh94owbiypcm/ani2x_dataset_n100.hdf5.gz?rlkey=pz7gnlncabtzr3b82lblr3yas&dl=1" + self.full_url = "https://www.dropbox.com/scl/fi/egg04dmtho7l1ghqiwn1z/ani2x_dataset.hdf5.gz?rlkey=wq5qjyph5q2k0bn6vza735n19&dl=1" + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "ani2x_dataset_n100.hdf5.gz", + "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", + } + hdf5_data_file = { + "name": "ani2x_dataset_n100.hdf5", + "md5": "4f54caf79e4c946dc3d6d53722d2b966", + } + processed_data_file = { + "name": "ani2x_dataset_n100_processed.npz", + "md5": "c1481fe9a6b15fb07b961d15411c0ddd", + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "ani2x_dataset.hdf5.gz", + "md5": "8daf9a7d8bbf9bcb1e9cea13b4df9270", + } + + hdf5_data_file = { + "name": "ani2x_dataset.hdf5", + "md5": "86bb855cb8df54e082506088e949518e", + } + + processed_data_file = { + "name": "ani2x_dataset_processed.npz", + "md5": "268438d8e1660728ba892bc7c3cd4339", + } + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies() + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 0ce604f9..008b4ea6 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -59,12 +59,23 @@ def __init__( If set to True, properties are preloaded as PyTorch tensors. Default is False. """ - self.properties_of_interest = { - "atomic_numbers": torch.from_numpy(dataset[property_name.Z]), - "positions": torch.from_numpy(dataset[property_name.R]), - "E": torch.from_numpy(dataset[property_name.E]), - "Q": torch.from_numpy(dataset[property_name.Q]), - } + self.properties_of_interest = {} + + self.properties_of_interest["atomic_numbers"] = torch.from_numpy( + dataset[property_name.Z] + ) + self.properties_of_interest["positions"] = torch.from_numpy( + dataset[property_name.R] + ) + self.properties_of_interest["E"] = torch.from_numpy(dataset[property_name.E]) + if property_name.Q is not None: + self.properties_of_interest["Q"] = torch.from_numpy( + dataset[property_name.Q] + ) + if property_name.F is not None: + self.properties_of_interest["F"] = torch.from_numpy( + dataset[property_name.F] + ) self.number_of_records = len(dataset["atomic_subsystem_counts"]) self.number_of_atoms = len(dataset["atomic_numbers"]) @@ -360,7 +371,7 @@ def _from_hdf5(self) -> None: series_atom_data: Dict[str, List[np.ndarray]] = OrderedDict() # value shapes: (n_confs, n_atoms, *) - # intialize each relevant value in data dicts to empty list + # initialize each relevant value in data dicts to empty list for value in self.properties_of_interest: value_format = hf[next(iter(hf.keys()))][value].attrs["format"] if value_format == "single_rec": @@ -495,7 +506,7 @@ def _from_file_cache(self) -> None: self.processed_data_file["name"], self.local_cache_dir ) raise ValueError( - f"Checksum mismatch for processed data file {self.processed_data_file}.Found {checksum}, expected {self.processed_data_file['md5']}" + f"Checksum mismatch for processed data file {self.processed_data_file['name']}. Found {checksum}, expected {self.processed_data_file['md5']}" ) log.debug(f"Loading processed data from {self.processed_data_file['name']}") diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index 1e38e493..862fed27 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -27,7 +27,7 @@ class QM9Dataset(HDF5Dataset): from modelforge.utils import PropertyNames _property_names = PropertyNames( - "atomic_numbers", "geometry", "internal_energy_at_0K", "charges" + Z="atomic_numbers", R="geometry", E="internal_energy_at_0K", Q="charges" ) _available_properties = [ @@ -90,19 +90,9 @@ def __init__( if for_unit_testing: dataset_name = f"{dataset_name}_subset" - # super().__init__( - # f"{local_cache_dir}/{dataset_name}_cache.hdf5.gz", - # f"{local_cache_dir}/{dataset_name}_processed.npz", - # local_cache_dir=local_cache_dir, - # ) self.dataset_name = dataset_name self.for_unit_testing = for_unit_testing - # self.local_cache_dir = local_cache_dir - # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/9jeselknatcw9xi0qp940/qm9_dataset_n100.hdf5.gz?rlkey=50of7gn2s12i65c6j06r73c97&dl=1" - - self.full_url = "https://www.dropbox.com/scl/fi/4wu7zlpuuixttp0u741rv/qm9_dataset.hdf5.gz?rlkey=nszkqt2t4kmghih5mt4ssppvo&dl=1" self._ase = { "H": -1313.4668615546, "C": -99366.70745535441, @@ -115,8 +105,11 @@ def __init__( # We need to define the checksums for the various files that we will be dealing with to load up the data # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. - if self.for_unit_testing: + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/9jeselknatcw9xi0qp940/qm9_dataset_n100.hdf5.gz?rlkey=50of7gn2s12i65c6j06r73c97&dl=1" + self.full_url = "https://www.dropbox.com/scl/fi/4wu7zlpuuixttp0u741rv/qm9_dataset.hdf5.gz?rlkey=nszkqt2t4kmghih5mt4ssppvo&dl=1" + if self.for_unit_testing: url = self.test_url gz_data_file = { "name": "qm9_dataset_n100.hdf5.gz", @@ -131,17 +124,6 @@ def __init__( "md5": "9d671b54f7b9d454db9a3dd7f4ef2020", } - # self.md5_raw_checksum = "af3afda5c3265c9c096935ab060f537a" - # self.raw_data_file = "qm9_dataset_n100.hdf5.gz" - - # define the name and checksum of the unzipped file - - # self.unzipped_data_file = "qm9_dataset_n100.hdf5" - # self.md5_unzipped_checksum = "77df0e1df7a5ec5629be52181e82a7d7" - - # self.processed_data_file = "qm9_dataset_n100_processed.npz" - # self.md5_processed_checksum = "9d671b54f7b9d454db9a3dd7f4ef2020" - logger.info("Using test dataset") else: @@ -161,15 +143,6 @@ def __init__( "md5": "62d98cf38bcf02966e1fa2d9e44b3fa0", } - # self.md5_raw_checksum = "d172127848de114bd9cc47da2bc72566" - # self.raw_data_file = "qm9_dataset.hdf5.gz" - # - # self.unzipped_data_file = "qm9_dataset.hdf5" - # self.md5_unzipped_checksum = "0b22dc048f3361875889f832527438db" - # - # self.processed_data_file = "qm9_dataset_processed.npz" - # self.md5_processed_checksum = "62d98cf38bcf02966e1fa2d9e44b3fa0" - logger.info("Using full dataset") # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor diff --git a/modelforge/dataset/utils.py b/modelforge/dataset/utils.py index 4e4410c7..6242faec 100644 --- a/modelforge/dataset/utils.py +++ b/modelforge/dataset/utils.py @@ -465,28 +465,6 @@ def _download_from_gdrive(id: str, raw_dataset_file: str): gdown.download(url, raw_dataset_file, quiet=False) -def _download_from_url(url: str, raw_dataset_file: str): - """ - Downloads a dataset from a specified URLS. - - Parameters - ---------- - url : str - raw link address. - raw_dataset_file : str - Path to save the downloaded dataset. - - Examples - -------- - >>> _download_from_url(url, "data_file.hdf5.gz") - """ - import requests - - r = requests.get(url) - with open(raw_dataset_file, "wb") as f: - f.write(r.content) - - def _to_file_cache( data: OrderedDict[str, List[np.ndarray]], processed_dataset_file: str ) -> None: diff --git a/modelforge/utils/prop.py b/modelforge/utils/prop.py index e8da3925..cd14bcb4 100644 --- a/modelforge/utils/prop.py +++ b/modelforge/utils/prop.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import torch -from typing import NamedTuple +from typing import NamedTuple, Optional from loguru import logger @@ -9,7 +9,8 @@ class PropertyNames: Z: str R: str E: str - Q: str + F: Optional[str] = None + Q: Optional[str] = None class SpeciesEnergies(NamedTuple): From e75d0f78ef426488906ed5a27154a506da2e907c Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 5 Apr 2024 11:55:25 -0700 Subject: [PATCH 06/37] If Q or F are not defined, we will initialize with zeros of the correct shape --- modelforge/dataset/ani2x.py | 2 +- modelforge/dataset/dataset.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index 587d539e..8f75cd97 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -155,7 +155,7 @@ def __init__( def atomic_self_energies(self): from modelforge.potential.utils import AtomicSelfEnergies - return AtomicSelfEnergies() + return AtomicSelfEnergies(element_energies=self._ase) @property def properties_of_interest(self) -> List[str]: diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 008b4ea6..949f4e5e 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -68,14 +68,31 @@ def __init__( dataset[property_name.R] ) self.properties_of_interest["E"] = torch.from_numpy(dataset[property_name.E]) + if property_name.Q is not None: self.properties_of_interest["Q"] = torch.from_numpy( dataset[property_name.Q] ) + else: + # this is a per atom property, so it will match atomic_numbers + self.properties_of_interest["Q"] = torch.zeros( + dataset[property_name.Z].shape + ) + if property_name.F is not None: self.properties_of_interest["F"] = torch.from_numpy( dataset[property_name.F] ) + else: + # a per atom property in each direction, so it will match geometry + self.properties_of_interest["F"] = torch.zeros( + dataset[property_name.R].shape + ) + + print("Z ", dataset[property_name.Z].shape) + print("R ", dataset[property_name.R].shape) + print("E ", dataset[property_name.E].shape) + print("Q ", dataset[property_name.Q].shape) self.number_of_records = len(dataset["atomic_subsystem_counts"]) self.number_of_atoms = len(dataset["atomic_numbers"]) From 09bd8633f7f60f93c5025a4c8f703bb9f8b335b3 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 5 Apr 2024 14:11:47 -0700 Subject: [PATCH 07/37] Q didn't have the right shape when not define, now has the right shape. atomic self energies now are defined with energy units (but returned without units, in our base unit system when used in removing self-energy). ANI2x now loads --- modelforge/dataset/ani2x.py | 22 ++++++++++++++-------- modelforge/dataset/dataset.py | 14 +++++++------- modelforge/dataset/qm9.py | 12 +++++++----- modelforge/potential/utils.py | 19 +++++++++++++------ 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index 8f75cd97..7fd64475 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -89,13 +89,19 @@ def __init__( self.dataset_name = dataset_name self.for_unit_testing = for_unit_testing - # self._ase = { - # "H": -1313.4668615546, - # "C": -99366.70745535441, - # "N": -143309.9379722722, - # "O": -197082.0671774158, - # "F": -261811.54555874597, - # } + from openff.units import unit + + # these come from the ANI-2x paper generated via linear fittingh of the data + # https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + self._ase = { + "H": -0.5978583943827134 * unit.hartree, + "C": -38.08933878049795 * unit.hartree, + "N": -54.711968298621066 * unit.hartree, + "O": -75.19106774742086 * unit.hartree, + "S": -398.1577125334925 * unit.hartree, + "F": -99.80348506781634 * unit.hartree, + "Cl": -460.1681939421027 * unit.hartree, + } from loguru import logger # We need to define the checksums for the various files that we will be dealing with to load up the data @@ -155,7 +161,7 @@ def __init__( def atomic_self_energies(self): from modelforge.potential.utils import AtomicSelfEnergies - return AtomicSelfEnergies(element_energies=self._ase) + return AtomicSelfEnergies(energies=self._ase) @property def properties_of_interest(self) -> List[str]: diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index da3d2c88..c6e89651 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -75,9 +75,9 @@ def __init__( dataset[property_name.Q] ) else: - # this is a per atom property, so it will match atomic_numbers + # this is a per atom property, so it will match the first dimension of the geometry self.properties_of_interest["Q"] = torch.zeros( - dataset[property_name.Z].shape + (dataset[property_name.R].shape[0], 1) ) if property_name.F is not None: @@ -90,11 +90,11 @@ def __init__( dataset[property_name.R].shape ) - print("Z ", dataset[property_name.Z].shape) - print("R ", dataset[property_name.R].shape) - print("E ", dataset[property_name.E].shape) - print("Q ", dataset[property_name.Q].shape) - + print("Z", self.properties_of_interest["atomic_numbers"].shape) + print("R", self.properties_of_interest["positions"].shape) + print("E", self.properties_of_interest["E"].shape) + print("Q", self.properties_of_interest["Q"].shape) + print("F", self.properties_of_interest["F"].shape) self.number_of_records = len(dataset["atomic_subsystem_counts"]) self.number_of_atoms = len(dataset["atomic_numbers"]) single_atom_start_idxs_by_rec = np.concatenate( diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index 862fed27..d59050b7 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -92,13 +92,15 @@ def __init__( self.dataset_name = dataset_name self.for_unit_testing = for_unit_testing + from openff.units import unit + # atomic self energies self._ase = { - "H": -1313.4668615546, - "C": -99366.70745535441, - "N": -143309.9379722722, - "O": -197082.0671774158, - "F": -261811.54555874597, + "H": -1313.4668615546 * unit.kilojoule_per_mole, + "C": -99366.70745535441 * unit.kilojoule_per_mole, + "N": -143309.9379722722 * unit.kilojoule_per_mole, + "O": -197082.0671774158 * unit.kilojoule_per_mole, + "F": -261811.54555874597 * unit.kilojoule_per_mole, } from loguru import logger diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index de380ef3..7d2811fc 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -6,6 +6,8 @@ from typing import Any from dataclasses import dataclass, field from loguru import logger as log +from modelforge.utils.units import * + @dataclass class NeuralNetworkData: @@ -390,7 +392,7 @@ class AtomicSelfEnergies: # We provide a dictionary with {str:float} of element name to atomic self-energy, # which can then be accessed by atomic index or element name - energies: Dict[str, float] = field(default_factory=dict) + energies: Dict[str, unit.Quantity] = field(default_factory=dict) # Example mapping, replace or extend as necessary atomic_number_to_element: Dict[int, str] = field( default_factory=lambda: { @@ -455,12 +457,17 @@ def __getitem__(self, key): element = self.atomic_number_to_element.get(key) if element is None: raise KeyError(f"Atomic number {key} not found.") - return self.energies.get(element) + if self.energies.get(element) is None: + return None + return self.energies.get(element).to(unit.kilojoule_per_mole, "chem").m elif isinstance(key, str): # Directly access by element symbol if key not in self.energies: raise KeyError(f"Element {key} not found.") - return self.energies[key] + if self.energies[key] is None: + return None + + return self.energies[key].to(unit.kilojoule_per_mole, "chem").m else: raise TypeError( "Key must be an integer (atomic number) or string (element name)." @@ -470,7 +477,7 @@ def __iter__(self) -> Iterator[Dict[str, float]]: """Iterate over the energies dictionary.""" for element, energy in self.energies.items(): atomic_number = self.element_to_atomic_number(element) - yield (atomic_number, energy) + yield (atomic_number, energy.to(unit.kilojoule_per_mole, "chem").m) def __len__(self) -> int: """Return the number of element-energy pairs.""" @@ -906,7 +913,7 @@ def calculate_radial_basis_centers( number_of_radial_basis_functions + 1, dtype=dtype, )[:-1] - log.info(f'{centers=}') + log.info(f"{centers=}") return centers def calculate_radial_scale_factor( @@ -976,7 +983,7 @@ def pair_list( return pair_indices.to(device) -from openff.units import unit +# from openff.units import unit def neighbor_list_with_cutoff( From df5e56b0879571b6d910a628aa4382c06edbd014 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 5 Apr 2024 15:23:18 -0700 Subject: [PATCH 08/37] Fixed ASE test in test_utils.py --- modelforge/tests/test_utils.py | 7 ++++- scripts/training_ani2x.py | 48 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 scripts/training_ani2x.py diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index 9cf107d7..25488316 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -7,10 +7,15 @@ def test_ase_dataclass(): from modelforge.potential.utils import AtomicSelfEnergies + from openff.units import unit # Example usage atomic_self_energies = AtomicSelfEnergies( - energies={"H": 13.6, "He": 24.6, "Li": 5.4} + energies={ + "H": 13.6 * unit.kilojoule_per_mole, + "He": 24.6 * unit.kilojoule_per_mole, + "Li": 5.4 * unit.kilojoule_per_mole, + } ) # Access by element name diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py new file mode 100644 index 00000000..269c5f0c --- /dev/null +++ b/scripts/training_ani2x.py @@ -0,0 +1,48 @@ +# This is an example script that trains an implemented model on the QM9 dataset. +from lightning import Trainer +import torch + +# import the models implemented in modelforge, for now SchNet, PaiNN, ANI2x or PhysNet +from modelforge.potential import NeuralNetworkPotentialFactory +from modelforge.dataset.ani2x import ANI2xDataset +from modelforge.dataset.dataset import TorchDataModule +from modelforge.dataset.utils import RandomRecordSplittingStrategy +from pytorch_lightning.loggers import TensorBoardLogger + +# set up tensor board logger +logger = TensorBoardLogger("tb_logs", name="training") + +# Set up dataset +data = ANI2xDataset(force_download=False, for_unit_testing=False) + +dataset = TorchDataModule( + data, batch_size=512, splitting_strategy=RandomRecordSplittingStrategy() +) + +dataset.prepare_data(remove_self_energies=True, normalize=False) + +# Set up model +model = NeuralNetworkPotentialFactory.create_nnp("training", "ANI2x") +model = model.to(torch.float32) + +print(model) + +# set up traininer +from lightning.pytorch.callbacks.early_stopping import EarlyStopping + +trainer = Trainer( + max_epochs=10_000, + num_nodes=1, + devices=1, + accelerator="cpu", + logger=logger, # Add the logger here + callbacks=[ + EarlyStopping(monitor="val_loss", min_delta=0.05, patience=20, verbose=True) + ], +) + + +# Run training loop and validate +trainer.fit(model, dataset.train_dataloader(), dataset.val_dataloader()) + +# tensorboard --logdir tb_logs From d0d8d94c60a972458a05b2236a5809cdfaf84f97 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 11:03:54 -0700 Subject: [PATCH 09/37] Removed md5 checksum validation of npz files. Checksums cary based on python version used to generate them causing issues. --- modelforge/dataset/ani2x.py | 7 +++-- modelforge/dataset/dataset.py | 59 +++++++++++++++++++++++------------ modelforge/dataset/qm9.py | 13 ++++++-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index 7fd64475..1d912660 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -53,7 +53,7 @@ def __init__( for_unit_testing: bool = False, local_cache_dir: str = ".", force_download: bool = False, - overwrite: bool = False, + regenerate_cache: bool = False, ) -> None: """ Initialize the QM9Data class. @@ -68,7 +68,9 @@ def __init__( Path to the local cache directory, by default ".". force_download: bool, optional If set to True, we will download the dataset even if it already exists; by default False. - + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. Examples -------- >>> data = ANI2xDataset() # Default dataset @@ -155,6 +157,7 @@ def __init__( processed_data_file=processed_data_file, local_cache_dir=local_cache_dir, force_download=force_download, + regenerate_cache=regenerate_cache, ) @property diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index c6e89651..6629fafc 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -262,6 +262,7 @@ def __init__( processed_data_file: Dict[str, str], local_cache_dir: str, force_download: bool = False, + regenerate_cache: bool = False, ): """ Initializes the HDF5Dataset with paths to raw and processed data files. @@ -287,6 +288,7 @@ def __init__( self.processed_data_file = processed_data_file self.local_cache_dir = local_cache_dir self.force_download = force_download + self.regenerate_cache = regenerate_cache self.hdf5data: Optional[Dict[str, List[np.ndarray]]] = None self.numpy_data: Optional[np.ndarray] = None @@ -309,7 +311,9 @@ def _ungzip_hdf5(self) -> None: ) as out_file: shutil.copyfileobj(gz_file, out_file) - def _file_validation(self, file_name: str, file_path: str, checksum: str) -> bool: + def _file_validation( + self, file_name: str, file_path: str, checksum: str = None + ) -> bool: """ Validates if the file exists, and if the calculated checksum matches the expected checksum. @@ -320,7 +324,8 @@ def _file_validation(self, file_name: str, file_path: str, checksum: str) -> boo file_path : str Path to the file. checksum : str - Expected checksum of the file. + Expected checksum of the file. Default=None + If None, checksum will not be validated. Returns ------- @@ -331,7 +336,7 @@ def _file_validation(self, file_name: str, file_path: str, checksum: str) -> boo if not os.path.exists(full_file_path): log.debug(f"File {full_file_path} does not exist.") return False - else: + elif checksum is not None: from modelforge.utils.remote import calculate_md5_checksum calculated_checksum = calculate_md5_checksum(file_name, file_path) @@ -341,6 +346,8 @@ def _file_validation(self, file_name: str, file_path: str, checksum: str) -> boo ) return False return True + else: + return True def _from_hdf5(self) -> None: """ @@ -510,28 +517,39 @@ def _from_file_cache(self) -> None: >>> hdf5_data = HDF5Dataset("raw_data.hdf5", "processed_data.npz") >>> processed_data = hdf5_data._from_file_cache() """ + # if self._file_validation( + # self.processed_data_file["name"], + # self.local_cache_dir, + # self.processed_data_file["md5"], + # ): + # log.debug(f"Loading processed data from {self.processed_data_file['name']}") + # + # else: + # from modelforge.utils.remote import calculate_md5_checksum + # + # checksum = calculate_md5_checksum( + # self.processed_data_file["name"], self.local_cache_dir + # ) + # raise ValueError( + # f"Checksum mismatch for processed data file {self.processed_data_file['name']}. Found {checksum}, expected {self.processed_data_file['md5']}" + # ) + import os + + # skip validating the checksum, as the npz file checksum of otherwise identical data differs between python 3.11 and 3.9/10 if self._file_validation( - self.processed_data_file["name"], - self.local_cache_dir, - self.processed_data_file["md5"], + self.processed_data_file["name"], self.local_cache_dir, checksum=None ): - log.debug(f"Loading processed data from {self.processed_data_file['name']}") - - else: - from modelforge.utils.remote import calculate_md5_checksum - - checksum = calculate_md5_checksum( - self.processed_data_file["name"], self.local_cache_dir + log.debug( + f"Loading processed data from {self.local_cache_dir}/{self.processed_data_file['name']}" + ) + self.numpy_data = np.load( + f"{self.local_cache_dir}/{self.processed_data_file['name']}" ) + else: raise ValueError( - f"Checksum mismatch for processed data file {self.processed_data_file['name']}. Found {checksum}, expected {self.processed_data_file['md5']}" + f"Processed data file {self.local_cache_dir}/{self.processed_data_file['name']} not found." ) - log.debug(f"Loading processed data from {self.processed_data_file['name']}") - self.numpy_data = np.load( - f"{self.local_cache_dir}/{self.processed_data_file['name']}" - ) - def _to_file_cache( self, ) -> None: @@ -594,9 +612,10 @@ def _load_or_process_data( data._file_validation( data.processed_data_file["name"], data.local_cache_dir, - data.processed_data_file["md5"], + None, ) and not data.force_download + and not data.regenerate_cache ): data._from_file_cache() # check to see if the hdf5 file exists and the checksum matches diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index d59050b7..31e02646 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -27,7 +27,9 @@ class QM9Dataset(HDF5Dataset): from modelforge.utils import PropertyNames _property_names = PropertyNames( - Z="atomic_numbers", R="geometry", E="internal_energy_at_0K", Q="charges" + Z="atomic_numbers", + R="geometry", + E="internal_energy_at_0K", # Q="charges" ) _available_properties = [ @@ -57,7 +59,7 @@ def __init__( for_unit_testing: bool = False, local_cache_dir: str = ".", force_download: bool = False, - overwrite: bool = False, + regenerate_cache=False, ) -> None: """ Initialize the QM9Data class. @@ -72,7 +74,9 @@ def __init__( Path to the local cache directory, by default ".". force_download: bool, optional If set to True, we will download the dataset even if it already exists; by default False. - + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + previously downloaded files, if available; by default False. Examples -------- >>> data = QM9Dataset() # Default dataset @@ -123,6 +127,8 @@ def __init__( } processed_data_file = { "name": "qm9_dataset_n100_processed.npz", + # checksum of otherwise identical npz files are different if using 3.11 vs 3.9/10 + # we will therefore skip checking these files "md5": "9d671b54f7b9d454db9a3dd7f4ef2020", } @@ -155,6 +161,7 @@ def __init__( processed_data_file=processed_data_file, local_cache_dir=local_cache_dir, force_download=force_download, + regenerate_cache=regenerate_cache, ) @property From 7800a6b6e385828029ac9123613a43efbbd8dfd1 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 11:11:52 -0700 Subject: [PATCH 10/37] updating tests for caching. --- modelforge/tests/test_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index b378aaad..ab5183dd 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -227,12 +227,14 @@ def test_caching(prep_temp_dir): data._to_file_cache() + # npz files saved with different versions of python lead to different checksums + # we will skip checking the checksums for these files, only seeing if they exist assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") assert ( data._file_validation( data.processed_data_file["name"], local_cache_dir, - data.processed_data_file["md5"], + None, ) == True ) @@ -375,8 +377,6 @@ def test_dataset_splitting(splitting_strategy, datasets_to_test): assert len(test_dataset) == 10 - - @pytest.mark.parametrize("dataset", DATASETS) def test_dataset_downloader(dataset, prep_temp_dir): """ @@ -389,7 +389,6 @@ def test_dataset_downloader(dataset, prep_temp_dir): assert os.path.exists(f"{local_cache_dir}/{data.gz_data_file['name']}") - def test_numpy_dataset_assignment(datasets_to_test): """ Test if the numpy_dataset attribute is correctly assigned after processing or loading. From bf58d7d6ed887a7d37be2bcaef4bf06335ddf45a Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 13:01:25 -0700 Subject: [PATCH 11/37] Adding in forces to spice datasets. --- modelforge/curation/spice_114_curation.py | 6 + modelforge/curation/spice_2_curation.py | 25 +- modelforge/curation/spice_openff_curation.py | 6 + modelforge/dataset/spice2.py | 241 +++++++++++++++++++ scripts/dataset_curation.py | 30 ++- scripts/training_ani2x.py | 2 +- 6 files changed, 295 insertions(+), 15 deletions(-) create mode 100644 modelforge/dataset/spice2.py diff --git a/modelforge/curation/spice_114_curation.py b/modelforge/curation/spice_114_curation.py index 5bec48f9..890df121 100644 --- a/modelforge/curation/spice_114_curation.py +++ b/modelforge/curation/spice_114_curation.py @@ -67,6 +67,10 @@ def _init_dataset_parameters(self): "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, }, + "dft_total_force": { + "u_in": unit.hartree / unit.bohr, + "u_out": unit.kilojoule_per_mole / unit.angstrom, + }, "mbis_charges": { "u_in": unit.elementary_charge, "u_out": unit.elementary_charge, @@ -136,6 +140,7 @@ def _init_record_entries_series(self): "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", + "dft_total_force": "series_atom", "formation_energy": "series_mol", "mayer_indices": "series_atom", "mbis_charges": "series_atom", @@ -249,6 +254,7 @@ def _process_downloaded( ds_temp["total_charge"] = self._calculate_reference_charge( ds_temp["smiles"] ) + ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] self.data.append(ds_temp) if self.convert_units: self._convert_units() diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index 7fb2e07f..0175ff23 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -9,8 +9,6 @@ class SPICE2Curation(DatasetCuration): """ Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. - March 2004: Note this is still the preliminary release; a subset of calculations are still being performed. - The SPICE dataset contains onformations for a diverse set of small molecules, dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and uncharged molecules, and a wide range of covalent and non-covalent interactions. @@ -44,6 +42,9 @@ class SPICE2Curation(DatasetCuration): - 'SPICE Amino Acid Ligand v1.0 + SPICE 2 zenodo release:ls + https://zenodo.org/records/10835749 + Reference to original SPICE publication: Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. @@ -112,7 +113,7 @@ def _init_dataset_parameters(self): "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, }, - "dispersion_corrected_dft_total_gradient": { + "dft_total_force": { "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, }, @@ -181,6 +182,7 @@ def _init_record_entries_series(self): "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", + "dft_total_force": "series_atom", "formation_energy": "series_mol", "mbis_charges": "series_atom", "scf_dipole": "series_atom", @@ -638,6 +640,7 @@ def _process_downloaded( self.data[index][quantity_o] = np.array( val["properties"][quantity] ).reshape(1, -1, 3) + else: self.data[index][quantity_o] = np.vstack( ( @@ -645,6 +648,22 @@ def _process_downloaded( np.array(val["properties"][quantity]).reshape(1, -1, 3), ) ) + # we will store force along with gradient + quantity = "dft total gradient" + quantity_o = "dft_total_force" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = -np.array( + val["properties"][quantity] + ).reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + -np.array(val["properties"][quantity]).reshape( + 1, -1, 3 + ), + ) + ) quantity = "mbis charges" quantity_o = "mbis_charges" diff --git a/modelforge/curation/spice_openff_curation.py b/modelforge/curation/spice_openff_curation.py index 086f72a8..7dcdec4f 100644 --- a/modelforge/curation/spice_openff_curation.py +++ b/modelforge/curation/spice_openff_curation.py @@ -100,6 +100,10 @@ def _init_dataset_parameters(self): "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, }, + "dft_total_force": { + "u_in": unit.hartree / unit.bohr, + "u_out": unit.kilojoule_per_mole / unit.angstrom, + }, "dispersion_corrected_dft_total_gradient": { "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, @@ -169,6 +173,7 @@ def _init_record_entries_series(self): "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", + "dft_total_force": "series_atom", "formation_energy": "series_mol", "mbis_charges": "series_atom", "scf_dipole": "series_atom", @@ -667,6 +672,7 @@ def _process_downloaded( datapoint["dft_total_gradient"] + datapoint["dispersion_correction_gradient"] ) + datapoint["dft_total_force"] = -datapoint["dft_total_gradient"] # we only want to write the dispersion corrected gradient to the file to avoid confusion datapoint.pop("dispersion_correction_gradient") diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py new file mode 100644 index 00000000..88162b1f --- /dev/null +++ b/modelforge/dataset/spice2.py @@ -0,0 +1,241 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class Spice2Dataset(HDF5Dataset): + """ + Data class for handling ANI2x data. + + This class provides utilities for processing the ANI2x dataset stored in the modelforge HDF5 format. + + The ANI-2x data set includes properties for small organic molecules that contain + H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for nearly 200,000 molecules. + This will fetch data generated with the wB97X/631Gd level of theory used in the original ANI-2x paper, + calculated using Gaussian 09. See ani2x_curation.py for more details on the dataset curation. + + Citation: Devereux, C, Zubatyuk, R., Smith, J. et al. + "Extending the applicability of the ANI deep learning molecular potential to sulfur and halogens." + Journal of Chemical Theory and Computation 16.7 (2020): 4192-4202. + https://doi.org/10.1021/acs.jctc.0c00121 + + DOI for dataset: 10.5281/zenodo.10108941 + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "ANI2x". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", R="geometry", E="energies", F="forces" + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "energies", + "forces", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "ANI2x", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + regenerate_cache: bool = False, + ) -> None: + """ + Initialize the QM9Data class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "ANI2x". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. + Examples + -------- + >>> data = ANI2xDataset() # Default dataset + >>> test_data = ANI2xDataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "energies", + "forces", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + from openff.units import unit + + # these come from the ANI-2x paper generated via linear fittingh of the data + # https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + self._ase = { + "H": -0.5978583943827134 * unit.hartree, + "C": -38.08933878049795 * unit.hartree, + "N": -54.711968298621066 * unit.hartree, + "O": -75.19106774742086 * unit.hartree, + "S": -398.1577125334925 * unit.hartree, + "F": -99.80348506781634 * unit.hartree, + "Cl": -460.1681939421027 * unit.hartree, + } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/okv311e9yvh94owbiypcm/ani2x_dataset_n100.hdf5.gz?rlkey=pz7gnlncabtzr3b82lblr3yas&dl=1" + self.full_url = "https://www.dropbox.com/scl/fi/egg04dmtho7l1ghqiwn1z/ani2x_dataset.hdf5.gz?rlkey=wq5qjyph5q2k0bn6vza735n19&dl=1" + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "ani2x_dataset_n100.hdf5.gz", + "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", + } + hdf5_data_file = { + "name": "ani2x_dataset_n100.hdf5", + "md5": "4f54caf79e4c946dc3d6d53722d2b966", + } + processed_data_file = { + "name": "ani2x_dataset_n100_processed.npz", + "md5": "c1481fe9a6b15fb07b961d15411c0ddd", + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "ani2x_dataset.hdf5.gz", + "md5": "8daf9a7d8bbf9bcb1e9cea13b4df9270", + } + + hdf5_data_file = { + "name": "ani2x_dataset.hdf5", + "md5": "86bb855cb8df54e082506088e949518e", + } + + processed_data_file = { + "name": "ani2x_dataset_processed.npz", + "md5": "268438d8e1660728ba892bc7c3cd4339", + } + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + regenerate_cache=regenerate_cache, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies(energies=self._ase) + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) diff --git a/scripts/dataset_curation.py b/scripts/dataset_curation.py index d3ad9d1c..bfe23fab 100644 --- a/scripts/dataset_curation.py +++ b/scripts/dataset_curation.py @@ -3,6 +3,7 @@ def SPICE_2( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + unit_testing_max_records=None, ): """ This Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. @@ -63,7 +64,14 @@ def SPICE_2( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - spice_2_data.process(force_download=force_download, n_threads=4) + if unit_testing_max_records is None: + spice_2_data.process(force_download=force_download, n_threads=4) + else: + spice_2_data.process( + force_download=force_download, + n_threads=4, + unit_testing_max_records=unit_testing_max_records, + ) def SPICE_114_OpenFF( @@ -330,15 +338,15 @@ def ANI2x( output_file_dir = f"{local_prefix}/hdf5_files" # # QM9 dataset -local_cache_dir = f"{local_prefix}/qm9_dataset" -hdf5_file_name = "qm9_dataset.hdf5" -QM9( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - unit_testing_max_records=100, -) +# local_cache_dir = f"{local_prefix}/qm9_dataset" +# hdf5_file_name = "qm9_dataset.hdf5" +# QM9( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# unit_testing_max_records=100, +# ) # we will save all the files to a central location @@ -349,7 +357,7 @@ def ANI2x( # hdf5_file_name = "spice_2_dataset.hdf5" # # SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) -# + # # SPICE 1.1.4 OpenFF dataset # local_cache_dir = f"{local_prefix}/spice_openff_dataset" # hdf5_file_name = "spice_114_openff_dataset.hdf5" diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py index 269c5f0c..c04bab0e 100644 --- a/scripts/training_ani2x.py +++ b/scripts/training_ani2x.py @@ -13,7 +13,7 @@ logger = TensorBoardLogger("tb_logs", name="training") # Set up dataset -data = ANI2xDataset(force_download=False, for_unit_testing=False) +data = ANI2xDataset(force_download=False, for_unit_testing=True) dataset = TorchDataModule( data, batch_size=512, splitting_strategy=RandomRecordSplittingStrategy() From bb531635ea4395a028f291dfe24da703880fcacb Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 13:04:58 -0700 Subject: [PATCH 12/37] minor changes to data fetching script --- modelforge/tests/model_datasets.py | 38 ++++++++++++++++++++++++++ scripts/dataset_curation.py | 43 +++++++++++++++++++++--------- 2 files changed, 69 insertions(+), 12 deletions(-) create mode 100644 modelforge/tests/model_datasets.py diff --git a/modelforge/tests/model_datasets.py b/modelforge/tests/model_datasets.py new file mode 100644 index 00000000..ce6d0665 --- /dev/null +++ b/modelforge/tests/model_datasets.py @@ -0,0 +1,38 @@ +from modelforge.curation.model_dataset import ModelDataset + +dataset = ModelDataset( + hdf5_file_name="PURE_MM.hdf5", + output_file_dir="/Users/cri/Dropbox/data_experiment/", + local_cache_dir="/Users/cri/Dropbox/data_experiment/", + convert_units=True, +) +dataset.process( + input_data_path="/Users/cri/Dropbox/data_experiment/", + input_data_file="molecule_data.hdf5", + data_combination="PURE_MM", +) + +dataset = ModelDataset( + hdf5_file_name="PURE_ML.hdf5", + output_file_dir="/Users/cri/Dropbox/data_experiment/", + local_cache_dir="/Users/cri/Dropbox/data_experiment/", + convert_units=True, +) +dataset.process( + input_data_path="/Users/cri/Dropbox/data_experiment/", + input_data_file="molecule_data.hdf5", + data_combination="PURE_ML", +) + + +dataset = ModelDataset( + hdf5_file_name="MM_low_e_correction.hdf5", + output_file_dir="/Users/cri/Dropbox/data_experiment/", + local_cache_dir="/Users/cri/Dropbox/data_experiment/", + convert_units=True, +) +dataset.process( + input_data_path="/Users/cri/Dropbox/data_experiment/", + input_data_file="molecule_data.hdf5", + data_combination="PURE_MM_low_temp_correction", +) diff --git a/scripts/dataset_curation.py b/scripts/dataset_curation.py index bfe23fab..596dc146 100644 --- a/scripts/dataset_curation.py +++ b/scripts/dataset_curation.py @@ -8,7 +8,6 @@ def SPICE_2( """ This Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. - March 2004: Note this is still the preliminary release; a subset of calculations are still being performed. It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, using Psi4. @@ -53,6 +52,8 @@ def SPICE_2( note, this will check to ensure that all records on qcarchive exist in the local database, and will be downloaded if missing. If True, the entire dataset will be redownloaded. + unit_testing_max_records: int, optional, default=None + If set, only the first n records will be processed; this is useful for unit testing. Returns ------- @@ -69,8 +70,8 @@ def SPICE_2( else: spice_2_data.process( force_download=force_download, - n_threads=4, unit_testing_max_records=unit_testing_max_records, + n_threads=4, ) @@ -303,6 +304,7 @@ def ANI2x( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + unit_testing_max_records=None, ): """ This fetches and processes the ANI2x dataset into a curated hdf5 file. @@ -326,7 +328,13 @@ def ANI2x( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - ani2x.process(force_download=force_download) + if unit_testing_max_records is None: + ani2x.process(force_download=force_download) + else: + ani2x.process( + force_download=force_download, + unit_testing_max_records=unit_testing_max_records, + ) """ @@ -335,11 +343,26 @@ def ANI2x( # define the local path prefix local_prefix = "/Users/cri/Documents/Projects-msk/datasets" + +# we will save all the files to a central location output_file_dir = f"{local_prefix}/hdf5_files" +# ANI2x test dataset +# local_cache_dir = f"{local_prefix}/ani2x_dataset" +# hdf5_file_name = "ani2x_dataset.hdf5" +# +# ANI2x( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# # unit_testing_max_records=100, +# ) + # # QM9 dataset # local_cache_dir = f"{local_prefix}/qm9_dataset" -# hdf5_file_name = "qm9_dataset.hdf5" +# hdf5_file_name = "qm9_dataset_n100.hdf5" +# # QM9( # hdf5_file_name, # output_file_dir, @@ -348,15 +371,11 @@ def ANI2x( # unit_testing_max_records=100, # ) - -# we will save all the files to a central location -# output_file_dir = f"{local_prefix}/hdf5_files" -# # # SPICE 2 dataset -# local_cache_dir = f"{local_prefix}/spice2_dataset" -# hdf5_file_name = "spice_2_dataset.hdf5" -# -# SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +local_cache_dir = f"{local_prefix}/spice2_dataset" +hdf5_file_name = "spice_2_dataset.hdf5" + +SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) # # SPICE 1.1.4 OpenFF dataset # local_cache_dir = f"{local_prefix}/spice_openff_dataset" From ec5cd736fb6c08f9ca6216b81a511908542fed0c Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 13:31:11 -0700 Subject: [PATCH 13/37] Added curation script for model system. --- modelforge/curation/model_dataset.py | 394 +++++++++++++++++++++++++++ 1 file changed, 394 insertions(+) create mode 100644 modelforge/curation/model_dataset.py diff --git a/modelforge/curation/model_dataset.py b/modelforge/curation/model_dataset.py new file mode 100644 index 00000000..d3ceffe2 --- /dev/null +++ b/modelforge/curation/model_dataset.py @@ -0,0 +1,394 @@ +from modelforge.curation.curation_baseclass import DatasetCuration, dict_to_hdf5 +from modelforge.utils.units import * + +import numpy as np + +from typing import Optional, List +from loguru import logger + + +class ModelDataset(DatasetCuration): + """ + Routines to fetch and process the model dataset used for examining different approaches to generating + training data. + + + """ + + def __init__( + self, + hdf5_file_name: str, + output_file_dir: str, + local_cache_dir: str, + convert_units: bool = True, + seed=12345, + ): + super().__init__( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + local_cache_dir=local_cache_dir, + convert_units=convert_units, + ) + self.seed = seed + + def _init_dataset_parameters(self): + self.qm_parameters = { + "geometry": {"u_in": unit.nanometer, "u_out": unit.nanometer}, + "energy": { + "u_in": unit.kilojoule_per_mole, + "u_out": unit.kilojoule_per_mole, + }, + } + + def _init_record_entries_series(self): + self._record_entries_series = { + "name": "single_rec", + "n_configs": "single_rec", + "atomic_numbers": "single_atom", + "geometry": "series_atom", + "energy": "series_mol", + } + + def _process_downloaded( + self, + local_path_dir: str, + filename: str, + model: str, + ): + file_path = f"{local_path_dir}/{filename}" + + import h5py + + data_temp = [] + with h5py.File(file_path, "r") as f: + molecule_names = list(f.keys()) + for molecule_name in molecule_names: + record_temp = {} + molecule = f[molecule_name] + for key in molecule.keys(): + temp = molecule[key][()] + if "u" in molecule[key].attrs: + temp = temp * unit(molecule[key].attrs["u"]) + record_temp[key] = temp + record_temp["name"] = molecule_name + data_temp.append(record_temp) + + self.data = [] + self.test_data_molecules = [] + self.test_data_conformers = [] + + # figure out how which molecules we have in our holdout set + # we will keep 10 % of the data for testing + n_molecules = len(data_temp) + from numpy.random import RandomState + + prng = RandomState(self.seed) + hold_out = prng.randint(n_molecules, size=(int(n_molecules * 0.1))) + + if model == "PURE_MM": + for i, record in enumerate(data_temp): + temp = {} + temp["name"] = record["name"] + temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) + + temp_conf_holdout = {} + temp_conf_holdout["name"] = record["name"] + temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( + -1, 1 + ) + + if i in hold_out: + temp["energy"] = ( + np.vstack( + ( + np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), + record["MM_300_ML_energy"].m.reshape(-1, 1), + record["MM_100_ML_energy"].m.reshape(-1, 1), + ) + ) + * record["MM_emin_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["MM_emin_coords"].m.reshape(1, -1, 3), + record["MM_coords_300"].m, + record["MM_coords_100"].m, + ) + ) + * record["MM_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].m.shape[0] + self.test_data_molecules.append(temp) + else: + temp["energy"] = ( + np.vstack( + ( + np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), + record["MM_300_ML_energy"][0:9].m.reshape(-1, 1), + record["MM_100_ML_energy"][0:9].m.reshape(-1, 1), + ) + ) + * record["MM_emin_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["MM_emin_coords"].m.reshape(1, -1, 3), + record["MM_coords_300"][0:9].m, + record["MM_coords_100"][0:9].m, + ) + ) + * record["MM_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].m.shape[0] + + self.data.append(temp) + + temp_conf_holdout["energy"] = ( + np.vstack( + ( + record["MM_300_ML_energy"][9:10].m.reshape(-1, 1), + record["MM_100_ML_energy"][9:10].m.reshape(-1, 1), + ) + ) + * record["MM_300_ML_energy"].u + ) + temp_conf_holdout["geometry"] = ( + np.vstack( + ( + record["MM_coords_300"][9:10].m, + record["MM_coords_100"][9:10].m, + ) + ) + * record["MM_emin_coords"].u + ) + temp_conf_holdout["n_configs"] = temp_conf_holdout[ + "geometry" + ].m.shape[0] + self.test_data_conformers.append(temp_conf_holdout) + + if model == "PURE_MM_low_temp_correction": + for i, record in enumerate(data_temp): + temp = {} + temp["name"] = record["name"] + temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) + + temp_conf_holdout = {} + temp_conf_holdout["name"] = record["name"] + temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( + -1, 1 + ) + + if i in hold_out: + temp["energy"] = ( + np.vstack( + ( + np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), + record["MM_300_ML_energy"].m.reshape(-1, 1), + record["MM_100_ML_energy"].m.reshape(-1, 1), + record["MM100_ML_emin_ML_energy"].m.reshape(-1, 1), + ) + ) + * record["MM_300_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["MM_emin_coords"].m.reshape(1, -1, 3), + record["MM_coords_300"].m, + record["MM_coords_100"].m, + record["MM100_ML_emin_coords"].m, + ) + ) + * record["MM_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].m.shape[0] + self.test_data_molecules.append(temp) + else: + temp["energy"] = ( + np.vstack( + ( + np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), + record["MM_300_ML_energy"][0:9].m.reshape(-1, 1), + record["MM_100_ML_energy"][0:9].m.reshape(-1, 1), + record["MM100_ML_emin_ML_energy"][0:9].m.reshape(-1, 1), + ) + ) + * record["MM_emin_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["MM_emin_coords"].m.reshape(1, -1, 3), + record["MM_coords_300"][0:9].m, + record["MM_coords_100"][0:9].m, + record["MM100_ML_emin_coords"][0:9].m, + ) + ) + * record["MM_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].m.shape[0] + self.data.append(temp) + + temp_conf_holdout["energy"] = ( + np.vstack( + ( + record["MM_300_ML_energy"][9:10].m.reshape(-1, 1), + record["MM_100_ML_energy"][9:10].m.reshape(-1, 1), + record["MM100_ML_emin_ML_energy"][9:10].m.reshape( + -1, 1 + ), + ) + ) + * record["MM_300_ML_energy"].u + ) + temp_conf_holdout["geometry"] = ( + np.vstack( + ( + record["MM_coords_300"][9:10].m, + record["MM_coords_100"][9:10].m, + record["MM100_ML_emin_coords"][9:10].m, + ) + ) + * record["MM_coords_300"].u + ) + temp_conf_holdout["n_configs"] = temp_conf_holdout[ + "geometry" + ].shape[0] + self.test_data_conformers.append(temp_conf_holdout) + + if model == "PURE_ML": + for i, record in enumerate(data_temp): + temp = {} + temp["name"] = record["name"] + temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) + + temp_conf_holdout = {} + temp_conf_holdout["name"] = record["name"] + temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( + -1, 1 + ) + + if i in hold_out: + temp["energy"] = ( + np.vstack( + ( + np.array(record["ML_emin_ML_energy"].m).reshape(-1, 1), + record["ML_300_ML_energy"].m.reshape(-1, 1), + record["ML_100_ML_energy"].m.reshape(-1, 1), + ) + ) + * record["ML_emin_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["ML_emin_coords"].m.reshape(1, -1, 3), + record["ML_coords_300"].m, + record["ML_coords_100"].m, + ) + ) + * record["ML_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].shape[0] + self.test_data_molecules.append(temp) + else: + temp["energy"] = ( + np.vstack( + ( + np.array(record["ML_emin_ML_energy"].m).reshape(-1, 1), + record["ML_300_ML_energy"][0:9].m.reshape(-1, 1), + record["ML_100_ML_energy"][0:9].m.reshape(-1, 1), + ) + ) + * record["ML_emin_ML_energy"].u + ) + temp["geometry"] = ( + np.vstack( + ( + record["ML_emin_coords"].m.reshape(1, -1, 3), + record["ML_coords_300"][0:9].m, + record["ML_coords_100"][0:9].m, + ) + ) + * record["ML_emin_coords"].u + ) + temp["n_configs"] = temp["geometry"].m.shape[0] + self.data.append(temp) + + temp_conf_holdout["energy"] = ( + np.vstack( + ( + record["ML_300_ML_energy"][9:10].m.reshape(-1, 1), + record["ML_100_ML_energy"][9:10].m.reshape(-1, 1), + ) + ) + * record["ML_300_ML_energy"].u + ) + temp_conf_holdout["geometry"] = ( + np.vstack( + ( + record["ML_coords_300"][9:10].m, + record["ML_coords_100"][9:10].m, + ) + ) + * record["ML_coords_300"].u + ) + temp_conf_holdout["n_configs"] = temp_conf_holdout[ + "geometry" + ].shape[0] + self.test_data_conformers.append(temp_conf_holdout) + + def _generate_hdf5_file(self, data, output_file_path, filename): + full_file_path = f"{output_file_path}/{filename}" + logger.debug("Writing data HDF5 file.") + import os + + os.makedirs(output_file_path, exist_ok=True) + + dict_to_hdf5( + full_file_path, + data, + series_info=self._record_entries_series, + id_key="name", + ) + + def process( + self, + input_data_path="./", + input_data_file="molecule_data.hdf5", + data_combination="pure_MM", + ) -> None: + """ + Process the dataset into a curated hdf5 file. + + Parameters + ---------- + force_download : Optional[bool], optional + Force download of the dataset, by default False + unit_testing_max_records : Optional[int], optional + Maximum number of records to process, by default None + + """ + self.data_combination = data_combination + self._clear_data() + self._process_downloaded( + input_data_path, input_data_file, self.data_combination + ) + if self.convert_units: + self._convert_units() + + # for datapoint in self.data: + # print(datapoint["name"]) + + self._generate_hdf5_file(self.data, self.output_file_dir, self.hdf5_file_name) + + fileout = self.hdf5_file_name.replace(".hdf5", "_test_conformers.hdf5") + self._generate_hdf5_file( + self.test_data_conformers, self.output_file_dir, fileout + ) + fileout = self.hdf5_file_name.replace(".hdf5", "_test_molecules.hdf5") + self._generate_hdf5_file( + self.test_data_molecules, self.output_file_dir, fileout + ) From ce54cc889b62decd22181138977745b455155248 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 8 Apr 2024 16:48:50 -0700 Subject: [PATCH 14/37] Adding in forces to spice datasets. --- modelforge/curation/model_dataset.py | 28 +++++++-- modelforge/curation/spice_2_curation.py | 2 +- modelforge/dataset/ani2x.py | 2 +- .../dataset/{spice2.py => model_dataset.py} | 60 +++++++------------ modelforge/potential/ani.py | 4 ++ scripts/training_ani2x.py | 4 +- 6 files changed, 51 insertions(+), 49 deletions(-) rename modelforge/dataset/{spice2.py => model_dataset.py} (76%) diff --git a/modelforge/curation/model_dataset.py b/modelforge/curation/model_dataset.py index d3ceffe2..fad3d5a0 100644 --- a/modelforge/curation/model_dataset.py +++ b/modelforge/curation/model_dataset.py @@ -356,9 +356,10 @@ def _generate_hdf5_file(self, data, output_file_path, filename): def process( self, - input_data_path="./", - input_data_file="molecule_data.hdf5", - data_combination="pure_MM", + # input_data_path="./", + # input_data_file="molecule_data.hdf5", + force_download=False, + data_combination="PURE_MM", ) -> None: """ Process the dataset into a curated hdf5 file. @@ -367,14 +368,29 @@ def process( ---------- force_download : Optional[bool], optional Force download of the dataset, by default False - unit_testing_max_records : Optional[int], optional - Maximum number of records to process, by default None + data_combination : str, optional + The type of data combination to use, by default "pure_MM" + Options, PURE_MM_low_temp_correction, PURE_MM, PURE_ML + """ + from modelforge.utils.remote import download_from_url + + # download the data + url = "https://www.dropbox.com/scl/fi/c23o54ckovnz6umd3why2/molecule_data.hdf5?rlkey=384kd8zo9w1iv34lzp3c2y3n3&dl=1" + checksum = "77a76f7005249aebe61b57a560a818f4" + + download_from_url( + url, + md5_checksum=checksum, + output_path=self.local_cache_dir, + output_filename="molecule_data.hdf5", + force_download=force_download, + ) self.data_combination = data_combination self._clear_data() self._process_downloaded( - input_data_path, input_data_file, self.data_combination + self.local_cache_dir, "molecule_data.hdf5", self.data_combination ) if self.convert_units: self._convert_units() diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index 0175ff23..bbef9738 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -42,7 +42,7 @@ class SPICE2Curation(DatasetCuration): - 'SPICE Amino Acid Ligand v1.0 - SPICE 2 zenodo release:ls + SPICE 2 zenodo release: https://zenodo.org/records/10835749 Reference to original SPICE publication: diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index 1d912660..e6a7f08c 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -56,7 +56,7 @@ def __init__( regenerate_cache: bool = False, ) -> None: """ - Initialize the QM9Data class. + Initialize the ANI2xDataset class. Parameters ---------- diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/model_dataset.py similarity index 76% rename from modelforge/dataset/spice2.py rename to modelforge/dataset/model_dataset.py index 88162b1f..3c0a7dff 100644 --- a/modelforge/dataset/spice2.py +++ b/modelforge/dataset/model_dataset.py @@ -3,23 +3,9 @@ from .dataset import HDF5Dataset -class Spice2Dataset(HDF5Dataset): +class ModelDataset(HDF5Dataset): """ - Data class for handling ANI2x data. - - This class provides utilities for processing the ANI2x dataset stored in the modelforge HDF5 format. - - The ANI-2x data set includes properties for small organic molecules that contain - H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for nearly 200,000 molecules. - This will fetch data generated with the wB97X/631Gd level of theory used in the original ANI-2x paper, - calculated using Gaussian 09. See ani2x_curation.py for more details on the dataset curation. - - Citation: Devereux, C, Zubatyuk, R., Smith, J. et al. - "Extending the applicability of the ANI deep learning molecular potential to sulfur and halogens." - Journal of Chemical Theory and Computation 16.7 (2020): 4192-4202. - https://doi.org/10.1021/acs.jctc.0c00121 - - DOI for dataset: 10.5281/zenodo.10108941 + Data class for handling the model data generated for the AlkEthOH dataset. Attributes ---------- @@ -36,34 +22,33 @@ class Spice2Dataset(HDF5Dataset): from modelforge.utils import PropertyNames - _property_names = PropertyNames( - Z="atomic_numbers", R="geometry", E="energies", F="forces" - ) + _property_names = PropertyNames(Z="atomic_numbers", R="geometry", E="energy") _available_properties = [ "geometry", "atomic_numbers", - "energies", - "forces", + "energy", ] # All properties within the datafile, aside from SMILES/inchi. def __init__( self, - dataset_name: str = "ANI2x", - for_unit_testing: bool = False, + dataset_name: str = "ModelDataset", + # for_unit_testing: bool = False, + data_combination: str = "PURE_MM", local_cache_dir: str = ".", force_download: bool = False, regenerate_cache: bool = False, ) -> None: """ - Initialize the QM9Data class. + Initialize the ANI2xDataset class. Parameters ---------- data_name : str, optional Name of the dataset, by default "ANI2x". - for_unit_testing : bool, optional - If set to True, a subset of the dataset is used for unit testing purposes; by default False. + data_combination : str, optional + The type of data combination to use, by default "PURE_MM" + Options, PURE_MM_low_temp_correction, PURE_MM, PURE_ML local_cache_dir: str, optional Path to the local cache directory, by default ".". force_download: bool, optional @@ -73,24 +58,20 @@ def __init__( the data from the hdf5 file; by default False. Examples -------- - >>> data = ANI2xDataset() # Default dataset - >>> test_data = ANI2xDataset(for_unit_testing=True) # Testing subset + >>> data = ModelDataset() # Default dataset + >>> test_data = ModelDataset() """ _default_properties_of_interest = [ "geometry", "atomic_numbers", - "energies", - "forces", + "energy", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest - if for_unit_testing: - dataset_name = f"{dataset_name}_subset" - - self.dataset_name = dataset_name - self.for_unit_testing = for_unit_testing + dataset_name = f"{dataset_name}_{data_combination}" + self.data_combination = data_combination from openff.units import unit # these come from the ANI-2x paper generated via linear fittingh of the data @@ -110,11 +91,12 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/okv311e9yvh94owbiypcm/ani2x_dataset_n100.hdf5.gz?rlkey=pz7gnlncabtzr3b82lblr3yas&dl=1" - self.full_url = "https://www.dropbox.com/scl/fi/egg04dmtho7l1ghqiwn1z/ani2x_dataset.hdf5.gz?rlkey=wq5qjyph5q2k0bn6vza735n19&dl=1" + self.PURE_MM_url = "https://www.dropbox.com/scl/fi/0642s2ilwwmu4cyb36ttm/PURE_MM.hdf5?rlkey=n2hzkoqtchrdxtfzbybxeu2lm&dl=1" + self.PURE_ML_url = "https://www.dropbox.com/scl/fi/lx2ets7ghw7abjbhyhvtd/PURE_ML.hdf5?rlkey=s7t0dtlab2rmt7j9bd9utx1lx&dl=1" + self.PURE_MM_low_temp_correction_url = "https://www.dropbox.com/scl/fi/2rhc1m2ta420wgtgde9qz/MM_low_e_correction.hdf5?rlkey=bpt3t6uqo8194vl2l133nm001&dl=1" - if self.for_unit_testing: - url = self.test_url + if self.data_combination == "PURE_MM": + url = self.PURE_MM_url gz_data_file = { "name": "ani2x_dataset_n100.hdf5.gz", "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 3f2d3398..480e2dec 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -20,6 +20,7 @@ def triu_index(num_species: int) -> torch.Tensor: ret = torch.zeros(num_species, num_species, dtype=torch.long) ret[species1, species2] = pair_index ret[species2, species1] = pair_index + return ret @@ -151,6 +152,7 @@ def forward(self, data: AniNeuralNetworkData) -> SpeciesAEV: # ----------------- Radial symmetry vector ---------------- # # compute radial aev + radial_feature_vector = self.radial_symmetry_functions(data.d_ij) # cutoff rcut_ij = self.cutoff_module(data.d_ij) @@ -209,6 +211,7 @@ def _postprocess_angular_aev( angular_aev = angular_terms_.new_zeros( (number_of_atoms * num_species_pairs, angular_sublength) ) + index = ( central_atom_index * num_species_pairs + self.triu_index[angular_species12[0], angular_species12[1]] @@ -227,6 +230,7 @@ def _postprocess_radial_aev( number_of_atoms = data.number_of_atoms radial_sublength = self.radial_symmetry_functions.radial_sublength radial_length = radial_sublength * self.nr_of_supported_elements + radial_aev = radial_feature_vector.new_zeros( ( number_of_atoms * self.nr_of_supported_elements, diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py index c04bab0e..e58dc17f 100644 --- a/scripts/training_ani2x.py +++ b/scripts/training_ani2x.py @@ -4,7 +4,7 @@ # import the models implemented in modelforge, for now SchNet, PaiNN, ANI2x or PhysNet from modelforge.potential import NeuralNetworkPotentialFactory -from modelforge.dataset.ani2x import ANI2xDataset +from modelforge.dataset.ani2x_test import ANI2xTestDataset from modelforge.dataset.dataset import TorchDataModule from modelforge.dataset.utils import RandomRecordSplittingStrategy from pytorch_lightning.loggers import TensorBoardLogger @@ -13,7 +13,7 @@ logger = TensorBoardLogger("tb_logs", name="training") # Set up dataset -data = ANI2xDataset(force_download=False, for_unit_testing=True) +data = ANI2xTestDataset(force_download=False, for_unit_testing=True) dataset = TorchDataModule( data, batch_size=512, splitting_strategy=RandomRecordSplittingStrategy() From 83ab6e19c2eaac212cccbd3a60a8d22c62203a04 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Tue, 9 Apr 2024 10:51:59 -0700 Subject: [PATCH 15/37] Added in model dataset. --- modelforge/dataset/model_dataset.py | 57 ++++++++++++++++++----------- scripts/training_ani2x.py | 4 +- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/modelforge/dataset/model_dataset.py b/modelforge/dataset/model_dataset.py index 3c0a7dff..53069a31 100644 --- a/modelforge/dataset/model_dataset.py +++ b/modelforge/dataset/model_dataset.py @@ -48,7 +48,7 @@ def __init__( Name of the dataset, by default "ANI2x". data_combination : str, optional The type of data combination to use, by default "PURE_MM" - Options, PURE_MM_low_temp_correction, PURE_MM, PURE_ML + Options, MM_low_temp_correction, PURE_MM, PURE_ML local_cache_dir: str, optional Path to the local cache directory, by default ".". force_download: bool, optional @@ -69,7 +69,7 @@ def __init__( ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest - dataset_name = f"{dataset_name}_{data_combination}" + self.dataset_name = f"{dataset_name}_{data_combination}" self.data_combination = data_combination from openff.units import unit @@ -91,42 +91,57 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.PURE_MM_url = "https://www.dropbox.com/scl/fi/0642s2ilwwmu4cyb36ttm/PURE_MM.hdf5?rlkey=n2hzkoqtchrdxtfzbybxeu2lm&dl=1" - self.PURE_ML_url = "https://www.dropbox.com/scl/fi/lx2ets7ghw7abjbhyhvtd/PURE_ML.hdf5?rlkey=s7t0dtlab2rmt7j9bd9utx1lx&dl=1" - self.PURE_MM_low_temp_correction_url = "https://www.dropbox.com/scl/fi/2rhc1m2ta420wgtgde9qz/MM_low_e_correction.hdf5?rlkey=bpt3t6uqo8194vl2l133nm001&dl=1" + self.PURE_MM_url = "https://www.dropbox.com/scl/fi/pq6d2px51o29pegi19z7m/PURE_MM.hdf5.gz?rlkey=9tjbdsvthj9f5zfar4zfb9joo&dl=1" + self.PURE_ML_url = "https://www.dropbox.com/scl/fi/6mf8recfxd10zf1za9xjq/PURE_ML.hdf5.gz?rlkey=2xvvrcd2nbeiw7ma70hq4nui4&dl=1" + self.MM_low_temp_correction_url = "https://www.dropbox.com/scl/fi/h7xowf0v63yszfstsftpc/MM_low_e_correction.hdf5.gz?rlkey=c8u5q212lv2ikre6pukzdakzp&dl=1" if self.data_combination == "PURE_MM": url = self.PURE_MM_url gz_data_file = { - "name": "ani2x_dataset_n100.hdf5.gz", - "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", + "name": "PURE_MM_dataset.hdf5.gz", + "md5": "869441523f826fcc4af7e1ecaca13772", } hdf5_data_file = { - "name": "ani2x_dataset_n100.hdf5", - "md5": "4f54caf79e4c946dc3d6d53722d2b966", - } - processed_data_file = { - "name": "ani2x_dataset_n100_processed.npz", - "md5": "c1481fe9a6b15fb07b961d15411c0ddd", + "name": "PURE_MM_dataset.hdf5", + "md5": "3921bd738d963cc5d26d581faa9bbd36", } + processed_data_file = {"name": "PURE_MM_dataset_processed.npz", "md5": None} logger.info("Using test dataset") - else: - url = self.full_url + elif self.data_combination == "PURE_ML": + url = self.PURE_ML_url + gz_data_file = { + "name": "PURE_ML_dataset.hdf5.gz", + "md5": "ff0ab16f4503e2537ed4bb10a0a6f465", + } + + hdf5_data_file = { + "name": "PURE_ML_dataset.hdf5", + "md5": "a968d6ee74a0dbcede25c98aaa7a33e7", + } + + processed_data_file = { + "name": "PURE_ML_dataset_processed.npz", + "md5": None, + } + + logger.info("Using full dataset") + elif self.data_combination == "MM_low_temp_correction": + url = self.MM_low_temp_correction_url gz_data_file = { - "name": "ani2x_dataset.hdf5.gz", - "md5": "8daf9a7d8bbf9bcb1e9cea13b4df9270", + "name": "MM_LTC_dataset.hdf5.gz", + "md5": "0c7dbc7636afe845f128c57dbc99f581", } hdf5_data_file = { - "name": "ani2x_dataset.hdf5", - "md5": "86bb855cb8df54e082506088e949518e", + "name": "MM_LTC_dataset.hdf5", + "md5": "fb448ea4eaaafaadcce62a2123cb8c1f", } processed_data_file = { - "name": "ani2x_dataset_processed.npz", - "md5": "268438d8e1660728ba892bc7c3cd4339", + "name": "MM_LTC_dataset_processed.npz", + "md5": None, } logger.info("Using full dataset") diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py index e58dc17f..269c5f0c 100644 --- a/scripts/training_ani2x.py +++ b/scripts/training_ani2x.py @@ -4,7 +4,7 @@ # import the models implemented in modelforge, for now SchNet, PaiNN, ANI2x or PhysNet from modelforge.potential import NeuralNetworkPotentialFactory -from modelforge.dataset.ani2x_test import ANI2xTestDataset +from modelforge.dataset.ani2x import ANI2xDataset from modelforge.dataset.dataset import TorchDataModule from modelforge.dataset.utils import RandomRecordSplittingStrategy from pytorch_lightning.loggers import TensorBoardLogger @@ -13,7 +13,7 @@ logger = TensorBoardLogger("tb_logs", name="training") # Set up dataset -data = ANI2xTestDataset(force_download=False, for_unit_testing=True) +data = ANI2xDataset(force_download=False, for_unit_testing=False) dataset = TorchDataModule( data, batch_size=512, splitting_strategy=RandomRecordSplittingStrategy() From 2abbe248ca85cc7188cd920f19624574ccf6de74 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 10 Apr 2024 22:31:55 -0700 Subject: [PATCH 16/37] Adding in additional models --- modelforge/curation/spice_2_curation.py | 20 +- modelforge/curation/spice_openff_curation.py | 11 +- modelforge/dataset/ani1x.py | 259 ++++++++++++++++ modelforge/dataset/model_dataset.py | 6 +- modelforge/dataset/spice.py | 0 modelforge/dataset/spice114.py | 277 +++++++++++++++++ modelforge/dataset/spice114openff.py | 287 ++++++++++++++++++ modelforge/dataset/spice2.py | 302 +++++++++++++++++++ 8 files changed, 1143 insertions(+), 19 deletions(-) create mode 100644 modelforge/dataset/ani1x.py delete mode 100644 modelforge/dataset/spice.py create mode 100644 modelforge/dataset/spice114.py create mode 100644 modelforge/dataset/spice114openff.py create mode 100644 modelforge/dataset/spice2.py diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index bbef9738..f4860718 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -9,16 +9,16 @@ class SPICE2Curation(DatasetCuration): """ Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. - The SPICE dataset contains onformations for a diverse set of small molecules, + The SPICE dataset contains conformations for a diverse set of small molecules, dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and uncharged molecules, and a wide range of covalent and non-covalent interactions. It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, using Psi4 along with other useful quantities such as multipole moments and bond orders. This includes the following collections from qcarchive. Collections included in SPICE 1.1.4 are annotated with - along with the version used in SPICE 1.1.4; while the underlying molecules are the same in a given collection, - newer versions may have had some calculations redone, e.g., rerun calculations that failed or reru with - a patched version Psi4 + along with the version used in SPICE 1.1.4; while the underlying molecules are typically the same in a given collection, + newer versions may have had some calculations redone, e.g., rerun calculations that failed or rerun with + a newer version Psi4 - 'SPICE Solvated Amino Acids Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) - 'SPICE Dipeptides Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) @@ -105,10 +105,6 @@ def _init_dataset_parameters(self): "u_in": unit.hartree, "u_out": unit.kilojoule_per_mole, }, - "dispersion_corrected_dft_total_energy": { - "u_in": unit.hartree, - "u_out": unit.kilojoule_per_mole, - }, "dft_total_gradient": { "u_in": unit.hartree / unit.bohr, "u_out": unit.kilojoule_per_mole / unit.angstrom, @@ -237,6 +233,8 @@ def _fetch_singlepoint_from_qcarchive( ds = client.get_dataset(dataset_type=dataset_type, dataset_name=dataset_name) logger.debug(f"Fetching {dataset_name} from the QCArchive.") + ds.fetch_entry_names() + entry_names = ds.entry_names if unit_testing_max_records is None: unit_testing_max_records = len(entry_names) @@ -282,7 +280,7 @@ def _fetch_singlepoint_from_qcarchive( specification_names=[specification_name], force_refetch=force_download, ): - spice_db[record[0]] = record[2] + spice_db[record[0]] = record[2].dict() if pbar is not None: pbar.update(1) @@ -548,7 +546,7 @@ def _process_downloaded( spec_keys = list(spice_db.keys()) for key in spec_keys: - if spice_db[key].status.value == "complete": + if spice_db[key]["status"].value == "complete": non_error_keys.append(key) sorted_keys, original_keys, molecule_names = self._sort_keys(non_error_keys) @@ -619,7 +617,7 @@ def _process_downloaded( for key in tqdm(sorted_keys): name = molecule_names[key] - val = spice_db[original_keys[key]].dict() + val = spice_db[original_keys[key]] index = self.molecule_names[name] diff --git a/modelforge/curation/spice_openff_curation.py b/modelforge/curation/spice_openff_curation.py index 7dcdec4f..30bf2298 100644 --- a/modelforge/curation/spice_openff_curation.py +++ b/modelforge/curation/spice_openff_curation.py @@ -228,6 +228,7 @@ def _fetch_singlepoint_from_qcarchive( ds = client.get_dataset(dataset_type=dataset_type, dataset_name=dataset_name) + ds.fetch_entry_names() entry_names = ds.entry_names if unit_testing_max_records is None: unit_testing_max_records = len(entry_names) @@ -273,7 +274,7 @@ def _fetch_singlepoint_from_qcarchive( specification_names=[specification_name], force_refetch=force_download, ): - spice_db[record[0]] = record[2] + spice_db[record[0]] = record[2].dict() if pbar is not None: pbar.update(1) @@ -474,8 +475,8 @@ def _process_downloaded( ) as spice_db_spec6: for key in spec2_keys: if ( - spice_db_spec2[key].status.value == "complete" - and spice_db_spec6[key].status.value == "complete" + spice_db_spec2[key]["status"].value == "complete" + and spice_db_spec6[key]["status"].value == "complete" ): non_error_keys.append(key) @@ -547,7 +548,7 @@ def _process_downloaded( for key in tqdm(sorted_keys): name = key.split("-")[0] - val = spice_db[original_name[key]].dict() + val = spice_db[original_name[key]] index = self.molecule_names[name] @@ -614,7 +615,7 @@ def _process_downloaded( for key in tqdm(sorted_keys): name = key.split("-")[0] - val = spice_db[original_name[key]].dict() + val = spice_db[original_name[key]] index = self.molecule_names[name] # typecasting issue in there diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py new file mode 100644 index 00000000..40886ad5 --- /dev/null +++ b/modelforge/dataset/ani1x.py @@ -0,0 +1,259 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class ANI1xDataset(HDF5Dataset): + """ + Data class for handling ANI1x data. + + This dataset includes ~5 million density function theory calculations + for small organic molecules containing H, C, N, and O. + A subset of ~500k are computed with accurate coupled cluster methods. + + References: + + ANI-1x dataset: + Smith, J. S.; Nebgen, B.; Lubbers, N.; Isayev, O.; Roitberg, A. E. + Less Is More: Sampling Chemical Space with Active Learning. + J. Chem. Phys. 2018, 148 (24), 241733. + https://doi.org/10.1063/1.5023802 + https://arxiv.org/abs/1801.09319 + + ANI-1ccx dataset: + Smith, J. S.; Nebgen, B. T.; Zubatyuk, R.; Lubbers, N.; Devereux, C.; Barros, K.; Tretiak, S.; Isayev, O.; Roitberg, A. E. + Approaching Coupled Cluster Accuracy with a General-Purpose Neural Network Potential through Transfer Learning. N + at. Commun. 2019, 10 (1), 2903. + https://doi.org/10.1038/s41467-019-10827-4 + + wB97x/def2-TZVPP data: + Zubatyuk, R.; Smith, J. S.; Leszczynski, J.; Isayev, O. + Accurate and Transferable Multitask Prediction of Chemical Properties with an Atoms-in-Molecules Neural Network. + Sci. Adv. 2019, 5 (8), eaav6490. + https://doi.org/10.1126/sciadv.aav6490 + + + Dataset DOI: + https://doi.org/10.6084/m9.figshare.c.4712477.v1 + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "ANI2x". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", + R="geometry", + E="wb97x_dz.energy", + F="wb97x_dz.forces", + Q="wb97x_dz.cm5_charges", + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "wb97x_dz.energy", + "wb97x_dz.forces", + "wb97x_dz.cm5_charges", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "ANI1x", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + regenerate_cache: bool = False, + ) -> None: + """ + Initialize the ANI2xDataset class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "ANI1x". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. + Examples + -------- + >>> data = ANI1xDataset() # Default dataset + >>> test_data = ANI2xDataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "wb97x_dz.energy", + "wb97x_dz.forces", + "wb97x_dz.cm5_charges", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + from openff.units import unit + + # these come from the ANI-2x paper generated via linear fittingh of the data + # https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + self._ase = { + "H": -0.5978583943827134 * unit.hartree, + "C": -38.08933878049795 * unit.hartree, + "N": -54.711968298621066 * unit.hartree, + "O": -75.19106774742086 * unit.hartree, + "S": -398.1577125334925 * unit.hartree, + "F": -99.80348506781634 * unit.hartree, + "Cl": -460.1681939421027 * unit.hartree, + } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/rqjc6pcv9jjzoq08hc5ao/ani1x_dataset_n100.hdf5.gz?rlkey=kgg0xvq9aac5sp3or9oh61igj&dl=1" + self.full_url = " " + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "ani1x_dataset_n100.hdf5.gz", + "md5": "51e2491e3c5b7b5a432e2012892cfcbb", + } + hdf5_data_file = { + "name": "ani1x_dataset_n100.hdf5", + "md5": "f3c934b79f035ecc3addf88c027f5e46", + } + processed_data_file = { + "name": "ani1x_dataset_n100_processed.npz", + "md5": None, + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "ani1x_dataset.hdf5.gz", + "md5": "", + } + + hdf5_data_file = { + "name": "ani1x_dataset.hdf5", + "md5": "", + } + + processed_data_file = {"name": "ani1x_dataset_processed.npz", "md5": None} + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + regenerate_cache=regenerate_cache, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies(energies=self._ase) + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) diff --git a/modelforge/dataset/model_dataset.py b/modelforge/dataset/model_dataset.py index 53069a31..800c3dd2 100644 --- a/modelforge/dataset/model_dataset.py +++ b/modelforge/dataset/model_dataset.py @@ -107,7 +107,7 @@ def __init__( } processed_data_file = {"name": "PURE_MM_dataset_processed.npz", "md5": None} - logger.info("Using test dataset") + logger.info("Using PURE MM dataset") elif self.data_combination == "PURE_ML": url = self.PURE_ML_url @@ -126,7 +126,7 @@ def __init__( "md5": None, } - logger.info("Using full dataset") + logger.info("Using PURE ML dataset") elif self.data_combination == "MM_low_temp_correction": url = self.MM_low_temp_correction_url gz_data_file = { @@ -144,7 +144,7 @@ def __init__( "md5": None, } - logger.info("Using full dataset") + logger.info("Using MM low temperature correction dataset") # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor super().__init__( diff --git a/modelforge/dataset/spice.py b/modelforge/dataset/spice.py deleted file mode 100644 index e69de29b..00000000 diff --git a/modelforge/dataset/spice114.py b/modelforge/dataset/spice114.py new file mode 100644 index 00000000..ee7d835e --- /dev/null +++ b/modelforge/dataset/spice114.py @@ -0,0 +1,277 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class SPICE114Dataset(HDF5Dataset): + """ + Data class for handling SPICE 1.1.4 dataset. + + The SPICE dataset contains 1.1 million conformations for a diverse set of small molecules, + dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and + uncharged molecules, and a wide range of covalent and non-covalent interactions. + It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, + using Psi4 1.4.1 along with other useful quantities such as multipole moments and bond orders. + + Reference: + Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, + A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. + Sci Data 10, 11 (2023). https://doi.org/10.1038/s41597-022-01882-6 + + Dataset DOI: + https://doi.org/10.5281/zenodo.8222043 + + + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "ANI2x". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", + R="geometry", + E="dft_total_energy", + F="dft_total_force", + Q="mbis_charges", + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + "mbis_multipoles", + "mbis_octopoles", + "formation_energy", + "scf_dipole", + "scf_quadrupole", + "total_charge", + "reference_energy", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "SPICE114", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + regenerate_cache: bool = False, + ) -> None: + """ + Initialize the SPICE2Dataset class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "SPICE114". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. + Examples + -------- + >>> data = SPICE2Dataset() # Default dataset + >>> test_data = SPICE2Dataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + from openff.units import unit + + # SPICE provides reference values that depend upon charge, as charged molecules are included in the dataset. + # The reference_energy (i.e., sum of the value of isolated atoms with appropriate charge considerations) + # are included in the dataset, along with the formation_energy, which is the difference between + # the dft_total_energy and the reference_energy. + + # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider + # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. + # See spice_2_curation.py for more details. + + # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything + # roughly to the same scale, and values do not vary substantially by charge state. + + # Reference energies, in hartrees, computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + + self._ase = { + "Br": -2574.1167240829964 * unit.hartree, + "C": -37.87264507233593 * unit.hartree, + "Ca": -676.9528465198214 * unit.hartree, # 2+ + "Cl": -460.1988762285739 * unit.hartree, + "F": -99.78611622985483 * unit.hartree, + "H": -0.498760510048753 * unit.hartree, + "I": -297.76228914445625 * unit.hartree, + "K": -599.8025677513111 * unit.hartree, # 1+ + "Li": -7.285254714046546 * unit.hartree, # 1+ + "Mg": -199.2688420040449 * unit.hartree, # 2+ + "N": -54.62327513368922 * unit.hartree, + "Na": -162.11366478783253 * unit.hartree, # 1+ + "O": -75.11317840410095 * unit.hartree, + "P": -341.3059197024934 * unit.hartree, + "S": -398.1599636677874 * unit.hartree, + } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/16g7n0f7qgzjhi02g3qce/spice_114_dataset_n100.hdf5.gz?rlkey=gyyc1cd3u8p64icpb450y44qv&dl=1" + self.full_url = " " + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "SPICE114_dataset_n100.hdf5.gz", + "md5": "ee7406aaf587340190e90e365ba9ba7b", + } + hdf5_data_file = { + "name": "SPICE114_dataset_n100.hdf5", + "md5": "88bd3fca0809ca47339c52edda155d6d", + } + # npz file checksums may vary with different versions of python/numpy + processed_data_file = { + "name": "SPICE114_dataset_n100_processed.npz", + "md5": None, + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "SPICE114_dataset.hdf5.gz", + "md5": "", + } + + hdf5_data_file = { + "name": "SPICE114_dataset.hdf5", + "md5": "", + } + + processed_data_file = { + "name": "SPICE114_dataset_processed.npz", + "md5": None, + } + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + regenerate_cache=regenerate_cache, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies(energies=self._ase) + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) diff --git a/modelforge/dataset/spice114openff.py b/modelforge/dataset/spice114openff.py new file mode 100644 index 00000000..6cdab5f8 --- /dev/null +++ b/modelforge/dataset/spice114openff.py @@ -0,0 +1,287 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class SPICE114OpenFFDataset(HDF5Dataset): + """ + Data class for handling SPICE 1.1.4 dataset at the OpenForceField level of theory. + + All QM datapoints retrieved were generated using B3LYP-D3BJ/DZVP level of theory. + This is the default theory used for force field development by the Open Force Field Initiative. + + This includes the following collections from qcarchive: + + "SPICE Solvated Amino Acids Single Points Dataset v1.1", + "SPICE Dipeptides Single Points Dataset v1.2", + "SPICE DES Monomers Single Points Dataset v1.1", + "SPICE DES370K Single Points Dataset v1.0", + "SPICE PubChem Set 1 Single Points Dataset v1.2", + "SPICE PubChem Set 2 Single Points Dataset v1.2", + "SPICE PubChem Set 3 Single Points Dataset v1.2", + "SPICE PubChem Set 4 Single Points Dataset v1.2", + "SPICE PubChem Set 5 Single Points Dataset v1.2", + "SPICE PubChem Set 6 Single Points Dataset v1.2", + + It does not include the following datasets that are part of the official 1.1.4 release of SPICE (calculated + at the ωB97M-D3(BJ)/def2-TZVPPD level of theory), as the openff level of theory was not used for these datasets: + + "SPICE Ion Pairs Single Points Dataset v1.1", + "SPICE DES370K Single Points Dataset Supplement v1.0", + + Reference to original SPICE publication: + Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, + A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. + Sci Data 10, 11 (2023). https://doi.org/10.1038/s41597-022-01882-6 + + + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "ANI2x". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", + R="geometry", + E="dft_total_energy", + F="dft_total_force", + Q="mbis_charges", + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + "formation_energy", + "scf_dipole", + "total_charge", + "reference_energy", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "SPICE114OpenFF", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + regenerate_cache: bool = False, + ) -> None: + """ + Initialize the SPICE2Dataset class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "SPICE114OpenFF". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. + Examples + -------- + >>> data = SPICE2Dataset() # Default dataset + >>> test_data = SPICE2Dataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + from openff.units import unit + + # SPICE provides reference values that depend upon charge, as charged molecules are included in the dataset. + # The reference_energy (i.e., sum of the value of isolated atoms with appropriate charge considerations) + # are included in the dataset, along with the formation_energy, which is the difference between + # the dft_total_energy and the reference_energy. + + # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider + # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. + # See spice_2_curation.py for more details. + + # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything + # roughly to the same scale, and values do not vary substantially by charge state. + + # Reference energies, in hartrees, computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + + self._ase = { + "Br": -2574.1167240829964 * unit.hartree, + "C": -37.87264507233593 * unit.hartree, + "Ca": -676.9528465198214 * unit.hartree, # 2+ + "Cl": -460.1988762285739 * unit.hartree, + "F": -99.78611622985483 * unit.hartree, + "H": -0.498760510048753 * unit.hartree, + "I": -297.76228914445625 * unit.hartree, + "K": -599.8025677513111 * unit.hartree, # 1+ + "Li": -7.285254714046546 * unit.hartree, # 1+ + "Mg": -199.2688420040449 * unit.hartree, # 2+ + "N": -54.62327513368922 * unit.hartree, + "Na": -162.11366478783253 * unit.hartree, # 1+ + "O": -75.11317840410095 * unit.hartree, + "P": -341.3059197024934 * unit.hartree, + "S": -398.1599636677874 * unit.hartree, + } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/e4lw7gh00i0tyl2mbbv3h/spice_114_openff_dataset_n100.hdf5.gz?rlkey=grnyfuecwl7ur3qs6147h4awo&dl=1" + self.full_url = "https://www.dropbox.com/scl/fi/kmdk4d6hntga7bk7wdqs6/spice_114_openff_dataset.hdf5.gz?rlkey=2mf954dswat4sbpus6vhj9pvz&dl=1" + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "SPICE114OpenFF_dataset_n100.hdf5.gz", + "md5": "8a99718246c178b8f318025ffe0e5560", + } + hdf5_data_file = { + "name": "SPICE114OpenFF_dataset_n100.hdf5", + "md5": "53c0c6db27adf1f11c1d0952624ebdb4", + } + # npz file checksums may vary with different versions of python/numpy + processed_data_file = { + "name": "SPICE114OpenFF_dataset_n100_processed.npz", + "md5": None, + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "SPICE114OpenFF_dataset.hdf5.gz", + "md5": "3aca534133ebff8dba9ff859c89e18d1", + } + + hdf5_data_file = { + "name": "SPICE114OpenFF_dataset.hdf5", + "md5": "d78e185ada6d1be26e6bc1a4bf6320fb", + } + + processed_data_file = { + "name": "SPICE114OpenFF_dataset_processed.npz", + "md5": None, + } + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + regenerate_cache=regenerate_cache, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies(energies=self._ase) + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py new file mode 100644 index 00000000..804f98d2 --- /dev/null +++ b/modelforge/dataset/spice2.py @@ -0,0 +1,302 @@ +from typing import List + +from .dataset import HDF5Dataset + + +class SPICE2Dataset(HDF5Dataset): + """ + Data class for handling SPICE 2 dataset. + + The SPICE dataset contains conformations for a diverse set of small molecules, + dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and + uncharged molecules, and a wide range of covalent and non-covalent interactions. + It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, + using Psi4 along with other useful quantities such as multipole moments and bond orders. + + This includes the following collections from qcarchive. Collections included in SPICE 1.1.4 are annotated with + along with the version used in SPICE 1.1.4; while the underlying molecules are typically the same in a given collection, + newer versions may have had some calculations redone, e.g., rerun calculations that failed or rerun with + a newer version Psi4 + + - 'SPICE Solvated Amino Acids Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) + - 'SPICE Dipeptides Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE DES Monomers Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) + - 'SPICE DES370K Single Points Dataset v1.0' * (SPICE 1.1.4 at v1.0) + - 'SPICE DES370K Single Points Dataset Supplement v1.1' * (SPICE 1.1.4 at v1.0) + - 'SPICE PubChem Set 1 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 2 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 3 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 4 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 5 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 6 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 7 Single Points Dataset v1.0' + - 'SPICE PubChem Set 8 Single Points Dataset v1.0' + - 'SPICE PubChem Set 9 Single Points Dataset v1.0' + - 'SPICE PubChem Set 10 Single Points Dataset v1.0' + - 'SPICE Ion Pairs Single Points Dataset v1.2' * (SPICE 1.1.4 at v1.1) + - 'SPICE PubChem Boron Silicon v1.0' + - 'SPICE Solvated PubChem Set 1 v1.0' + - 'SPICE Water Clusters v1.0' + - 'SPICE Amino Acid Ligand v1.0 + + + SPICE 2 zenodo release: + https://zenodo.org/records/10835749 + + Reference to original SPICE publication: + Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, + A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. + Sci Data 10, 11 (2023). https://doi.org/10.1038/s41597-022-01882-6 + + + Attributes + ---------- + dataset_name : str + Name of the dataset, default is "SPICE2". + for_unit_testing : bool + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + Examples + -------- + + """ + + from modelforge.utils import PropertyNames + + _property_names = PropertyNames( + Z="atomic_numbers", + R="geometry", + E="dft_total_energy", + F="dft_total_force", + Q="mbis_charges", + ) + + _available_properties = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + "formation_energy", + "scf_dipole", + "total_charge", + "reference_energy", + ] # All properties within the datafile, aside from SMILES/inchi. + + def __init__( + self, + dataset_name: str = "SPICE2", + for_unit_testing: bool = False, + local_cache_dir: str = ".", + force_download: bool = False, + regenerate_cache: bool = False, + ) -> None: + """ + Initialize the SPICE2Dataset class. + + Parameters + ---------- + data_name : str, optional + Name of the dataset, by default "ANI2x". + for_unit_testing : bool, optional + If set to True, a subset of the dataset is used for unit testing purposes; by default False. + local_cache_dir: str, optional + Path to the local cache directory, by default ".". + force_download: bool, optional + If set to True, we will download the dataset even if it already exists; by default False. + regenerate_cache: bool, optional + If set to True, we will regenerate the npz cache file even if it already exists, using + the data from the hdf5 file; by default False. + Examples + -------- + >>> data = SPICE2Dataset() # Default dataset + >>> test_data = SPICE2Dataset(for_unit_testing=True) # Testing subset + """ + + _default_properties_of_interest = [ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ] # NOTE: Default values + + self._properties_of_interest = _default_properties_of_interest + if for_unit_testing: + dataset_name = f"{dataset_name}_subset" + + self.dataset_name = dataset_name + self.for_unit_testing = for_unit_testing + + from openff.units import unit + + # SPICE provides reference values that depend upon charge, as charged molecules are included in the dataset. + # The reference_energy (i.e., sum of the value of isolated atoms with appropriate charge considerations) + # are included in the dataset, along with the formation_energy, which is the difference between + # the dft_total_energy and the reference_energy. + + # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider + # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. + # See spice_2_curation.py for more details. + + # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything + # roughly to the same scale, and values do not vary substantially by charge state. + + # Reference energies, in hartrees, computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + + self._ase = { + "B": -24.671520535482145 * unit.hartree, + "Br": -2574.1167240829964 * unit.hartree, + "C": -37.87264507233593 * unit.hartree, + "Ca": -676.9528465198214 * unit.hartree, # 2+ + "Cl": -460.1988762285739 * unit.hartree, + "F": -99.78611622985483 * unit.hartree, + "H": -0.498760510048753 * unit.hartree, + "I": -297.76228914445625 * unit.hartree, + "K": -599.8025677513111 * unit.hartree, # 1+ + "Li": -7.285254714046546 * unit.hartree, # 1+ + "Mg": -199.2688420040449 * unit.hartree, # 2+ + "N": -54.62327513368922 * unit.hartree, + "Na": -162.11366478783253 * unit.hartree, # 1+ + "O": -75.11317840410095 * unit.hartree, + "P": -341.3059197024934 * unit.hartree, + "S": -398.1599636677874 * unit.hartree, + "Si": -289.4131352299586 * unit.hartree, + } + from loguru import logger + + # We need to define the checksums for the various files that we will be dealing with to load up the data + # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. + + # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download + self.test_url = "https://www.dropbox.com/scl/fi/08u7e400qvrq2aklxw2yo/spice_2_dataset_n100.hdf5.gz?rlkey=ifv7hfzqnwl2faef8xxr5ggj2&dl=1" + self.full_url = " " + + if self.for_unit_testing: + url = self.test_url + gz_data_file = { + "name": "SPICE2_dataset_n100.hdf5.gz", + "md5": "6f3f2931d4eb59f7a54f0a11c72bb604", + } + hdf5_data_file = { + "name": "SPICE2_dataset_n100.hdf5", + "md5": "ff89646eab99e31447be1697de8b7208", + } + # npz file checksums may vary with different versions of python/numpy + processed_data_file = { + "name": "SPICE2_dataset_n100_processed.npz", + "md5": None, + } + + logger.info("Using test dataset") + + else: + url = self.full_url + gz_data_file = { + "name": "SPICE2_dataset.hdf5.gz", + "md5": "", + } + + hdf5_data_file = { + "name": "SPICE2_dataset.hdf5", + "md5": "", + } + + processed_data_file = { + "name": "SPICE2_dataset_processed.npz", + "md5": None, + } + + logger.info("Using full dataset") + + # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor + super().__init__( + url=url, + gz_data_file=gz_data_file, + hdf5_data_file=hdf5_data_file, + processed_data_file=processed_data_file, + local_cache_dir=local_cache_dir, + force_download=force_download, + regenerate_cache=regenerate_cache, + ) + + @property + def atomic_self_energies(self): + from modelforge.potential.utils import AtomicSelfEnergies + + return AtomicSelfEnergies(energies=self._ase) + + @property + def properties_of_interest(self) -> List[str]: + """ + Getter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset. + + Returns + ------- + List[str] + List of properties of interest. + + """ + return self._properties_of_interest + + @property + def available_properties(self) -> List[str]: + """ + List of available properties in the dataset. + + Returns + ------- + List[str] + List of available properties in the dataset. + + Examples + -------- + + """ + return self._available_properties + + @properties_of_interest.setter + def properties_of_interest(self, properties_of_interest: List[str]) -> None: + """ + Setter for the properties of interest. + The order of this list determines also the order provided in the __getitem__ call + from the PytorchDataset + + Parameters + ---------- + properties_of_interest : List[str] + List of properties of interest. + + Examples + -------- + + """ + if not set(properties_of_interest).issubset(self._available_properties): + raise ValueError( + f"Properties of interest must be a subset of {self._available_properties}" + ) + self._properties_of_interest = properties_of_interest + + def _download(self) -> None: + """ + Download the hdf5 file containing the data from Dropbox. + + Examples + -------- + + + """ + # Right now this function needs to be defined for each dataset. + # once all datasets are moved to zenodo, we should only need a single function defined in the base class + from modelforge.utils.remote import download_from_url + + download_from_url( + url=self.url, + md5_checksum=self.gz_data_file["md5"], + output_path=self.local_cache_dir, + output_filename=self.gz_data_file["name"], + force_download=self.force_download, + ) From e1c40a58da0badaa1647e21276785918053e6232 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 11 Apr 2024 20:32:07 -0700 Subject: [PATCH 17/37] added metadata generation/validation for npz files. Added testing --- modelforge/dataset/ani1x.py | 11 +- modelforge/dataset/ani2x.py | 3 + modelforge/dataset/dataset.py | 147 +++++++++++++++++++++------ modelforge/dataset/spice114.py | 9 +- modelforge/dataset/spice114openff.py | 3 + modelforge/dataset/spice2.py | 9 +- modelforge/tests/test_dataset.py | 68 ++++++++++++- modelforge/utils/remote.py | 8 +- scripts/training_ani2x.py | 1 + 9 files changed, 212 insertions(+), 47 deletions(-) diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py index 40886ad5..1a367def 100644 --- a/modelforge/dataset/ani1x.py +++ b/modelforge/dataset/ani1x.py @@ -56,7 +56,6 @@ class ANI1xDataset(HDF5Dataset): R="geometry", E="wb97x_dz.energy", F="wb97x_dz.forces", - Q="wb97x_dz.cm5_charges", ) _available_properties = [ @@ -102,7 +101,6 @@ def __init__( "atomic_numbers", "wb97x_dz.energy", "wb97x_dz.forces", - "wb97x_dz.cm5_charges", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest @@ -132,13 +130,14 @@ def __init__( # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download self.test_url = "https://www.dropbox.com/scl/fi/rqjc6pcv9jjzoq08hc5ao/ani1x_dataset_n100.hdf5.gz?rlkey=kgg0xvq9aac5sp3or9oh61igj&dl=1" - self.full_url = " " + self.full_url = "https://www.dropbox.com/scl/fi/d98h9kt4pl40qeapqzu00/ani1x_dataset.hdf5.gz?rlkey=7q1o8hh9qzbxehsobjurcksit&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { "name": "ani1x_dataset_n100.hdf5.gz", "md5": "51e2491e3c5b7b5a432e2012892cfcbb", + "length": 85445473, } hdf5_data_file = { "name": "ani1x_dataset_n100.hdf5", @@ -155,12 +154,13 @@ def __init__( url = self.full_url gz_data_file = { "name": "ani1x_dataset.hdf5.gz", - "md5": "", + "md5": "408cdcf9768ac96a8ae8ade9f078c51b", + "length": 4510287721, } hdf5_data_file = { "name": "ani1x_dataset.hdf5", - "md5": "", + "md5": "361b7c4b9a4dfeece70f0fe6a893e76a", } processed_data_file = {"name": "ani1x_dataset_processed.npz", "md5": None} @@ -255,5 +255,6 @@ def _download(self) -> None: md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, output_filename=self.gz_data_file["name"], + length=self.gz_data_file["length"], force_download=self.force_download, ) diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index e6a7f08c..c2c7957f 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -118,6 +118,7 @@ def __init__( gz_data_file = { "name": "ani2x_dataset_n100.hdf5.gz", "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", + "length": 22254528, } hdf5_data_file = { "name": "ani2x_dataset_n100.hdf5", @@ -135,6 +136,7 @@ def __init__( gz_data_file = { "name": "ani2x_dataset.hdf5.gz", "md5": "8daf9a7d8bbf9bcb1e9cea13b4df9270", + "length": 5085941907, } hdf5_data_file = { @@ -237,5 +239,6 @@ def _download(self) -> None: md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, output_filename=self.gz_data_file["name"], + length=self.gz_data_file["length"], force_download=self.force_download, ) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 6629fafc..00ff5444 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -311,6 +311,72 @@ def _ungzip_hdf5(self) -> None: ) as out_file: shutil.copyfileobj(gz_file, out_file) + def _check_lists(self, list_1: List, list_2: List) -> bool: + """ + Check to see if all elements in the lists match and the length is the same. + + Note the order of the lists do not matter. + + Parameters + ---------- + list_1 : List + First list to compare + list_2 : List + Second list to compare + + Returns + ------- + bool + True if all elements of sub_list are in containing_list, False otherwise + """ + if len(list_1) != len(list_2): + return False + for a in list_1: + if a not in list_2: + return False + return True + + def _metadata_validation(self, file_name: str, file_path: str) -> bool: + """ + Validates the metadata file for the npz file. + + Parameters + ---------- + file_name : str + Name of the metadata file. + file_path : str + Path to the metadata file. + + Returns + ------- + bool + True if the metadata file exists, False otherwise. + """ + if not os.path.exists(f"{file_path}/{file_name}"): + log.debug(f"Metadata file {file_path}/{file_name} does not exist.") + return False + else: + import json + + with open(f"{file_path}/{file_name}", "r") as f: + self._npz_metadata = json.load(f) + + if not self._check_lists( + self._npz_metadata["data_keys"], self.properties_of_interest + ): + log.warning( + f"Data keys used to generate {file_path}/{file_name} ({self._npz_metadata['data_keys']}) do not match data loader ({self.properties_of_interest}) ." + ) + return False + + if self._npz_metadata["hdf5_checksum"] != self.hdf5_data_file["md5"]: + log.warning( + f"Checksum for hdf5 file used to generate npz file does not match current file in dataloader." + ) + return False + + return True + def _file_validation( self, file_name: str, file_path: str, checksum: str = None ) -> bool: @@ -514,37 +580,25 @@ def _from_file_cache(self) -> None: Examples -------- - >>> hdf5_data = HDF5Dataset("raw_data.hdf5", "processed_data.npz") - >>> processed_data = hdf5_data._from_file_cache() - """ - # if self._file_validation( - # self.processed_data_file["name"], - # self.local_cache_dir, - # self.processed_data_file["md5"], - # ): - # log.debug(f"Loading processed data from {self.processed_data_file['name']}") - # - # else: - # from modelforge.utils.remote import calculate_md5_checksum - # - # checksum = calculate_md5_checksum( - # self.processed_data_file["name"], self.local_cache_dir - # ) - # raise ValueError( - # f"Checksum mismatch for processed data file {self.processed_data_file['name']}. Found {checksum}, expected {self.processed_data_file['md5']}" - # ) - import os - + """ # skip validating the checksum, as the npz file checksum of otherwise identical data differs between python 3.11 and 3.9/10 + # we have a metadatafile we validate separately instead if self._file_validation( self.processed_data_file["name"], self.local_cache_dir, checksum=None ): - log.debug( - f"Loading processed data from {self.local_cache_dir}/{self.processed_data_file['name']}" - ) - self.numpy_data = np.load( - f"{self.local_cache_dir}/{self.processed_data_file['name']}" - ) + if self._metadata_validation( + self.processed_data_file["name"].replace(".npz", ".json"), + self.local_cache_dir, + ): + log.debug( + f"Loading processed data from {self.local_cache_dir}/{self.processed_data_file['name']} generated on {self._npz_metadata['date_generated']}" + ) + log.debug( + f"Properties of Interes in .npz file: {self._npz_metadata['data_keys']}" + ) + self.numpy_data = np.load( + f"{self.local_cache_dir}/{self.processed_data_file['name']}" + ) else: raise ValueError( f"Processed data file {self.local_cache_dir}/{self.processed_data_file['name']} not found." @@ -573,6 +627,25 @@ def _to_file_cache( n_confs=self.n_confs, **self.hdf5data, ) + import datetime + + # we will generate a simple metadata file to list which data keys were used to generate the npz file + # and the checksum of the hdf5 file used to create the npz + # we can also add in the date of generation so we can report on when the datafile was generated when we load the npz + metadata = { + "data_keys": list(self.hdf5data.keys()), + "hdf5_checksum": self.hdf5_data_file["md5"], + "hdf5_gz_checkusm": self.gz_data_file["md5"], + "date_generated": str(datetime.datetime.now()), + } + import json + + with open( + f"{self.local_cache_dir}/{self.processed_data_file['name'].replace('.npz', '.json')}", + "w", + ) as f: + json.dump(metadata, f) + del self.hdf5data @@ -607,16 +680,26 @@ def _load_or_process_data( The HDF5 dataset instance to use. """ - # check to see if we can load from the npz file. This also validates the checksum - if ( - data._file_validation( - data.processed_data_file["name"], + # For efficiency purposes, we first want to see if there is an npz file available before reprocessing the hdf5 + # file, expanding the gzziped archive or download the file. + # Saving to cache will create an npz file and metadata file. + # The metadata file will contain the keys used to generate the npz file, the checksum of the hdf5 and gz + # file used to generate the npz file. We will look at the metadata file and compare this to the + # variables saved in the HDF5Dataset class to determine if the npz file is valid. + # It is important to check the keys used to generate the npz file, as these are allowed to be changed by the user. + + if data._file_validation( + data.processed_data_file["name"], + data.local_cache_dir, + ) and ( + data._metadata_validation( + data.processed_data_file["name"].replace(".npz", ".json"), data.local_cache_dir, - None, ) and not data.force_download and not data.regenerate_cache ): + data._from_file_cache() # check to see if the hdf5 file exists and the checksum matches elif ( diff --git a/modelforge/dataset/spice114.py b/modelforge/dataset/spice114.py index ee7d835e..e32e9fa5 100644 --- a/modelforge/dataset/spice114.py +++ b/modelforge/dataset/spice114.py @@ -146,13 +146,14 @@ def __init__( # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download self.test_url = "https://www.dropbox.com/scl/fi/16g7n0f7qgzjhi02g3qce/spice_114_dataset_n100.hdf5.gz?rlkey=gyyc1cd3u8p64icpb450y44qv&dl=1" - self.full_url = " " + self.full_url = "https://www.dropbox.com/scl/fi/zfh4sq2kiz250bvd9oshr/spice_114_dataset.hdf5.gz?rlkey=q3sp7p8ir21o0y0224bt75aw7&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { "name": "SPICE114_dataset_n100.hdf5.gz", "md5": "ee7406aaf587340190e90e365ba9ba7b", + "length": 72001865, } hdf5_data_file = { "name": "SPICE114_dataset_n100.hdf5", @@ -170,12 +171,13 @@ def __init__( url = self.full_url gz_data_file = { "name": "SPICE114_dataset.hdf5.gz", - "md5": "", + "md5": "ad4722574dd820d7eb7b7db64a763bb2", + "length": 11193528261, } hdf5_data_file = { "name": "SPICE114_dataset.hdf5", - "md5": "", + "md5": "943a3df3bef247c8cffbb55d913c0bba", } processed_data_file = { @@ -273,5 +275,6 @@ def _download(self) -> None: md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, output_filename=self.gz_data_file["name"], + length=self.gz_data_file["length"], force_download=self.force_download, ) diff --git a/modelforge/dataset/spice114openff.py b/modelforge/dataset/spice114openff.py index 6cdab5f8..b0b1623a 100644 --- a/modelforge/dataset/spice114openff.py +++ b/modelforge/dataset/spice114openff.py @@ -163,6 +163,7 @@ def __init__( gz_data_file = { "name": "SPICE114OpenFF_dataset_n100.hdf5.gz", "md5": "8a99718246c178b8f318025ffe0e5560", + "length": 306289237, } hdf5_data_file = { "name": "SPICE114OpenFF_dataset_n100.hdf5", @@ -181,6 +182,7 @@ def __init__( gz_data_file = { "name": "SPICE114OpenFF_dataset.hdf5.gz", "md5": "3aca534133ebff8dba9ff859c89e18d1", + "length": 2540106767, } hdf5_data_file = { @@ -283,5 +285,6 @@ def _download(self) -> None: md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, output_filename=self.gz_data_file["name"], + length=self.gz_data_file["length"], force_download=self.force_download, ) diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py index 804f98d2..455d6d71 100644 --- a/modelforge/dataset/spice2.py +++ b/modelforge/dataset/spice2.py @@ -171,13 +171,14 @@ def __init__( # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download self.test_url = "https://www.dropbox.com/scl/fi/08u7e400qvrq2aklxw2yo/spice_2_dataset_n100.hdf5.gz?rlkey=ifv7hfzqnwl2faef8xxr5ggj2&dl=1" - self.full_url = " " + self.full_url = "https://www.dropbox.com/scl/fi/udoc3jj7wa7du8jgqiat0/spice_2_dataset.hdf5.gz?rlkey=csgwqa237m002n54jnld5pfgy&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { "name": "SPICE2_dataset_n100.hdf5.gz", "md5": "6f3f2931d4eb59f7a54f0a11c72bb604", + "length": 315275240, # the number of bytes to be able to display the download progress bar correctly } hdf5_data_file = { "name": "SPICE2_dataset_n100.hdf5", @@ -195,12 +196,13 @@ def __init__( url = self.full_url gz_data_file = { "name": "SPICE2_dataset.hdf5.gz", - "md5": "", + "md5": "244a559a6062bbec5c9cb49af036ff7d", + "length": 5532866319, } hdf5_data_file = { "name": "SPICE2_dataset.hdf5", - "md5": "", + "md5": "9659a0f18050b9e7b122c0046b705480", } processed_data_file = { @@ -298,5 +300,6 @@ def _download(self) -> None: md5_checksum=self.gz_data_file["md5"], output_path=self.local_cache_dir, output_filename=self.gz_data_file["name"], + length=self.gz_data_file["length"], force_download=self.force_download, ) diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index ab5183dd..99b135f5 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -141,10 +141,15 @@ def test_different_properties_of_interest(dataset): assert isinstance(raw_data_item, dict) assert len(raw_data_item) == 7 - data.properties_of_interest = ["internal_energy_at_0K", "geometry"] + data.properties_of_interest = [ + "internal_energy_at_0K", + "geometry", + "atomic_numbers", + ] assert data.properties_of_interest == [ "internal_energy_at_0K", "geometry", + "atomic_numbers", ] dataset = factory.create_dataset(data) @@ -242,11 +247,59 @@ def test_caching(prep_temp_dir): data._from_file_cache() +def test_metadata_validation(prep_temp_dir): + local_cache_dir = str(prep_temp_dir) + + from modelforge.dataset.qm9 import QM9Dataset + + data = QM9Dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) + + a = ["energy", "force", "atomic_numbers"] + b = ["energy", "atomic_numbers", "force"] + assert data._check_lists(a, b) == True + + a = ["energy", "force"] + + assert data._check_lists(a, b) == False + + a = ["energy", "force", "atomic_numbers", "charges"] + + assert data._check_lists(a, b) == False + + # we do not have a metadata files so this will fail + assert data._metadata_validation("qm9_test.json", local_cache_dir) == False + + metadata = { + "data_keys": ["atomic_numbers", "internal_energy_at_0K", "geometry", "charges"], + "hdf5_checksum": "77df0e1df7a5ec5629be52181e82a7d7", + "hdf5_gz_checkusm": "af3afda5c3265c9c096935ab060f537a", + "date_generated": "2024-04-11 14:05:14.297305", + } + + import json + + with open( + f"{local_cache_dir}/qm9_test.json", + "w", + ) as f: + json.dump(metadata, f) + + assert data._metadata_validation("qm9_test.json", local_cache_dir) == True + + metadata["hdf5_checksum"] = "wrong_checksum" + with open( + f"{local_cache_dir}/qm9_test.json", + "w", + ) as f: + json.dump(metadata, f) + assert data._metadata_validation("qm9_test.json", local_cache_dir) == False + + @pytest.mark.parametrize("dataset", DATASETS) def test_different_scenarios_of_file_availability(dataset, prep_temp_dir): """Test the behavior when raw and processed dataset files are removed.""" - local_cache_dir = str(prep_temp_dir) + "/test_diff_secnarios" + local_cache_dir = str(prep_temp_dir) + "/test_diff_scenarios" factory = DatasetFactory() data = dataset(for_unit_testing=True, local_cache_dir=local_cache_dir) @@ -254,12 +307,21 @@ def test_different_scenarios_of_file_availability(dataset, prep_temp_dir): # this will download the .gz, the .hdf5 and the .npz files factory.create_dataset(data) - # first check if we remote the npz file, rerunning it will regenerated it + # first check if we remove the npz file, rerunning it will regenerate it os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") factory.create_dataset(data) assert os.path.exists(f"{local_cache_dir}/{data.processed_data_file['name']}") + # now remove metadata file, rerunning will regenerate the npz file + os.remove( + f"{local_cache_dir}/{data.processed_data_file['name'].replace('npz', 'json')}" + ) + factory.create_dataset(data) + assert os.path.exists( + f"{local_cache_dir}/{data.processed_data_file['name'].replace('npz', 'json')}" + ) + # now remove the npz and hdf5 files, rerunning will generate it os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") diff --git a/modelforge/utils/remote.py b/modelforge/utils/remote.py index 7a68c337..0a652912 100644 --- a/modelforge/utils/remote.py +++ b/modelforge/utils/remote.py @@ -98,6 +98,7 @@ def download_from_url( md5_checksum: str, output_path: str, output_filename: str, + length: Optional[int] = None, force_download=False, ) -> str: @@ -125,14 +126,19 @@ def download_from_url( r = requests.get(url, stream=True) os.makedirs(output_path, exist_ok=True) - + if length is not None: + total = int(length / chunk_size) + 1 + else: + total = None with open(f"{output_path}/{output_filename}", "wb") as fd: for chunk in tqdm( r.iter_content(chunk_size=chunk_size), ascii=True, desc="downloading", + total=total, ): fd.write(chunk) + calculated_checksum = calculate_md5_checksum( file_name=output_filename, file_path=output_path ) diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py index 269c5f0c..6c7f229d 100644 --- a/scripts/training_ani2x.py +++ b/scripts/training_ani2x.py @@ -5,6 +5,7 @@ # import the models implemented in modelforge, for now SchNet, PaiNN, ANI2x or PhysNet from modelforge.potential import NeuralNetworkPotentialFactory from modelforge.dataset.ani2x import ANI2xDataset + from modelforge.dataset.dataset import TorchDataModule from modelforge.dataset.utils import RandomRecordSplittingStrategy from pytorch_lightning.loggers import TensorBoardLogger From 0d2f2161580a45c0721171b8a47a57cf72170c98 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 11 Apr 2024 22:19:47 -0700 Subject: [PATCH 18/37] typo --- modelforge/dataset/ani1x.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py index 1a367def..5247caae 100644 --- a/modelforge/dataset/ani1x.py +++ b/modelforge/dataset/ani1x.py @@ -5,7 +5,7 @@ class ANI1xDataset(HDF5Dataset): """ - Data class for handling ANI1x data. + Data class for handling ANI1x dataset. This dataset includes ~5 million density function theory calculations for small organic molecules containing H, C, N, and O. From 1d4efbe35f06d0cca66a238bffc92799bb4adc10 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Fri, 19 Apr 2024 13:15:37 +0200 Subject: [PATCH 19/37] Update training_ani2x.py small changes to script --- scripts/training_ani2x.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/scripts/training_ani2x.py b/scripts/training_ani2x.py index 6c7f229d..f73e724f 100644 --- a/scripts/training_ani2x.py +++ b/scripts/training_ani2x.py @@ -1,8 +1,8 @@ -# This is an example script that trains an implemented model on the QM9 dataset. +# This is an example script that trains the ANI2x model on the ANI2x dataset. from lightning import Trainer import torch -# import the models implemented in modelforge, for now SchNet, PaiNN, ANI2x or PhysNet +# import the dataset and model factory from modelforge.potential import NeuralNetworkPotentialFactory from modelforge.dataset.ani2x import ANI2xDataset @@ -24,9 +24,6 @@ # Set up model model = NeuralNetworkPotentialFactory.create_nnp("training", "ANI2x") -model = model.to(torch.float32) - -print(model) # set up traininer from lightning.pytorch.callbacks.early_stopping import EarlyStopping From e242a2108cedfdc1bd67376255c0283bd545cee1 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 22 Apr 2024 14:27:24 -0700 Subject: [PATCH 20/37] updating based on comments. --- modelforge/curation/ani1x_curation.py | 1 - modelforge/curation/ani2x_curation.py | 1 - modelforge/curation/curation_baseclass.py | 4 +- modelforge/curation/model_dataset.py | 410 ---------------------- modelforge/curation/qm9_curation.py | 2 +- modelforge/curation/spice_114_curation.py | 1 - modelforge/dataset/dataset.py | 6 +- modelforge/dataset/model_dataset.py | 238 ------------- modelforge/potential/models.py | 21 +- modelforge/potential/utils.py | 6 +- modelforge/tests/model_datasets.py | 38 -- modelforge/utils/units.py | 18 +- 12 files changed, 38 insertions(+), 708 deletions(-) delete mode 100644 modelforge/curation/model_dataset.py delete mode 100644 modelforge/dataset/model_dataset.py delete mode 100644 modelforge/tests/model_datasets.py diff --git a/modelforge/curation/ani1x_curation.py b/modelforge/curation/ani1x_curation.py index 169a5b4c..4b8975ce 100644 --- a/modelforge/curation/ani1x_curation.py +++ b/modelforge/curation/ani1x_curation.py @@ -1,5 +1,4 @@ from modelforge.curation.curation_baseclass import DatasetCuration -from modelforge.utils.units import * from typing import Optional from loguru import logger diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index d12cff50..81edc076 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -1,5 +1,4 @@ from modelforge.curation.curation_baseclass import DatasetCuration -from modelforge.utils.units import * from typing import Optional from loguru import logger diff --git a/modelforge/curation/curation_baseclass.py b/modelforge/curation/curation_baseclass.py index 16d8d40c..58ac54e5 100644 --- a/modelforge/curation/curation_baseclass.py +++ b/modelforge/curation/curation_baseclass.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional from loguru import logger -from modelforge.utils.units import * def dict_to_hdf5( @@ -151,6 +150,9 @@ def _convert_units(self): """ import pint + # this is needed for the "chem" context to convert hartrees to kj/mol + from modelforge.utils.units import chem_context + for datapoint in self.data: for key, val in datapoint.items(): if isinstance(val, pint.Quantity): diff --git a/modelforge/curation/model_dataset.py b/modelforge/curation/model_dataset.py deleted file mode 100644 index fad3d5a0..00000000 --- a/modelforge/curation/model_dataset.py +++ /dev/null @@ -1,410 +0,0 @@ -from modelforge.curation.curation_baseclass import DatasetCuration, dict_to_hdf5 -from modelforge.utils.units import * - -import numpy as np - -from typing import Optional, List -from loguru import logger - - -class ModelDataset(DatasetCuration): - """ - Routines to fetch and process the model dataset used for examining different approaches to generating - training data. - - - """ - - def __init__( - self, - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - convert_units: bool = True, - seed=12345, - ): - super().__init__( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - convert_units=convert_units, - ) - self.seed = seed - - def _init_dataset_parameters(self): - self.qm_parameters = { - "geometry": {"u_in": unit.nanometer, "u_out": unit.nanometer}, - "energy": { - "u_in": unit.kilojoule_per_mole, - "u_out": unit.kilojoule_per_mole, - }, - } - - def _init_record_entries_series(self): - self._record_entries_series = { - "name": "single_rec", - "n_configs": "single_rec", - "atomic_numbers": "single_atom", - "geometry": "series_atom", - "energy": "series_mol", - } - - def _process_downloaded( - self, - local_path_dir: str, - filename: str, - model: str, - ): - file_path = f"{local_path_dir}/{filename}" - - import h5py - - data_temp = [] - with h5py.File(file_path, "r") as f: - molecule_names = list(f.keys()) - for molecule_name in molecule_names: - record_temp = {} - molecule = f[molecule_name] - for key in molecule.keys(): - temp = molecule[key][()] - if "u" in molecule[key].attrs: - temp = temp * unit(molecule[key].attrs["u"]) - record_temp[key] = temp - record_temp["name"] = molecule_name - data_temp.append(record_temp) - - self.data = [] - self.test_data_molecules = [] - self.test_data_conformers = [] - - # figure out how which molecules we have in our holdout set - # we will keep 10 % of the data for testing - n_molecules = len(data_temp) - from numpy.random import RandomState - - prng = RandomState(self.seed) - hold_out = prng.randint(n_molecules, size=(int(n_molecules * 0.1))) - - if model == "PURE_MM": - for i, record in enumerate(data_temp): - temp = {} - temp["name"] = record["name"] - temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) - - temp_conf_holdout = {} - temp_conf_holdout["name"] = record["name"] - temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( - -1, 1 - ) - - if i in hold_out: - temp["energy"] = ( - np.vstack( - ( - np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), - record["MM_300_ML_energy"].m.reshape(-1, 1), - record["MM_100_ML_energy"].m.reshape(-1, 1), - ) - ) - * record["MM_emin_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["MM_emin_coords"].m.reshape(1, -1, 3), - record["MM_coords_300"].m, - record["MM_coords_100"].m, - ) - ) - * record["MM_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].m.shape[0] - self.test_data_molecules.append(temp) - else: - temp["energy"] = ( - np.vstack( - ( - np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), - record["MM_300_ML_energy"][0:9].m.reshape(-1, 1), - record["MM_100_ML_energy"][0:9].m.reshape(-1, 1), - ) - ) - * record["MM_emin_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["MM_emin_coords"].m.reshape(1, -1, 3), - record["MM_coords_300"][0:9].m, - record["MM_coords_100"][0:9].m, - ) - ) - * record["MM_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].m.shape[0] - - self.data.append(temp) - - temp_conf_holdout["energy"] = ( - np.vstack( - ( - record["MM_300_ML_energy"][9:10].m.reshape(-1, 1), - record["MM_100_ML_energy"][9:10].m.reshape(-1, 1), - ) - ) - * record["MM_300_ML_energy"].u - ) - temp_conf_holdout["geometry"] = ( - np.vstack( - ( - record["MM_coords_300"][9:10].m, - record["MM_coords_100"][9:10].m, - ) - ) - * record["MM_emin_coords"].u - ) - temp_conf_holdout["n_configs"] = temp_conf_holdout[ - "geometry" - ].m.shape[0] - self.test_data_conformers.append(temp_conf_holdout) - - if model == "PURE_MM_low_temp_correction": - for i, record in enumerate(data_temp): - temp = {} - temp["name"] = record["name"] - temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) - - temp_conf_holdout = {} - temp_conf_holdout["name"] = record["name"] - temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( - -1, 1 - ) - - if i in hold_out: - temp["energy"] = ( - np.vstack( - ( - np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), - record["MM_300_ML_energy"].m.reshape(-1, 1), - record["MM_100_ML_energy"].m.reshape(-1, 1), - record["MM100_ML_emin_ML_energy"].m.reshape(-1, 1), - ) - ) - * record["MM_300_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["MM_emin_coords"].m.reshape(1, -1, 3), - record["MM_coords_300"].m, - record["MM_coords_100"].m, - record["MM100_ML_emin_coords"].m, - ) - ) - * record["MM_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].m.shape[0] - self.test_data_molecules.append(temp) - else: - temp["energy"] = ( - np.vstack( - ( - np.array(record["MM_emin_ML_energy"].m).reshape(-1, 1), - record["MM_300_ML_energy"][0:9].m.reshape(-1, 1), - record["MM_100_ML_energy"][0:9].m.reshape(-1, 1), - record["MM100_ML_emin_ML_energy"][0:9].m.reshape(-1, 1), - ) - ) - * record["MM_emin_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["MM_emin_coords"].m.reshape(1, -1, 3), - record["MM_coords_300"][0:9].m, - record["MM_coords_100"][0:9].m, - record["MM100_ML_emin_coords"][0:9].m, - ) - ) - * record["MM_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].m.shape[0] - self.data.append(temp) - - temp_conf_holdout["energy"] = ( - np.vstack( - ( - record["MM_300_ML_energy"][9:10].m.reshape(-1, 1), - record["MM_100_ML_energy"][9:10].m.reshape(-1, 1), - record["MM100_ML_emin_ML_energy"][9:10].m.reshape( - -1, 1 - ), - ) - ) - * record["MM_300_ML_energy"].u - ) - temp_conf_holdout["geometry"] = ( - np.vstack( - ( - record["MM_coords_300"][9:10].m, - record["MM_coords_100"][9:10].m, - record["MM100_ML_emin_coords"][9:10].m, - ) - ) - * record["MM_coords_300"].u - ) - temp_conf_holdout["n_configs"] = temp_conf_holdout[ - "geometry" - ].shape[0] - self.test_data_conformers.append(temp_conf_holdout) - - if model == "PURE_ML": - for i, record in enumerate(data_temp): - temp = {} - temp["name"] = record["name"] - temp["atomic_numbers"] = record["atomic_numbers"].reshape(-1, 1) - - temp_conf_holdout = {} - temp_conf_holdout["name"] = record["name"] - temp_conf_holdout["atomic_numbers"] = record["atomic_numbers"].reshape( - -1, 1 - ) - - if i in hold_out: - temp["energy"] = ( - np.vstack( - ( - np.array(record["ML_emin_ML_energy"].m).reshape(-1, 1), - record["ML_300_ML_energy"].m.reshape(-1, 1), - record["ML_100_ML_energy"].m.reshape(-1, 1), - ) - ) - * record["ML_emin_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["ML_emin_coords"].m.reshape(1, -1, 3), - record["ML_coords_300"].m, - record["ML_coords_100"].m, - ) - ) - * record["ML_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].shape[0] - self.test_data_molecules.append(temp) - else: - temp["energy"] = ( - np.vstack( - ( - np.array(record["ML_emin_ML_energy"].m).reshape(-1, 1), - record["ML_300_ML_energy"][0:9].m.reshape(-1, 1), - record["ML_100_ML_energy"][0:9].m.reshape(-1, 1), - ) - ) - * record["ML_emin_ML_energy"].u - ) - temp["geometry"] = ( - np.vstack( - ( - record["ML_emin_coords"].m.reshape(1, -1, 3), - record["ML_coords_300"][0:9].m, - record["ML_coords_100"][0:9].m, - ) - ) - * record["ML_emin_coords"].u - ) - temp["n_configs"] = temp["geometry"].m.shape[0] - self.data.append(temp) - - temp_conf_holdout["energy"] = ( - np.vstack( - ( - record["ML_300_ML_energy"][9:10].m.reshape(-1, 1), - record["ML_100_ML_energy"][9:10].m.reshape(-1, 1), - ) - ) - * record["ML_300_ML_energy"].u - ) - temp_conf_holdout["geometry"] = ( - np.vstack( - ( - record["ML_coords_300"][9:10].m, - record["ML_coords_100"][9:10].m, - ) - ) - * record["ML_coords_300"].u - ) - temp_conf_holdout["n_configs"] = temp_conf_holdout[ - "geometry" - ].shape[0] - self.test_data_conformers.append(temp_conf_holdout) - - def _generate_hdf5_file(self, data, output_file_path, filename): - full_file_path = f"{output_file_path}/{filename}" - logger.debug("Writing data HDF5 file.") - import os - - os.makedirs(output_file_path, exist_ok=True) - - dict_to_hdf5( - full_file_path, - data, - series_info=self._record_entries_series, - id_key="name", - ) - - def process( - self, - # input_data_path="./", - # input_data_file="molecule_data.hdf5", - force_download=False, - data_combination="PURE_MM", - ) -> None: - """ - Process the dataset into a curated hdf5 file. - - Parameters - ---------- - force_download : Optional[bool], optional - Force download of the dataset, by default False - data_combination : str, optional - The type of data combination to use, by default "pure_MM" - Options, PURE_MM_low_temp_correction, PURE_MM, PURE_ML - - - """ - from modelforge.utils.remote import download_from_url - - # download the data - url = "https://www.dropbox.com/scl/fi/c23o54ckovnz6umd3why2/molecule_data.hdf5?rlkey=384kd8zo9w1iv34lzp3c2y3n3&dl=1" - checksum = "77a76f7005249aebe61b57a560a818f4" - - download_from_url( - url, - md5_checksum=checksum, - output_path=self.local_cache_dir, - output_filename="molecule_data.hdf5", - force_download=force_download, - ) - self.data_combination = data_combination - self._clear_data() - self._process_downloaded( - self.local_cache_dir, "molecule_data.hdf5", self.data_combination - ) - if self.convert_units: - self._convert_units() - - # for datapoint in self.data: - # print(datapoint["name"]) - - self._generate_hdf5_file(self.data, self.output_file_dir, self.hdf5_file_name) - - fileout = self.hdf5_file_name.replace(".hdf5", "_test_conformers.hdf5") - self._generate_hdf5_file( - self.test_data_conformers, self.output_file_dir, fileout - ) - fileout = self.hdf5_file_name.replace(".hdf5", "_test_molecules.hdf5") - self._generate_hdf5_file( - self.test_data_molecules, self.output_file_dir, fileout - ) diff --git a/modelforge/curation/qm9_curation.py b/modelforge/curation/qm9_curation.py index 7ab611c7..e450c3af 100644 --- a/modelforge/curation/qm9_curation.py +++ b/modelforge/curation/qm9_curation.py @@ -1,5 +1,5 @@ from modelforge.curation.curation_baseclass import DatasetCuration -from modelforge.utils.units import * +from modelforge.utils.units import chem_context import numpy as np from typing import Optional, List diff --git a/modelforge/curation/spice_114_curation.py b/modelforge/curation/spice_114_curation.py index 890df121..da9e76f0 100644 --- a/modelforge/curation/spice_114_curation.py +++ b/modelforge/curation/spice_114_curation.py @@ -1,5 +1,4 @@ from modelforge.curation.curation_baseclass import DatasetCuration -from modelforge.utils.units import * from typing import Optional from loguru import logger diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 6ed3813e..364ec9c6 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -90,11 +90,6 @@ def __init__( dataset[property_name.R].shape ) - print("Z", self.properties_of_interest["atomic_numbers"].shape) - print("R", self.properties_of_interest["positions"].shape) - print("E", self.properties_of_interest["E"].shape) - print("Q", self.properties_of_interest["Q"].shape) - print("F", self.properties_of_interest["F"].shape) self.number_of_records = len(dataset["atomic_subsystem_counts"]) self.number_of_atoms = len(dataset["atomic_numbers"]) single_atom_start_idxs_by_rec = np.concatenate( @@ -869,6 +864,7 @@ def prepare_data( # remove self energies self.subtract_self_energies(torch_dataset, self_energies) + # write the self energies that are removed from the dataset to disk import toml diff --git a/modelforge/dataset/model_dataset.py b/modelforge/dataset/model_dataset.py deleted file mode 100644 index 800c3dd2..00000000 --- a/modelforge/dataset/model_dataset.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import List - -from .dataset import HDF5Dataset - - -class ModelDataset(HDF5Dataset): - """ - Data class for handling the model data generated for the AlkEthOH dataset. - - Attributes - ---------- - dataset_name : str - Name of the dataset, default is "ANI2x". - for_unit_testing : bool - If set to True, a subset of the dataset is used for unit testing purposes; by default False. - local_cache_dir: str, optional - Path to the local cache directory, by default ".". - Examples - -------- - - """ - - from modelforge.utils import PropertyNames - - _property_names = PropertyNames(Z="atomic_numbers", R="geometry", E="energy") - - _available_properties = [ - "geometry", - "atomic_numbers", - "energy", - ] # All properties within the datafile, aside from SMILES/inchi. - - def __init__( - self, - dataset_name: str = "ModelDataset", - # for_unit_testing: bool = False, - data_combination: str = "PURE_MM", - local_cache_dir: str = ".", - force_download: bool = False, - regenerate_cache: bool = False, - ) -> None: - """ - Initialize the ANI2xDataset class. - - Parameters - ---------- - data_name : str, optional - Name of the dataset, by default "ANI2x". - data_combination : str, optional - The type of data combination to use, by default "PURE_MM" - Options, MM_low_temp_correction, PURE_MM, PURE_ML - local_cache_dir: str, optional - Path to the local cache directory, by default ".". - force_download: bool, optional - If set to True, we will download the dataset even if it already exists; by default False. - regenerate_cache: bool, optional - If set to True, we will regenerate the npz cache file even if it already exists, using - the data from the hdf5 file; by default False. - Examples - -------- - >>> data = ModelDataset() # Default dataset - >>> test_data = ModelDataset() - """ - - _default_properties_of_interest = [ - "geometry", - "atomic_numbers", - "energy", - ] # NOTE: Default values - - self._properties_of_interest = _default_properties_of_interest - self.dataset_name = f"{dataset_name}_{data_combination}" - - self.data_combination = data_combination - from openff.units import unit - - # these come from the ANI-2x paper generated via linear fittingh of the data - # https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat - self._ase = { - "H": -0.5978583943827134 * unit.hartree, - "C": -38.08933878049795 * unit.hartree, - "N": -54.711968298621066 * unit.hartree, - "O": -75.19106774742086 * unit.hartree, - "S": -398.1577125334925 * unit.hartree, - "F": -99.80348506781634 * unit.hartree, - "Cl": -460.1681939421027 * unit.hartree, - } - from loguru import logger - - # We need to define the checksums for the various files that we will be dealing with to load up the data - # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. - - # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.PURE_MM_url = "https://www.dropbox.com/scl/fi/pq6d2px51o29pegi19z7m/PURE_MM.hdf5.gz?rlkey=9tjbdsvthj9f5zfar4zfb9joo&dl=1" - self.PURE_ML_url = "https://www.dropbox.com/scl/fi/6mf8recfxd10zf1za9xjq/PURE_ML.hdf5.gz?rlkey=2xvvrcd2nbeiw7ma70hq4nui4&dl=1" - self.MM_low_temp_correction_url = "https://www.dropbox.com/scl/fi/h7xowf0v63yszfstsftpc/MM_low_e_correction.hdf5.gz?rlkey=c8u5q212lv2ikre6pukzdakzp&dl=1" - - if self.data_combination == "PURE_MM": - url = self.PURE_MM_url - gz_data_file = { - "name": "PURE_MM_dataset.hdf5.gz", - "md5": "869441523f826fcc4af7e1ecaca13772", - } - hdf5_data_file = { - "name": "PURE_MM_dataset.hdf5", - "md5": "3921bd738d963cc5d26d581faa9bbd36", - } - processed_data_file = {"name": "PURE_MM_dataset_processed.npz", "md5": None} - - logger.info("Using PURE MM dataset") - - elif self.data_combination == "PURE_ML": - url = self.PURE_ML_url - gz_data_file = { - "name": "PURE_ML_dataset.hdf5.gz", - "md5": "ff0ab16f4503e2537ed4bb10a0a6f465", - } - - hdf5_data_file = { - "name": "PURE_ML_dataset.hdf5", - "md5": "a968d6ee74a0dbcede25c98aaa7a33e7", - } - - processed_data_file = { - "name": "PURE_ML_dataset_processed.npz", - "md5": None, - } - - logger.info("Using PURE ML dataset") - elif self.data_combination == "MM_low_temp_correction": - url = self.MM_low_temp_correction_url - gz_data_file = { - "name": "MM_LTC_dataset.hdf5.gz", - "md5": "0c7dbc7636afe845f128c57dbc99f581", - } - - hdf5_data_file = { - "name": "MM_LTC_dataset.hdf5", - "md5": "fb448ea4eaaafaadcce62a2123cb8c1f", - } - - processed_data_file = { - "name": "MM_LTC_dataset_processed.npz", - "md5": None, - } - - logger.info("Using MM low temperature correction dataset") - - # to ensure that that we are consistent in our naming, we need to set all the names and checksums in the HDF5Dataset class constructor - super().__init__( - url=url, - gz_data_file=gz_data_file, - hdf5_data_file=hdf5_data_file, - processed_data_file=processed_data_file, - local_cache_dir=local_cache_dir, - force_download=force_download, - regenerate_cache=regenerate_cache, - ) - - @property - def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies - - return AtomicSelfEnergies(energies=self._ase) - - @property - def properties_of_interest(self) -> List[str]: - """ - Getter for the properties of interest. - The order of this list determines also the order provided in the __getitem__ call - from the PytorchDataset. - - Returns - ------- - List[str] - List of properties of interest. - - """ - return self._properties_of_interest - - @property - def available_properties(self) -> List[str]: - """ - List of available properties in the dataset. - - Returns - ------- - List[str] - List of available properties in the dataset. - - Examples - -------- - - """ - return self._available_properties - - @properties_of_interest.setter - def properties_of_interest(self, properties_of_interest: List[str]) -> None: - """ - Setter for the properties of interest. - The order of this list determines also the order provided in the __getitem__ call - from the PytorchDataset - - Parameters - ---------- - properties_of_interest : List[str] - List of properties of interest. - - Examples - -------- - - """ - if not set(properties_of_interest).issubset(self._available_properties): - raise ValueError( - f"Properties of interest must be a subset of {self._available_properties}" - ) - self._properties_of_interest = properties_of_interest - - def _download(self) -> None: - """ - Download the hdf5 file containing the data from Dropbox. - - Examples - -------- - - - """ - # Right now this function needs to be defined for each dataset. - # once all datasets are moved to zenodo, we should only need a single function defined in the base class - from modelforge.utils.remote import download_from_url - - download_from_url( - url=self.url, - md5_checksum=self.gz_data_file["md5"], - output_path=self.local_cache_dir, - output_filename=self.gz_data_file["name"], - force_download=self.force_download, - ) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 9d618987..223d52b3 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Tuple, Type, Mapping, Union import lightning as pl import torch @@ -543,6 +543,25 @@ def _forward( """ pass + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + # Prefix to remove + prefix = "model." + + # check if prefix is present + if any(key.startswith(prefix) for key in state_dict.keys()): + # Create a new dictionary without the prefix in the keys if prefix exists + new_d = { + key[len(prefix) :] if key.startswith(prefix) else key: value + for key, value in state_dict.items() + } + log.debug(f"Removed prefix: {prefix}") + else: + log.debug("No prefix found. No modifications to keys in state loading.") + + super().load_state_dict(new_d, strict=strict, assign=assign) + def load_pretrained_weights(self, path: str): """ Loads pretrained weights into the model from the specified path. diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 7d2e5b6a..b6139eb3 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -5,8 +5,6 @@ import torch import torch.nn as nn from loguru import logger as log -from modelforge.utils.units import * - @dataclass @@ -463,6 +461,8 @@ class AtomicSelfEnergies: _ase_tensor_for_indexing = None def __getitem__(self, key): + from modelforge.utils.units import chem_context + if isinstance(key, int): # Convert atomic number to element symbol element = self.atomic_number_to_element.get(key) @@ -486,6 +486,8 @@ def __getitem__(self, key): def __iter__(self) -> Iterator[Dict[str, float]]: """Iterate over the energies dictionary.""" + from modelforge.utils.units import chem_context + for element, energy in self.energies.items(): atomic_number = self.element_to_atomic_number(element) yield (atomic_number, energy.to(unit.kilojoule_per_mole, "chem").m) diff --git a/modelforge/tests/model_datasets.py b/modelforge/tests/model_datasets.py deleted file mode 100644 index ce6d0665..00000000 --- a/modelforge/tests/model_datasets.py +++ /dev/null @@ -1,38 +0,0 @@ -from modelforge.curation.model_dataset import ModelDataset - -dataset = ModelDataset( - hdf5_file_name="PURE_MM.hdf5", - output_file_dir="/Users/cri/Dropbox/data_experiment/", - local_cache_dir="/Users/cri/Dropbox/data_experiment/", - convert_units=True, -) -dataset.process( - input_data_path="/Users/cri/Dropbox/data_experiment/", - input_data_file="molecule_data.hdf5", - data_combination="PURE_MM", -) - -dataset = ModelDataset( - hdf5_file_name="PURE_ML.hdf5", - output_file_dir="/Users/cri/Dropbox/data_experiment/", - local_cache_dir="/Users/cri/Dropbox/data_experiment/", - convert_units=True, -) -dataset.process( - input_data_path="/Users/cri/Dropbox/data_experiment/", - input_data_file="molecule_data.hdf5", - data_combination="PURE_ML", -) - - -dataset = ModelDataset( - hdf5_file_name="MM_low_e_correction.hdf5", - output_file_dir="/Users/cri/Dropbox/data_experiment/", - local_cache_dir="/Users/cri/Dropbox/data_experiment/", - convert_units=True, -) -dataset.process( - input_data_path="/Users/cri/Dropbox/data_experiment/", - input_data_file="molecule_data.hdf5", - data_combination="PURE_MM_low_temp_correction", -) diff --git a/modelforge/utils/units.py b/modelforge/utils/units.py index 9650e1b4..a93a992d 100644 --- a/modelforge/utils/units.py +++ b/modelforge/utils/units.py @@ -2,42 +2,42 @@ # define new context for converting energy (e.g., hartree) # to energy/mol (e.g., kJ/mol) - -c = unit.Context("chem") -c.add_transformation( +__all__ = ["chem_context"] +chem_context = unit.Context("chem") +chem_context.add_transformation( "[force] * [length]", "[force] * [length]/[substance]", lambda unit, x: x * unit.avogadro_constant, ) -c.add_transformation( +chem_context.add_transformation( "[force] * [length]/[substance]", "[force] * [length]", lambda unit, x: x / unit.avogadro_constant, ) -c.add_transformation( +chem_context.add_transformation( "[force] * [length]/[length]", "[force] * [length]/[substance]/[length]", lambda unit, x: x * unit.avogadro_constant, ) -c.add_transformation( +chem_context.add_transformation( "[force] * [length]/[substance]/[length]", "[force] * [length]/[length]", lambda unit, x: x / unit.avogadro_constant, ) -c.add_transformation( +chem_context.add_transformation( "[force] * [length]/[length]/[length]", "[force] * [length]/[substance]/[length]/[length]", lambda unit, x: x * unit.avogadro_constant, ) -c.add_transformation( +chem_context.add_transformation( "[force] * [length]/[substance]/[length]/[length]", "[force] * [length]/[length]/[length]", lambda unit, x: x / unit.avogadro_constant, ) -unit.add_context(c) +unit.add_context(chem_context) def print_modelforge_unit_system(): From de53603407a734cd8ec6620626f35676cf6d553c Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Mon, 22 Apr 2024 14:49:58 -0700 Subject: [PATCH 21/37] Accidentically deleted some imports. Readded those in. --- modelforge/curation/ani1x_curation.py | 1 + modelforge/curation/ani2x_curation.py | 1 + modelforge/curation/curation_baseclass.py | 1 + modelforge/curation/qm9_curation.py | 1 + modelforge/curation/spice_114_curation.py | 1 + modelforge/curation/spice_2_curation.py | 5 ++++- modelforge/curation/spice_openff_curation.py | 1 + 7 files changed, 10 insertions(+), 1 deletion(-) diff --git a/modelforge/curation/ani1x_curation.py b/modelforge/curation/ani1x_curation.py index 4b8975ce..2963c3ff 100644 --- a/modelforge/curation/ani1x_curation.py +++ b/modelforge/curation/ani1x_curation.py @@ -1,6 +1,7 @@ from modelforge.curation.curation_baseclass import DatasetCuration from typing import Optional from loguru import logger +from openff.units import unit class ANI1xCuration(DatasetCuration): diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index 81edc076..ea7c44da 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -1,6 +1,7 @@ from modelforge.curation.curation_baseclass import DatasetCuration from typing import Optional from loguru import logger +from openff.units import unit class ANI2xCuration(DatasetCuration): diff --git a/modelforge/curation/curation_baseclass.py b/modelforge/curation/curation_baseclass.py index 58ac54e5..2b5976fe 100644 --- a/modelforge/curation/curation_baseclass.py +++ b/modelforge/curation/curation_baseclass.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional from loguru import logger +from openff.units import unit def dict_to_hdf5( diff --git a/modelforge/curation/qm9_curation.py b/modelforge/curation/qm9_curation.py index e450c3af..38981744 100644 --- a/modelforge/curation/qm9_curation.py +++ b/modelforge/curation/qm9_curation.py @@ -4,6 +4,7 @@ from typing import Optional, List from loguru import logger +from openff.units import unit class QM9Curation(DatasetCuration): diff --git a/modelforge/curation/spice_114_curation.py b/modelforge/curation/spice_114_curation.py index da9e76f0..cc474290 100644 --- a/modelforge/curation/spice_114_curation.py +++ b/modelforge/curation/spice_114_curation.py @@ -1,6 +1,7 @@ from modelforge.curation.curation_baseclass import DatasetCuration from typing import Optional from loguru import logger +from openff.units import unit class SPICE114Curation(DatasetCuration): diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index f4860718..3250d944 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -1,8 +1,11 @@ from typing import List, Optional, Dict, Tuple -from modelforge.curation.curation_baseclass import * +from modelforge.curation.curation_baseclass import DatasetCuration from retry import retry from tqdm import tqdm +from openff.units import unit + +from loguru import logger class SPICE2Curation(DatasetCuration): diff --git a/modelforge/curation/spice_openff_curation.py b/modelforge/curation/spice_openff_curation.py index 30bf2298..432f3fe6 100644 --- a/modelforge/curation/spice_openff_curation.py +++ b/modelforge/curation/spice_openff_curation.py @@ -3,6 +3,7 @@ from modelforge.curation.curation_baseclass import * from retry import retry from tqdm import tqdm +from openff.units import unit class SPICEOpenFFCuration(DatasetCuration): From aed99d756bcf0f4034ec332f256626dfddcd520f Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 24 Apr 2024 19:23:50 -0700 Subject: [PATCH 22/37] removed euler angles; put in quaterions for rotational invariance tests. --- modelforge/tests/conftest.py | 163 ++++++++++++++++++++++++++++------- 1 file changed, 133 insertions(+), 30 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 1f4d8f18..4b9dacff 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -183,6 +183,128 @@ def methane() -> BatchData: import math +def generate_uniform_quaternion(u=None): + """ + Generates a uniform normalized quaternion. + + Adapted from numpy implementation in openmm-tools + https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py + + Parameters + ---------- + u : torch.Tensor + Tensor of shape (3,). Optional, default is None. + If not provided, a random tensor is generated. + + References + ---------- + [1] K. Shoemake. Uniform random rotations. In D. Kirk, editor, + Graphics Gems III, pages 124-132. Academic, New York, 1992. + [2] Described briefly here: http://planning.cs.uiuc.edu/node198.html + """ + import torch + + if u is None: + u = torch.rand(3) + # import numpy for pi + import numpy as np + + q = torch.tensor( + [ + torch.sqrt(1 - u[0]) * torch.sin(2 * np.pi * u[1]), + torch.sqrt(1 - u[0]) * torch.cos(2 * np.pi * u[1]), + torch.sqrt(u[0]) * torch.sin(2 * np.pi * u[2]), + torch.sqrt(u[0]) * torch.cos(2 * np.pi * u[2]), + ] + ) + return q + + +def rotation_matrix_from_quaternion(quaternion): + """Compute a 3x3 rotation matrix from a given quaternion (4-vector). + + Adapted from the numpy implementation in openmm-tools + + https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py + + Parameters + ---------- + q : torch.Tensor + Quaternion tensor of shape (4,). + + Returns + ------- + torch.Tensor + Rotation matrix tensor of shape (3, 3). + + References + ---------- + [1] http://en.wikipedia.org/wiki/Rotation_matrix#Quaternion + """ + + w, x, y, z = quaternion.unbind() + Nq = (quaternion**2).sum() # Squared norm. + if Nq > 0.0: + s = 2.0 / Nq + else: + s = 0.0 + + X = x * s + Y = y * s + Z = z * s + wX = w * X + wY = w * Y + wZ = w * Z + xX = x * X + xY = x * Y + xZ = x * Z + yY = y * Y + yZ = y * Z + zZ = z * Z + + rotation_matrix = torch.tensor( + [ + [1.0 - (yY + zZ), xY - wZ, xZ + wY], + [xY + wZ, 1.0 - (xX + zZ), yZ - wX], + [xZ - wY, yZ + wX, 1.0 - (xX + yY)], + ] + ) + return rotation_matrix + + +def apply_rotation_matrix(coordinates, rotation_matrix, use_center_of_mass=True): + """ + Rotate the coordinates using the rotation matrix. + + Parameters + ---------- + coordinates : torch.Tensor + The coordinates to rotate. + rotation_matrix : torch.Tensor + The rotation matrix. + use_center_of_mass : bool + If True, the coordinates are rotated around the center of mass, not the origin. + + Returns + ------- + torch.Tensor + The rotated coordinates. + """ + + if use_center_of_mass: + coordinates_com = torch.mean(coordinates, 0) + else: + coordinates_com = torch.zeros(3) + + coordinates_proposed = ( + torch.matmul( + rotation_matrix, (coordinates - coordinates_com).transpose(0, -1) + ).transpose(0, -1) + ) + coordinates_com + + return coordinates_proposed + + def equivariance_test_utils(): """ Generates random tensors for testing equivariance of a neural network. @@ -198,42 +320,23 @@ def equivariance_test_utils(): """ # Define translation function - #torch.manual_seed(12345) + # CRI: Let us manually seed the random number generator to ensure that we perfrom the same tests each time. + # While our tests of translation and rotation should ALWAYS pass regardless of the seed, + # if the code is correctly implemented, there may be instances where the tolerance we set is not + # sufficient to pass the test, and without the workflow being deterministic, it may be hard to + # debug if it is an underlying issue with the code or just the tolerance. + + torch.manual_seed(12345) x_translation = torch.randn( size=(1, 3), ) translation = lambda x: x + x_translation - # Define rotation function - alpha = torch.distributions.Uniform(-math.pi, math.pi).sample() - beta = torch.distributions.Uniform(-math.pi, math.pi).sample() - gamma = torch.distributions.Uniform(-math.pi, math.pi).sample() - - rz = torch.tensor( - [ - [math.cos(alpha), -math.sin(alpha), 0], - [math.sin(alpha), math.cos(alpha), 0], - [0, 0, 1], - ] - ) - - ry = torch.tensor( - [ - [math.cos(beta), 0, math.sin(beta)], - [0, 1, 0], - [-math.sin(beta), 0, math.cos(beta)], - ] - ) - - rx = torch.tensor( - [ - [1, 0, 0], - [0, math.cos(gamma), -math.sin(gamma)], - [0, math.sin(gamma), math.cos(gamma)], - ] - ) + # generate random quaternion and rotation matrix + q = generate_uniform_quaternion() + rotation_matrix = rotation_matrix_from_quaternion(q) - rotation = lambda x: x @ rz @ ry @ rx + rotation = lambda x: apply_rotation_matrix(x, rotation_matrix) # Define reflection function alpha = torch.distributions.Uniform(-math.pi, math.pi).sample() From 51f5148bfdb9dab282cd867443face3574788e3d Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 25 Apr 2024 12:11:16 -0700 Subject: [PATCH 23/37] Adding in skipping of training testing on Mac OS do to change of runners to Apple Silicon --- modelforge/tests/test_training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index f187c594..3920668c 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -1,8 +1,10 @@ from typing import Type import pytest +import sys +@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS") def test_train_with_lightning(train_model, initialized_dataset): """ Test the forward pass for a given model and dataset. @@ -31,6 +33,7 @@ def test_train_with_lightning(train_model, initialized_dataset): ) +@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS") def test_hypterparameter_tuning_with_ray(train_model, initialized_dataset): train_model.tune_with_ray( From fb15dc92fa7430d208d7dab9a498b856882e17a2 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Tue, 30 Apr 2024 00:23:17 -0700 Subject: [PATCH 24/37] Modifying curation scripts to allow a fixed number of total conformers (e.g., limiting to max of 10 per record, for a total of 1000, for unit testing). --- modelforge/curation/ani1x_curation.py | 71 +++++++-- modelforge/curation/ani2x_curation.py | 80 ++++++++-- modelforge/curation/curation_baseclass.py | 44 +++++- modelforge/curation/qm9_curation.py | 72 +++++++-- modelforge/curation/spice_114_curation.py | 69 +++++++-- modelforge/curation/spice_openff_curation.py | 80 ++++++++-- modelforge/dataset/__init__.py | 9 +- modelforge/dataset/dataset.py | 2 +- modelforge/tests/conftest.py | 151 +++++++++++++++++-- modelforge/tests/test_curation.py | 66 ++++---- modelforge/tests/test_dataset.py | 98 ++++++------ scripts/dataset_curation.py | 143 +++++++++++++----- 12 files changed, 688 insertions(+), 197 deletions(-) diff --git a/modelforge/curation/ani1x_curation.py b/modelforge/curation/ani1x_curation.py index 2963c3ff..d62c57d2 100644 --- a/modelforge/curation/ani1x_curation.py +++ b/modelforge/curation/ani1x_curation.py @@ -223,7 +223,9 @@ def _process_downloaded( self, local_path_dir: str, name: str, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -234,8 +236,15 @@ def _process_downloaded( Path to the directory that contains the raw hdf5 datafile name: str, required Name of the raw hdf5 file, - unit_testing_max_records: int, optional, default=None - If set to an integer ('n') the routine will only process the first 'n' records; useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. Examples -------- @@ -269,14 +278,32 @@ def _process_downloaded( } with h5py.File(input_file_name, "r") as hf: names = list(hf.keys()) - if unit_testing_max_records is None: + if max_records is None: n_max = len(names) - else: - n_max = unit_testing_max_records + elif max_records is not None: + n_max = max_records + + conformers_counter = 0 for i, name in tqdm(enumerate(names[0:n_max]), total=n_max): + if total_conformers is not None: + if conformers_counter >= total_conformers: + break + # Extract the total number of configurations for a given molecule - n_configs = hf[name]["coordinates"].shape[0] + + if max_conformers_per_record is not None: + conformers_per_molecule = min( + hf[name]["coordinates"].shape[0], max_conformers_per_record + ) + else: + conformers_per_molecule = hf[name]["coordinates"].shape[0] + + if total_conformers is not None: + if conformers_counter + conformers_per_molecule > total_conformers: + conformers_per_molecule = total_conformers - conformers_counter + + n_configs = conformers_per_molecule keys_list = list(hf[name].keys()) @@ -300,6 +327,8 @@ def _process_downloaded( if param_in in add_new_axis: temp = temp[..., newaxis] + temp = temp[0:conformers_per_molecule] + param_unit = param_data["u_in"] if param_unit is not None: ani1x_temp[param_out] = temp * param_unit @@ -307,6 +336,8 @@ def _process_downloaded( ani1x_temp[param_out] = temp self.data.append(ani1x_temp) + conformers_counter += conformers_per_molecule + if self.convert_units: self._convert_units() # From documentation: By default, objects inside group are iterated in alphanumeric order. @@ -318,7 +349,9 @@ def _process_downloaded( def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -328,8 +361,15 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. Examples -------- @@ -338,6 +378,11 @@ def process( >>> ani1_data.process() """ + if max_records is not None and total_conformers is not None: + raise ValueError( + "max_records and total_conformers cannot be set at the same time." + ) + from modelforge.utils.remote import download_from_figshare url = self.dataset_download_url @@ -356,7 +401,11 @@ def process( if self.name is None: raise Exception("Failed to retrieve name of file from figshare.") self._process_downloaded( - self.local_cache_dir, self.name, unit_testing_max_records + self.local_cache_dir, + self.name, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, ) self._generate_hdf5() diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index ea7c44da..a7720e33 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -100,7 +100,9 @@ def _process_downloaded( self, local_path_dir: str, name: str, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -111,8 +113,16 @@ def _process_downloaded( Path to the directory that contains the raw hdf5 datafile name: str, required Name of the raw hdf5 file, - unit_testing_max_records: int, optional, default=None - If set to an integer ('n') the routine will only process the first 'n' records; useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records or total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + Examples -------- @@ -122,6 +132,9 @@ def _process_downloaded( input_file_name = f"{local_path_dir}/{name}" logger.debug(f"Processing {input_file_name}.") + + conformers_counter = 0 + with h5py.File(input_file_name, "r") as hf: # The ani2x hdf5 file groups molecules by number of atoms # we need to break up each of these groups into individual molecules @@ -145,16 +158,24 @@ def _process_downloaded( unique_molecules = np.unique(species, axis=0) - if unit_testing_max_records is None: + if max_records is None: n_max = unique_molecules.shape[0] else: - n_max = min(unit_testing_max_records, unique_molecules.shape[0]) - unit_testing_max_records -= n_max + n_max = min(max_records, unique_molecules.shape[0]) + max_records -= n_max + if n_max == 0: break + for i, molecule in tqdm( enumerate(unique_molecules[0:n_max]), total=n_max ): + # stop processing if we have reached the total number of conformers + + if total_conformers is not None: + if conformers_counter >= total_conformers: + break + ds_temp = {} # molecule represents an aray of atomic species, e.g., [ 8, 8 ] is O_2 # here we will create an array of shape( num_confomer, num_atoms) of bools @@ -174,27 +195,41 @@ def _process_downloaded( ds_temp["name"] = molecule_as_string ds_temp["atomic_numbers"] = molecule.reshape(-1, 1) - ds_temp["n_configs"] = int(np.sum(mask)) + conformers_per_molecule = int(np.sum(mask)) + if max_conformers_per_record is not None: + conformers_per_molecule = min( + conformers_per_molecule, max_conformers_per_record + ) + if total_conformers is not None: + conformers_per_molecule = min( + conformers_per_molecule, + total_conformers - conformers_counter, + ) + ds_temp["n_configs"] = conformers_per_molecule ds_temp["geometry"] = ( coordinates[mask] * self.qm_parameters["geometry"]["u_in"] - ) + )[0:conformers_per_molecule] ds_temp["energies"] = ( energies[mask].reshape(-1, 1) * self.qm_parameters["energies"]["u_in"] - ) + )[0:conformers_per_molecule] ds_temp["forces"] = ( forces[mask] * self.qm_parameters["forces"]["u_in"] - ) + )[0:conformers_per_molecule] self.data.append(ds_temp) + conformers_counter += conformers_per_molecule + if self.convert_units: self._convert_units() def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -204,8 +239,16 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records or total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + Examples -------- @@ -214,6 +257,11 @@ def process( >>> ani2_data.process() """ + if max_records is not None and total_conformers is not None: + raise ValueError( + "max_records and total_conformers cannot be set at the same time." + ) + from modelforge.utils.remote import download_from_zenodo url = self.dataset_download_url @@ -247,7 +295,11 @@ def process( # process the rest of the dataset self._process_downloaded( - f"{self.local_cache_dir}/final_h5/", hdf5_filename, unit_testing_max_records + f"{self.local_cache_dir}/final_h5/", + hdf5_filename, + max_records, + max_conformers_per_record, + total_conformers, ) self._generate_hdf5() diff --git a/modelforge/curation/curation_baseclass.py b/modelforge/curation/curation_baseclass.py index 2b5976fe..3395769a 100644 --- a/modelforge/curation/curation_baseclass.py +++ b/modelforge/curation/curation_baseclass.py @@ -167,11 +167,42 @@ def _convert_units(self): f"could not convert {key} with unit {val.u} to {self.qm_parameters[key]['u_out']}" ) + @property + def total_conformers(self) -> int: + """ + Returns the total number of conformers in the dataset. + + Returns + ------- + int + Total number of conformers in the dataset. + + """ + total = 0 + for record in self.data: + total += record["n_configs"] + return total + + @property + def total_records(self) -> int: + """ + Returns the total number of records in the dataset. + + Returns + ------- + int + Total number of records in the dataset. + + """ + return len(self.data) + @abstractmethod def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -181,8 +212,15 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records or total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. Examples -------- diff --git a/modelforge/curation/qm9_curation.py b/modelforge/curation/qm9_curation.py index 38981744..e3a38e62 100644 --- a/modelforge/curation/qm9_curation.py +++ b/modelforge/curation/qm9_curation.py @@ -502,7 +502,9 @@ def _parse_xyzfile(self, file_name: str) -> dict: def _process_downloaded( self, local_path_dir: str, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information into a list of dicts. @@ -511,8 +513,16 @@ def _process_downloaded( ---------- local_path_dir: str, required Path to the directory that contains the tar.bz2 file. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with umax_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. + Examples -------- @@ -523,10 +533,33 @@ def _process_downloaded( # list the files in the directory to examine files = list_files(directory=local_path_dir, extension=".xyz") - if unit_testing_max_records is None: + + # qm9 only has a single conformer in it, so unit_test_max_records and unit_testing_max_conformers_per_record behave the same way + + if max_records is None and total_conformers is None: n_max = len(files) - else: - n_max = unit_testing_max_records + elif max_records is not None and total_conformers is None: + if max_records > len(files): + n_max = len(files) + logger.warning( + f"max_records ({max_records})is greater than the number of records in the dataset {len(files)}. Using {len(files)}." + ) + else: + n_max = max_records + elif max_records is None and total_conformers is not None: + if total_conformers > len(files): + n_max = len(files) + logger.warning( + f"total_conformers ({total_conformers}) is greater than the number of records in the dataset {len(files)}. Using {len(files)}." + ) + else: + n_max = total_conformers + + # we do not need to do anything check unit_testing_max_conformers_per_record because qm9 only has a single conformer per record + if max_conformers_per_record is not None: + logger.warning( + "max_conformers_per_record is not used for QM9 dataset as there is only one conformer per record. Using a value of 1" + ) for i, file in enumerate( tqdm(files[0:n_max], desc="processing", total=len(files)) @@ -546,7 +579,9 @@ def _process_downloaded( def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -556,8 +591,18 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. + + Note for qm9, only a single conformer is present per record, so max_records and total_conformers behave the same way, + and max_conformers_per_record does not alter the behavior (i.e., it is always 1). Examples -------- @@ -565,6 +610,11 @@ def process( >>> qm9_data.process() """ + if max_records is not None and total_conformers is not None: + raise ValueError( + "max_records and total_conformers cannot be set at the same time." + ) + from modelforge.utils.remote import download_from_figshare url = self.dataset_download_url @@ -597,7 +647,9 @@ def process( self._process_downloaded( f"{self.local_cache_dir}/qm9_xyz_files", - unit_testing_max_records, + max_records, + max_conformers_per_record, + total_conformers, ) # generate the hdf5 file diff --git a/modelforge/curation/spice_114_curation.py b/modelforge/curation/spice_114_curation.py index cc474290..28e4c70f 100644 --- a/modelforge/curation/spice_114_curation.py +++ b/modelforge/curation/spice_114_curation.py @@ -176,7 +176,9 @@ def _process_downloaded( self, local_path_dir: str, name: str, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -187,8 +189,15 @@ def _process_downloaded( Path to the directory that contains the raw hdf5 datafile name: str, required Name of the raw hdf5 file, - unit_testing_max_records: int, optional, default=None - If set to an integer ('n') the routine will only process the first 'n' records; useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. Examples -------- @@ -201,14 +210,20 @@ def _process_downloaded( need_to_reshape = {"formation_energy": True, "dft_total_energy": True} with h5py.File(input_file_name, "r") as hf: names = list(hf.keys()) - if unit_testing_max_records is None: + if max_records is None: n_max = len(names) - else: - n_max = unit_testing_max_records + elif max_records is not None: + n_max = max_records + + conformers_counter = 0 for i, name in tqdm(enumerate(names[0:n_max]), total=n_max): + if total_conformers is not None: + if conformers_counter >= total_conformers: + break + # Extract the total number of conformations for a given molecule - n_configs = hf[name]["conformations"].shape[0] + conformers_per_record = hf[name]["conformations"].shape[0] keys_list = list(hf[name].keys()) @@ -220,7 +235,17 @@ def _process_downloaded( ds_temp["atomic_numbers"] = hf[name]["atomic_numbers"][()].reshape( -1, 1 ) - ds_temp["n_configs"] = n_configs + if max_conformers_per_record is not None: + conformers_per_record = min( + conformers_per_record, + max_conformers_per_record, + ) + if total_conformers is not None: + conformers_per_record = min( + conformers_per_record, total_conformers - conformers_counter + ) + + ds_temp["n_configs"] = conformers_per_record # param_in is the name of the entry, param_data contains input (u_in) and output (u_out) units for param_in, param_data in self.qm_parameters.items(): @@ -234,6 +259,7 @@ def _process_downloaded( if param_in in need_to_reshape: temp = temp.reshape(-1, 1) + temp = temp[0:conformers_per_record] param_unit = param_data["u_in"] if param_unit is not None: # check that units in the hdf5 file match those we have defined in self.qm_parameters @@ -256,6 +282,8 @@ def _process_downloaded( ) ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] self.data.append(ds_temp) + conformers_counter += conformers_per_record + if self.convert_units: self._convert_units() @@ -268,7 +296,9 @@ def _process_downloaded( def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -278,8 +308,15 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records or total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. Examples -------- @@ -288,6 +325,10 @@ def process( >>> spice114_data.process() """ + if max_records is not None and total_conformers is not None: + raise ValueError( + "max_records and total_conformers cannot be set at the same time." + ) from modelforge.utils.remote import download_from_zenodo url = self.dataset_download_url @@ -306,7 +347,11 @@ def process( if self.name is None: raise Exception("Failed to retrieve name of file from zenodo.") self._process_downloaded( - self.local_cache_dir, self.name, unit_testing_max_records + self.local_cache_dir, + self.name, + max_records, + max_conformers_per_record, + total_conformers, ) self._generate_hdf5() diff --git a/modelforge/curation/spice_openff_curation.py b/modelforge/curation/spice_openff_curation.py index 432f3fe6..4059531a 100644 --- a/modelforge/curation/spice_openff_curation.py +++ b/modelforge/curation/spice_openff_curation.py @@ -189,7 +189,7 @@ def _fetch_singlepoint_from_qcarchive( local_database_name: str, local_path_dir: str, force_download: bool, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, pbar: Optional[tqdm] = None, ): """ @@ -231,8 +231,8 @@ def _fetch_singlepoint_from_qcarchive( ds.fetch_entry_names() entry_names = ds.entry_names - if unit_testing_max_records is None: - unit_testing_max_records = len(entry_names) + if max_records is None: + max_records = len(entry_names) with SqliteDict( f"{local_path_dir}/{local_database_name}", tablename=specification_name, @@ -243,10 +243,10 @@ def _fetch_singlepoint_from_qcarchive( db_keys = set(spice_db.keys()) to_fetch = [] if force_download: - for name in entry_names[0:unit_testing_max_records]: + for name in entry_names[0:max_records]: to_fetch.append(name) else: - for name in entry_names[0:unit_testing_max_records]: + for name in entry_names[0:max_records]: if name not in db_keys: to_fetch.append(name) if pbar is not None: @@ -439,6 +439,8 @@ def _process_downloaded( local_path_dir: str, filenames: List[str], dataset_names: List[str], + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -687,10 +689,47 @@ def _process_downloaded( if self.convert_units: self._convert_units() + if total_conformers is not None or max_conformers_per_record is not None: + conformers_count = 0 + + for datapoint in self.data: + if total_conformers is not None: + if conformers_count >= total_conformers: + break + n_conformers = datapoint["n_configs"] + if max_conformers_per_record is not None: + n_conformers = min(n_conformers, max_conformers_per_record) + + if total_conformers is not None: + n_conformers = min( + n_conformers, total_conformers - conformers_count + ) + + datapoint["n_configs"] = n_conformers + datapoint["geometry"] = datapoint["geometry"][0:n_conformers] + datapoint["dft_total_energy"] = datapoint["dft_total_energy"][ + 0:n_conformers + ] + datapoint["dft_total_gradient"] = datapoint["dft_total_gradient"][ + 0:n_conformers + ] + datapoint["dft_total_force"] = datapoint["dft_total_force"][ + 0:n_conformers + ] + datapoint["formation_energy"] = datapoint["formation_energy"][ + 0:n_conformers + ] + datapoint["mbis_charges"] = datapoint["mbis_charges"][0:n_conformers] + datapoint["scf_dipole"] = datapoint["scf_dipole"][0:n_conformers] + + conformers_count += n_conformers + def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, n_threads=6, ) -> None: """ @@ -701,11 +740,17 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. - Note, that in SPICE, conformers are stored as separate records, and are combined within this routine. - As such the number of molecules in 'data' may be less than unit_testing_max_records, if the records fetched - are all conformers of the same molecule. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + Note defining this will only fetch from the "SPICE PubChem Set 1 Single Points Dataset v1.2" + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. + Note defining this will only fetch from the "SPICE PubChem Set 1 Single Points Dataset v1.2" n_threads, int, default=6 Number of concurrent threads for retrieving data from QCArchive Examples @@ -715,6 +760,11 @@ def process( >>> spice_openff_data.process() """ + # if max_records is not None and total_conformers is not None: + # raise ValueError( + # "max_records and total_conformers cannot be set at the same time." + # ) + from concurrent.futures import ThreadPoolExecutor, as_completed if self.release_version == "1.1.4": @@ -740,8 +790,8 @@ def process( specification_names = ["spec_2", "spec_6", "entry"] # if we specify the number of records, restrict to only the first subset - # so we don't do this 6 times. - if unit_testing_max_records is not None: + # so we don't download multiple collections. + if max_records is not None or total_conformers is not None: dataset_names = ["SPICE PubChem Set 1 Single Points Dataset v1.2"] threads = [] local_database_names = [] @@ -761,7 +811,7 @@ def process( local_database_name=local_database_name, local_path_dir=self.local_cache_dir, force_download=force_download, - unit_testing_max_records=unit_testing_max_records, + max_records=max_records, pbar=pbar, ) ) @@ -774,6 +824,8 @@ def process( self.local_cache_dir, local_database_names, dataset_names, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, ) self._generate_hdf5() diff --git a/modelforge/dataset/__init__.py b/modelforge/dataset/__init__.py index fd894289..6b476e09 100644 --- a/modelforge/dataset/__init__.py +++ b/modelforge/dataset/__init__.py @@ -2,4 +2,11 @@ from .qm9 import QM9Dataset from .dataset import DatasetFactory, TorchDataModule -_IMPLEMENTED_DATASETS = ["QM9"] +_IMPLEMENTED_DATASETS = [ + "QM9", + "ANI1X", + # "ANI2X", + # "SPICE114", + # "SPICE2", + # "SPICE114_OPENFF", +] diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 364ec9c6..bbe633df 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -581,7 +581,7 @@ def _from_file_cache(self) -> None: f"Loading processed data from {self.local_cache_dir}/{self.processed_data_file['name']} generated on {self._npz_metadata['date_generated']}" ) log.debug( - f"Properties of Interes in .npz file: {self._npz_metadata['data_keys']}" + f"Properties of Interest in .npz file: {self._npz_metadata['data_keys']}" ) self.numpy_data = np.load( f"{self.local_cache_dir}/{self.processed_data_file['name']}" diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 4b9dacff..d3b88aae 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -1,15 +1,23 @@ import torch import pytest from modelforge.dataset import TorchDataModule, _IMPLEMENTED_DATASETS +from modelforge.dataset.dataset import HDF5Dataset + from typing import Optional, Dict from modelforge.potential import NeuralNetworkPotentialFactory, _IMPLEMENTED_NNPS - +from dataclasses import dataclass _DATASETS_TO_TEST = [name for name in _IMPLEMENTED_DATASETS] _MODELS_TO_TEST = [name for name in _IMPLEMENTED_NNPS] from modelforge.potential.utils import BatchData +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("dataset_testing") + return fn + + @pytest.fixture(params=_MODELS_TO_TEST) def train_model(request): model_name = request.param @@ -30,32 +38,147 @@ def inference_model(request): return model +@dataclass +class DataSetContainer: + dataset: HDF5Dataset + name: str + expected_properties_of_interest: list + expected_E_random_split: float + expected_E_fcfs_split: float + + @pytest.fixture(params=_DATASETS_TO_TEST) -def datasets_to_test(request): +def datasets_to_test(request, prep_temp_dir): dataset_name = request.param if dataset_name == "QM9": - from modelforge.dataset import QM9Dataset - - dataset = QM9Dataset(for_unit_testing=True) - return dataset + from modelforge.dataset.qm9 import QM9Dataset + + datasetDC = DataSetContainer( + dataset=QM9Dataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "internal_energy_at_0K", + "charges", + ], + expected_E_random_split=-412509.93109875394, + expected_E_fcfs_split=-106277.4161215308, + ) + return datasetDC + elif dataset_name == "ANI1X": + from modelforge.dataset.ani1x import ANI1xDataset + + datasetDC = DataSetContainer( + dataset=ANI1xDataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "wb97x_dz.energy", + "wb97x_dz.forces", + ], + expected_E_random_split=-1739101.9014184382, + expected_E_fcfs_split=-1015736.8142089575, + ) + return datasetDC + elif dataset_name == "ANI2X": + from modelforge.dataset.ani2x import ANI2xDataset + + datasetDC = DataSetContainer( + dataset=ANI2xDataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "energies", + "forces", + ], + expected_E_random_split=-2614282.09174506, + expected_E_fcfs_split=-2096692.258327173, + ) + return datasetDC + elif dataset_name == "SPICE114": + from modelforge.dataset.spice114 import SPICE114Dataset + + datasetDC = DataSetContainer( + dataset=SPICE114Dataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ], + expected_E_random_split=-4289211.145285763, + expected_E_fcfs_split=-972574.265833225, + ) + return datasetDC + elif dataset_name == "SPICE2": + from modelforge.dataset.spice2 import SPICE2Dataset + + datasetDC = DataSetContainer( + dataset=SPICE2Dataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ], + expected_E_random_split=-2293275.9758066307, + expected_E_fcfs_split=-1517627.6999202403, + ) + return datasetDC + elif dataset_name == "SPICE114_OPENFF": + from modelforge.dataset.spice114openff import SPICE114OpenFFDataset + + datasetDC = DataSetContainer( + dataset=SPICE114OpenFFDataset( + for_unit_testing=True, local_cache_dir=str(prep_temp_dir) + ), + name=dataset_name, + expected_properties_of_interest=[ + "geometry", + "atomic_numbers", + "dft_total_energy", + "dft_total_force", + "mbis_charges", + ], + expected_E_random_split=-2011114.830087605, + expected_E_fcfs_split=-1516718.0904709378, + ) + return datasetDC else: raise NotImplementedError(f"Dataset {dataset_name} is not implemented.") @pytest.fixture(params=_DATASETS_TO_TEST) -def initialized_dataset(request): - dataset_name = request.param - if dataset_name == "QM9": - from modelforge.dataset import QM9Dataset - - dataset = QM9Dataset(for_unit_testing=True) - +def initialized_dataset(datasets_to_test): + # dataset_name = request.param + # if dataset_name == "QM9": + # from modelforge.dataset import QM9Dataset + # + # dataset = QM9Dataset(for_unit_testing=True) + dataset = datasets_to_test.dataset return initialize_dataset(dataset) @pytest.fixture(params=_DATASETS_TO_TEST) def batch(initialized_dataset, request): - """ + """py Fixture to obtain a single batch from an initialized dataset. This fixture depends on the `initialized_dataset` fixture for the dataset instance. diff --git a/modelforge/tests/test_curation.py b/modelforge/tests/test_curation.py index 6a44abc1..bc481ac3 100644 --- a/modelforge/tests/test_curation.py +++ b/modelforge/tests/test_curation.py @@ -326,9 +326,7 @@ def test_qm9_curation_parse_xyz(prep_temp_dir): assert data_dict_temp["energy_of_homo"] == [[-0.3877]] * unit.hartree assert data_dict_temp["energy_of_lumo"] == [[0.1171]] * unit.hartree assert data_dict_temp["lumo-homo_gap"] == [[0.5048]] * unit.hartree - assert ( - data_dict_temp["electronic_spatial_extent"] == [[35.3641]] * unit.angstrom**2 - ) + assert data_dict_temp["electronic_spatial_extent"] == [[35.3641]] * unit.angstrom**2 assert ( data_dict_temp["zero_point_vibrational_energy"] == [[0.044749]] * unit.hartree ) @@ -452,8 +450,24 @@ def test_qm9_local_archive(prep_temp_dir): qm9_data._clear_data() assert len(qm9_data.data) == 0 - qm9_data._process_downloaded(str(prep_temp_dir), unit_testing_max_records=5) + qm9_data._process_downloaded(str(prep_temp_dir), max_records=5) + + assert len(qm9_data.data) == 5 + qm9_data._clear_data() + qm9_data._process_downloaded(str(prep_temp_dir), total_conformers=5) + assert len(qm9_data.data) == 5 + assert qm9_data.total_conformers == 5 + # only one conformer per record so these should be the same + assert qm9_data.total_records == 5 + + qm9_data._clear_data() + qm9_data._process_downloaded( + str(prep_temp_dir), + max_records=2, + total_conformers=5, + ) + assert qm9_data.total_conformers == 5 assert len(qm9_data.data) == 5 @@ -500,9 +514,7 @@ def test_an1_process_download_short(prep_temp_dir): assert len(ani1_data.data) == 0 # test max records exclusion - ani1_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=2 - ) + ani1_data._process_downloaded(str(local_data_path), hdf5_file, max_records=2) assert len(ani1_data.data) == 2 @@ -524,9 +536,7 @@ def test_an1_process_download_no_conversion(prep_temp_dir): file_name_path = str(local_data_path) + "/" + hdf5_file assert os.path.isfile(file_name_path) - ani1_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=1 - ) + ani1_data._process_downloaded(str(local_data_path), hdf5_file, max_records=1) # @@ -867,9 +877,7 @@ def test_an1_process_download_unit_conversion(prep_temp_dir): file_name_path = str(local_data_path) + "/" + hdf5_file assert os.path.isfile(file_name_path) - ani1_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=1 - ) + ani1_data._process_downloaded(str(local_data_path), hdf5_file, max_records=1) # @@ -1115,9 +1123,7 @@ def spice114_process_download_short(prep_temp_dir): assert len(spice_data.data) == 0 # test max records exclusion - spice_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=1 - ) + spice_data._process_downloaded(str(local_data_path), hdf5_file, max_records=1) assert len(spice_data.data) == 1 assert spice_data.data[0]["atomic_numbers"].shape == (27, 1) @@ -1193,9 +1199,7 @@ def test_spice114_process_download_no_conversion(prep_temp_dir): file_name_path = str(local_data_path) + "/" + hdf5_file assert os.path.isfile(file_name_path) - spice_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=1 - ) + spice_data._process_downloaded(str(local_data_path), hdf5_file, max_records=1) # @@ -1349,9 +1353,7 @@ def test_spice114_process_download_conversion(prep_temp_dir): file_name_path = str(local_data_path) + "/" + hdf5_file assert os.path.isfile(file_name_path) - spice_data._process_downloaded( - str(local_data_path), hdf5_file, unit_testing_max_records=1 - ) + spice_data._process_downloaded(str(local_data_path), hdf5_file, max_records=1) # @@ -1502,9 +1504,7 @@ def test_ani2x(prep_temp_dir): output_file_dir=local_path_dir, local_cache_dir=local_path_dir, ) - ani2x_dataset._process_downloaded( - local_data_path, filename, unit_testing_max_records=1 - ) + ani2x_dataset._process_downloaded(local_data_path, filename, max_records=1) assert len(ani2x_dataset.data) == 1 assert ani2x_dataset.data[0]["name"] == "[1_9]" @@ -1563,7 +1563,7 @@ def test_spice114_openff_test_fetching(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=True, - unit_testing_max_records=2, + max_records=2, ) with SqliteDict( @@ -1588,7 +1588,7 @@ def test_spice114_openff_test_fetching(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=True, - unit_testing_max_records=2, + max_records=2, pbar=pbar, ) @@ -1613,7 +1613,7 @@ def test_spice114_openff_test_fetching(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=False, - unit_testing_max_records=2, + max_records=2, ) assert pbar.total == 0 @@ -1640,7 +1640,7 @@ def test_spice114_openff_test_fetching(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=False, - unit_testing_max_records=10, + max_records=10, pbar=pbar, ) @@ -1741,7 +1741,7 @@ def test_spice114_openff_test_process_downloaded(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=True, - unit_testing_max_records=2, + umax_records=2, ) spice_openff_data._process_downloaded( @@ -1775,9 +1775,7 @@ def test_spice_114_openff_process_datasets(prep_temp_dir): assert np.isclose(self_energy, -162.113665 * unit.hartree) assert charge == 1.0 * unit.elementary_charge - spice_openff_data.process( - force_download=True, unit_testing_max_records=10, n_threads=3 - ) + spice_openff_data.process(force_download=True, max_records=10, n_threads=3) # note that when we fetch the data, all the records are conformers of the same molecule # so we only end up with one molecule in data, but with 10 conformers @@ -1941,7 +1939,7 @@ def test_spice_2_process_datasets(prep_temp_dir): assert np.isclose(self_energy, -162.113665 * unit.hartree) assert charge == 1.0 * unit.elementary_charge - spice_2_data.process(force_download=True, unit_testing_max_records=10, n_threads=2) + spice_2_data.process(force_download=True, max_records=10, n_threads=2) # note that when we fetch the data, all the records are conformers of the same molecule # so we only end up with one molecule in data, but with 10 conformers diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 86190ec7..894c7582 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -9,7 +9,7 @@ from modelforge.dataset import QM9Dataset DATASETS = [QM9Dataset] -from ..utils import PropertyNames +from modelforge.utils.prop import PropertyNames @pytest.fixture(scope="session") @@ -18,34 +18,34 @@ def prep_temp_dir(tmp_path_factory): return fn -@pytest.fixture( - autouse=True, -) -def cleanup_files(): - """Fixture to clean up temporary files before and after test execution.""" - - def _cleanup(): - for dataset in DATASETS: - dataset_name = dataset().dataset_name - - files = [ - f"{dataset_name}_cache.hdf5", - f"{dataset_name}_cache.hdf5.gz", - f"{dataset_name}_processed.npz", - f"{dataset_name}_subset_cache.hdf5", - f"{dataset_name}_subset_cache.hdf5.gz", - f"{dataset_name}_subset_processed.npz", - ] - for f in files: - try: - os.remove(f) - print(f"Deleted {f}") - except FileNotFoundError: - print(f"{f} not found") - - _cleanup() - yield # This is where the test will execute - _cleanup() +# @pytest.fixture( +# autouse=True, +# ) +# def cleanup_files(): +# """Fixture to clean up temporary files before and after test execution.""" +# +# def _cleanup(): +# for dataset in DATASETS: +# dataset_name = dataset().dataset_name +# +# files = [ +# f"{dataset_name}_cache.hdf5", +# f"{dataset_name}_cache.hdf5.gz", +# f"{dataset_name}_processed.npz", +# f"{dataset_name}_subset_cache.hdf5", +# f"{dataset_name}_subset_cache.hdf5.gz", +# f"{dataset_name}_subset_processed.npz", +# ] +# for f in files: +# try: +# os.remove(f) +# print(f"Deleted {f}") +# except FileNotFoundError: +# print(f"{f} not found") +# +# _cleanup() +# yield # This is where the test will execute +# _cleanup() def test_dataset_imported(): @@ -128,7 +128,7 @@ def test_dataset_basic_operations(): @pytest.mark.parametrize("dataset", DATASETS) def test_different_properties_of_interest(dataset): factory = DatasetFactory() - data = dataset(for_unit_testing=True) + data = dataset(for_unit_testing=True, regenerate_cache=True) assert data.properties_of_interest == [ "geometry", "atomic_numbers", @@ -139,7 +139,7 @@ def test_different_properties_of_interest(dataset): dataset = factory.create_dataset(data) raw_data_item = dataset[0] assert isinstance(raw_data_item, dict) - assert len(raw_data_item) == 6 # 6 properties are returned + assert len(raw_data_item) == 6 # 6 properties are returned data.properties_of_interest = [ "internal_energy_at_0K", @@ -370,7 +370,8 @@ def test_data_item_format(initialized_dataset): """Test the format of individual data items in the dataset.""" from typing import Dict - raw_data_item = initialized_dataset.torch_dataset[0] + dataset = initialized_dataset + raw_data_item = dataset.torch_dataset[0] assert isinstance(raw_data_item, Dict) assert isinstance(raw_data_item["atomic_numbers"], torch.Tensor) assert isinstance(raw_data_item["positions"], torch.Tensor) @@ -418,26 +419,35 @@ def test_dataset_splitting(splitting_strategy, datasets_to_test): """Test random_split on the the dataset.""" from modelforge.dataset import DatasetFactory - dataset = DatasetFactory.create_dataset(datasets_to_test) + dataset = DatasetFactory.create_dataset(datasets_to_test.dataset) train_dataset, val_dataset, test_dataset = splitting_strategy().split(dataset) - + print("local cache dir, ", datasets_to_test.dataset.local_cache_dir) energy = train_dataset[0]["E"].item() - assert np.isclose(energy, -412509.9375) or np.isclose(energy, -106277.4161215308) + if splitting_strategy == RandomRecordSplittingStrategy: + assert np.isclose(energy, datasets_to_test.expected_E_random_split) + elif splitting_strategy == FirstComeFirstServeSplittingStrategy: + assert np.isclose(energy, datasets_to_test.expected_E_fcfs_split) + train_dataset2, val_dataset2, test_dataset2 = splitting_strategy( + split=[0.6, 0.3, 0.1] + ).split(dataset) + + # since not all datasets will ultimately have 100 records, since some may include multiple conformers + # associated with each record, we will look at the ratio + total = len(train_dataset2) + len(val_dataset2) + len(test_dataset2) + assert np.isclose(len(train_dataset2) / total / 0.6, 1.0, rtol=0.1) + assert np.isclose(len(val_dataset2) / total / 0.3, 1.0, rtol=0.1) + assert np.isclose(len(test_dataset2) / total / 0.1, 1.0, rtol=0.1) + + # assert len(train_dataset) == 60 + # assert len(val_dataset) == 30 + # assert len(test_dataset) == 10 try: splitting_strategy(split=[0.2, 0.1, 0.1]) except AssertionError as excinfo: print(f"AssertionError raised: {excinfo}") - train_dataset, val_dataset, test_dataset = splitting_strategy( - split=[0.6, 0.3, 0.1] - ).split(dataset) - - assert len(train_dataset) == 60 - assert len(val_dataset) == 30 - assert len(test_dataset) == 10 - @pytest.mark.parametrize("dataset", DATASETS) def test_dataset_downloader(dataset, prep_temp_dir): @@ -457,7 +467,7 @@ def test_numpy_dataset_assignment(datasets_to_test): """ factory = DatasetFactory() - data = datasets_to_test + data = datasets_to_test.dataset factory._load_or_process_data(data) assert hasattr(data, "numpy_data") diff --git a/scripts/dataset_curation.py b/scripts/dataset_curation.py index 596dc146..a513a910 100644 --- a/scripts/dataset_curation.py +++ b/scripts/dataset_curation.py @@ -80,6 +80,9 @@ def SPICE_114_OpenFF( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This fetches the SPICE 1.1.4 dataset from MOLSSI QCArchive using the OpenFF level of theory. @@ -132,7 +135,14 @@ def SPICE_114_OpenFF( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - spice_dataset.process(force_download=force_download) + spice_dataset.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {spice_dataset.total_records}") + print(f"Total conformers: {spice_dataset.total_conformers}") def SPICE_114( @@ -140,6 +150,9 @@ def SPICE_114( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This fetches the SPICE 1.1.4 dataset from Zenodo and saves it as curated hdf5 file. @@ -184,7 +197,14 @@ def SPICE_114( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - spice_114.process(force_download=force_download) + spice_114.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {spice_114.total_records}") + print(f"Total conformers: {spice_114.total_conformers}") def QM9( @@ -192,7 +212,9 @@ def QM9( output_file_dir: str, local_cache_dir: str, force_download: bool = False, - unit_testing_max_records=None, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This fetches and process the QM9 dataset into a curated hdf5 file. @@ -228,13 +250,15 @@ def QM9( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - if unit_testing_max_records is None: - qm9.process(force_download=force_download) - else: - qm9.process( - force_download=force_download, - unit_testing_max_records=unit_testing_max_records, - ) + + qm9.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {qm9.total_records}") + print(f"Total conformers: {qm9.total_conformers}") def ANI1x( @@ -242,6 +266,9 @@ def ANI1x( output_file_dir: str, local_cache_dir: str, force_download: bool = False, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This fetches and processes the ANI1x dataset into a curated hdf5 file. @@ -296,7 +323,14 @@ def ANI1x( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - ani1x.process(force_download=force_download) + ani1x.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {ani1x.total_records}") + print(f"Total conformers: {ani1x.total_conformers}") def ANI2x( @@ -304,7 +338,9 @@ def ANI2x( output_file_dir: str, local_cache_dir: str, force_download: bool = False, - unit_testing_max_records=None, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This fetches and processes the ANI2x dataset into a curated hdf5 file. @@ -328,13 +364,15 @@ def ANI2x( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - if unit_testing_max_records is None: - ani2x.process(force_download=force_download) - else: - ani2x.process( - force_download=force_download, - unit_testing_max_records=unit_testing_max_records, - ) + + ani2x.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {ani2x.total_records}") + print(f"Total conformers: {ani2x.total_conformers}") """ @@ -342,52 +380,72 @@ def ANI2x( """ # define the local path prefix -local_prefix = "/Users/cri/Documents/Projects-msk/datasets" +local_prefix = "/home/cri/Documents/datasets" # we will save all the files to a central location output_file_dir = f"{local_prefix}/hdf5_files" # ANI2x test dataset # local_cache_dir = f"{local_prefix}/ani2x_dataset" -# hdf5_file_name = "ani2x_dataset.hdf5" +# hdf5_file_name = "ani2x_dataset_ntc_1000.hdf5" # # ANI2x( # hdf5_file_name, # output_file_dir, # local_cache_dir, # force_download=False, -# # unit_testing_max_records=100, +# # max_records=100, +# max_conformers_per_record=10, +# total_conformers=1000, # ) -# # QM9 dataset +# QM9 dataset # local_cache_dir = f"{local_prefix}/qm9_dataset" -# hdf5_file_name = "qm9_dataset_n100.hdf5" +# hdf5_file_name = "qm9_dataset_ntc_1000.hdf5" # # QM9( # hdf5_file_name, # output_file_dir, # local_cache_dir, # force_download=False, -# unit_testing_max_records=100, +# max_records=1000, +# max_conformers_per_record=1, +# total_conformers=1000, # ) # # SPICE 2 dataset -local_cache_dir = f"{local_prefix}/spice2_dataset" -hdf5_file_name = "spice_2_dataset.hdf5" - -SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# local_cache_dir = f"{local_prefix}/spice2_dataset" +# hdf5_file_name = "spice_2_dataset.hdf5" +# +# SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) # # SPICE 1.1.4 OpenFF dataset -# local_cache_dir = f"{local_prefix}/spice_openff_dataset" -# hdf5_file_name = "spice_114_openff_dataset.hdf5" -# -# SPICE_114_OpenFF(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) -# +local_cache_dir = f"{local_prefix}/spice_openff_dataset" +hdf5_file_name = "spice_114_openff_dataset_ntc_1000.hdf5" + +SPICE_114_OpenFF( + hdf5_file_name, + output_file_dir, + local_cache_dir, + force_download=False, + max_records=10000, + total_conformers=1000, + max_conformers_per_record=10, +) + # # SPICE 1.1.4 dataset # local_cache_dir = f"{local_prefix}/spice_114_dataset" -# hdf5_file_name = "spice_114_dataset.hdf5" +# hdf5_file_name = "spice_114_dataset_ntc_1000.hdf5" # -# SPICE_114(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# SPICE_114( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# # max_records=100, +# max_conformers_per_record=10, +# total_conformers=1000, +# ) # # # # QM9 dataset # local_cache_dir = f"{local_prefix}/qm9_dataset" @@ -397,10 +455,17 @@ def ANI2x( # # # ANI-1x dataset # local_cache_dir = f"{local_prefix}/ani1x_dataset" -# hdf5_file_name = "ani1x_dataset.hdf5" -# -# ANI1x(hdf5_file_name, output_file_dir, local_cache_dir, force_download=True) +# hdf5_file_name = "ani1x_dataset_ntc_1000.hdf5" # +# ANI1x( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# total_conformers=1000, +# max_conformers_per_record=10, +# ) + # # ANI-2x dataset # local_cache_dir = f"{local_prefix}/ani2x_dataset" # hdf5_file_name = "ani2x_dataset.hdf5" From 4ba6b70a1c8a70d8d3f9fe08cecb58fbd954ff14 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 17:24:55 -0700 Subject: [PATCH 25/37] updated training datasets to have consistent total conformers. Updated test_dataset.py tests --- modelforge/curation/spice_2_curation.py | 967 ++++------------- .../spice_2_from_qcarchive_curation.py | 977 ++++++++++++++++++ modelforge/curation/spice_openff_curation.py | 4 +- modelforge/dataset/__init__.py | 8 +- modelforge/dataset/ani1x.py | 14 +- modelforge/dataset/ani2x.py | 18 +- modelforge/dataset/qm9.py | 18 +- modelforge/dataset/spice114.py | 16 +- modelforge/dataset/spice114openff.py | 16 +- modelforge/dataset/spice2.py | 18 +- modelforge/potential/utils.py | 92 +- modelforge/tests/conftest.py | 27 +- modelforge/tests/test_curation.py | 2 +- modelforge/tests/test_dataset.py | 90 +- scripts/dataset_curation.py | 64 +- 15 files changed, 1418 insertions(+), 913 deletions(-) create mode 100644 modelforge/curation/spice_2_from_qcarchive_curation.py diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index 3250d944..afca4791 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -1,58 +1,27 @@ -from typing import List, Optional, Dict, Tuple - from modelforge.curation.curation_baseclass import DatasetCuration -from retry import retry -from tqdm import tqdm -from openff.units import unit - +from typing import Optional from loguru import logger +from openff.units import unit class SPICE2Curation(DatasetCuration): """ - Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. + Routines to fetch the spice 2 dataset from zenodo and process into a curated hdf5 file. - The SPICE dataset contains conformations for a diverse set of small molecules, + Small-molecule/Protein Interaction Chemical Energies (SPICE). + The SPICE dataset containsconformations for a diverse set of small molecules, dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and uncharged molecules, and a wide range of covalent and non-covalent interactions. It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, - using Psi4 along with other useful quantities such as multipole moments and bond orders. - - This includes the following collections from qcarchive. Collections included in SPICE 1.1.4 are annotated with - along with the version used in SPICE 1.1.4; while the underlying molecules are typically the same in a given collection, - newer versions may have had some calculations redone, e.g., rerun calculations that failed or rerun with - a newer version Psi4 - - - 'SPICE Solvated Amino Acids Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) - - 'SPICE Dipeptides Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE DES Monomers Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) - - 'SPICE DES370K Single Points Dataset v1.0' * (SPICE 1.1.4 at v1.0) - - 'SPICE DES370K Single Points Dataset Supplement v1.1' * (SPICE 1.1.4 at v1.0) - - 'SPICE PubChem Set 1 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 2 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 3 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 4 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 5 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 6 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) - - 'SPICE PubChem Set 7 Single Points Dataset v1.0' - - 'SPICE PubChem Set 8 Single Points Dataset v1.0' - - 'SPICE PubChem Set 9 Single Points Dataset v1.0' - - 'SPICE PubChem Set 10 Single Points Dataset v1.0' - - 'SPICE Ion Pairs Single Points Dataset v1.2' * (SPICE 1.1.4 at v1.1) - - 'SPICE PubChem Boron Silicon v1.0' - - 'SPICE Solvated PubChem Set 1 v1.0' - - 'SPICE Water Clusters v1.0' - - 'SPICE Amino Acid Ligand v1.0 - - - SPICE 2 zenodo release: - https://zenodo.org/records/10835749 - - Reference to original SPICE publication: + using Psi4 1.4.1 along with other useful quantities such as multipole moments and bond orders. + + Reference to the SPICE 1 dataset publication: Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. Sci Data 10, 11 (2023). https://doi.org/10.1038/s41597-022-01882-6 + Dataset DOI: + https://doi.org/10.5281/zenodo.8222043 Parameters ---------- @@ -63,47 +32,33 @@ class SPICE2Curation(DatasetCuration): local_cache_dir: str, optional, default='./spice_dataset' Location to save downloaded dataset. convert_units: bool, optional, default=True - Convert from the source units (e.g., angstrom, bohr, hartree) + Convert from [e.g., angstrom, bohr, hartree] (i.e., source units) to [nanometer, kJ/mol] (i.e., target units) - release_version: str, optional, default='2' - Version of the SPICE dataset to fetch from the MOLSSI QCArchive. - Currently doesn't do anything + Examples -------- - >>> spice2_data = SPICE2Curation(hdf5_file_name='spice2_dataset.hdf5', - >>> local_cache_dir='~/datasets/spice2_dataset') - >>> spice2_data.process() + >>> spice114_data = SPICE114Curation(hdf5_file_name='spice114_dataset.hdf5', + >>> local_cache_dir='~/datasets/spice114_dataset') + >>> spice114_data.process() """ - def __init__( - self, - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - convert_units: bool = True, - release_version: str = "2", - ): - super().__init__( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - convert_units=convert_units, - ) - self.release_version = release_version - def _init_dataset_parameters(self): - self.qcarchive_server = "ml.qcarchive.molssi.org" - - self.molecule_names = {} - - # dictionary of properties and their input units (i.e., those from QCArchive) - # and desired output units; unit conversion is performed if convert_units = True + self.dataset_download_url = ( + "https://zenodo.org/records/10975225/files/SPICE-2.0.1.hdf5" + ) + self.dataset_md5_checksum = "bfba2224b6540e1390a579569b475510" + # the spice dataset includes openff compatible unit definitions in the hdf5 file + # these values were used to generate this dictionary self.qm_parameters = { "geometry": { "u_in": unit.bohr, "u_out": unit.nanometer, }, + "formation_energy": { + "u_in": unit.hartree, + "u_out": unit.kilojoule_per_mole, + }, "dft_total_energy": { "u_in": unit.hartree, "u_out": unit.kilojoule_per_mole, @@ -124,25 +79,33 @@ def _init_dataset_parameters(self): "u_in": unit.elementary_charge, "u_out": unit.elementary_charge, }, - "scf_dipole": { + "mbis_dipoles": { "u_in": unit.elementary_charge * unit.bohr, "u_out": unit.elementary_charge * unit.nanometer, }, - "dispersion_correction_energy": { - "u_in": unit.hartree, - "u_out": unit.kilojoule_per_mole, + "mbis_quadrupoles": { + "u_in": unit.elementary_charge * unit.bohr**2, + "u_out": unit.elementary_charge * unit.nanometer**2, }, - "dispersion_correction_gradient": { - "u_in": unit.hartree / unit.bohr, - "u_out": unit.kilojoule_per_mole / unit.angstrom, + "mbis_octupoles": { + "u_in": unit.elementary_charge * unit.bohr**3, + "u_out": unit.elementary_charge * unit.nanometer**3, }, - "reference_energy": { - "u_in": unit.hartree, - "u_out": unit.kilojoule_per_mole, + "scf_dipole": { + "u_in": unit.elementary_charge * unit.bohr, + "u_out": unit.elementary_charge * unit.nanometer, }, - "formation_energy": { - "u_in": unit.hartree, - "u_out": unit.kilojoule_per_mole, + "scf_quadrupole": { + "u_in": unit.elementary_charge * unit.bohr**2, + "u_out": unit.elementary_charge * unit.nanometer**2, + }, + "mayer_indices": { + "u_in": None, + "u_out": None, + }, + "wiberg_lowdin_indices": { + "u_in": None, + "u_out": None, }, } @@ -158,7 +121,6 @@ def _init_record_entries_series(self): Options include: single_rec, e.g., name, n_configs, smiles single_atom, e.g., atomic_numbers (these are the same for all conformers) - single_mol, e.g., reference energy series_atom, e.g., charges series_mol, e.g., dft energy, dipole moment, etc. These ultimately appear under the "format" attribute in the hdf5 file. @@ -170,348 +132,53 @@ def _init_record_entries_series(self): self._record_entries_series = { "name": "single_rec", - "dataset_name": "single_rec", - "source": "single_rec", "atomic_numbers": "single_atom", - "total_charge": "single_rec", "n_configs": "single_rec", - "reference_energy": "single_rec", - "molecular_formula": "single_rec", - "canonical_isomeric_explicit_hydrogen_mapped_smiles": "single_rec", + "smiles": "single_rec", + "subset": "single_rec", + "total_charge": "single_rec", "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", "dft_total_force": "series_atom", "formation_energy": "series_mol", + "mayer_indices": "series_atom", "mbis_charges": "series_atom", - "scf_dipole": "series_atom", + "mbis_dipoles": "series_atom", + "mbis_octupoles": "series_atom", + "mbis_quadrupoles": "series_atom", + "scf_dipole": "series_mol", + "scf_quadrupole": "series_mol", + "wiberg_lowdin_indices": "series_atom", } - # we will use the retry package to allow us to resume download if we lose connection to the server - @retry(delay=1, jitter=1, backoff=2, tries=50, logger=logger, max_delay=10) - def _fetch_singlepoint_from_qcarchive( - self, - dataset_name: str, - specification_name: str, - local_database_name: str, - local_path_dir: str, - force_download: bool, - unit_testing_max_records: Optional[int] = None, - pbar: Optional[tqdm] = None, - ): - """ - Fetches a singlepoint dataset from the MOLSSI QCArchive and stores it in a local sqlite database. - - Parameters - ---------- - dataset_name: str, required - Name of the dataset to fetch from the QCArchive - specification_name: str, required - Name of the specification to fetch from the QCArchive - local_database_name: str, required - Name of the local sqlite database to store the dataset - local_path_dir: str, required - Path to the directory to store the local sqlite database - force_download: bool, required - If True, this will force the software to download the data again, even if present. - unit_testing_max_records: Optional[int], optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. - Note, conformers of the same molecule are saved in separate records, and thus the number of molecules - that end up in the 'data' list after _process_downloaded is called may be less than unit_testing_max_records. - pbar: Optional[tqdm], optional, default=None - Progress bar to track the download process. - - pbar - - Returns - ------- - - """ - from sqlitedict import SqliteDict - from loguru import logger - from qcportal import PortalClient - - dataset_type = "singlepoint" - client = PortalClient(self.qcarchive_server) - - ds = client.get_dataset(dataset_type=dataset_type, dataset_name=dataset_name) - logger.debug(f"Fetching {dataset_name} from the QCArchive.") - ds.fetch_entry_names() - - entry_names = ds.entry_names - if unit_testing_max_records is None: - unit_testing_max_records = len(entry_names) - with SqliteDict( - f"{local_path_dir}/{local_database_name}", - tablename=specification_name, - autocommit=True, - ) as spice_db: - # defining the db_keys as a set is faster for - # searching to see if a key exists - db_keys = set(spice_db.keys()) - to_fetch = [] - if force_download: - for name in entry_names[0:unit_testing_max_records]: - to_fetch.append(name) - else: - for name in entry_names[0:unit_testing_max_records]: - if name not in db_keys: - to_fetch.append(name) - if pbar is not None: - pbar.total = pbar.total + len(to_fetch) - pbar.refresh() - - # We need a different routine to fetch entries vs records with a give specification - if len(to_fetch) > 0: - if specification_name == "entry": - logger.debug( - f"Fetching {len(to_fetch)} entries from dataset {dataset_name}." - ) - for entry in ds.iterate_entries( - to_fetch, force_refetch=force_download - ): - spice_db[entry.dict()["name"]] = entry - if pbar is not None: - pbar.update(1) - - else: - logger.debug( - f"Fetching {len(to_fetch)} records for {specification_name} from dataset {dataset_name}." - ) - for record in ds.iterate_records( - to_fetch, - specification_names=[specification_name], - force_refetch=force_download, - ): - spice_db[record[0]] = record[2].dict() - if pbar is not None: - pbar.update(1) - - def _calculate_reference_energy_and_charge( - self, smiles: str - ) -> Tuple[unit.Quantity, unit.Quantity]: + def _calculate_reference_charge(self, smiles: str) -> unit.Quantity: """ - Calculate the reference energy for a given molecule, as defined by the SMILES string. - - This routine is taken from - https://github.com/openmm/spice-dataset/blob/f20d4887fa86d8875688d2dfe9bb2a2fc51dd98c/downloader/downloader.py - Reference energies for individual atoms are computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + Calculate the total charge of a molecule from its SMILES string. Parameters ---------- smiles: str, required - SMILES string describing the molecule of interest. + SMILES string of the molecule. Returns ------- - Tuple[unit.Quantity, unit.Quantity] - Returns the reference energy of for the atoms in the molecule (in hartrees) - and the total charge of the molecule (in elementary charge). + total_charge: unit.Quantity """ from rdkit import Chem - import numpy as np - - # Reference energies, in hartrees, computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. - - atom_energy = { - "B": { - -1: -24.677421752684776, - 0: -24.671520535482145, - 1: -24.364648707125294, - }, - "Br": {-1: -2574.2451510945853, 0: -2574.1167240829964}, - "C": {-1: -37.91424135791358, 0: -37.87264507233593, 1: -37.45349214963933}, - "Ca": {2: -676.9528465198214}, - "Cl": {-1: -460.3350243496703, 0: -460.1988762285739}, - "F": {-1: -99.91298732343974, 0: -99.78611622985483}, - "H": {-1: -0.5027370838721259, 0: -0.4987605100487531, 1: 0.0}, - "I": {-1: -297.8813829975981, 0: -297.76228914445625}, - "K": {1: -599.8025677513111}, - "Li": {1: -7.285254714046546}, - "Mg": {2: -199.2688420040449}, - "N": { - -1: -54.602291095426494, - 0: -54.62327513368922, - 1: -54.08594142587869, - }, - "Na": {1: -162.11366478783253}, - "O": {-1: -75.17101657391741, 0: -75.11317840410095, 1: -74.60241514396725}, - "P": {0: -341.3059197024934, 1: -340.9258392474849}, - "S": {-1: -398.2405387031612, 0: -398.1599636677874, 1: -397.7746615977658}, - "Si": { - -1: -289.4540686037408, - 0: -289.4131352299586, - 1: -289.1189404777897, - }, - } - default_charge = {} - for symbol in atom_energy: - energies = [ - (energy, charge) for charge, energy in atom_energy[symbol].items() - ] - default_charge[symbol] = sorted(energies)[0][1] rdmol = Chem.MolFromSmiles(smiles, sanitize=False) total_charge = sum(atom.GetFormalCharge() for atom in rdmol.GetAtoms()) - symbol = [atom.GetSymbol() for atom in rdmol.GetAtoms()] - charge = [default_charge[s] for s in symbol] - delta = np.sign(total_charge - sum(charge)) - while delta != 0: - best_index = -1 - best_energy = None - for i in range(len(symbol)): - s = symbol[i] - e = atom_energy[s] - new_charge = charge[i] + delta - - if new_charge in e: - if best_index == -1 or e[new_charge] - e[charge[i]] < best_energy: - best_index = i - best_energy = e[new_charge] - e[charge[i]] - - charge[best_index] += delta - delta = np.sign(total_charge - sum(charge)) - - return ( - sum(atom_energy[s][c] for s, c in zip(symbol, charge)) * unit.hartree, - int(total_charge) * unit.elementary_charge, - ) - - def _check_name_format(self, name: str): - """ - Check if the name of the molecule conforms to the form {name}-{conformer_number}. - If not, we will return false - - Parameters - ---------- - name: str, required - Name of the molecule to check. - - Returns - ------- - bool - True if the name conforms to the form {name}-{conformer_number}, False otherwise. - - """ - import re - - if len(re.findall(r"-[0-9]+", name)) > 0: - # if re.match(r"^[a-zA-Z0-9-_()]+-[0-9]+$", name): - return True - else: - return False - - def _sort_keys( - self, non_error_keys: List[str] - ) -> Tuple[List[str], Dict[str, str], Dict[str, str]]: - """ - This will sort record identifiers such that conformers are listed in numerical order. - - This will, if necessarily also sanitize the key of the molecule, to ensure that we have the following - form {name}-{conformer_number}. In some cases, the original name has a hyphen which would causes issues - with simply splitting based upon a "-" to either get the name or the conformer number. - - The function is called by _process_downloaded. - - Parameters - ---------- - non_error_keys - List of keys that do not have errors that will be sorted. These need to be of the form of {name}-{conformer_number}. - - Returns - ------- - Tuple[List[str], Dict[str, str], Dict[str, str]] - List of sorted keys, dictionary that maps the sanitized key to the original key, and a dictionary that maps the - sorted keys to the molecule name (i.e., drops any conformer numbers from the end). - - """ - # we need to sanitize the names of the molecule, as - # some of the names have a dash in them, for example ALA-ALA-1 - # This will replace any hyphens in the name with an underscore. - # To be able to retain the original name, needed for accessing the record in the sqlite file - # we will create a simple dictionary that maps the sanitized name to the original name. - - non_error_keys_sanitized = [] - original_keys = {} - - for key in non_error_keys: - # check if we have a name of the form {name}-{conformer_number} - if self._check_name_format(key): - s = "_" - d = "-" - temp = key.split("-") - # replace all but the last hyphens with an underscore - temp_key = d.join([s.join(temp[0:-1]), temp[-1]]) - # if we do not have a conformer number at the end of the name, we will simply replace ANY hyphens with an underscore - else: - temp_key = key.replace("-", "_") - - non_error_keys_sanitized.append(temp_key) - original_keys[temp_key] = key - - # We will sort the keys such that conformers are listed in numerical order. - # This is not strictly necessary, but will help to better retain - # connection to the original QCArchive data, where in most cases conformer-id will directly correspond to - # the index of the final array constructed here. - # Note, if the calculation of an individual conformer failed on qcarchive, - # it will have been excluded from the non_error_keys list. As such, in such cases, - # the conformer id in the record name will no longer have a one-to-one correspondence with the - # index of the conformers in the combined arrays. This should not be an issue in terms of training, - # but could cause some confusion when interrogating a specific set of conformers geometries for a molecule. - - sorted_keys = [] - - # often names are of the format {name}-{conformer_number} - # we will first sort by name - # note, if we don't have a conformer number, this will still work - pre_sort = sorted(non_error_keys_sanitized, key=lambda x: (x.split("-")[0])) - - # then sort each molecule by conformer_number - # we'll do this by simple iteration through the list, and when we encounter a new molecule name, we'll sort the - # previous temporary list we generated. - current_val = pre_sort[0].split("-")[0] - temp_list = [] - - for val in pre_sort: - name = val.split("-")[0] - - if name == current_val: - temp_list.append(val) - else: - # we need to check to see if the name actually has a conformer id, otherwise this will fail - # we are going to assume the first entry in the list means the entire list has the right format - if self._check_name_format(temp_list[0]): - sorted_keys += sorted( - temp_list, key=lambda x: int(x.split("-")[-1]) - ) - else: - sorted_keys += temp_list - - # clear out the list and restart - temp_list = [] - current_val = name - temp_list.append(val) - - # sort the final batch - # we need to check to see if the name actually has a conformer id, otherwise this will fail - if self._check_name_format(temp_list[0]): - sorted_keys += sorted(temp_list, key=lambda x: int(x.split("-")[-1])) - - names = {} - - # store the name in a dictionary - for key in sorted_keys: - name = key.split("-")[0] - names[key] = name - - return sorted_keys, original_keys, names + return int(total_charge) * unit.elementary_charge def _process_downloaded( self, local_path_dir: str, - filenames: List[str], - dataset_sources: List[Dict], + name: str, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -520,255 +187,118 @@ def _process_downloaded( ---------- local_path_dir: str, required Path to the directory that contains the raw hdf5 datafile - filenames: List[str], required - Names of the raw sqlite files to process, - dataset_sources: List[Dict], required - List of Dicts, where each Dict provides the names of the sqlite file to process ( accessed with key 'name') - and specification where data is stored on qcarchive (accessible with key 'specifications'). + name: str, required + Name of the raw hdf5 file, + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. + Examples + -------- """ + import h5py from tqdm import tqdm - import numpy as np - from sqlitedict import SqliteDict - from loguru import logger - import qcelemental as qcel - from numpy import newaxis - - for filename, dataset_info in zip(filenames, dataset_sources): - input_file_name = f"{local_path_dir}/{filename}" - dataset_name = dataset_info["name"] - specifications = dataset_info["specifications"] - spec = [s for i, s in enumerate(specifications) if s != "entry"] - - non_error_keys = [] - - # identify the set of molecules that do not have errors - with SqliteDict( - input_file_name, tablename=spec[0], autocommit=False - ) as spice_db: - spec_keys = list(spice_db.keys()) - - for key in spec_keys: - if spice_db[key]["status"].value == "complete": - non_error_keys.append(key) - - sorted_keys, original_keys, molecule_names = self._sort_keys(non_error_keys) - - # first read in molecules from entry - with SqliteDict( - input_file_name, tablename="entry", autocommit=False - ) as spice_db: - logger.debug(f"Processing {filename} entries.") - for key in tqdm(sorted_keys): - val = spice_db[original_keys[key]].dict() - name = molecule_names[key] - # if we haven't processed a molecule with this name yet - # we will add to the molecule_names dictionary - if name not in self.molecule_names.keys(): - self.molecule_names[name] = len(self.data) - - data_temp = {} - data_temp["name"] = name - data_temp["source"] = input_file_name.replace(".sqlite", "") - atomic_numbers = [] - for element in val["molecule"]["symbols"]: - atomic_numbers.append( - qcel.periodictable.to_atomic_number(element) - ) - data_temp["atomic_numbers"] = np.array(atomic_numbers).reshape( - -1, 1 - ) - data_temp["molecular_formula"] = val["molecule"]["identifiers"][ - "molecular_formula" - ] - data_temp[ - "canonical_isomeric_explicit_hydrogen_mapped_smiles" - ] = val["molecule"]["extras"][ - "canonical_isomeric_explicit_hydrogen_mapped_smiles" - ] - data_temp["n_configs"] = 1 - data_temp["geometry"] = val["molecule"]["geometry"].reshape( - 1, -1, 3 - ) - ( - data_temp["reference_energy"], - data_temp["total_charge"], - ) = self._calculate_reference_energy_and_charge( - data_temp[ - "canonical_isomeric_explicit_hydrogen_mapped_smiles" - ] - ) - data_temp["dataset_name"] = dataset_name - self.data.append(data_temp) - else: - # if we have already encountered this molecule we need to append to the data - # since we are using numpy we will use vstack to append to the arrays - index = self.molecule_names[name] - - self.data[index]["n_configs"] += 1 - self.data[index]["geometry"] = np.vstack( - ( - self.data[index]["geometry"], - val["molecule"]["geometry"].reshape(1, -1, 3), - ) - ) - - with SqliteDict( - input_file_name, tablename=spec[0], autocommit=False - ) as spice_db: - logger.debug(f"Processing {filename} {spec[0]}.") - - for key in tqdm(sorted_keys): - name = molecule_names[key] - val = spice_db[original_keys[key]] - - index = self.molecule_names[name] - - # note, we will use the convention of names being lowercase - # and spaces denoted by underscore - quantity = "dft total energy" - quantity_o = "dft_total_energy" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = val["properties"][quantity] - else: - self.data[index][quantity_o] = np.vstack( - (self.data[index][quantity_o], val["properties"][quantity]) - ) - - quantity = "dft total gradient" - quantity_o = "dft_total_gradient" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - val["properties"][quantity] - ).reshape(1, -1, 3) - - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array(val["properties"][quantity]).reshape(1, -1, 3), - ) - ) - # we will store force along with gradient - quantity = "dft total gradient" - quantity_o = "dft_total_force" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = -np.array( - val["properties"][quantity] - ).reshape(1, -1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - -np.array(val["properties"][quantity]).reshape( - 1, -1, 3 - ), - ) - ) - - quantity = "mbis charges" - quantity_o = "mbis_charges" - if quantity_o not in self.data[index].keys(): - if quantity in val["properties"].keys(): - self.data[index][quantity_o] = np.array( - val["properties"][quantity] - ).reshape(1, -1)[..., newaxis] - - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array(val["properties"][quantity]).reshape(1, -1)[ - ..., newaxis - ], - ) - ) - - quantity = "scf dipole" - quantity_o = "scf_dipole" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - val["properties"][quantity] - ).reshape(1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array(val["properties"][quantity]).reshape(1, 3), - ) - ) - - # typecasting issue in there - - quantity = "dispersion correction energy" - quantity_o = "dispersion_correction_energy" - - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - val["properties"][quantity] - ).reshape(1, 1) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array(float(val["properties"][quantity])).reshape( - 1, 1 - ), - ), - ) - quantity = "dispersion correction gradient" - quantity_o = "dispersion_correction_gradient" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - val["properties"][quantity] - ).reshape(1, -1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array(val["properties"][quantity]).reshape(1, -1, 3), - ) - ) - # assign units - for datapoint in self.data: - for key in datapoint.keys(): - if key in self.qm_parameters: - if not isinstance(datapoint[key], unit.Quantity): - datapoint[key] = ( - datapoint[key] * self.qm_parameters[key]["u_in"] - ) - # add in the formation energy defined as: - # dft_total_energy + dispersion_correction_energy - reference_energy - - # the dispersion corrected energy and gradient can be calculated from the raw data - datapoint["dft_total_energy"] = ( - datapoint["dft_total_energy"] - + datapoint["dispersion_correction_energy"] - ) - # we only want to write the dispersion corrected energy to the file to avoid confusion - datapoint.pop("dispersion_correction_energy") - datapoint["dft_total_gradient"] = ( - datapoint["dft_total_gradient"] - + datapoint["dispersion_correction_gradient"] - ) - # we only want to write the dispersion corrected gradient to the file to avoid confusion - datapoint.pop("dispersion_correction_gradient") + input_file_name = f"{local_path_dir}/{name}" - datapoint["formation_energy"] = ( - datapoint["dft_total_energy"] - - np.array(datapoint["reference_energy"].m * datapoint["n_configs"]) - * datapoint["reference_energy"].u - ) + need_to_reshape = {"formation_energy": True, "dft_total_energy": True} + with h5py.File(input_file_name, "r") as hf: + names = list(hf.keys()) + if max_records is None: + n_max = len(names) + elif max_records is not None: + n_max = max_records + + conformers_counter = 0 + + for i, name in tqdm(enumerate(names[0:n_max]), total=n_max): + if total_conformers is not None: + if conformers_counter >= total_conformers: + break + + # Extract the total number of conformations for a given molecule + conformers_per_record = hf[name]["conformations"].shape[0] + + keys_list = list(hf[name].keys()) + + # temp dictionary for ANI-1x and ANI-1ccx data + ds_temp = {} + + ds_temp["name"] = f"{name}" + ds_temp["smiles"] = hf[name]["smiles"][()][0].decode("utf-8") + ds_temp["atomic_numbers"] = hf[name]["atomic_numbers"][()].reshape( + -1, 1 + ) + if max_conformers_per_record is not None: + conformers_per_record = min( + conformers_per_record, + max_conformers_per_record, + ) + if total_conformers is not None: + conformers_per_record = min( + conformers_per_record, total_conformers - conformers_counter + ) + + ds_temp["n_configs"] = conformers_per_record + + # param_in is the name of the entry, param_data contains input (u_in) and output (u_out) units + for param_in, param_data in self.qm_parameters.items(): + # for consistency between datasets, we will all the particle positions "geometry" + param_out = param_in + if param_in == "geometry": + param_in = "conformations" + + if param_in in keys_list: + temp = hf[name][param_in][()] + if param_in in need_to_reshape: + temp = temp.reshape(-1, 1) + + temp = temp[0:conformers_per_record] + param_unit = param_data["u_in"] + if param_unit is not None: + # check that units in the hdf5 file match those we have defined in self.qm_parameters + try: + assert ( + hf[name][param_in].attrs["units"] + == param_data["u_in"] + ) + except: + msg1 = f'unit mismatch: units in hdf5 file: {hf[name][param_in].attrs["units"]},' + msg2 = f'units defined in curation class: {param_data["u_in"]}.' + + raise AssertionError(f"{msg1} {msg2}") + + ds_temp[param_out] = temp * param_unit + else: + ds_temp[param_out] = temp + ds_temp["total_charge"] = self._calculate_reference_charge( + ds_temp["smiles"] + ) + ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] + self.data.append(ds_temp) + conformers_counter += conformers_per_record if self.convert_units: self._convert_units() + # From documentation: By default, objects inside group are iterated in alphanumeric order. + # However, if group is created with track_order=True, the insertion order for the group is remembered (tracked) + # in HDF5 file, and group contents are iterated in that order. + # As such, we shouldn't need to do sort the objects to ensure reproducibility. + # self.data = sorted(self.data, key=lambda x: x["name"]) + def process( self, force_download: bool = False, - unit_testing_max_records: Optional[int] = None, - n_threads=6, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, ) -> None: """ Downloads the dataset, extracts relevant information, and writes an hdf5 file. @@ -778,13 +308,16 @@ def process( force_download: bool, optional, default=False If the raw data_file is present in the local_cache_dir, the local copy will be used. If True, this will force the software to download the data again, even if present. - unit_testing_max_records: int, optional, default=None - If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. - Note, that in SPICE, conformers are stored as separate records, and are combined within this routine. - As such the number of molecules in 'data' may be less than unit_testing_max_records, if the records fetched - are all conformers of the same molecule. - n_threads, int, default=6 - Number of concurrent threads for retrieving data from QCArchive + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records or total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_conformers_per_record. + Examples -------- >>> spice2_data = SPICE2Curation(hdf5_file_name='spice2_dataset.hdf5', @@ -792,137 +325,33 @@ def process( >>> spice2_data.process() """ - from concurrent.futures import ThreadPoolExecutor, as_completed - - if self.release_version == "2": - # The SPICE dataset is available in the MOLSSI QCArchive - # This will need to load from various datasets, as described on the spice-dataset github page - # see https://github.com/openmm/spice-dataset/blob/f20d4887fa86d8875688d2dfe9bb2a2fc51dd98c/downloader/downloader.py - - dataset_sources = [ - { - "name": "SPICE Solvated Amino Acids Single Points Dataset v1.1", - "specifications": ["entry", "spec_4"], - }, - { - "name": "SPICE Dipeptides Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE DES Monomers Single Points Dataset v1.1", - "specifications": ["entry", "spec_4"], - }, - { - "name": "SPICE DES370K Single Points Dataset v1.0", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE DES370K Single Points Dataset Supplement v1.1", - "specifications": ["entry", "spec_1"], - }, - { - "name": "SPICE PubChem Set 1 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 2 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 3 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 4 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 5 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 6 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Set 7 Single Points Dataset v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE PubChem Set 8 Single Points Dataset v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE PubChem Set 9 Single Points Dataset v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE PubChem Set 10 Single Points Dataset v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE Ion Pairs Single Points Dataset v1.2", - "specifications": ["entry", "spec_3"], - }, - { - "name": "SPICE PubChem Boron Silicon v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE Solvated PubChem Set 1 v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE Water Clusters v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - { - "name": "SPICE Amino Acid Ligand v1.0", - "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], - }, - ] - - # if we specify the number of records, restrict to a subset so we don't try to access multiple datasets - if unit_testing_max_records is not None: - dataset_sources = [ - { - "name": "SPICE PubChem Set 1 Single Points Dataset v1.3", - "specifications": ["entry", "spec_3"], - }, - ] - threads = [] - local_database_names = [] - - with tqdm() as pbar: - pbar.total = 0 - with ThreadPoolExecutor(max_workers=n_threads) as e: - for i, dataset_info in enumerate(dataset_sources): - dataset_name = dataset_info["name"] - specification_names = dataset_info["specifications"] - local_database_name = f"{dataset_name}.sqlite" - local_database_names.append(local_database_name) - for specification_name in specification_names: - threads.append( - e.submit( - self._fetch_singlepoint_from_qcarchive, - dataset_name=dataset_name, - specification_name=specification_name, - local_database_name=local_database_name, - local_path_dir=self.local_cache_dir, - force_download=force_download, - unit_testing_max_records=unit_testing_max_records, - pbar=pbar, - ) - ) - logger.debug(f"Data fetched.") + if max_records is not None and total_conformers is not None: + raise ValueError( + "max_records and total_conformers cannot be set at the same time." + ) + from modelforge.utils.remote import download_from_zenodo + + url = self.dataset_download_url + + # download the dataset + self.name = download_from_zenodo( + url=url, + md5_checksum=self.dataset_md5_checksum, + output_path=self.local_cache_dir, + force_download=force_download, + ) + self._clear_data() - self.molecule_names.clear() - logger.debug(f"Processing downloaded dataset.") + # process the rest of the dataset + if self.name is None: + raise Exception("Failed to retrieve name of file from zenodo.") self._process_downloaded( self.local_cache_dir, - local_database_names, - dataset_sources, + self.name, + max_records, + max_conformers_per_record, + total_conformers, ) self._generate_hdf5() diff --git a/modelforge/curation/spice_2_from_qcarchive_curation.py b/modelforge/curation/spice_2_from_qcarchive_curation.py new file mode 100644 index 00000000..abbd2a62 --- /dev/null +++ b/modelforge/curation/spice_2_from_qcarchive_curation.py @@ -0,0 +1,977 @@ +from typing import List, Optional, Dict, Tuple + +from modelforge.curation.curation_baseclass import DatasetCuration +from retry import retry +from tqdm import tqdm +from openff.units import unit + +from loguru import logger + + +class SPICE2Curation(DatasetCuration): + """ + Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. + + The SPICE dataset contains conformations for a diverse set of small molecules, + dimers, dipeptides, and solvated amino acids. It includes 15 elements, charged and + uncharged molecules, and a wide range of covalent and non-covalent interactions. + It provides both forces and energies calculated at the ωB97M-D3(BJ)/def2-TZVPPD level of theory, + using Psi4 along with other useful quantities such as multipole moments and bond orders. + + This includes the following collections from qcarchive. Collections included in SPICE 1.1.4 are annotated with + along with the version used in SPICE 1.1.4; while the underlying molecules are typically the same in a given collection, + newer versions may have had some calculations redone, e.g., rerun calculations that failed or rerun with + a newer version Psi4 + + - 'SPICE Solvated Amino Acids Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) + - 'SPICE Dipeptides Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE DES Monomers Single Points Dataset v1.1' * (SPICE 1.1.4 at v1.1) + - 'SPICE DES370K Single Points Dataset v1.0' * (SPICE 1.1.4 at v1.0) + - 'SPICE DES370K Single Points Dataset Supplement v1.1' * (SPICE 1.1.4 at v1.0) + - 'SPICE PubChem Set 1 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 2 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 3 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 4 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 5 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 6 Single Points Dataset v1.3' * (SPICE 1.1.4 at v1.2) + - 'SPICE PubChem Set 7 Single Points Dataset v1.0' + - 'SPICE PubChem Set 8 Single Points Dataset v1.0' + - 'SPICE PubChem Set 9 Single Points Dataset v1.0' + - 'SPICE PubChem Set 10 Single Points Dataset v1.0' + - 'SPICE Ion Pairs Single Points Dataset v1.2' * (SPICE 1.1.4 at v1.1) + - 'SPICE PubChem Boron Silicon v1.0' + - 'SPICE Solvated PubChem Set 1 v1.0' + - 'SPICE Water Clusters v1.0' + - 'SPICE Amino Acid Ligand v1.0 + + + SPICE 2 zenodo release: + https://zenodo.org/records/10835749 + + Reference to original SPICE publication: + Eastman, P., Behara, P.K., Dotson, D.L. et al. SPICE, + A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials. + Sci Data 10, 11 (2023). https://doi.org/10.1038/s41597-022-01882-6 + + + Parameters + ---------- + hdf5_file_name, str, required + name of the hdf5 file generated for the SPICE dataset + output_file_dir: str, optional, default='./' + Path to write the output hdf5 files. + local_cache_dir: str, optional, default='./spice_dataset' + Location to save downloaded dataset. + convert_units: bool, optional, default=True + Convert from the source units (e.g., angstrom, bohr, hartree) + to [nanometer, kJ/mol] (i.e., target units) + release_version: str, optional, default='2' + Version of the SPICE dataset to fetch from the MOLSSI QCArchive. + Currently doesn't do anything + Examples + -------- + >>> spice2_data = SPICE2Curation(hdf5_file_name='spice2_dataset.hdf5', + >>> local_cache_dir='~/datasets/spice2_dataset') + >>> spice2_data.process() + + """ + + def __init__( + self, + hdf5_file_name: str, + output_file_dir: str, + local_cache_dir: str, + convert_units: bool = True, + release_version: str = "2", + ): + super().__init__( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + local_cache_dir=local_cache_dir, + convert_units=convert_units, + ) + self.release_version = release_version + + def _init_dataset_parameters(self): + self.qcarchive_server = "ml.qcarchive.molssi.org" + + self.molecule_names = {} + + # dictionary of properties and their input units (i.e., those from QCArchive) + # and desired output units; unit conversion is performed if convert_units = True + self.qm_parameters = { + "geometry": { + "u_in": unit.bohr, + "u_out": unit.nanometer, + }, + "dft_total_energy": { + "u_in": unit.hartree, + "u_out": unit.kilojoule_per_mole, + }, + "dft_total_gradient": { + "u_in": unit.hartree / unit.bohr, + "u_out": unit.kilojoule_per_mole / unit.angstrom, + }, + "dft_total_force": { + "u_in": unit.hartree / unit.bohr, + "u_out": unit.kilojoule_per_mole / unit.angstrom, + }, + "mbis_charges": { + "u_in": unit.elementary_charge, + "u_out": unit.elementary_charge, + }, + "total_charge": { + "u_in": unit.elementary_charge, + "u_out": unit.elementary_charge, + }, + "scf_dipole": { + "u_in": unit.elementary_charge * unit.bohr, + "u_out": unit.elementary_charge * unit.nanometer, + }, + "dispersion_correction_energy": { + "u_in": unit.hartree, + "u_out": unit.kilojoule_per_mole, + }, + "dispersion_correction_gradient": { + "u_in": unit.hartree / unit.bohr, + "u_out": unit.kilojoule_per_mole / unit.angstrom, + }, + "reference_energy": { + "u_in": unit.hartree, + "u_out": unit.kilojoule_per_mole, + }, + "formation_energy": { + "u_in": unit.hartree, + "u_out": unit.kilojoule_per_mole, + }, + } + + def _init_record_entries_series(self): + """ + Init the dictionary that defines the format of the data. + + For data efficiency, information for different conformers will be grouped together + To make it clear to the dataset loader which pieces of information are common to all + conformers or which quantities are series (i.e., have different values for each conformer). + These labels will also allow us to define whether a given entry is per-atom, per-molecule, + or is a scalar/string that applies to the entire record. + Options include: + single_rec, e.g., name, n_configs, smiles + single_atom, e.g., atomic_numbers (these are the same for all conformers) + single_mol, e.g., reference energy + series_atom, e.g., charges + series_mol, e.g., dft energy, dipole moment, etc. + These ultimately appear under the "format" attribute in the hdf5 file. + + Examples + >>> series = {'name': 'single_rec', 'atomic_numbers': 'single_atom', + ... 'n_configs': 'single_rec', 'geometry': 'series_atom', 'energy': 'series_mol'} + """ + + self._record_entries_series = { + "name": "single_rec", + "dataset_name": "single_rec", + "source": "single_rec", + "atomic_numbers": "single_atom", + "total_charge": "single_rec", + "n_configs": "single_rec", + "reference_energy": "single_rec", + "molecular_formula": "single_rec", + "canonical_isomeric_explicit_hydrogen_mapped_smiles": "single_rec", + "geometry": "series_atom", + "dft_total_energy": "series_mol", + "dft_total_gradient": "series_atom", + "dft_total_force": "series_atom", + "formation_energy": "series_mol", + "mbis_charges": "series_atom", + "scf_dipole": "series_atom", + } + + # we will use the retry package to allow us to resume download if we lose connection to the server + @retry(delay=1, jitter=1, backoff=2, tries=50, logger=logger, max_delay=10) + def _fetch_singlepoint_from_qcarchive( + self, + dataset_name: str, + specification_name: str, + local_database_name: str, + local_path_dir: str, + force_download: bool, + max_records: Optional[int] = None, + pbar: Optional[tqdm] = None, + ): + """ + Fetches a singlepoint dataset from the MOLSSI QCArchive and stores it in a local sqlite database. + + Parameters + ---------- + dataset_name: str, required + Name of the dataset to fetch from the QCArchive + specification_name: str, required + Name of the specification to fetch from the QCArchive + local_database_name: str, required + Name of the local sqlite database to store the dataset + local_path_dir: str, required + Path to the directory to store the local sqlite database + force_download: bool, required + If True, this will force the software to download the data again, even if present. + max_records: Optional[int], optional, default=None + If set to an integer, 'n', the routine will only process the first 'n' records, useful for unit tests. + Note, conformers of the same molecule are saved in separate records, and thus the number of molecules + that end up in the 'data' list after _process_downloaded is called may be less than unit_testing_max_records. + pbar: Optional[tqdm], optional, default=None + Progress bar to track the download process. + + pbar + + Returns + ------- + + """ + from sqlitedict import SqliteDict + from loguru import logger + from qcportal import PortalClient + + dataset_type = "singlepoint" + client = PortalClient(self.qcarchive_server) + + ds = client.get_dataset(dataset_type=dataset_type, dataset_name=dataset_name) + logger.debug(f"Fetching {dataset_name} from the QCArchive.") + ds.fetch_entry_names() + + entry_names = ds.entry_names + if max_records is None: + max_records = len(entry_names) + with SqliteDict( + f"{local_path_dir}/{local_database_name}", + tablename=specification_name, + autocommit=True, + ) as spice_db: + # defining the db_keys as a set is faster for + # searching to see if a key exists + db_keys = set(spice_db.keys()) + to_fetch = [] + if force_download: + for name in entry_names[0:max_records]: + to_fetch.append(name) + else: + for name in entry_names[0:max_records]: + if name not in db_keys: + to_fetch.append(name) + if pbar is not None: + pbar.total = pbar.total + len(to_fetch) + pbar.refresh() + + # We need a different routine to fetch entries vs records with a give specification + if len(to_fetch) > 0: + if specification_name == "entry": + logger.debug( + f"Fetching {len(to_fetch)} entries from dataset {dataset_name}." + ) + for entry in ds.iterate_entries( + to_fetch, force_refetch=force_download + ): + spice_db[entry.dict()["name"]] = entry + if pbar is not None: + pbar.update(1) + + else: + logger.debug( + f"Fetching {len(to_fetch)} records for {specification_name} from dataset {dataset_name}." + ) + for record in ds.iterate_records( + to_fetch, + specification_names=[specification_name], + force_refetch=force_download, + ): + spice_db[record[0]] = record[2].dict() + if pbar is not None: + pbar.update(1) + + def _calculate_reference_energy_and_charge( + self, smiles: str + ) -> Tuple[unit.Quantity, unit.Quantity]: + """ + Calculate the reference energy for a given molecule, as defined by the SMILES string. + + This routine is taken from + https://github.com/openmm/spice-dataset/blob/f20d4887fa86d8875688d2dfe9bb2a2fc51dd98c/downloader/downloader.py + Reference energies for individual atoms are computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + + Parameters + ---------- + smiles: str, required + SMILES string describing the molecule of interest. + + Returns + ------- + Tuple[unit.Quantity, unit.Quantity] + Returns the reference energy of for the atoms in the molecule (in hartrees) + and the total charge of the molecule (in elementary charge). + """ + + from rdkit import Chem + import numpy as np + + # Reference energies, in hartrees, computed with Psi4 1.5 wB97M-D3BJ/def2-TZVPPD. + + atom_energy = { + "B": { + -1: -24.677421752684776, + 0: -24.671520535482145, + 1: -24.364648707125294, + }, + "Br": {-1: -2574.2451510945853, 0: -2574.1167240829964}, + "C": {-1: -37.91424135791358, 0: -37.87264507233593, 1: -37.45349214963933}, + "Ca": {2: -676.9528465198214}, + "Cl": {-1: -460.3350243496703, 0: -460.1988762285739}, + "F": {-1: -99.91298732343974, 0: -99.78611622985483}, + "H": {-1: -0.5027370838721259, 0: -0.4987605100487531, 1: 0.0}, + "I": {-1: -297.8813829975981, 0: -297.76228914445625}, + "K": {1: -599.8025677513111}, + "Li": {1: -7.285254714046546}, + "Mg": {2: -199.2688420040449}, + "N": { + -1: -54.602291095426494, + 0: -54.62327513368922, + 1: -54.08594142587869, + }, + "Na": {1: -162.11366478783253}, + "O": {-1: -75.17101657391741, 0: -75.11317840410095, 1: -74.60241514396725}, + "P": {0: -341.3059197024934, 1: -340.9258392474849}, + "S": {-1: -398.2405387031612, 0: -398.1599636677874, 1: -397.7746615977658}, + "Si": { + -1: -289.4540686037408, + 0: -289.4131352299586, + 1: -289.1189404777897, + }, + } + default_charge = {} + for symbol in atom_energy: + energies = [ + (energy, charge) for charge, energy in atom_energy[symbol].items() + ] + default_charge[symbol] = sorted(energies)[0][1] + + rdmol = Chem.MolFromSmiles(smiles, sanitize=False) + total_charge = sum(atom.GetFormalCharge() for atom in rdmol.GetAtoms()) + symbol = [atom.GetSymbol() for atom in rdmol.GetAtoms()] + charge = [default_charge[s] for s in symbol] + delta = np.sign(total_charge - sum(charge)) + while delta != 0: + best_index = -1 + best_energy = None + for i in range(len(symbol)): + s = symbol[i] + e = atom_energy[s] + new_charge = charge[i] + delta + + if new_charge in e: + if best_index == -1 or e[new_charge] - e[charge[i]] < best_energy: + best_index = i + best_energy = e[new_charge] - e[charge[i]] + + charge[best_index] += delta + delta = np.sign(total_charge - sum(charge)) + + return ( + sum(atom_energy[s][c] for s, c in zip(symbol, charge)) * unit.hartree, + int(total_charge) * unit.elementary_charge, + ) + + def _check_name_format(self, name: str): + """ + Check if the name of the molecule conforms to the form {name}-{conformer_number}. + If not, we will return false + + Parameters + ---------- + name: str, required + Name of the molecule to check. + + Returns + ------- + bool + True if the name conforms to the form {name}-{conformer_number}, False otherwise. + + """ + import re + + if len(re.findall(r"-[0-9]+", name)) > 0: + # if re.match(r"^[a-zA-Z0-9-_()]+-[0-9]+$", name): + return True + else: + return False + + def _sort_keys( + self, non_error_keys: List[str] + ) -> Tuple[List[str], Dict[str, str], Dict[str, str]]: + """ + This will sort record identifiers such that conformers are listed in numerical order. + + This will, if necessarily also sanitize the key of the molecule, to ensure that we have the following + form {name}-{conformer_number}. In some cases, the original name has a hyphen which would causes issues + with simply splitting based upon a "-" to either get the name or the conformer number. + + The function is called by _process_downloaded. + + Parameters + ---------- + non_error_keys + List of keys that do not have errors that will be sorted. These need to be of the form of {name}-{conformer_number}. + + Returns + ------- + Tuple[List[str], Dict[str, str], Dict[str, str]] + List of sorted keys, dictionary that maps the sanitized key to the original key, and a dictionary that maps the + sorted keys to the molecule name (i.e., drops any conformer numbers from the end). + + """ + # we need to sanitize the names of the molecule, as + # some of the names have a dash in them, for example ALA-ALA-1 + # This will replace any hyphens in the name with an underscore. + # To be able to retain the original name, needed for accessing the record in the sqlite file + # we will create a simple dictionary that maps the sanitized name to the original name. + + non_error_keys_sanitized = [] + original_keys = {} + + for key in non_error_keys: + # check if we have a name of the form {name}-{conformer_number} + if self._check_name_format(key): + s = "_" + d = "-" + temp = key.split("-") + # replace all but the last hyphens with an underscore + temp_key = d.join([s.join(temp[0:-1]), temp[-1]]) + # if we do not have a conformer number at the end of the name, we will simply replace ANY hyphens with an underscore + else: + temp_key = key.replace("-", "_") + + non_error_keys_sanitized.append(temp_key) + original_keys[temp_key] = key + + # We will sort the keys such that conformers are listed in numerical order. + # This is not strictly necessary, but will help to better retain + # connection to the original QCArchive data, where in most cases conformer-id will directly correspond to + # the index of the final array constructed here. + # Note, if the calculation of an individual conformer failed on qcarchive, + # it will have been excluded from the non_error_keys list. As such, in such cases, + # the conformer id in the record name will no longer have a one-to-one correspondence with the + # index of the conformers in the combined arrays. This should not be an issue in terms of training, + # but could cause some confusion when interrogating a specific set of conformers geometries for a molecule. + + sorted_keys = [] + + # often names are of the format {name}-{conformer_number} + # we will first sort by name + # note, if we don't have a conformer number, this will still work + pre_sort = sorted(non_error_keys_sanitized, key=lambda x: (x.split("-")[0])) + + # then sort each molecule by conformer_number + # we'll do this by simple iteration through the list, and when we encounter a new molecule name, we'll sort the + # previous temporary list we generated. + current_val = pre_sort[0].split("-")[0] + temp_list = [] + + for val in pre_sort: + name = val.split("-")[0] + + if name == current_val: + temp_list.append(val) + else: + # we need to check to see if the name actually has a conformer id, otherwise this will fail + # we are going to assume the first entry in the list means the entire list has the right format + if self._check_name_format(temp_list[0]): + sorted_keys += sorted( + temp_list, key=lambda x: int(x.split("-")[-1]) + ) + else: + sorted_keys += temp_list + + # clear out the list and restart + temp_list = [] + current_val = name + temp_list.append(val) + + # sort the final batch + # we need to check to see if the name actually has a conformer id, otherwise this will fail + if self._check_name_format(temp_list[0]): + sorted_keys += sorted(temp_list, key=lambda x: int(x.split("-")[-1])) + + names = {} + + # store the name in a dictionary + for key in sorted_keys: + name = key.split("-")[0] + names[key] = name + + return sorted_keys, original_keys, names + + def _process_downloaded( + self, + local_path_dir: str, + filenames: List[str], + dataset_sources: List[Dict], + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, + ): + """ + Processes a downloaded dataset: extracts relevant information. + + Parameters + ---------- + local_path_dir: str, required + Path to the directory that contains the raw hdf5 datafile + filenames: List[str], required + Names of the raw sqlite files to process, + dataset_sources: List[Dict], required + List of Dicts, where each Dict provides the names of the sqlite file to process ( accessed with key 'name') + and specification where data is stored on qcarchive (accessible with key 'specifications'). + + """ + from tqdm import tqdm + import numpy as np + from sqlitedict import SqliteDict + from loguru import logger + import qcelemental as qcel + from numpy import newaxis + + for filename, dataset_info in zip(filenames, dataset_sources): + input_file_name = f"{local_path_dir}/{filename}" + dataset_name = dataset_info["name"] + specifications = dataset_info["specifications"] + spec = [s for i, s in enumerate(specifications) if s != "entry"] + + non_error_keys = [] + + # identify the set of molecules that do not have errors + with SqliteDict( + input_file_name, tablename=spec[0], autocommit=False + ) as spice_db: + spec_keys = list(spice_db.keys()) + + for key in spec_keys: + if spice_db[key]["status"].value == "complete": + non_error_keys.append(key) + + sorted_keys, original_keys, molecule_names = self._sort_keys(non_error_keys) + + # first read in molecules from entry + with SqliteDict( + input_file_name, tablename="entry", autocommit=False + ) as spice_db: + logger.debug(f"Processing {filename} entries.") + for key in tqdm(sorted_keys): + val = spice_db[original_keys[key]].dict() + name = molecule_names[key] + # if we haven't processed a molecule with this name yet + # we will add to the molecule_names dictionary + if name not in self.molecule_names.keys(): + self.molecule_names[name] = len(self.data) + + data_temp = {} + data_temp["name"] = name + data_temp["source"] = input_file_name.replace(".sqlite", "") + atomic_numbers = [] + for element in val["molecule"]["symbols"]: + atomic_numbers.append( + qcel.periodictable.to_atomic_number(element) + ) + data_temp["atomic_numbers"] = np.array(atomic_numbers).reshape( + -1, 1 + ) + data_temp["molecular_formula"] = val["molecule"]["identifiers"][ + "molecular_formula" + ] + data_temp[ + "canonical_isomeric_explicit_hydrogen_mapped_smiles" + ] = val["molecule"]["extras"][ + "canonical_isomeric_explicit_hydrogen_mapped_smiles" + ] + data_temp["n_configs"] = 1 + data_temp["geometry"] = val["molecule"]["geometry"].reshape( + 1, -1, 3 + ) + ( + data_temp["reference_energy"], + data_temp["total_charge"], + ) = self._calculate_reference_energy_and_charge( + data_temp[ + "canonical_isomeric_explicit_hydrogen_mapped_smiles" + ] + ) + data_temp["dataset_name"] = dataset_name + self.data.append(data_temp) + else: + # if we have already encountered this molecule we need to append to the data + # since we are using numpy we will use vstack to append to the arrays + index = self.molecule_names[name] + + self.data[index]["n_configs"] += 1 + self.data[index]["geometry"] = np.vstack( + ( + self.data[index]["geometry"], + val["molecule"]["geometry"].reshape(1, -1, 3), + ) + ) + + with SqliteDict( + input_file_name, tablename=spec[0], autocommit=False + ) as spice_db: + logger.debug(f"Processing {filename} {spec[0]}.") + + for key in tqdm(sorted_keys): + name = molecule_names[key] + val = spice_db[original_keys[key]] + + index = self.molecule_names[name] + + # note, we will use the convention of names being lowercase + # and spaces denoted by underscore + quantity = "dft total energy" + quantity_o = "dft_total_energy" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = val["properties"][quantity] + else: + self.data[index][quantity_o] = np.vstack( + (self.data[index][quantity_o], val["properties"][quantity]) + ) + + quantity = "dft total gradient" + quantity_o = "dft_total_gradient" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + val["properties"][quantity] + ).reshape(1, -1, 3) + + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array(val["properties"][quantity]).reshape(1, -1, 3), + ) + ) + # we will store force along with gradient + quantity = "dft total gradient" + quantity_o = "dft_total_force" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = -np.array( + val["properties"][quantity] + ).reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + -np.array(val["properties"][quantity]).reshape( + 1, -1, 3 + ), + ) + ) + + quantity = "mbis charges" + quantity_o = "mbis_charges" + if quantity_o not in self.data[index].keys(): + if quantity in val["properties"].keys(): + self.data[index][quantity_o] = np.array( + val["properties"][quantity] + ).reshape(1, -1)[..., newaxis] + + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array(val["properties"][quantity]).reshape(1, -1)[ + ..., newaxis + ], + ) + ) + + quantity = "scf dipole" + quantity_o = "scf_dipole" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + val["properties"][quantity] + ).reshape(1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array(val["properties"][quantity]).reshape(1, 3), + ) + ) + + # typecasting issue in there + + quantity = "dispersion correction energy" + quantity_o = "dispersion_correction_energy" + + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + val["properties"][quantity] + ).reshape(1, 1) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array(float(val["properties"][quantity])).reshape( + 1, 1 + ), + ), + ) + quantity = "dispersion correction gradient" + quantity_o = "dispersion_correction_gradient" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + val["properties"][quantity] + ).reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array(val["properties"][quantity]).reshape(1, -1, 3), + ) + ) + # assign units + for datapoint in self.data: + for key in datapoint.keys(): + if key in self.qm_parameters: + if not isinstance(datapoint[key], unit.Quantity): + datapoint[key] = ( + datapoint[key] * self.qm_parameters[key]["u_in"] + ) + # add in the formation energy defined as: + # dft_total_energy + dispersion_correction_energy - reference_energy + + # the dispersion corrected energy and gradient can be calculated from the raw data + datapoint["dft_total_energy"] = ( + datapoint["dft_total_energy"] + + datapoint["dispersion_correction_energy"] + ) + # we only want to write the dispersion corrected energy to the file to avoid confusion + datapoint.pop("dispersion_correction_energy") + + datapoint["dft_total_gradient"] = ( + datapoint["dft_total_gradient"] + + datapoint["dispersion_correction_gradient"] + ) + # we only want to write the dispersion corrected gradient to the file to avoid confusion + datapoint.pop("dispersion_correction_gradient") + + datapoint["formation_energy"] = ( + datapoint["dft_total_energy"] + - np.array(datapoint["reference_energy"].m * datapoint["n_configs"]) + * datapoint["reference_energy"].u + ) + + if self.convert_units: + self._convert_units() + + if total_conformers is not None or max_conformers_per_record is not None: + conformers_count = 0 + temp_data = [] + for datapoint in self.data: + if total_conformers is not None: + if conformers_count >= total_conformers: + break + n_conformers = datapoint["n_configs"] + if max_conformers_per_record is not None: + n_conformers = min(n_conformers, max_conformers_per_record) + + if total_conformers is not None: + n_conformers = min( + n_conformers, total_conformers - conformers_count + ) + + datapoint["n_configs"] = n_conformers + datapoint["geometry"] = datapoint["geometry"][0:n_conformers] + datapoint["dft_total_energy"] = datapoint["dft_total_energy"][ + 0:n_conformers + ] + datapoint["dft_total_gradient"] = datapoint["dft_total_gradient"][ + 0:n_conformers + ] + datapoint["dft_total_force"] = datapoint["dft_total_force"][ + 0:n_conformers + ] + datapoint["formation_energy"] = datapoint["formation_energy"][ + 0:n_conformers + ] + datapoint["mbis_charges"] = datapoint["mbis_charges"][0:n_conformers] + datapoint["scf_dipole"] = datapoint["scf_dipole"][0:n_conformers] + + temp_data.append(datapoint) + conformers_count += n_conformers + self.data = temp_data + + def process( + self, + force_download: bool = False, + max_records: Optional[int] = None, + max_conformers_per_record: Optional[int] = None, + total_conformers: Optional[int] = None, + n_threads=6, + ) -> None: + """ + Downloads the dataset, extracts relevant information, and writes an hdf5 file. + + Parameters + ---------- + force_download: bool, optional, default=False + If the raw data_file is present in the local_cache_dir, the local copy will be used. + If True, this will force the software to download the data again, even if present. + max_records: int, optional, default=None + If set to an integer, 'n_r', the routine will only process the first 'n_r' records, useful for unit tests. + Can be used in conjunction with max_conformers_per_record and total_conformers. + Note defining this will only fetch from the "SPICE PubChem Set 1 Single Points Dataset v1.2" + max_conformers_per_record: int, optional, default=None + If set to an integer, 'n_c', the routine will only process the first 'n_c' conformers per record, useful for unit tests. + Can be used in conjunction with max_records and total_conformers. + total_conformers: int, optional, default=None + If set to an integer, 'n_t', the routine will only process the first 'n_t' conformers in total, useful for unit tests. + Can be used in conjunction with max_records and max_conformers_per_record. + Note defining this will only fetch from the "SPICE PubChem Set 1 Single Points Dataset v1.2" + n_threads, int, default=6 + Number of concurrent threads for retrieving data from QCArchive + Examples + -------- + >>> spice2_data = SPICE2Curation(hdf5_file_name='spice2_dataset.hdf5', + >>> local_cache_dir='~/datasets/spice2_dataset') + >>> spice2_data.process() + + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + if self.release_version == "2": + # The SPICE dataset is available in the MOLSSI QCArchive + # This will need to load from various datasets, as described on the spice-dataset github page + # see https://github.com/openmm/spice-dataset/blob/f20d4887fa86d8875688d2dfe9bb2a2fc51dd98c/downloader/downloader.py + + dataset_sources = [ + { + "name": "SPICE Solvated Amino Acids Single Points Dataset v1.1", + "specifications": ["entry", "spec_4"], + }, + { + "name": "SPICE Dipeptides Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE DES Monomers Single Points Dataset v1.1", + "specifications": ["entry", "spec_4"], + }, + { + "name": "SPICE DES370K Single Points Dataset v1.0", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE DES370K Single Points Dataset Supplement v1.1", + "specifications": ["entry", "spec_1"], + }, + { + "name": "SPICE PubChem Set 1 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 2 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 3 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 4 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 5 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 6 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Set 7 Single Points Dataset v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE PubChem Set 8 Single Points Dataset v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE PubChem Set 9 Single Points Dataset v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE PubChem Set 10 Single Points Dataset v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE Ion Pairs Single Points Dataset v1.2", + "specifications": ["entry", "spec_3"], + }, + { + "name": "SPICE PubChem Boron Silicon v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE Solvated PubChem Set 1 v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE Water Clusters v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + { + "name": "SPICE Amino Acid Ligand v1.0", + "specifications": ["entry", "wb97m-d3bj/def2-tzvppd"], + }, + ] + + # if we specify the number of records, restrict to a subset so we don't try to access multiple datasets + if max_records is not None: + dataset_sources = [ + { + "name": "SPICE PubChem Set 1 Single Points Dataset v1.3", + "specifications": ["entry", "spec_3"], + }, + ] + threads = [] + local_database_names = [] + + with tqdm() as pbar: + pbar.total = 0 + with ThreadPoolExecutor(max_workers=n_threads) as e: + for i, dataset_info in enumerate(dataset_sources): + dataset_name = dataset_info["name"] + specification_names = dataset_info["specifications"] + local_database_name = f"{dataset_name}.sqlite" + local_database_names.append(local_database_name) + for specification_name in specification_names: + threads.append( + e.submit( + self._fetch_singlepoint_from_qcarchive, + dataset_name=dataset_name, + specification_name=specification_name, + local_database_name=local_database_name, + local_path_dir=self.local_cache_dir, + force_download=force_download, + max_records=max_records, + pbar=pbar, + ) + ) + logger.debug(f"Data fetched.") + self._clear_data() + self.molecule_names.clear() + logger.debug(f"Processing downloaded dataset.") + + self._process_downloaded( + self.local_cache_dir, + local_database_names, + dataset_sources, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + + self._generate_hdf5() diff --git a/modelforge/curation/spice_openff_curation.py b/modelforge/curation/spice_openff_curation.py index 4059531a..ecdcb7f4 100644 --- a/modelforge/curation/spice_openff_curation.py +++ b/modelforge/curation/spice_openff_curation.py @@ -691,7 +691,7 @@ def _process_downloaded( if total_conformers is not None or max_conformers_per_record is not None: conformers_count = 0 - + temp_data = [] for datapoint in self.data: if total_conformers is not None: if conformers_count >= total_conformers: @@ -722,7 +722,9 @@ def _process_downloaded( datapoint["mbis_charges"] = datapoint["mbis_charges"][0:n_conformers] datapoint["scf_dipole"] = datapoint["scf_dipole"][0:n_conformers] + temp_data.append(datapoint) conformers_count += n_conformers + self.data = temp_data def process( self, diff --git a/modelforge/dataset/__init__.py b/modelforge/dataset/__init__.py index 6b476e09..016f8c8a 100644 --- a/modelforge/dataset/__init__.py +++ b/modelforge/dataset/__init__.py @@ -5,8 +5,8 @@ _IMPLEMENTED_DATASETS = [ "QM9", "ANI1X", - # "ANI2X", - # "SPICE114", - # "SPICE2", - # "SPICE114_OPENFF", + "ANI2X", + "SPICE114", + "SPICE2", + "SPICE114_OPENFF", ] diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py index 5247caae..95f4d09d 100644 --- a/modelforge/dataset/ani1x.py +++ b/modelforge/dataset/ani1x.py @@ -129,22 +129,22 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/rqjc6pcv9jjzoq08hc5ao/ani1x_dataset_n100.hdf5.gz?rlkey=kgg0xvq9aac5sp3or9oh61igj&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/26expl20116cqacdk9l1t/ani1x_dataset_ntc_1000.hdf5.gz?rlkey=swciz9dfr7suia6nrsznwbk6i&st=ryqysch3&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/d98h9kt4pl40qeapqzu00/ani1x_dataset.hdf5.gz?rlkey=7q1o8hh9qzbxehsobjurcksit&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "ani1x_dataset_n100.hdf5.gz", - "md5": "51e2491e3c5b7b5a432e2012892cfcbb", - "length": 85445473, + "name": "ani1x_dataset_nc_1000.hdf5.gz", + "md5": "f47a92bf4791607d9fc92a4cf16cd096", + "length": 1761417, } hdf5_data_file = { - "name": "ani1x_dataset_n100.hdf5", - "md5": "f3c934b79f035ecc3addf88c027f5e46", + "name": "ani1x_dataset_nc_1000.hdf5", + "md5": "776d38c18f3aa37b00360556cf8d78cc", } processed_data_file = { - "name": "ani1x_dataset_n100_processed.npz", + "name": "ani1x_dataset_nc_1000_processed.npz", "md5": None, } diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index c2c7957f..36617ce0 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -110,23 +110,23 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/okv311e9yvh94owbiypcm/ani2x_dataset_n100.hdf5.gz?rlkey=pz7gnlncabtzr3b82lblr3yas&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/7zhgtcbaoyw3lnnwy3l4j/ani2x_dataset_ntc_1000.hdf5.gz?rlkey=uqcgl687lfxe4dmpe6tboje7e&st=ui5n77nj&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/egg04dmtho7l1ghqiwn1z/ani2x_dataset.hdf5.gz?rlkey=wq5qjyph5q2k0bn6vza735n19&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "ani2x_dataset_n100.hdf5.gz", - "md5": "093fa23aeb8f8813abd1ec08e9ff83ad", - "length": 22254528, + "name": "ani2x_dataset_nc_1000.hdf5.gz", + "md5": "9f043115c38db3739f7c529f900c0e07", + "length": 174189, } hdf5_data_file = { - "name": "ani2x_dataset_n100.hdf5", - "md5": "4f54caf79e4c946dc3d6d53722d2b966", + "name": "ani2x_dataset_nc_1000.hdf5", + "md5": "bed8b011c080078c15c3e7d79dfa99a3", } processed_data_file = { - "name": "ani2x_dataset_n100_processed.npz", - "md5": "c1481fe9a6b15fb07b961d15411c0ddd", + "name": "ani2x_dataset_nc_1000_processed.npz", + "md5": None, } logger.info("Using test dataset") @@ -146,7 +146,7 @@ def __init__( processed_data_file = { "name": "ani2x_dataset_processed.npz", - "md5": "268438d8e1660728ba892bc7c3cd4339", + "md5": None, } logger.info("Using full dataset") diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index 31e02646..a73f10dd 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -112,24 +112,25 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/9jeselknatcw9xi0qp940/qm9_dataset_n100.hdf5.gz?rlkey=50of7gn2s12i65c6j06r73c97&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/oe2tooxwrkget75zwrfey/qm9_dataset_ntc_1000.hdf5.gz?rlkey=6hfb8ge0pqf4tly15rmdsthmw&st=tusk38vt&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/4wu7zlpuuixttp0u741rv/qm9_dataset.hdf5.gz?rlkey=nszkqt2t4kmghih5mt4ssppvo&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "qm9_dataset_n100.hdf5.gz", - "md5": "af3afda5c3265c9c096935ab060f537a", + "name": "qm9_dataset_nc_1000.hdf5.gz", + "md5": "dc8ada0d808d02c699daf2000aff1fe9", + "length": 1697917, } hdf5_data_file = { - "name": "qm9_dataset_n100.hdf5", - "md5": "77df0e1df7a5ec5629be52181e82a7d7", + "name": "qm9_dataset_nc_1000.hdf5", + "md5": "305a0602860f181fafa75f7c7e3e6de4", } processed_data_file = { - "name": "qm9_dataset_n100_processed.npz", + "name": "qm9_dataset_nc_1000_processed.npz", # checksum of otherwise identical npz files are different if using 3.11 vs 3.9/10 # we will therefore skip checking these files - "md5": "9d671b54f7b9d454db9a3dd7f4ef2020", + "md5": None, } logger.info("Using test dataset") @@ -139,6 +140,7 @@ def __init__( gz_data_file = { "name": "qm9_dataset.hdf5.gz", "md5": "d172127848de114bd9cc47da2bc72566", + "length": 267228348, } hdf5_data_file = { @@ -148,7 +150,7 @@ def __init__( processed_data_file = { "name": "qm9_dataset_processed.npz", - "md5": "62d98cf38bcf02966e1fa2d9e44b3fa0", + "md5": None, } logger.info("Using full dataset") diff --git a/modelforge/dataset/spice114.py b/modelforge/dataset/spice114.py index e32e9fa5..eaa230e5 100644 --- a/modelforge/dataset/spice114.py +++ b/modelforge/dataset/spice114.py @@ -115,7 +115,7 @@ def __init__( # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. - # See spice_2_curation.py for more details. + # See spice_2_from_qcarchive_curation.py for more details. # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything # roughly to the same scale, and values do not vary substantially by charge state. @@ -145,23 +145,23 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/16g7n0f7qgzjhi02g3qce/spice_114_dataset_n100.hdf5.gz?rlkey=gyyc1cd3u8p64icpb450y44qv&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/vh05bql1fza8jyj7ibyk2/spice_114_dataset_ntc_1000.hdf5.gz?rlkey=dqx0eq0wcux0ez48n2et35dow&st=hafqe5sv&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/zfh4sq2kiz250bvd9oshr/spice_114_dataset.hdf5.gz?rlkey=q3sp7p8ir21o0y0224bt75aw7&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "SPICE114_dataset_n100.hdf5.gz", - "md5": "ee7406aaf587340190e90e365ba9ba7b", - "length": 72001865, + "name": "SPICE114_dataset_nc_1000.hdf5.gz", + "md5": "f7027814a98a5393272c45b4cf97f4e9", + "length": 15166190, } hdf5_data_file = { - "name": "SPICE114_dataset_n100.hdf5", - "md5": "88bd3fca0809ca47339c52edda155d6d", + "name": "SPICE114_dataset_nc_1000.hdf5", + "md5": "885e17826e65011559ff6fae2b2b44e3", } # npz file checksums may vary with different versions of python/numpy processed_data_file = { - "name": "SPICE114_dataset_n100_processed.npz", + "name": "SPICE114_dataset_nc_1000_processed.npz", "md5": None, } diff --git a/modelforge/dataset/spice114openff.py b/modelforge/dataset/spice114openff.py index b0b1623a..c8161235 100644 --- a/modelforge/dataset/spice114openff.py +++ b/modelforge/dataset/spice114openff.py @@ -125,7 +125,7 @@ def __init__( # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. - # See spice_2_curation.py for more details. + # See spice_2_from_qcarchive_curation.py for more details. # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything # roughly to the same scale, and values do not vary substantially by charge state. @@ -155,23 +155,23 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/e4lw7gh00i0tyl2mbbv3h/spice_114_openff_dataset_n100.hdf5.gz?rlkey=grnyfuecwl7ur3qs6147h4awo&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/bwk83meoeishr16g8s5bt/spice_114_openff_dataset_ntc_1000.hdf5.gz?rlkey=pd9seffp63xe5f1uenddh3ug1&st=3b9zdhct&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/kmdk4d6hntga7bk7wdqs6/spice_114_openff_dataset.hdf5.gz?rlkey=2mf954dswat4sbpus6vhj9pvz&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "SPICE114OpenFF_dataset_n100.hdf5.gz", - "md5": "8a99718246c178b8f318025ffe0e5560", - "length": 306289237, + "name": "SPICE114OpenFF_dataset_nc_1000.hdf5.gz", + "md5": "d1cb65eff7fa7dc182188731e5ec6bf9", + "length": 2508982, } hdf5_data_file = { - "name": "SPICE114OpenFF_dataset_n100.hdf5", - "md5": "53c0c6db27adf1f11c1d0952624ebdb4", + "name": "SPICE114OpenFF_dataset_nc_1000.hdf5", + "md5": "cb7c69b9aca9e836642f78716aea665b", } # npz file checksums may vary with different versions of python/numpy processed_data_file = { - "name": "SPICE114OpenFF_dataset_n100_processed.npz", + "name": "SPICE114OpenFF_dataset_nc_1000_processed.npz", "md5": None, } diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py index 455d6d71..1d9ee8a3 100644 --- a/modelforge/dataset/spice2.py +++ b/modelforge/dataset/spice2.py @@ -138,7 +138,7 @@ def __init__( # To be able to use the dataset for training in a consistent way with the ANI datasets, we will only consider # the ase values for the uncharged isolated atoms, if available. Ions will use the values Ca 2+, K 1+, Li 1+, Mg 2+, Na 1+. - # See spice_2_curation.py for more details. + # See spice_2_from_qcarchive_curation.py for more details. # We will need to address this further later to see how we best want to handle this; the ASE are just meant to bring everything # roughly to the same scale, and values do not vary substantially by charge state. @@ -170,23 +170,23 @@ def __init__( # There are 3 files types that need name/checksum defined, of extensions hdf5.gz, hdf5, and npz. # note, need to change the end of the url to dl=1 instead of dl=0 (the default when you grab the share list), to ensure the same checksum each time we download - self.test_url = "https://www.dropbox.com/scl/fi/08u7e400qvrq2aklxw2yo/spice_2_dataset_n100.hdf5.gz?rlkey=ifv7hfzqnwl2faef8xxr5ggj2&dl=1" + self.test_url = "https://www.dropbox.com/scl/fi/1jawffjrh17r796g76udi/spice_2_dataset_ntc_1000.hdf5.gz?rlkey=r0crabvyg7xdgapv2qk3hk6t9&st=0ro9na0c&dl=1" self.full_url = "https://www.dropbox.com/scl/fi/udoc3jj7wa7du8jgqiat0/spice_2_dataset.hdf5.gz?rlkey=csgwqa237m002n54jnld5pfgy&dl=1" if self.for_unit_testing: url = self.test_url gz_data_file = { - "name": "SPICE2_dataset_n100.hdf5.gz", - "md5": "6f3f2931d4eb59f7a54f0a11c72bb604", - "length": 315275240, # the number of bytes to be able to display the download progress bar correctly + "name": "SPICE2_dataset_nc_1000.hdf5.gz", + "md5": "04063f08a7ec93abfc661c22b12ceeb0", + "length": 26751220, # the number of bytes to be able to display the download progress bar correctly } hdf5_data_file = { - "name": "SPICE2_dataset_n100.hdf5", - "md5": "ff89646eab99e31447be1697de8b7208", + "name": "SPICE2_dataset_nc_1000.hdf5", + "md5": "0a2554d0dba4f289dd93670686e4842e", } # npz file checksums may vary with different versions of python/numpy processed_data_file = { - "name": "SPICE2_dataset_n100_processed.npz", + "name": "SPICE2_dataset_nc_1000_processed.npz", "md5": None, } @@ -197,7 +197,7 @@ def __init__( gz_data_file = { "name": "SPICE2_dataset.hdf5.gz", "md5": "244a559a6062bbec5c9cb49af036ff7d", - "length": 5532866319, + "length": 26313472231, } hdf5_data_file = { diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 5139f2c4..95c7b8f8 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -473,6 +473,12 @@ class AtomicSelfEnergies: 48: "Cd", 49: "In", 50: "Sn", + 51: "Sb", + 52: "Te", + 53: "I", + 54: "Xe", + 55: "Cs", + 56: "Ba", # Add more elements as needed } ) @@ -960,32 +966,37 @@ def calculate_radial_scale_factor( class SAKERadialSymmetryFunction(RadialSymmetryFunction): def calculate_radial_basis_centers( - self, - _unitless_min_distance, - _unitless_max_distance, - number_of_radial_basis_functions, - dtype, + self, + _unitless_min_distance, + _unitless_max_distance, + number_of_radial_basis_functions, + dtype, ): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 start_value = torch.exp( - torch.scalar_tensor(-_unitless_max_distance + _unitless_min_distance, dtype=dtype) + torch.scalar_tensor( + -_unitless_max_distance + _unitless_min_distance, dtype=dtype + ) + ) + centers = torch.linspace( + start_value, 1, number_of_radial_basis_functions, dtype=dtype ) - centers = torch.linspace(start_value, 1, number_of_radial_basis_functions, dtype=dtype) return centers def calculate_radial_scale_factor( - self, - _unitless_min_distance, - _unitless_max_distance, - number_of_radial_basis_functions, + self, + _unitless_min_distance, + _unitless_max_distance, + number_of_radial_basis_functions, ): start_value = torch.exp( torch.scalar_tensor(-_unitless_max_distance + _unitless_min_distance) ) radial_scale_factor = torch.tensor( - [(2 / number_of_radial_basis_functions * (1 - start_value)) ** -2] * number_of_radial_basis_functions + [(2 / number_of_radial_basis_functions * (1 - start_value)) ** -2] + * number_of_radial_basis_functions ) return radial_scale_factor @@ -995,20 +1006,26 @@ class SAKERadialBasisFunction(RadialBasisFunction): def __init__(self, max_distance, min_distance): super().__init__() self._unitless_min_distance = min_distance.to(unit.nanometer).m - self.alpha = (5.0 * unit.nanometer / (max_distance - min_distance)).to_base_units().m # check units + self.alpha = ( + (5.0 * unit.nanometer / (max_distance - min_distance)).to_base_units().m + ) # check units def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + self, + distances: torch.Tensor, + centers: torch.Tensor, + scale_factors: torch.Tensor, ) -> torch.Tensor: return torch.exp( - -scale_factors * - (torch.exp( - self.alpha * - (-distances.unsqueeze(-1) + self._unitless_min_distance)) - - centers) ** 2 + -scale_factors + * ( + torch.exp( + self.alpha + * (-distances.unsqueeze(-1) + self._unitless_min_distance) + ) + - centers + ) + ** 2 ) @@ -1112,7 +1129,11 @@ def neighbor_list_with_cutoff( def scatter_softmax( - src: torch.Tensor, index: torch.Tensor, dim: int, dim_size: Optional[int] = None, device: Optional[torch.device] = None + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -1139,28 +1160,33 @@ def scatter_softmax( Adapted from: https://github.com/rusty1s/pytorch_scatter/blob/c31915e1c4ceb27b2e7248d21576f685dc45dd01/torch_scatter/composite/softmax.py """ if not torch.is_floating_point(src): - raise ValueError('`scatter_softmax` can only be computed over tensors ' - 'with floating point data types.') + raise ValueError( + "`scatter_softmax` can only be computed over tensors " + "with floating point data types." + ) assert dim >= 0, f"dim must be non-negative, got {dim}" - assert dim < src.dim(), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" + assert ( + dim < src.dim() + ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ - other_dim_size - if (other_dim != dim) - else dim_size - for (other_dim, other_dim_size) - in enumerate(src.shape) + other_dim_size if (other_dim != dim) else dim_size + for (other_dim, other_dim_size) in enumerate(src.shape) ] zeros = torch.zeros(out_shape, dtype=src.dtype, device=device) - max_value_per_index = zeros.scatter_reduce(dim, index, src, "amax", include_self=False) + max_value_per_index = zeros.scatter_reduce( + dim, index, src, "amax", include_self=False + ) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element recentered_scores_exp = recentered_scores.exp() - sum_per_index = torch.zeros(out_shape, dtype=src.dtype, device=device).scatter_add(dim, index, recentered_scores_exp) + sum_per_index = torch.zeros(out_shape, dtype=src.dtype, device=device).scatter_add( + dim, index, recentered_scores_exp + ) normalizing_constants = sum_per_index.gather(dim, index) return recentered_scores_exp.div(normalizing_constants) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index d3b88aae..9396b829 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -64,7 +64,7 @@ def datasets_to_test(request, prep_temp_dir): "internal_energy_at_0K", "charges", ], - expected_E_random_split=-412509.93109875394, + expected_E_random_split=-622027.790147837, expected_E_fcfs_split=-106277.4161215308, ) return datasetDC @@ -82,7 +82,7 @@ def datasets_to_test(request, prep_temp_dir): "wb97x_dz.energy", "wb97x_dz.forces", ], - expected_E_random_split=-1739101.9014184382, + expected_E_random_split=-1652066.552014041, expected_E_fcfs_split=-1015736.8142089575, ) return datasetDC @@ -100,7 +100,7 @@ def datasets_to_test(request, prep_temp_dir): "energies", "forces", ], - expected_E_random_split=-2614282.09174506, + expected_E_random_split=-148410.43286007023, expected_E_fcfs_split=-2096692.258327173, ) return datasetDC @@ -119,7 +119,7 @@ def datasets_to_test(request, prep_temp_dir): "dft_total_force", "mbis_charges", ], - expected_E_random_split=-4289211.145285763, + expected_E_random_split=-1922185.3358204272, expected_E_fcfs_split=-972574.265833225, ) return datasetDC @@ -138,8 +138,8 @@ def datasets_to_test(request, prep_temp_dir): "dft_total_force", "mbis_charges", ], - expected_E_random_split=-2293275.9758066307, - expected_E_fcfs_split=-1517627.6999202403, + expected_E_random_split=-5844365.936898948, + expected_E_fcfs_split=-3418985.278140791, ) return datasetDC elif dataset_name == "SPICE114_OPENFF": @@ -157,7 +157,7 @@ def datasets_to_test(request, prep_temp_dir): "dft_total_force", "mbis_charges", ], - expected_E_random_split=-2011114.830087605, + expected_E_random_split=-2263605.616072006, expected_E_fcfs_split=-1516718.0904709378, ) return datasetDC @@ -165,7 +165,7 @@ def datasets_to_test(request, prep_temp_dir): raise NotImplementedError(f"Dataset {dataset_name} is not implemented.") -@pytest.fixture(params=_DATASETS_TO_TEST) +@pytest.fixture() def initialized_dataset(datasets_to_test): # dataset_name = request.param # if dataset_name == "QM9": @@ -250,8 +250,15 @@ def initialize_dataset( TorchDataModule Initialized TorchDataModule. """ - - data_module = TorchDataModule(dataset, split_file=split_file) + from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy + + # we need to use the first come first serve splitting strategy, as random is default + # using random would make it hard to validate the expected values in the tests + data_module = TorchDataModule( + dataset, + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + split_file=split_file, + ) data_module.prepare_data() return data_module diff --git a/modelforge/tests/test_curation.py b/modelforge/tests/test_curation.py index bc481ac3..577f1985 100644 --- a/modelforge/tests/test_curation.py +++ b/modelforge/tests/test_curation.py @@ -11,7 +11,7 @@ from modelforge.curation.ani1x_curation import ANI1xCuration from modelforge.curation.spice_114_curation import SPICE114Curation from modelforge.curation.spice_openff_curation import SPICEOpenFFCuration -from modelforge.curation.spice_2_curation import SPICE2Curation +from modelforge.curation.spice_2_from_qcarchive_curation import SPICE2Curation from modelforge.curation.curation_baseclass import dict_to_hdf5 diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 894c7582..5504a323 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -248,6 +248,9 @@ def test_caching(prep_temp_dir): def test_metadata_validation(prep_temp_dir): + """When we generate an .npz file, we also write out metadata in a .json file which is used + to validate if we can use .npz file, or we need to regenerate it.""" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.qm9 import QM9Dataset @@ -271,8 +274,8 @@ def test_metadata_validation(prep_temp_dir): metadata = { "data_keys": ["atomic_numbers", "internal_energy_at_0K", "geometry", "charges"], - "hdf5_checksum": "77df0e1df7a5ec5629be52181e82a7d7", - "hdf5_gz_checkusm": "af3afda5c3265c9c096935ab060f537a", + "hdf5_checksum": "305a0602860f181fafa75f7c7e3e6de4", + "hdf5_gz_checkusm": "dc8ada0d808d02c699daf2000aff1fe9", "date_generated": "2024-04-11 14:05:14.297305", } @@ -389,7 +392,7 @@ def test_dataset_generation(initialized_dataset): dataset = initialized_dataset train_dataloader = dataset.train_dataloader() val_dataloader = dataset.val_dataloader() - + test_dataloader = dataset.test_dataloader() try: dataset.test_dataloader() except AttributeError: @@ -397,23 +400,39 @@ def test_dataset_generation(initialized_dataset): pass # the dataloader automatically splits and batches the dataset - # for the training set it batches the 80 datapoints in - # a batch of 64 and a batch of 16 samples - assert len(train_dataloader) == 2 # nr of batches + # for the training set it batches the 800 training datapoints (of 1000 total) in 13 batches + # all with 64 points until the last which has 32 + + assert len(train_dataloader) == 13 # nr of batches batch_data = [v_ for v_ in train_dataloader] + val_data = [v_ for v_ in val_dataloader] + sum_batch = sum([len(b.metadata.atomic_subsystem_counts) for b in batch_data]) + sum_val = sum([len(b.metadata.atomic_subsystem_counts) for b in val_data]) + sum_test = sum([len(b.metadata.atomic_subsystem_counts) for b in test_dataloader]) + + assert sum_batch == 800 + assert sum_val == 100 + assert sum_test == 100 + assert len(batch_data[0].metadata.atomic_subsystem_counts) == 64 - assert len(batch_data[1].metadata.atomic_subsystem_counts) == 16 + assert len(batch_data[1].metadata.atomic_subsystem_counts) == 64 + assert len(batch_data[-1].metadata.atomic_subsystem_counts) == 32 from modelforge.dataset.utils import ( RandomRecordSplittingStrategy, + RandomSplittingStrategy, FirstComeFirstServeSplittingStrategy, ) @pytest.mark.parametrize( "splitting_strategy", - [RandomRecordSplittingStrategy, FirstComeFirstServeSplittingStrategy], + [ + RandomSplittingStrategy, + FirstComeFirstServeSplittingStrategy, + RandomRecordSplittingStrategy, + ], ) def test_dataset_splitting(splitting_strategy, datasets_to_test): """Test random_split on the the dataset.""" @@ -424,7 +443,7 @@ def test_dataset_splitting(splitting_strategy, datasets_to_test): print("local cache dir, ", datasets_to_test.dataset.local_cache_dir) energy = train_dataset[0]["E"].item() - if splitting_strategy == RandomRecordSplittingStrategy: + if splitting_strategy == RandomSplittingStrategy: assert np.isclose(energy, datasets_to_test.expected_E_random_split) elif splitting_strategy == FirstComeFirstServeSplittingStrategy: assert np.isclose(energy, datasets_to_test.expected_E_fcfs_split) @@ -433,16 +452,24 @@ def test_dataset_splitting(splitting_strategy, datasets_to_test): split=[0.6, 0.3, 0.1] ).split(dataset) - # since not all datasets will ultimately have 100 records, since some may include multiple conformers - # associated with each record, we will look at the ratio - total = len(train_dataset2) + len(val_dataset2) + len(test_dataset2) - assert np.isclose(len(train_dataset2) / total / 0.6, 1.0, rtol=0.1) - assert np.isclose(len(val_dataset2) / total / 0.3, 1.0, rtol=0.1) - assert np.isclose(len(test_dataset2) / total / 0.1, 1.0, rtol=0.1) + if ( + splitting_strategy == RandomSplittingStrategy + or splitting_strategy == FirstComeFirstServeSplittingStrategy + ): + total = len(train_dataset2) + len(val_dataset2) + len(test_dataset2) + print(len(train_dataset2), len(val_dataset2), len(test_dataset2), total) + assert np.isclose(len(train_dataset2) / total, 0.6, atol=0.01) + assert np.isclose(len(val_dataset2) / total, 0.3, atol=0.01) + assert np.isclose(len(test_dataset2) / total, 0.1, atol=0.01) + elif splitting_strategy == RandomRecordSplittingStrategy: + # for the random record splitting we need to have a larger tolerance + # as we are not guaranteed to get the exact split since the number of conformers per record is not fixed + total = len(train_dataset2) + len(val_dataset2) + len(test_dataset2) + + assert np.isclose(len(train_dataset2) / total, 0.6, atol=0.05) + assert np.isclose(len(val_dataset2) / total, 0.3, atol=0.05) + assert np.isclose(len(test_dataset2) / total, 0.1, atol=0.05) - # assert len(train_dataset) == 60 - # assert len(val_dataset) == 30 - # assert len(test_dataset) == 10 try: splitting_strategy(split=[0.2, 0.1, 0.1]) except AssertionError as excinfo: @@ -528,16 +555,31 @@ def test_self_energy(): assert dataset.dataset_statistics self_energies = dataset.dataset_statistics.atomic_self_energies # 5 elements present in the total QM9 dataset - # but only 4 in the reduced QM9 dataset - assert len(self_energies) == 4 + assert len(self_energies) == 5 + # value from DFT calculation # H: -1313.4668615546 - assert np.isclose(self_energies[1], -1584.5087457646314) + assert np.isclose( + self_energies[1], + -1577.0870687452618, + ) + # value from DFT calculation # C: -99366.70745535441 - assert np.isclose(self_energies[6], -99960.8894178209) + assert np.isclose( + self_energies[6], + -99977.40806211969, + ) + # value from DFT calculation # N: -143309.9379722722 - assert np.isclose(self_energies[7], -143754.02638655982) + assert np.isclose( + self_energies[7], + -143742.7416655554, + ) + # value from DFT calculation # O: -197082.0671774158 - assert np.isclose(self_energies[8], -197495.00132926635) + assert np.isclose( + self_energies[8], + -197492.33270235246, + ) dataset.prepare_data( remove_self_energies=True, normalize=False, regression_ase=True diff --git a/scripts/dataset_curation.py b/scripts/dataset_curation.py index a513a910..cbb19880 100644 --- a/scripts/dataset_curation.py +++ b/scripts/dataset_curation.py @@ -3,7 +3,9 @@ def SPICE_2( output_file_dir: str, local_cache_dir: str, force_download: bool = False, - unit_testing_max_records=None, + max_records=None, + max_conformers_per_record=None, + total_conformers=None, ): """ This Fetches the SPICE 2 dataset from MOLSSI QCArchive and processes it into a curated hdf5 file. @@ -65,14 +67,15 @@ def SPICE_2( output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, ) - if unit_testing_max_records is None: - spice_2_data.process(force_download=force_download, n_threads=4) - else: - spice_2_data.process( - force_download=force_download, - unit_testing_max_records=unit_testing_max_records, - n_threads=4, - ) + + spice_2_data.process( + force_download=force_download, + max_records=max_records, + max_conformers_per_record=max_conformers_per_record, + total_conformers=total_conformers, + ) + print(f"Total records: {spice_2_data.total_records}") + print(f"Total conformers: {spice_2_data.total_conformers}") def SPICE_114_OpenFF( @@ -140,6 +143,7 @@ def SPICE_114_OpenFF( max_records=max_records, max_conformers_per_record=max_conformers_per_record, total_conformers=total_conformers, + n_threads=1, ) print(f"Total records: {spice_dataset.total_records}") print(f"Total conformers: {spice_dataset.total_conformers}") @@ -415,23 +419,39 @@ def ANI2x( # # SPICE 2 dataset # local_cache_dir = f"{local_prefix}/spice2_dataset" +# hdf5_file_name = "spice_2_dataset_ntc_1000.hdf5" +# +# SPICE_2( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# max_conformers_per_record=10, +# total_conformers=1000, +# ) + # hdf5_file_name = "spice_2_dataset.hdf5" # -# SPICE_2(hdf5_file_name, output_file_dir, local_cache_dir, force_download=False) +# SPICE_2( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# ) # # SPICE 1.1.4 OpenFF dataset -local_cache_dir = f"{local_prefix}/spice_openff_dataset" -hdf5_file_name = "spice_114_openff_dataset_ntc_1000.hdf5" - -SPICE_114_OpenFF( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - max_records=10000, - total_conformers=1000, - max_conformers_per_record=10, -) +# local_cache_dir = f"{local_prefix}/spice_openff_dataset" +# hdf5_file_name = "spice_114_openff_dataset_ntc_1000.hdf5" +# +# SPICE_114_OpenFF( +# hdf5_file_name, +# output_file_dir, +# local_cache_dir, +# force_download=False, +# max_records=10000, +# total_conformers=1000, +# max_conformers_per_record=10, +# ) # # SPICE 1.1.4 dataset # local_cache_dir = f"{local_prefix}/spice_114_dataset" From 9b908f837d50e9557b95ca02ec1578688bd167a2 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 23:08:57 -0700 Subject: [PATCH 26/37] updated curation tests. --- modelforge/curation/ani1x_curation.py | 2 +- modelforge/curation/ani2x_curation.py | 2 +- modelforge/curation/spice_114_curation.py | 2 +- modelforge/tests/test_curation.py | 66 +++++++++++++++++++++-- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/modelforge/curation/ani1x_curation.py b/modelforge/curation/ani1x_curation.py index d62c57d2..02784bea 100644 --- a/modelforge/curation/ani1x_curation.py +++ b/modelforge/curation/ani1x_curation.py @@ -379,7 +379,7 @@ def process( """ if max_records is not None and total_conformers is not None: - raise ValueError( + raise Exception( "max_records and total_conformers cannot be set at the same time." ) diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index a7720e33..2804cf7b 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -258,7 +258,7 @@ def process( """ if max_records is not None and total_conformers is not None: - raise ValueError( + raise Exception( "max_records and total_conformers cannot be set at the same time." ) diff --git a/modelforge/curation/spice_114_curation.py b/modelforge/curation/spice_114_curation.py index 28e4c70f..ec3f2d38 100644 --- a/modelforge/curation/spice_114_curation.py +++ b/modelforge/curation/spice_114_curation.py @@ -326,7 +326,7 @@ def process( """ if max_records is not None and total_conformers is not None: - raise ValueError( + raise Exception( "max_records and total_conformers cannot be set at the same time." ) from modelforge.utils.remote import download_from_zenodo diff --git a/modelforge/tests/test_curation.py b/modelforge/tests/test_curation.py index 577f1985..edde0463 100644 --- a/modelforge/tests/test_curation.py +++ b/modelforge/tests/test_curation.py @@ -12,6 +12,7 @@ from modelforge.curation.spice_114_curation import SPICE114Curation from modelforge.curation.spice_openff_curation import SPICEOpenFFCuration from modelforge.curation.spice_2_from_qcarchive_curation import SPICE2Curation +from modelforge.curation.spice_2_curation import SPICE2Curation as SPICE2CurationH5 from modelforge.curation.curation_baseclass import dict_to_hdf5 @@ -464,14 +465,13 @@ def test_qm9_local_archive(prep_temp_dir): qm9_data._clear_data() qm9_data._process_downloaded( str(prep_temp_dir), - max_records=2, total_conformers=5, ) assert qm9_data.total_conformers == 5 assert len(qm9_data.data) == 5 -def test_an1_process_download_short(prep_temp_dir): +def test_ani1_process_download_short(prep_temp_dir): # first check where we don't convert units ani1_data = ANI1xCuration( hdf5_file_name="test_dataset.hdf5", @@ -516,9 +516,28 @@ def test_an1_process_download_short(prep_temp_dir): # test max records exclusion ani1_data._process_downloaded(str(local_data_path), hdf5_file, max_records=2) assert len(ani1_data.data) == 2 + ani1_data._clear_data() + ani1_data._process_downloaded( + str(local_data_path), hdf5_file, max_records=2, max_conformers_per_record=1 + ) + assert ani1_data.total_conformers == 2 + + ani1_data._clear_data() + ani1_data._process_downloaded( + str(local_data_path), hdf5_file, max_records=3, max_conformers_per_record=2 + ) + assert ani1_data.total_conformers > 3 + + ani1_data._clear_data() + ani1_data._process_downloaded( + str(local_data_path), hdf5_file, total_conformers=5, max_conformers_per_record=2 + ) + assert ani1_data.total_conformers == 5 + with pytest.raises(Exception): + ani1_data.process(max_records=10, total_conformers=5) -def test_an1_process_download_no_conversion(prep_temp_dir): +def test_ani1_process_download_no_conversion(prep_temp_dir): from numpy import array, float32, uint8 from openff.units import unit @@ -1481,6 +1500,34 @@ def test_spice114_process_download_conversion(prep_temp_dir): * unit.parse_expression("elementary_charge * nanometer ** 2"), ) ) + spice_data._clear_data() + spice_data._process_downloaded( + str(local_data_path), hdf5_file, max_records=1, max_conformers_per_record=1 + ) + assert spice_data.total_conformers == 1 + + spice_data._clear_data() + spice_data._process_downloaded( + str(local_data_path), hdf5_file, max_records=2, max_conformers_per_record=1 + ) + assert spice_data.total_conformers == 2 + + spice_data._clear_data() + spice_data._process_downloaded( + str(local_data_path), hdf5_file, total_conformers=4, max_conformers_per_record=2 + ) + assert spice_data.total_conformers == 4 + assert spice_data.total_records == 2 + + spice_data._clear_data() + + spice_data._process_downloaded( + str(local_data_path), hdf5_file, max_records=1, max_conformers_per_record=1 + ) + assert spice_data.total_conformers == 1 + + with pytest.raises(Exception): + spice_data.process(max_records=2, total_conformers=1) def test_ani2x(prep_temp_dir): @@ -1538,6 +1585,17 @@ def test_ani2x(prep_temp_dir): [[0.0, 0.0, -0.08543934673070908], [0.0, 0.0, 0.009493260644376278]] ) ) + ani2x_dataset._clear_data() + ani2x_dataset._process_downloaded(local_data_path, filename, total_conformers=10) + assert ani2x_dataset.total_conformers == 10 + + ani2x_dataset._clear_data() + ani2x_dataset._process_downloaded( + local_data_path, filename, max_records=2, max_conformers_per_record=2 + ) + assert ani2x_dataset.total_conformers > 2 + with pytest.raises(Exception): + ani2x_dataset.process(max_records=2, total_conformers=1) def test_spice114_openff_test_fetching(prep_temp_dir): @@ -1741,7 +1799,7 @@ def test_spice114_openff_test_process_downloaded(prep_temp_dir): local_database_name=local_database_name, local_path_dir=local_path_dir, force_download=True, - umax_records=2, + max_records=2, ) spice_openff_data._process_downloaded( From 0764120c4fef8492a7ce648594a4fe30c8c416d7 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 23:27:44 -0700 Subject: [PATCH 27/37] merging --- modelforge/potential/processing.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 217c2059..0b06ce41 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -111,6 +111,12 @@ class AtomicSelfEnergies: 48: "Cd", 49: "In", 50: "Sn", + 51: "Sb", + 52: "Te", + 53: "I", + 54: "Xe", + 55: "Cs", + 56: "Ba", # Add more elements as needed } ) From 1af1f2e8b22ae6048bcd1d7c78b3cd432c558984 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 23:31:28 -0700 Subject: [PATCH 28/37] fixing formatting issue from merge. --- modelforge/potential/utils.py | 232 +++++++++++++++++----------------- 1 file changed, 116 insertions(+), 116 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index c560b90b..1491e122 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -52,9 +52,9 @@ class NNPInput: total_charge: torch.Tensor def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): """Move all tensors in this instance to the specified device/dtype.""" @@ -134,7 +134,7 @@ class Metadata: F: torch.Tensor = torch.tensor([], dtype=torch.float32) def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Move all tensors in this instance to the specified device.""" if device: @@ -156,9 +156,9 @@ class BatchData: metadata: Metadata def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) self.metadata = self.metadata.to(device=device, dtype=dtype) @@ -177,7 +177,7 @@ def shared_config_prior(): def triple_by_molecule( - atom_pairs: torch.Tensor, + atom_pairs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and @@ -213,8 +213,8 @@ def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) ) mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) ).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) @@ -291,13 +291,13 @@ class Dense(nn.Linear): """ def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Optional[nn.Module] = None, - weight_init: Callable = xavier_uniform_, - bias_init: Callable = zeros_, + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Optional[nn.Module] = None, + weight_init: Callable = xavier_uniform_, + bias_init: Callable = zeros_, ): """ Args: @@ -362,7 +362,7 @@ def forward(self, d_ij: torch.Tensor): """ # Compute values of cutoff function input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 ) # NOTE: ANI adds 0.5 instead of 1. # Remove contributions beyond the cutoff radius input_cut *= (d_ij < self.cutoff).float() @@ -372,7 +372,6 @@ def forward(self, d_ij: torch.Tensor): from typing import Dict - class ShiftedSoftplus(nn.Module): def __init__(self): super().__init__() @@ -407,13 +406,13 @@ class AngularSymmetryFunction(nn.Module): """ def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, + self, + max_distance: unit.Quantity, + min_distance: unit.Quantity, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, ) -> None: """ Parameters @@ -537,21 +536,21 @@ def compute(self, distances, centers, scale_factors): class GaussianRadialBasisFunction(RadialBasisFunction): def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + self, + distances: torch.Tensor, + centers: torch.Tensor, + scale_factors: torch.Tensor, ) -> torch.Tensor: diff = distances - centers - return torch.exp((-1 * scale_factors) * diff ** 2) + return torch.exp((-1 * scale_factors) * diff**2) class DoubleExponentialRadialBasisFunction(RadialBasisFunction): def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + self, + distances: torch.Tensor, + centers: torch.Tensor, + scale_factors: torch.Tensor, ) -> torch.Tensor: diff = distances - centers return torch.exp(-torch.abs(diff / scale_factors)) @@ -559,13 +558,13 @@ def compute( class RadialSymmetryFunction(nn.Module): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), ): """RadialSymmetryFunction class. @@ -630,11 +629,11 @@ def initialize_parameters(self): self.register_buffer("prefactor", torch.tensor([1.0])) def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): # the default approach to calculate radial basis centers # can be overwritten by subclasses @@ -647,10 +646,10 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): # the default approach to calculate radial scale factors (each of them are scaled by the same value) # can be overwritten by subclasses @@ -683,13 +682,13 @@ def forward(self, d_ij: torch.Tensor) -> torch.Tensor: class SchnetRadialSymmetryFunction(RadialSymmetryFunction): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), ): """RadialSymmetryFunction class. @@ -710,10 +709,10 @@ def __init__( self.prefactor = torch.tensor([1.0]) def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -722,8 +721,8 @@ def calculate_radial_scale_factor( ) widths = ( - torch.abs(scale_factors[1] - scale_factors[0]) - * torch.ones_like(scale_factors) + torch.abs(scale_factors[1] - scale_factors[0]) + * torch.ones_like(scale_factors) ).to(self.dtype) scale_factors = 0.5 / torch.square_(widths) @@ -732,13 +731,13 @@ def calculate_radial_scale_factor( class AniRadialSymmetryFunction(RadialSymmetryFunction): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), ): """RadialSymmetryFunction class. @@ -759,11 +758,11 @@ def __init__( self.prefactor = torch.tensor([0.25]) def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): centers = torch.linspace( _min_distance_in_nanometer, @@ -775,10 +774,10 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): # ANI uses a predefined scaling factor scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100)) @@ -787,11 +786,11 @@ def calculate_radial_scale_factor( class SAKERadialSymmetryFunction(RadialSymmetryFunction): def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 @@ -808,10 +807,10 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): start_value = torch.exp( torch.scalar_tensor( @@ -819,7 +818,10 @@ def calculate_radial_scale_factor( ) ) # NOTE: this is defined in Angstrom radial_scale_factor = torch.tensor( - torch.full((number_of_radial_basis_functions, ), (2 / number_of_radial_basis_functions * (1 - start_value)) ** -2) + torch.full( + (number_of_radial_basis_functions,), + (2 / number_of_radial_basis_functions * (1 - start_value)) ** -2, + ) ) return radial_scale_factor @@ -828,7 +830,6 @@ class SAKERadialBasisFunction(RadialBasisFunction): def __init__(self, min_distance): super().__init__() -) # check units self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m def compute( @@ -840,11 +841,10 @@ def compute( return torch.exp( -scale_factors * ( - torch.exp( - (-distances.unsqueeze(-1) + self._min_distance_in_nanometer) - * 10 - ) - - centers + torch.exp( + (-distances.unsqueeze(-1) + self._min_distance_in_nanometer) * 10 + ) + - centers ) ** 2 ) @@ -853,13 +853,13 @@ def compute( class PhysNetRadialSymmetryFunction(SAKERadialSymmetryFunction): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: Optional[SAKERadialBasisFunction] = None, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: Optional[SAKERadialBasisFunction] = None, ): """RadialSymmetryFunction class. @@ -884,8 +884,8 @@ def __init__( def pair_list( - atomic_subsystem_indices: torch.Tensor, - only_unique_pairs: bool = False, + atomic_subsystem_indices: torch.Tensor, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -926,7 +926,7 @@ def pair_list( # filter pairs to only keep those belonging to the same molecule same_molecule_mask = ( - atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] + atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] ) # Apply mask to get final pair indices @@ -940,10 +940,10 @@ def pair_list( def neighbor_list_with_cutoff( - coordinates: torch.Tensor, # in nanometer - atomic_subsystem_indices: torch.Tensor, - cutoff: unit.Quantity, - only_unique_pairs: bool = False, + coordinates: torch.Tensor, # in nanometer + atomic_subsystem_indices: torch.Tensor, + cutoff: unit.Quantity, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -980,11 +980,11 @@ def neighbor_list_with_cutoff( def scatter_softmax( - src: torch.Tensor, - index: torch.Tensor, - dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -1018,7 +1018,7 @@ def scatter_softmax( assert dim >= 0, f"dim must be non-negative, got {dim}" assert ( - dim < src.dim() + dim < src.dim() ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ From 8fa058bf2939b301f44984a45330a90d8ba1dfa4 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 23:47:06 -0700 Subject: [PATCH 29/37] AtomicSelfEnergies was moved from utils to processing; updated processing to have my improved class that uses units. --- modelforge/potential/processing.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 0b06ce41..6b447095 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -43,6 +43,8 @@ def forward( from dataclasses import dataclass, field from typing import Dict, Iterator +from openff.units import unit + @dataclass class AtomicSelfEnergies: @@ -57,7 +59,7 @@ class AtomicSelfEnergies: # We provide a dictionary with {str:float} of element name to atomic self-energy, # which can then be accessed by atomic index or element name - energies: Dict[str, float] = field(default_factory=dict) + energies: Dict[str, unit.Quantity] = field(default_factory=dict) # Example mapping, replace or extend as necessary atomic_number_to_element: Dict[int, str] = field( default_factory=lambda: { @@ -123,17 +125,24 @@ class AtomicSelfEnergies: _ase_tensor_for_indexing = None def __getitem__(self, key): + from modelforge.utils.units import chem_context + if isinstance(key, int): # Convert atomic number to element symbol element = self.atomic_number_to_element.get(key) if element is None: raise KeyError(f"Atomic number {key} not found.") - return self.energies.get(element) + if self.energies.get(element) is None: + return None + return self.energies.get(element).to(unit.kilojoule_per_mole, "chem").m elif isinstance(key, str): # Directly access by element symbol if key not in self.energies: raise KeyError(f"Element {key} not found.") - return self.energies[key] + if self.energies[key] is None: + return None + + return self.energies[key].to(unit.kilojoule_per_mole, "chem").m else: raise TypeError( "Key must be an integer (atomic number) or string (element name)." @@ -141,9 +150,11 @@ def __getitem__(self, key): def __iter__(self) -> Iterator[Dict[str, float]]: """Iterate over the energies dictionary.""" + from modelforge.utils.units import chem_context + for element, energy in self.energies.items(): atomic_number = self.element_to_atomic_number(element) - yield (atomic_number, energy) + yield (atomic_number, energy.to(unit.kilojoule_per_mole, "chem").m) def __len__(self) -> int: """Return the number of element-energy pairs.""" From d0797487d7ab6ccb9b172fd35bf4d4397d557204 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 1 May 2024 23:53:29 -0700 Subject: [PATCH 30/37] Fixed imports for self energies having moved to processing. --- modelforge/dataset/ani1x.py | 2 +- modelforge/dataset/ani2x.py | 2 +- modelforge/dataset/dataset.py | 2 +- modelforge/dataset/spice114.py | 2 +- modelforge/dataset/spice114openff.py | 2 +- modelforge/dataset/spice2.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py index 95f4d09d..3f457cd6 100644 --- a/modelforge/dataset/ani1x.py +++ b/modelforge/dataset/ani1x.py @@ -180,7 +180,7 @@ def __init__( @property def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies return AtomicSelfEnergies(energies=self._ase) diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index 36617ce0..82105c52 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -164,7 +164,7 @@ def __init__( @property def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies return AtomicSelfEnergies(energies=self._ase) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index bbe633df..9bc8ceed 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -14,7 +14,7 @@ from dataclasses import dataclass if TYPE_CHECKING: - from modelforge.potential.utils import BatchData, AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies @dataclass diff --git a/modelforge/dataset/spice114.py b/modelforge/dataset/spice114.py index eaa230e5..991b3150 100644 --- a/modelforge/dataset/spice114.py +++ b/modelforge/dataset/spice114.py @@ -200,7 +200,7 @@ def __init__( @property def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies return AtomicSelfEnergies(energies=self._ase) diff --git a/modelforge/dataset/spice114openff.py b/modelforge/dataset/spice114openff.py index c8161235..751fc053 100644 --- a/modelforge/dataset/spice114openff.py +++ b/modelforge/dataset/spice114openff.py @@ -210,7 +210,7 @@ def __init__( @property def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies return AtomicSelfEnergies(energies=self._ase) diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py index 1d9ee8a3..f2d2d09d 100644 --- a/modelforge/dataset/spice2.py +++ b/modelforge/dataset/spice2.py @@ -225,7 +225,7 @@ def __init__( @property def atomic_self_energies(self): - from modelforge.potential.utils import AtomicSelfEnergies + from modelforge.potential.processing import AtomicSelfEnergies return AtomicSelfEnergies(energies=self._ase) From a2af7f15bc9d3845d5ef85d4dc8e87b91bb49808 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 00:01:20 -0700 Subject: [PATCH 31/37] Fixed imports for self energies having moved to processing. --- modelforge/tests/conftest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 9396b829..81ab5d41 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -176,8 +176,7 @@ def initialized_dataset(datasets_to_test): return initialize_dataset(dataset) -@pytest.fixture(params=_DATASETS_TO_TEST) -def batch(initialized_dataset, request): +def batch(initialized_dataset): """py Fixture to obtain a single batch from an initialized dataset. From 1f6a2dc74cf28a3dae79cdf885662a25621279da Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 08:05:10 -0700 Subject: [PATCH 32/37] restricting test_training to qm9 (not all NNPs are compatible with all datasets at the current moment). --- modelforge/tests/conftest.py | 4 ++-- modelforge/tests/test_training.py | 5 ++--- modelforge/tests/test_utils.py | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 81ab5d41..c95b7864 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -189,10 +189,10 @@ def batch(initialized_dataset): # Fixture for initializing QM9Dataset @pytest.fixture -def qm9_dataset(): +def qm9_dataset(prep_temp_dir): from modelforge.dataset import QM9Dataset - dataset = QM9Dataset(for_unit_testing=True) + dataset = QM9Dataset(for_unit_testing=True, local_cache_dir=str(prep_temp_dir)) return dataset diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index f8ad1ea3..f1776b0d 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -9,7 +9,7 @@ @pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") -def test_train_with_lightning(train_model, initialized_dataset): +def test_train_with_lightning(train_model, qm9_dataset): """ Test the forward pass for a given model and dataset. @@ -37,9 +37,8 @@ def test_train_with_lightning(train_model, initialized_dataset): ) - @pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") -def test_hypterparameter_tuning_with_ray(train_model, initialized_dataset): +def test_hypterparameter_tuning_with_ray(train_model, qm9_dataset): train_model.tune_with_ray( train_dataloader=initialized_dataset.train_dataloader(), diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index 8fa6b440..480fdcaf 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -8,6 +8,7 @@ def test_ase_dataclass(): from modelforge.potential.processing import AtomicSelfEnergies + from openff.units import unit # Example usage atomic_self_energies = AtomicSelfEnergies( From cec96e77703787d3de97e10fd9c4355796830bd8 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 08:20:50 -0700 Subject: [PATCH 33/37] restricting test_training to qm9 (not all NNPs are compatible with all datasets at the current moment). --- modelforge/tests/test_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index f1776b0d..bbe0aad6 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -32,8 +32,8 @@ def test_train_with_lightning(train_model, qm9_dataset): # Run training loop and validate trainer.fit( model, - initialized_dataset.train_dataloader(), - initialized_dataset.val_dataloader(), + qm9_dataset.train_dataloader(), + qm9_dataset.val_dataloader(), ) @@ -41,8 +41,8 @@ def test_train_with_lightning(train_model, qm9_dataset): def test_hypterparameter_tuning_with_ray(train_model, qm9_dataset): train_model.tune_with_ray( - train_dataloader=initialized_dataset.train_dataloader(), - val_dataloader=initialized_dataset.val_dataloader(), + train_dataloader=qm9_dataset.train_dataloader(), + val_dataloader=qm9_dataset.val_dataloader(), number_of_ray_workers=1, number_of_epochs=1, number_of_samples=1, From 73a36cb2b927502048a4f116caca09cfc0b65697 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 08:30:35 -0700 Subject: [PATCH 34/37] restricting test_training to qm9 (not all NNPs are compatible with all datasets at the current moment). --- modelforge/tests/conftest.py | 5 +++++ modelforge/tests/test_training.py | 12 ++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index c95b7864..3b20f4db 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -196,6 +196,11 @@ def qm9_dataset(prep_temp_dir): return dataset +@pytest.fixture +def initialized_qm9_dataset(qm9_dataset): + return initialize_dataset(qm9_dataset) + + # Fixture for generating simplified input data @pytest.fixture(params=["methane", "qm9_batch"]) def simplified_input_data(request, qm9_batch): diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index bbe0aad6..ef0916bf 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -9,7 +9,7 @@ @pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") -def test_train_with_lightning(train_model, qm9_dataset): +def test_train_with_lightning(train_model, initialized_qm9_dataset): """ Test the forward pass for a given model and dataset. @@ -32,17 +32,17 @@ def test_train_with_lightning(train_model, qm9_dataset): # Run training loop and validate trainer.fit( model, - qm9_dataset.train_dataloader(), - qm9_dataset.val_dataloader(), + initialized_qm9_dataset.train_dataloader(), + initialized_qm9_dataset.val_dataloader(), ) @pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") -def test_hypterparameter_tuning_with_ray(train_model, qm9_dataset): +def test_hypterparameter_tuning_with_ray(train_model, initialized_qm9_dataset): train_model.tune_with_ray( - train_dataloader=qm9_dataset.train_dataloader(), - val_dataloader=qm9_dataset.val_dataloader(), + train_dataloader=initialized_qm9_dataset.train_dataloader(), + val_dataloader=initialized_qm9_dataset.val_dataloader(), number_of_ray_workers=1, number_of_epochs=1, number_of_samples=1, From 485f8e7ad0d5240b25c47d32de06654e03dd6484 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 08:55:14 -0700 Subject: [PATCH 35/37] fixing testing issues. --- modelforge/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 3b20f4db..b0c072eb 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -176,6 +176,7 @@ def initialized_dataset(datasets_to_test): return initialize_dataset(dataset) +@pytest.fixture() def batch(initialized_dataset): """py Fixture to obtain a single batch from an initialized dataset. From 555e9bf8a76ba56c48ff4d03a189f662c3071b2f Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 11:40:51 -0700 Subject: [PATCH 36/37] CI tests keep stopping for unknown reasons. Reducing number of datasets tested in test_models.py to qm9 and ani2x test sets. --- modelforge/tests/conftest.py | 29 ++++++++++++++++++++++++++++- modelforge/tests/test_models.py | 8 +++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index b0c072eb..468c4ddc 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -8,6 +8,7 @@ from dataclasses import dataclass _DATASETS_TO_TEST = [name for name in _IMPLEMENTED_DATASETS] +_DATASETS_TO_TEST_QM9_ANI2X = ["QM9", "ANI2X"] _MODELS_TO_TEST = [name for name in _IMPLEMENTED_NNPS] from modelforge.potential.utils import BatchData @@ -188,7 +189,32 @@ def batch(initialized_dataset): return batch -# Fixture for initializing QM9Dataset +@pytest.fixture(params=_DATASETS_TO_TEST_QM9_ANI2X) +def QM9_ANI2X_to_test(request, prep_temp_dir): + dataset_name = request.param + if dataset_name == "QM9": + from modelforge.dataset.qm9 import QM9Dataset + + return QM9Dataset(for_unit_testing=True, local_cache_dir=str(prep_temp_dir)) + + elif dataset_name == "ANI2X": + from modelforge.dataset.ani2x import ANI2xDataset + + return ANI2xDataset(for_unit_testing=True, local_cache_dir=str(prep_temp_dir)) + + +@pytest.fixture() +def initialized_QM9_ANI2X_dataset(QM9_ANI2X_to_test): + return initialize_dataset(QM9_ANI2X_to_test) + + +@pytest.fixture() +def batch_QM9_ANI2x(initialized_QM9_ANI2X_dataset): + batch = return_single_batch(initialized_QM9_ANI2X_dataset) + return batch + + +# Fixture for setting up QM9Dataset @pytest.fixture def qm9_dataset(prep_temp_dir): from modelforge.dataset import QM9Dataset @@ -197,6 +223,7 @@ def qm9_dataset(prep_temp_dir): return dataset +# fixture for initializing QM9Dataset @pytest.fixture def initialized_qm9_dataset(qm9_dataset): return initialize_dataset(qm9_dataset) diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index daa4e959..f64cce95 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -65,7 +65,7 @@ def test_energy_scaling_and_offset(): ) -def test_forward_pass(inference_model, batch): +def test_forward_pass(inference_model, batch_QM9_ANI2x): # this test sends a single batch from different datasets through the model nnp_input = batch.nnp_input @@ -78,7 +78,7 @@ def test_forward_pass(inference_model, batch): assert len(output) == nr_of_mols -def test_calculate_energies_and_forces(inference_model, batch): +def test_calculate_energies_and_forces(inference_model, batch_QM9_ANI2x): """ Test the calculation of energies and forces for a molecule. """ @@ -347,7 +347,9 @@ def test_casting(batch, inference_model): model(nnp_input) -def test_equivariant_energies_and_forces(batch, inference_model, equivariance_utils): +def test_equivariant_energies_and_forces( + batch_QM9_ANI2x, inference_model, equivariance_utils +): """ Test the calculation of energies and forces for a molecule. NOTE: test will be adapted once we have a trained model. From 0aa23ddd6b35625b5f2dc81ceacfa8ca1f36e53b Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 2 May 2024 12:04:52 -0700 Subject: [PATCH 37/37] CI tests keep stopping for unknown reasons. Reducing number of datasets tested in test_models.py to qm9 and ani2x test sets. --- modelforge/tests/test_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index f64cce95..41e406af 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -68,6 +68,7 @@ def test_energy_scaling_and_offset(): def test_forward_pass(inference_model, batch_QM9_ANI2x): # this test sends a single batch from different datasets through the model + batch = batch_QM9_ANI2x nnp_input = batch.nnp_input nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] @@ -82,6 +83,7 @@ def test_calculate_energies_and_forces(inference_model, batch_QM9_ANI2x): """ Test the calculation of energies and forces for a molecule. """ + batch = batch_QM9_ANI2x import torch nnp_input = batch.nnp_input @@ -319,10 +321,11 @@ def test_pairlist_on_dataset(initialized_dataset): assert shapePairlist[0] == 2 -def test_casting(batch, inference_model): +def test_casting(batch_QM9_ANI2x, inference_model): # test dtype casting import torch + batch = batch_QM9_ANI2x batch_ = batch.to(dtype=torch.float64) assert batch_.nnp_input.positions.dtype == torch.float64 batch_ = batch_.to(dtype=torch.float32) @@ -354,6 +357,7 @@ def test_equivariant_energies_and_forces( Test the calculation of energies and forces for a molecule. NOTE: test will be adapted once we have a trained model. """ + batch = batch_QM9_ANI2x import torch from dataclasses import replace