Skip to content

Commit

Permalink
Merge pull request #31 from tulay/tm_dev
Browse files Browse the repository at this point in the history
enable configuring depictor styles
  • Loading branch information
OBrink authored Oct 7, 2022
2 parents f3e5c08 + ae38756 commit 9dc5f7f
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ $ bash Miniconda3-latest-Linux-x86_64.sh
$ echo -e "channels:\n - conda-forge\n - nodefaults" > ~/.condarc
$ conda update conda
$ conda install conda-libmamba-solver
$ conda config --set experimental_solver libmamba
$ conda create --name RanDepict python=3.8
$ conda activate RanDepict
# pypi has rdkit so not necessary to install it using conda
Expand Down
14 changes: 7 additions & 7 deletions RanDepict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
chemical structure depictions (random depiction styles and image augmentations).
Typical usage example:
from RanDepict import RandomDepictor
smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
with RandomDepictor() as depictor:
image = depictor(smiles)
Example:
--------
>>> from RanDepict import RandomDepictor
>>> smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
>>> with RandomDepictor() as depictor:
>>> image = depictor(smiles)
Have a look in the RanDepictNotebook.ipynb for more examples.
Expand All @@ -28,4 +28,4 @@
]


from .randepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator
from .randepict import RandomDepictor, RandomDepictorConfig, DepictionFeatureRanges, RandomMarkushStructureCreator
132 changes: 112 additions & 20 deletions RanDepict/randepict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations
import copy

import os
import pathlib
from pathlib import Path
import numpy as np
import io
from skimage import io as sk_io
Expand All @@ -10,7 +13,7 @@
import imgaug.augmenters as iaa
import random
from copy import deepcopy
from typing import Tuple, List, Dict, Any, Callable
from typing import Optional, Tuple, List, Dict, Any, Callable
import re

from rdkit import Chem
Expand All @@ -22,6 +25,9 @@
from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker
from itertools import product

from omegaconf import OmegaConf, DictConfig # configuration package
from dataclasses import dataclass, field

from indigo import Indigo
from indigo.renderer import IndigoRenderer
from jpype import startJVM, getDefaultJVMPath
Expand All @@ -34,6 +40,41 @@
from scipy.ndimage import gaussian_filter
from scipy.ndimage import map_coordinates

@dataclass
class RandomDepictorConfig:
"""
Examples
--------
>>> c1 = RandomDepictorConfig(seed=24, styles=["cdk", "indigo"])
>>> c1
RandomDepictorConfig(seed=24, hand_drawn=False, augment=True, styles=['cdk', 'indigo'])
>>> c2 = RandomDepictorConfig(styles=["cdk", "indigo", "pikachu", "rdkit"])
>>> c2
RandomDepictorConfig(seed=42, hand_drawn=False, augment=True, styles=['cdk', 'indigo', 'pikachu', 'rdkit'])
"""
seed: int = 42
hand_drawn: bool = False
augment: bool = True
# unions of containers are not supported yet
# https://github.com/omry/omegaconf/issues/144
# styles: Union[str, List[str]] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"])
styles: List[str] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"])

@classmethod
def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig':
return OmegaConf.structured(cls(**dict_config))

def __post_init__(self):
# Ensure styles are always List[str] when "cdk, indigo" is passed
if isinstance(self.styles, str):
self.styles = [v.strip() for v in self.styles.split(",")]
if len(self.styles) == 0:
raise ValueError("Empty list of styles was supplied.")
# Not sure if this is the best way in order to not repeat the list of styles
ss = set(self.__dataclass_fields__['styles'].default_factory())
if any([s not in ss for s in self.styles]):
raise ValueError(f"Use only {', '.join(ss)}")


class RandomDepictor:
"""
Expand All @@ -43,12 +84,38 @@ class RandomDepictor:
the RGB image with the given chemical structure.
"""

def __init__(self, seed: int = 42, hand_drawn: bool = False):
def __init__(self, seed: Optional[int] = None, hand_drawn: Optional[bool] = None, *, config: RandomDepictorConfig = None):
"""
Load the JVM only once, load superatom list (OSRA),
set context for multiprocessing
Parameters
----------
seed : int
seed for random number generator
hand_drawn : bool
Whether to augment with hand drawn features
config : Path object to configuration file in yaml format.
RandomDepictor section is expected.
Returns
-------
"""
self.HERE = pathlib.Path(__file__).resolve().parent.joinpath("assets")

