From 5a2c126c31f23e958feb9d168ac9d5a6e4f47cdc Mon Sep 17 00:00:00 2001 From: Otto Brinkhaus Date: Thu, 29 Jun 2023 16:44:34 +0200 Subject: [PATCH 1/2] add depiction of explicite hydrogen atoms --- RanDepict/__init__.py | 2 +- RanDepict/cdk_functionalities.py | 77 +- RanDepict/randepict.py | 19 +- .../random_markush_structure_generator.py | 56 +- Tests/test_functions.py | 27 +- docs/tutorial.ipynb | 4850 ++++++++++------- examples/RanDepictNotebook.ipynb | 4850 ++++++++++------- setup.py | 2 +- 8 files changed, 5761 insertions(+), 4122 deletions(-) diff --git a/RanDepict/__init__.py b/RanDepict/__init__.py index 138d2ba..a572a05 100644 --- a/RanDepict/__init__.py +++ b/RanDepict/__init__.py @@ -21,7 +21,7 @@ """ -__version__ = "1.1.7" +__version__ = "1.1.8" __all__ = [ "RanDepict", diff --git a/RanDepict/cdk_functionalities.py b/RanDepict/cdk_functionalities.py index df61ede..91e1fb0 100644 --- a/RanDepict/cdk_functionalities.py +++ b/RanDepict/cdk_functionalities.py @@ -59,20 +59,29 @@ def cdk_depict( depiction = self._cdk_render_molecule(molecule, has_R_group, shape) return depiction - def _cdk_mol_block_to_cxsmiles(self, mol_block: str) -> str: + def _cdk_mol_block_to_cxsmiles( + self, + mol_block: str, + ignore_explicite_hydrogens: bool = True, + ) -> str: """ This function takes a mol block str and returns the corresponding CXSMILES with coordinates using the CDK. Args: mol_block (str): mol block str + ignore_explicite_hydrogens (bool, optional): whether or not to ignore H Returns: str: CXSMILES """ atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) - smi_gen = JClass("org.openscience.cdk.smiles.SmilesGenerator") - flavor = JClass("org.openscience.cdk.smiles.SmiFlavor") + if ignore_explicite_hydrogens: + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + atom_container = manipulator.copyAndSuppressedHydrogens(atom_container) + smi_gen = JClass(cdk_base + "smiles.SmilesGenerator") + flavor = JClass(cdk_base + "smiles.SmiFlavor") smi_gen = smi_gen(flavor.CxSmilesWithCoords) cxsmiles = smi_gen.create(atom_container) return cxsmiles @@ -134,6 +143,68 @@ def _cdk_iatomcontainer_to_mol_block(self, i_atom_container) -> str: mol_str = string_writer.toString() return str(mol_str) + def _cdk_add_explicite_hydrogen_to_molblock(self, mol_block: str) -> str: + """ + This function takes a mol block and returns the mol block with explicit + hydrogen atoms. + + Args: + mol_block (str): mol block that describes a molecule + + Returns: + str: The same mol block with explicit hydrogen atoms + """ + i_atom_container = self._cdk_mol_block_to_iatomcontainer(mol_block) + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + manipulator.convertImplicitToExplicitHydrogens(i_atom_container) + mol_block = self._cdk_iatomcontainer_to_mol_block(i_atom_container) + return mol_block + + def _cdk_add_explicite_hydrogen_to_smiles(self, smiles: str) -> str: + """ + This function takes a SMILES str and uses CDK to add explicite hydrogen atoms. + It returns an adapted version of the SMILES str. + + Args: + smiles (str): SMILES representation of a molecule + + Returns: + smiles (str): SMILES representation of a molecule with explicite H + """ + i_atom_container = self._cdk_smiles_to_IAtomContainer(smiles) + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + manipulator.convertImplicitToExplicitHydrogens(i_atom_container) + smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute + smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( + smi_flavor + ) + smiles = smiles_generator.create(i_atom_container) + return str(smiles) + + def _cdk_remove_explicite_hydrogen_from_smiles(self, smiles: str) -> str: + """ + This function takes a SMILES str and uses CDK to remove explicite hydrogen atoms. + It returns an adapted version of the SMILES str. + + Args: + smiles (str): SMILES representation of a molecule + + Returns: + smiles (str): SMILES representation of a molecule with explicite H + """ + i_atom_container = self._cdk_smiles_to_IAtomContainer(smiles) + cdk_base = "org.openscience.cdk." + manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") + i_atom_container = manipulator.copyAndSuppressedHydrogens(i_atom_container) + smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute + smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( + smi_flavor + ) + smiles = smiles_generator.create(i_atom_container) + return str(smiles) + def _cdk_get_depiction_generator(self, molecule, has_R_group: bool = False): """ This function defines random rendering options for the structure diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 3c5657f..58eb1a0 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -195,8 +195,13 @@ def random_depiction( Returns: np.array: Chemical structure depiction """ + orig_styles = self._config.styles + # TODO: add this to depiction feature fingerprint + if self.random_choice([True] + [False] * 5): + smiles = self._cdk_add_explicite_hydrogen_to_smiles(smiles) + self._config.styles = [style for style in orig_styles if style != 'pikachu'] depiction_functions = self.get_depiction_functions(smiles) - + self._config.styles = orig_styles for _ in range(3): if len(depiction_functions) != 0: # Pick random depiction function and call it @@ -267,11 +272,17 @@ def random_depiction_with_coordinates( orig_styles = self._config.styles self._config.styles = [style for style in orig_styles if style != 'pikachu'] depiction_functions = self.get_depiction_functions(smiles) + fun = self.random_choice(depiction_functions) self._config.styles = orig_styles + # TODO: add this to depiction feature fingerprint + if self.random_choice([True] + [False] * 5): + smiles = self._cdk_add_explicite_hydrogen_to_smiles(smiles) mol_block = self._smiles_to_mol_block(smiles, - self.random_choice(['rdkit', 'indigo', 'cdk'])) - cxsmiles = self._cdk_mol_block_to_cxsmiles(mol_block) - fun = self.random_choice(depiction_functions) + self.random_choice(['rdkit', + 'indigo', + 'cdk'])) + cxsmiles = self._cdk_mol_block_to_cxsmiles(mol_block, + ignore_explicite_hydrogens=True) depiction = fun(mol_block=mol_block, shape=shape) if augment: depiction = self.add_augmentations(depiction) diff --git a/RanDepict/random_markush_structure_generator.py b/RanDepict/random_markush_structure_generator.py index c710a16..e749997 100644 --- a/RanDepict/random_markush_structure_generator.py +++ b/RanDepict/random_markush_structure_generator.py @@ -1,5 +1,3 @@ -from jpype import JClass -# import sys from typing import List from .randepict import RandomDepictor @@ -49,7 +47,7 @@ def insert_R_group_var(self, smiles: str, num: int) -> str: Returns: smiles (str): input SMILES with $num inserted R group variables """ - smiles = self.add_explicite_hydrogen_to_smiles(smiles) + smiles = self.depictor._cdk_add_explicite_hydrogen_to_smiles(smiles) potential_replacement_positions = self.get_valid_replacement_positions(smiles) r_groups = [] # Replace C or H in SMILES with * @@ -66,7 +64,7 @@ def insert_R_group_var(self, smiles: str, num: int) -> str: break # Remove explicite hydrogen again and get absolute SMILES smiles = "".join(smiles) - smiles = self.remove_explicite_hydrogen_from_smiles(smiles) + smiles = self.depictor._cdk_remove_explicite_hydrogen_from_smiles(smiles) # Replace * with R groups for r_group in r_groups: smiles = smiles.replace("*", r_group, 1) @@ -136,53 +134,3 @@ def get_valid_replacement_positions(self, smiles: str) -> List[int]: ]: replacement_positions.append(index - 1) return replacement_positions - - def add_explicite_hydrogen_to_smiles(self, smiles: str) -> str: - """ - This function takes a SMILES str and uses CDK to add explicite hydrogen atoms. - It returns an adapted version of the SMILES str. - - Args: - smiles (str): SMILES representation of a molecule - - Returns: - smiles (str): SMILES representation of a molecule with explicite H - """ - i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) - - # Add explicite hydrogen atoms - cdk_base = "org.openscience.cdk." - manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") - manipulator.convertImplicitToExplicitHydrogens(i_atom_container) - - # Create absolute SMILES - smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute - smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( - smi_flavor - ) - smiles = smiles_generator.create(i_atom_container) - return str(smiles) - - def remove_explicite_hydrogen_from_smiles(self, smiles: str) -> str: - """ - This function takes a SMILES str and uses CDK to remove explicite hydrogen atoms. - It returns an adapted version of the SMILES str. - - Args: - smiles (str): SMILES representation of a molecule - - Returns: - smiles (str): SMILES representation of a molecule with explicite H - """ - i_atom_container = self.depictor._cdk_smiles_to_IAtomContainer(smiles) - # Remove explicite hydrogen atoms - cdk_base = "org.openscience.cdk." - manipulator = JClass(cdk_base + "tools.manipulator.AtomContainerManipulator") - i_atom_container = manipulator.copyAndSuppressedHydrogens(i_atom_container) - # Create absolute SMILES - smi_flavor = JClass("org.openscience.cdk.smiles.SmiFlavor").Absolute - smiles_generator = JClass("org.openscience.cdk.smiles.SmilesGenerator")( - smi_flavor - ) - smiles = smiles_generator.create(i_atom_container) - return str(smiles) diff --git a/Tests/test_functions.py b/Tests/test_functions.py index 50203fa..c055923 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -412,6 +412,20 @@ def test_get_depiction_functions_normal(self): difference = set(observed) ^ set(expected) assert not difference + def test_add_explicite_hydrogen_to_smiles(self): + # Assert that hydrogen atoms are added + input_smiles = "CCC" + expected_output = "C([H])([H])([H])C([H])([H])C([H])([H])[H]" + observed_output = self.depictor._cdk_add_explicite_hydrogen_to_smiles(input_smiles) + assert expected_output == observed_output + + def test_remove_explicite_hydrogen_to_smiles(self): + # Assert that hydrogen atoms are removed + input_smiles = "C([H])([H])([H])C([H])([H])C([H])([H])[H]" + expected_output = "CCC" + observed_output = self.depictor._cdk_remove_explicite_hydrogen_from_smiles(input_smiles) + assert expected_output == observed_output + def test_get_depiction_functions_isotopes(self): # PIKAChU can't handle isotopes observed = self.depictor.get_depiction_functions("[13CH3]N1C=NC2=C1C(=O)N(C(=O)N2C)C") @@ -516,19 +530,6 @@ def test_insert_R_group_var_can_be_depicted(self): depiction = self.depictor.random_depiction(output_smiles) assert type(depiction) == np.ndarray - def test_add_explicite_hydrogen_to_smiles(self): - # Assert that hydrogen atoms are added - input_smiles = "CCC" - expected_output = "C([H])([H])([H])C([H])([H])C([H])([H])[H]" - observed_output = self.markush_creator.add_explicite_hydrogen_to_smiles(input_smiles) - assert expected_output == observed_output - - def test_remove_explicite_hydrogen_to_smiles(self): - # Assert that hydrogen atoms are removed - input_smiles = "C([H])([H])([H])C([H])([H])C([H])([H])[H]" - expected_output = "CCC" - observed_output = self.markush_creator.remove_explicite_hydrogen_from_smiles(input_smiles) - assert expected_output == observed_output def test_get_valid_replacement_positions_simple_chain(self): # Simple example case diff --git a/docs/tutorial.ipynb b/docs/tutorial.ipynb index 0fe35f4..a8a3cb6 100644 --- a/docs/tutorial.ipynb +++ b/docs/tutorial.ipynb @@ -10,7 +10,7 @@ { "data": { "text/plain": [ - "'1.1.7'" + "'1.1.8'" ] }, "execution_count": 1, @@ -59,14 +59,14 @@ "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - " \n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "

0

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

1

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

2

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

3

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

4

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

5

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

6

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

7

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

8

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

9

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

10

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

11

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

12

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

13

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

14

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

15

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

16

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

17

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

18

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

19

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(hand_drawn=True) as depictor:\n", + " random_augmented_images = []\n", + " for _ in range(20):\n", + " random_augmented_images.append(depictor(smiles))\n", + " \n", + "\n", + "ipyplot.plot_images(random_augmented_images, max_images=20, img_width=100)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and save a batch of images\n", + "\n", + "After calling an instance of RandomDepictor, simply call the method depict_save().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list (List[str]): List of SMILES str\n", + "- images_per_structure (int): Amount of images to create per SMILES str\n", + "- output_dir (str): Output directory \n", + "- augment (bool): Boolean that indicates whether or not to use augmentations\n", + "- ID_list (List[str]): List of IDs (should be as long as smiles_list)\n", + "- shape (Tuple[int, int], optional): image shape. Defaults to (299, 299).\n", + "- processes (int, optional): Number of parallel threads. Defaults to 4.\n", + "- seed (int, optional): Seed for pseudo-random decisions. Defaults to 42." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure the output directories exist\n", + "if not os.path.exists('not_augmented'):\n", + " os.mkdir('not_augmented')\n", " \n", + "if not os.path.exists('augmented'):\n", + " os.mkdir('augmented')\n", "\n", - "ipyplot.plot_images(random_images, max_images=20, img_width=100)" + "# Depict and save two batches of images\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(42) as depictor:\n", + " depictor.batch_depict_save([smiles], 20, 'not_augmented', False, ['caffeine'], (299, 299), 5)\n", + " depictor.batch_depict_save([smiles], 20, 'augmented', True, ['caffeine'], (299, 299), 5)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n" + ] + } + ], + "source": [ + "if not os.path.exists('kohulan'):\n", + " os.mkdir(\"kohulan\")\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "r_smiles = \"[R1]N1C=NC2=C1[X](=O)N(C(=O)N2C)[R]\"\n", + "seed = 233\n", + "r_seed = 1\n", + "with RandomDepictor(1) as depictor:\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_512_512', (512, 512), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_512_512', (512, 512), seed=r_seed)" ] }, { @@ -3609,29 +4489,53 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create random hand-drawn like depictions\n", + "## Create a batch of images while ensuring diversity using feature fingerprints\n", "\n", - "After calling an instance of RandomDepictor, this instance can simply be called as a function in order to generate a chemical structure depiction using CDK, RDKit, Indigo or PIKAChU (randomly chosen) and apply random augmentations and random background addition." + "\n", + "After calling an instance of RandomDepictor, simply call the method batch_depict_with_fingerprints().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list: List[str]\n", + "- images_per_structure: int\n", + "- indigo_proportion: float = 0.15\n", + "- rdkit_proportion: float = 0.25\n", + "- pikachu_proportion: float = 0.25\n", + "- cdk_proportion: float = 0.35\n", + "- aug_proportion: float = 0.5\n", + "- shape: Tuple[int, int] = (299, 299)\n", + "- processes: int = 4\n", + "- seed: int = 42\n", + "\n", + "* Note: Have a look at examples/generate_depiction_grids_with_fingerprints.py to see how this function was used to generate the grid figures from our publication." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Otto Brinkhaus\\anaconda3\\envs\\RanDepict\\lib\\site-packages\\ipyplot\\_utils.py:97: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", + " return np.asarray(seq)\n" + ] + }, { "data": { "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - " \n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "

0

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

1

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

2

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

3

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

4

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

5

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

6

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

7

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

8

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

9

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

10

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

11

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

12

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

13

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

14

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

15

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

16

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

17

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

18

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "

19

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(hand_drawn=True) as depictor:\n", + " random_augmented_images = []\n", + " for _ in range(20):\n", + " random_augmented_images.append(depictor(smiles))\n", + " \n", + "\n", + "ipyplot.plot_images(random_augmented_images, max_images=20, img_width=100)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and save a batch of images\n", + "\n", + "After calling an instance of RandomDepictor, simply call the method depict_save().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list (List[str]): List of SMILES str\n", + "- images_per_structure (int): Amount of images to create per SMILES str\n", + "- output_dir (str): Output directory \n", + "- augment (bool): Boolean that indicates whether or not to use augmentations\n", + "- ID_list (List[str]): List of IDs (should be as long as smiles_list)\n", + "- shape (Tuple[int, int], optional): image shape. Defaults to (299, 299).\n", + "- processes (int, optional): Number of parallel threads. Defaults to 4.\n", + "- seed (int, optional): Seed for pseudo-random decisions. Defaults to 42." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Make sure the output directories exist\n", + "if not os.path.exists('not_augmented'):\n", + " os.mkdir('not_augmented')\n", " \n", + "if not os.path.exists('augmented'):\n", + " os.mkdir('augmented')\n", "\n", - "ipyplot.plot_images(random_images, max_images=20, img_width=100)" + "# Depict and save two batches of images\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)CC(=O)O\"\n", + "with RandomDepictor(42) as depictor:\n", + " depictor.batch_depict_save([smiles], 20, 'not_augmented', False, ['caffeine'], (299, 299), 5)\n", + " depictor.batch_depict_save([smiles], 20, 'augmented', True, ['caffeine'], (299, 299), 5)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n", + "Warning! Rogue electron.\n", + "R1_0\n", + "Warning! Rogue electron.\n", + "R_13\n" + ] + } + ], + "source": [ + "if not os.path.exists('kohulan'):\n", + " os.mkdir(\"kohulan\")\n", + "smiles = \"CN1C=NC2=C1C(=O)N(C(=O)N2C)C\"\n", + "r_smiles = \"[R1]N1C=NC2=C1[X](=O)N(C(=O)N2C)[R]\"\n", + "seed = 233\n", + "r_seed = 1\n", + "with RandomDepictor(1) as depictor:\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_299_299', (299, 299), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', False, 'caffeine_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(smiles, 1, 'kohulan', True, 'caffeine_aug_512_512', (512, 512), seed=seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_299_299', (299, 299), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', False, 'caffeine_R_512_512', (512, 512), seed=r_seed)\n", + " depictor.depict_save(r_smiles, 1, 'kohulan', True, 'caffeine_R_aug_512_512', (512, 512), seed=r_seed)" ] }, { @@ -3609,29 +4489,53 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create random hand-drawn like depictions\n", + "## Create a batch of images while ensuring diversity using feature fingerprints\n", "\n", - "After calling an instance of RandomDepictor, this instance can simply be called as a function in order to generate a chemical structure depiction using CDK, RDKit, Indigo or PIKAChU (randomly chosen) and apply random augmentations and random background addition." + "\n", + "After calling an instance of RandomDepictor, simply call the method batch_depict_with_fingerprints().\n", + "\n", + "Args:\n", + "\n", + "- smiles_list: List[str]\n", + "- images_per_structure: int\n", + "- indigo_proportion: float = 0.15\n", + "- rdkit_proportion: float = 0.25\n", + "- pikachu_proportion: float = 0.25\n", + "- cdk_proportion: float = 0.35\n", + "- aug_proportion: float = 0.5\n", + "- shape: Tuple[int, int] = (299, 299)\n", + "- processes: int = 4\n", + "- seed: int = 42\n", + "\n", + "* Note: Have a look at examples/generate_depiction_grids_with_fingerprints.py to see how this function was used to generate the grid figures from our publication." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Otto Brinkhaus\\anaconda3\\envs\\RanDepict\\lib\\site-packages\\ipyplot\\_utils.py:97: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", + " return np.asarray(seq)\n" + ] + }, { "data": { "text/html": [ "\n", " \n", "
\n", - " \n", - " \n", - "