Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/OBrink/RanDepict into main
Browse files Browse the repository at this point in the history
  • Loading branch information
OBrink committed Nov 30, 2022
2 parents 5da90e2 + f0ac416 commit cd29766
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 42 deletions.
14 changes: 9 additions & 5 deletions RanDepict/randepict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dataclasses import dataclass, field

from indigo import Indigo
from indigo import IndigoException
from indigo.renderer import IndigoRenderer
from jpype import startJVM, getDefaultJVMPath
from jpype import JClass, JVMNotFoundException, isJVMStarted
Expand Down Expand Up @@ -868,11 +869,14 @@ def depict_and_resize_indigo(
# Instantiate Indigo with random settings and IndigoRenderer
indigo, renderer = self.get_random_indigo_rendering_settings()
# Load molecule
if not self.has_r_group(smiles):
molecule = indigo.loadMolecule(smiles)
else:
mol_str = self.smiles_to_mol_str(smiles)
molecule = indigo.loadMolecule(mol_str)
try:
if not self.has_r_group(smiles):
molecule = indigo.loadMolecule(smiles)
else:
mol_str = self.smiles_to_mol_str(smiles)
molecule = indigo.loadMolecule(mol_str)
except IndigoException:
return None
# Kekulize in 67% of cases
if not self.random_choice(
[True, True, False], log_attribute="indigo_kekulized"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def batch_depict_save_with_and_without_aug_with_fingerprints(
images_per_structure: int,
output_dir: str,
ID_list: List[str],
indigo_proportion: float = 0.2,
rdkit_proportion: float = 0.3,
cdk_proportion: float = 0.5,
indigo_proportion: float = 0.15,
rdkit_proportion: float = 0.25,
pikachu_proportion: float = 0.25,
cdk_proportion: float = 0.35,
shape: Tuple[int, int] = (299, 299),
processes: int = 4,
seed: int = 42,
Expand All @@ -51,7 +52,7 @@ def batch_depict_save_with_and_without_aug_with_fingerprints(
indigo_proportion (float): Indigo proportion. Defaults to 0.15.
rdkit_proportion (float): RDKit proportion. Defaults to 0.3.
cdk_proportion (float): CDK proportion. Defaults to 0.55.
shape (Tuple[int, int]): [description]. Defaults to (299, 299).
shape (Tuple[int, int]): image shape. Defaults to (299, 299).
processes (int, optional): Number of threads. Defaults to 4.
"""
# Duplicate elements in smiles_list images_per_structure times
Expand All @@ -60,15 +61,21 @@ def batch_depict_save_with_and_without_aug_with_fingerprints(
# Generate corresponding amount of fingerprints
dataset_size = len(smiles_list)
FR = DepictionFeatureRanges()
fingerprint_tuples = FR.generate_fingerprints_for_dataset(
dataset_size,
indigo_proportion,
rdkit_proportion,
cdk_proportion,
aug_proportion=1,
)
with open("fingerprint_tuples.pkl", "wb") as fingerprint_file:
pickle.dump(fingerprint_tuples, fingerprint_file)
if "fingerprint_tuples.pkl" not in os.listdir(output_dir):
fingerprint_tuples = FR.generate_fingerprints_for_dataset(
dataset_size,
indigo_proportion,
rdkit_proportion,
pikachu_proportion,
cdk_proportion,
aug_proportion=1,
)
with open(os.path.join(output_dir, "fingerprint_tuples.pkl"), "wb") as fingerprint_file:
pickle.dump(fingerprint_tuples, fingerprint_file)
else:
with open(os.path.join(output_dir, "fingerprint_tuples.pkl"), "rb") as fingerprint_file:
fingerprint_tuples = pickle.load(fingerprint_file)


starmap_tuple_generator = (
(
Expand Down Expand Up @@ -124,21 +131,17 @@ def depict_save_from_fingerprint_with_and_without_aug(
Returns:
np.array: Chemical structure depiction
"""
self.output_dir = output_dir

# Generate chemical structure depiction
try:
depiction, augmented_depiction = self.depict_from_fingerprint(
smiles, fingerprints, schemes, shape, seed)
# Save at given_path:
output_file_path = os.path.join(output_dir, filename + ".png")
sk_io.imsave(output_file_path, img_as_ubyte(depiction))
# Save at given_path:
output_file_path = os.path.join(output_dir, filename + "_aug.png")
sk_io.imsave(output_file_path, img_as_ubyte(augmented_depiction))
except IndexError:
with open("error_log.txt", "a") as error_log:
error_message = f"Could not depict SMILES {smiles} due to IndexError.\n"
error_log.write(error_message)
depiction, augmented_depiction = self.depict_from_fingerprint(
smiles, fingerprints, schemes, shape, seed)
# Save at given_path:
output_file_path = os.path.join(output_dir, filename + ".png")
sk_io.imsave(output_file_path, img_as_ubyte(depiction))
# Save at given_path:
output_file_path = os.path.join(output_dir, filename + "_aug.png")
sk_io.imsave(output_file_path, img_as_ubyte(augmented_depiction))

def depict_from_fingerprint(
self,
Expand Down Expand Up @@ -174,12 +177,17 @@ def depict_from_fingerprint(
self.active_fingerprint = fingerprints[0]
self.active_scheme = schemes[0]
# Depict molecule
if "indigo" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_indigo(smiles, shape)
elif "rdkit" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_rdkit(smiles, shape)
elif "cdk" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_cdk(smiles, shape)
try:
if "indigo" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_indigo(smiles, shape)
elif "rdkit" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_rdkit(smiles, shape)
elif "cdk" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_cdk(smiles, shape)
elif "pikachu" in list(schemes[0].keys())[0]:
depiction = depictor.depict_and_resize_pikachu(smiles, shape)
except IndexError:
depiction = None

if depiction is False or depiction is None:
# For the rare case: Use CDK
Expand All @@ -189,7 +197,7 @@ def depict_from_fingerprint(
False,
)
depiction = depictor.depict_and_resize_cdk(smiles, shape)
with open('error_log.txt', 'a') as error_log:
with open(os.path.join(self.output_dir, 'error_log.txt'), 'a') as error_log:
error_log.write(f'Failed depicting SMILES: {smiles}\n')
error_log.write('It was depicted using CDK WITHOUT fingerprints.\n')
# Add augmentations
Expand All @@ -213,12 +221,13 @@ def main():
"""
input_file_path = sys.argv[1]
output_path = sys.argv[2]
shape = (int(sys.argv[3]), int(sys.argv[3]))
# Read input file
with open(input_file_path, 'r') as input_file:
ids: List = []
smiles: List = []
for line in input_file.readlines()[:]:
id, smi = line[:-1].split(',')
id, smi = line[:-1].split('\t')
ids.append(id)
smiles.append(smi)
# Generate balanced dataset of non-augmented and augmented depictions
Expand All @@ -228,13 +237,14 @@ def main():
images_per_structure=1,
output_dir=output_path,
ID_list=ids,
processes=15,
processes=8,
seed=42,
shape=shape,
)


if __name__ == '__main__':
if len(sys.argv) == 3:
if len(sys.argv) == 4:
main()
else:
print(f'Usage: {sys.argv[0]} ID_SMILES_dataset output_dir')
print(f'Usage: {sys.argv[0]} ID_SMILES_dataset output_dir im_size')

0 comments on commit cd29766

Please sign in to comment.