if config is None:
self._config = RandomDepictorConfig()
else:
self._config = copy.deepcopy(config)
if seed is not None:
self._config.seed = seed
if hand_drawn is not None:
self._config.hand_drawn = hand_drawn

self.seed = self._config.seed
self.hand_drawn = self._config.hand_drawn

self.HERE = Path(__file__).resolve().parent.joinpath("assets")

# Start the JVM to access Java classes
try:
Expand All @@ -67,8 +134,6 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False):
self.jar_path = self.HERE.joinpath("jar_files/cdk-2.8.jar")
startJVM(self.jvmPath, "-ea", "-Djava.class.path=" + str(self.jar_path))

self.seed = seed
self.hand_drawn = hand_drawn
random.seed(self.seed)

# Load list of superatoms for label generation
Expand Down Expand Up @@ -97,6 +162,25 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False):
except RuntimeError:
pass

@classmethod
def from_config(cls, config_file: Path) -> 'RandomDepictor':
try:
# TODO Needs documentation of config_file yaml format...
"""
# randepict.yaml
RandomDepictorConfig:
seed: 42
augment: False
styles:
- cdk
"""
config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[RandomDepictorConfig.__name__])
except Exception as e:
print(f"Error loading from {config_file}. Make sure it has {cls.__name__} section. {e}")
print("Using default config.")
config = RandomDepictorConfig()
return RandomDepictor(config=config)

def __call__(
self,
smiles: str,
Expand All @@ -105,14 +189,16 @@ def __call__(
hand_drawn: bool = False,
):
# Depict structure with random parameters
# TODO hand_drawn to this call is ignored. Decide which one to keep
hand_drawn = self.hand_drawn
if hand_drawn:
depiction = self.random_depiction(smiles, shape)

# TODO is call to hand_drawn_augment missing?
else:
depiction = self.random_depiction(smiles, shape)
# Add augmentations
depiction = self.add_augmentations(depiction)
if self._config.augment:
depiction = self.add_augmentations(depiction)

if grayscale:
return self.to_grayscale_float_img(depiction)
Expand Down Expand Up @@ -1346,6 +1432,7 @@ def random_depiction(
"""
depiction_functions = self.get_depiction_functions(smiles)
# If nothing is returned, try different function
# FIXME: depictions_functions could be an empty list
for _ in range(3):
if len(depiction_functions) != 0:
# Pick random depiction function and call it
Expand Down Expand Up @@ -1400,12 +1487,15 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]:
Returns:
List[Callable]: List of depiction functions
"""
depiction_functions = [
self.depict_and_resize_rdkit,
self.depict_and_resize_indigo,
self.depict_and_resize_cdk,
self.depict_and_resize_pikachu,
]

depiction_functions_registry = {
'rdkit': self.depict_and_resize_rdkit,
'indigo': self.depict_and_resize_indigo,
'cdk': self.depict_and_resize_cdk,
'pikachu': self.depict_and_resize_pikachu,
}
depiction_functions = [depiction_functions_registry[k] for k in self._config.styles]

# Remove PIKAChU if there is an isotope
if re.search("(\[\d\d\d?[A-Z])|(\[2H\])|(\[3H\])|(D)|(T)", smiles):
depiction_functions.remove(self.depict_and_resize_pikachu)
Expand All @@ -1418,11 +1508,13 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]:
depiction_functions.remove(self.depict_and_resize_pikachu)
# "R", "X", "Z" are not depicted by RDKit
# The same is valid for X,Y,Z and a number
if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles):
depiction_functions.remove(self.depict_and_resize_rdkit)
if self.depict_and_resize_rdkit in depiction_functions:
if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles):
depiction_functions.remove(self.depict_and_resize_rdkit)
# "X", "R0" and indices above 32 are not depicted by Indigo
if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles):
depiction_functions.remove(self.depict_and_resize_indigo)
if self.depict_and_resize_indigo in depiction_functions:
if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles):
depiction_functions.remove(self.depict_and_resize_indigo)
return depiction_functions

