diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 9c71f87..ef5dec6 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -28,6 +28,7 @@ from dataclasses import dataclass, field from indigo import Indigo +from indigo import IndigoException from indigo.renderer import IndigoRenderer from jpype import startJVM, getDefaultJVMPath from jpype import JClass, JVMNotFoundException, isJVMStarted @@ -868,11 +869,14 @@ def depict_and_resize_indigo( # Instantiate Indigo with random settings and IndigoRenderer indigo, renderer = self.get_random_indigo_rendering_settings() # Load molecule - if not self.has_r_group(smiles): - molecule = indigo.loadMolecule(smiles) - else: - mol_str = self.smiles_to_mol_str(smiles) - molecule = indigo.loadMolecule(mol_str) + try: + if not self.has_r_group(smiles): + molecule = indigo.loadMolecule(smiles) + else: + mol_str = self.smiles_to_mol_str(smiles) + molecule = indigo.loadMolecule(mol_str) + except IndigoException: + return None # Kekulize in 67% of cases if not self.random_choice( [True, True, False], log_attribute="indigo_kekulized" diff --git a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py index 85246b2..baaeb17 100644 --- a/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py +++ b/examples/generate_fingerprint_based_dataset_with_and_without_augmentations.py @@ -27,9 +27,10 @@ def batch_depict_save_with_and_without_aug_with_fingerprints( images_per_structure: int, output_dir: str, ID_list: List[str], - indigo_proportion: float = 0.2, - rdkit_proportion: float = 0.3, - cdk_proportion: float = 0.5, + indigo_proportion: float = 0.15, + rdkit_proportion: float = 0.25, + pikachu_proportion: float = 0.25, + cdk_proportion: float = 0.35, shape: Tuple[int, int] = (299, 299), processes: int = 4, seed: int = 42, @@ -51,7 +52,7 @@ def batch_depict_save_with_and_without_aug_with_fingerprints( indigo_proportion (float): Indigo proportion. Defaults to 0.15. rdkit_proportion (float): RDKit proportion. Defaults to 0.3. cdk_proportion (float): CDK proportion. Defaults to 0.55. - shape (Tuple[int, int]): [description]. Defaults to (299, 299). + shape (Tuple[int, int]): image shape. Defaults to (299, 299). processes (int, optional): Number of threads. Defaults to 4. """ # Duplicate elements in smiles_list images_per_structure times @@ -60,15 +61,21 @@ def batch_depict_save_with_and_without_aug_with_fingerprints( # Generate corresponding amount of fingerprints dataset_size = len(smiles_list) FR = DepictionFeatureRanges() - fingerprint_tuples = FR.generate_fingerprints_for_dataset( - dataset_size, - indigo_proportion, - rdkit_proportion, - cdk_proportion, - aug_proportion=1, - ) - with open("fingerprint_tuples.pkl", "wb") as fingerprint_file: - pickle.dump(fingerprint_tuples, fingerprint_file) + if "fingerprint_tuples.pkl" not in os.listdir(output_dir): + fingerprint_tuples = FR.generate_fingerprints_for_dataset( + dataset_size, + indigo_proportion, + rdkit_proportion, + pikachu_proportion, + cdk_proportion, + aug_proportion=1, + ) + with open(os.path.join(output_dir, "fingerprint_tuples.pkl"), "wb") as fingerprint_file: + pickle.dump(fingerprint_tuples, fingerprint_file) + else: + with open(os.path.join(output_dir, "fingerprint_tuples.pkl"), "rb") as fingerprint_file: + fingerprint_tuples = pickle.load(fingerprint_file) + starmap_tuple_generator = ( ( @@ -124,21 +131,17 @@ def depict_save_from_fingerprint_with_and_without_aug( Returns: np.array: Chemical structure depiction """ + self.output_dir = output_dir # Generate chemical structure depiction - try: - depiction, augmented_depiction = self.depict_from_fingerprint( - smiles, fingerprints, schemes, shape, seed) - # Save at given_path: - output_file_path = os.path.join(output_dir, filename + ".png") - sk_io.imsave(output_file_path, img_as_ubyte(depiction)) - # Save at given_path: - output_file_path = os.path.join(output_dir, filename + "_aug.png") - sk_io.imsave(output_file_path, img_as_ubyte(augmented_depiction)) - except IndexError: - with open("error_log.txt", "a") as error_log: - error_message = f"Could not depict SMILES {smiles} due to IndexError.\n" - error_log.write(error_message) + depiction, augmented_depiction = self.depict_from_fingerprint( + smiles, fingerprints, schemes, shape, seed) + # Save at given_path: + output_file_path = os.path.join(output_dir, filename + ".png") + sk_io.imsave(output_file_path, img_as_ubyte(depiction)) + # Save at given_path: + output_file_path = os.path.join(output_dir, filename + "_aug.png") + sk_io.imsave(output_file_path, img_as_ubyte(augmented_depiction)) def depict_from_fingerprint( self, @@ -174,12 +177,17 @@ def depict_from_fingerprint( self.active_fingerprint = fingerprints[0] self.active_scheme = schemes[0] # Depict molecule - if "indigo" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_indigo(smiles, shape) - elif "rdkit" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_rdkit(smiles, shape) - elif "cdk" in list(schemes[0].keys())[0]: - depiction = depictor.depict_and_resize_cdk(smiles, shape) + try: + if "indigo" in list(schemes[0].keys())[0]: + depiction = depictor.depict_and_resize_indigo(smiles, shape) + elif "rdkit" in list(schemes[0].keys())[0]: + depiction = depictor.depict_and_resize_rdkit(smiles, shape) + elif "cdk" in list(schemes[0].keys())[0]: + depiction = depictor.depict_and_resize_cdk(smiles, shape) + elif "pikachu" in list(schemes[0].keys())[0]: + depiction = depictor.depict_and_resize_pikachu(smiles, shape) + except IndexError: + depiction = None if depiction is False or depiction is None: # For the rare case: Use CDK @@ -189,7 +197,7 @@ def depict_from_fingerprint( False, ) depiction = depictor.depict_and_resize_cdk(smiles, shape) - with open('error_log.txt', 'a') as error_log: + with open(os.path.join(self.output_dir, 'error_log.txt'), 'a') as error_log: error_log.write(f'Failed depicting SMILES: {smiles}\n') error_log.write('It was depicted using CDK WITHOUT fingerprints.\n') # Add augmentations @@ -213,12 +221,13 @@ def main(): """ input_file_path = sys.argv[1] output_path = sys.argv[2] + shape = (int(sys.argv[3]), int(sys.argv[3])) # Read input file with open(input_file_path, 'r') as input_file: ids: List = [] smiles: List = [] for line in input_file.readlines()[:]: - id, smi = line[:-1].split(',') + id, smi = line[:-1].split('\t') ids.append(id) smiles.append(smi) # Generate balanced dataset of non-augmented and augmented depictions @@ -228,13 +237,14 @@ def main(): images_per_structure=1, output_dir=output_path, ID_list=ids, - processes=15, + processes=8, seed=42, + shape=shape, ) if __name__ == '__main__': - if len(sys.argv) == 3: + if len(sys.argv) == 4: main() else: - print(f'Usage: {sys.argv[0]} ID_SMILES_dataset output_dir') + print(f'Usage: {sys.argv[0]} ID_SMILES_dataset output_dir im_size')