def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array:
Expand Down Expand Up @@ -3030,7 +3122,7 @@ def distribute_elements_evenly(


class RandomMarkushStructureCreator:
def __init__(self, *, variables_list=None, max_index=21):
def __init__(self, *, variables_list=None, max_index=20):
"""
RandomMarkushStructureCreator objects are instantiated with the desired
inserted R group variables. Otherwise, "R", "X" and "Z" are used.
Expand All @@ -3043,7 +3135,7 @@ def __init__(self, *, variables_list=None, max_index=21):
else:
self.r_group_variables = variables_list

self.potential_indices = range(max_index + 1)
self.potential_indices = range(1, max_index + 1)

def generate_markush_structure_dataset(self, smiles_list: List[str]) -> List[str]:
"""
Expand Down
86 changes: 82 additions & 4 deletions Tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from RanDepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator
from rdkit import DataStructs
import numpy as np
from omegaconf import OmegaConf

import pytest

from RanDepict import RandomDepictorConfig

class TestDepictionFeatureRanges:
DFR = DepictionFeatureRanges()

Expand Down Expand Up @@ -272,6 +275,76 @@ def test_correct_amount_of_FP_to_pick(self):
assert list(picked_fingerprints) == ["A"] * 20
assert corrected_n == 3

class TestRandomDepictorConstruction:

def test_default_depicter(self):
"""RandomDepictor default values should match that of the default RandomDepictorConfig"""
depicter = RandomDepictor()
config = RandomDepictorConfig()
assert depicter.seed == config.seed
assert depicter.hand_drawn == config.hand_drawn

def test_init_param_override(self):
"""Values passed to init should override defaults"""
config = RandomDepictorConfig()
assert config.seed != 21
assert not config.hand_drawn
depicter = RandomDepictor(seed=21, hand_drawn=True)
assert depicter.seed == 21
assert depicter.hand_drawn

def test_config_override(self):
"""Config passed to init should override defaults"""
config = RandomDepictorConfig(seed=21, hand_drawn=True)
depicter = RandomDepictor(config=config)
assert depicter.seed == 21
assert depicter.hand_drawn

@pytest.mark.xfail(raises=ValueError)
def test_invalid_style(self):
"""Invalid style passed to config should raise exception"""
_ = RandomDepictorConfig(styles=["pio", "cdk"])

def test_empty_style_list(self):
"""Empty style list passed to config should raise exception"""
with pytest.raises(ValueError) as excinfo:
_ = RandomDepictorConfig(styles=[])
assert 'Empty list' in str(excinfo.value)

def test_omega_config_rdc(self):
"""Can create RandomDepictorConfig from yaml"""
s = """
# RandomDepictorConfig:
seed: 21
augment: False
styles:
- cdk
"""
dict_config = OmegaConf.create(s)
rdc = RandomDepictorConfig.from_config(dict_config)
assert rdc.seed == 21
assert not rdc.hand_drawn
assert not rdc.augment
assert len(rdc.styles) == 1
assert 'cdk' in rdc.styles

def test_omega_config_rd(self, tmp_path):
"""Can create RandomDepictor from yaml"""
s = """
RandomDepictorConfig:
seed: 21
augment: False
styles:
- cdk
- indigo
"""
temp_config_file = tmp_path / "omg.yaml"
temp_config_file.write_text(s)
rd = RandomDepictor.from_config(config_file=temp_config_file)
assert rd.seed == 21
assert not rd.hand_drawn


class TestRandomDepictor:
depictor = RandomDepictor()

Expand Down Expand Up @@ -321,7 +394,9 @@ def test_get_depiction_functions_normal(self):
self.depictor.depict_and_resize_cdk,
self.depictor.depict_and_resize_pikachu,
]
assert observed == expected
# symmetric_difference
difference = set(observed) ^ set(expected)
assert not difference

def test_get_depiction_functions_isotopes(self):
# PIKAChU can't handle isotopes
Expand All @@ -331,7 +406,8 @@ def test_get_depiction_functions_isotopes(self):
self.depictor.depict_and_resize_indigo,
self.depictor.depict_and_resize_cdk,
]
assert observed == expected
difference = set(observed) ^ set(expected)
assert not difference

def test_get_depiction_functions_R(self):
# RDKit depicts "R" without indices as '*' (which is not desired)
Expand All @@ -341,7 +417,8 @@ def test_get_depiction_functions_R(self):
self.depictor.depict_and_resize_cdk,
self.depictor.depict_and_resize_pikachu,
]
assert observed == expected
difference = set(observed) ^ set(expected)
assert not difference

def test_get_depiction_functions_X(self):
# RDKit and Indigo don't depict "X"
Expand All @@ -350,7 +427,8 @@ def test_get_depiction_functions_X(self):
self.depictor.depict_and_resize_cdk,
self.depictor.depict_and_resize_pikachu,
]
assert observed == expected
difference = set(observed) ^ set(expected)
assert not difference

def test_smiles_to_mol_str(self):
# Compare generated mol file str with reference string
Expand Down
Loading

0 comments on commit 9dc5f7f

Please sign in to comment.