diff --git a/README.md b/README.md index d0b3ffd..c5f8956 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ $ bash Miniconda3-latest-Linux-x86_64.sh $ echo -e "channels:\n - conda-forge\n - nodefaults" > ~/.condarc $ conda update conda $ conda install conda-libmamba-solver +$ conda config --set experimental_solver libmamba $ conda create --name RanDepict python=3.8 $ conda activate RanDepict # pypi has rdkit so not necessary to install it using conda diff --git a/RanDepict/__init__.py b/RanDepict/__init__.py index 291f5a8..c7c7a5e 100644 --- a/RanDepict/__init__.py +++ b/RanDepict/__init__.py @@ -7,12 +7,12 @@ chemical structure depictions (random depiction styles and image augmentations). -Typical usage example: - -from RanDepict import RandomDepictor -smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" -with RandomDepictor() as depictor: - image = depictor(smiles) +Example: +-------- +>>> from RanDepict import RandomDepictor +>>> smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" +>>> with RandomDepictor() as depictor: +>>> image = depictor(smiles) Have a look in the RanDepictNotebook.ipynb for more examples. @@ -28,4 +28,4 @@ ] -from .randepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator +from .randepict import RandomDepictor, RandomDepictorConfig, DepictionFeatureRanges, RandomMarkushStructureCreator diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index b911eaa..7419aac 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1,5 +1,8 @@ +from __future__ import annotations +import copy + import os -import pathlib +from pathlib import Path import numpy as np import io from skimage import io as sk_io @@ -10,7 +13,7 @@ import imgaug.augmenters as iaa import random from copy import deepcopy -from typing import Tuple, List, Dict, Any, Callable +from typing import Optional, Tuple, List, Dict, Any, Callable import re from rdkit import Chem @@ -22,6 +25,9 @@ from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker from itertools import product +from omegaconf import OmegaConf, DictConfig # configuration package +from dataclasses import dataclass, field + from indigo import Indigo from indigo.renderer import IndigoRenderer from jpype import startJVM, getDefaultJVMPath @@ -34,6 +40,41 @@ from scipy.ndimage import gaussian_filter from scipy.ndimage import map_coordinates +@dataclass +class RandomDepictorConfig: + """ + Examples + -------- + >>> c1 = RandomDepictorConfig(seed=24, styles=["cdk", "indigo"]) + >>> c1 + RandomDepictorConfig(seed=24, hand_drawn=False, augment=True, styles=['cdk', 'indigo']) + >>> c2 = RandomDepictorConfig(styles=["cdk", "indigo", "pikachu", "rdkit"]) + >>> c2 + RandomDepictorConfig(seed=42, hand_drawn=False, augment=True, styles=['cdk', 'indigo', 'pikachu', 'rdkit']) + """ + seed: int = 42 + hand_drawn: bool = False + augment: bool = True + # unions of containers are not supported yet + # https://github.com/omry/omegaconf/issues/144 + # styles: Union[str, List[str]] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) + styles: List[str] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) + + @classmethod + def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig': + return OmegaConf.structured(cls(**dict_config)) + + def __post_init__(self): + # Ensure styles are always List[str] when "cdk, indigo" is passed + if isinstance(self.styles, str): + self.styles = [v.strip() for v in self.styles.split(",")] + if len(self.styles) == 0: + raise ValueError("Empty list of styles was supplied.") + # Not sure if this is the best way in order to not repeat the list of styles + ss = set(self.__dataclass_fields__['styles'].default_factory()) + if any([s not in ss for s in self.styles]): + raise ValueError(f"Use only {', '.join(ss)}") + class RandomDepictor: """ @@ -43,12 +84,38 @@ class RandomDepictor: the RGB image with the given chemical structure. """ - def __init__(self, seed: int = 42, hand_drawn: bool = False): + def __init__(self, seed: Optional[int] = None, hand_drawn: Optional[bool] = None, *, config: RandomDepictorConfig = None): """ Load the JVM only once, load superatom list (OSRA), set context for multiprocessing + + Parameters + ---------- + seed : int + seed for random number generator + hand_drawn : bool + Whether to augment with hand drawn features + config : Path object to configuration file in yaml format. + RandomDepictor section is expected. + + Returns + ------- + """ - self.HERE = pathlib.Path(__file__).resolve().parent.joinpath("assets") + + if config is None: + self._config = RandomDepictorConfig() + else: + self._config = copy.deepcopy(config) + if seed is not None: + self._config.seed = seed + if hand_drawn is not None: + self._config.hand_drawn = hand_drawn + + self.seed = self._config.seed + self.hand_drawn = self._config.hand_drawn + + self.HERE = Path(__file__).resolve().parent.joinpath("assets") # Start the JVM to access Java classes try: @@ -67,8 +134,6 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False): self.jar_path = self.HERE.joinpath("jar_files/cdk-2.8.jar") startJVM(self.jvmPath, "-ea", "-Djava.class.path=" + str(self.jar_path)) - self.seed = seed - self.hand_drawn = hand_drawn random.seed(self.seed) # Load list of superatoms for label generation @@ -97,6 +162,25 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False): except RuntimeError: pass + @classmethod + def from_config(cls, config_file: Path) -> 'RandomDepictor': + try: + # TODO Needs documentation of config_file yaml format... + """ + # randepict.yaml + RandomDepictorConfig: + seed: 42 + augment: False + styles: + - cdk + """ + config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[RandomDepictorConfig.__name__]) + except Exception as e: + print(f"Error loading from {config_file}. Make sure it has {cls.__name__} section. {e}") + print("Using default config.") + config = RandomDepictorConfig() + return RandomDepictor(config=config) + def __call__( self, smiles: str, @@ -105,14 +189,16 @@ def __call__( hand_drawn: bool = False, ): # Depict structure with random parameters + # TODO hand_drawn to this call is ignored. Decide which one to keep hand_drawn = self.hand_drawn if hand_drawn: depiction = self.random_depiction(smiles, shape) - + # TODO is call to hand_drawn_augment missing? else: depiction = self.random_depiction(smiles, shape) # Add augmentations - depiction = self.add_augmentations(depiction) + if self._config.augment: + depiction = self.add_augmentations(depiction) if grayscale: return self.to_grayscale_float_img(depiction) @@ -1346,6 +1432,7 @@ def random_depiction( """ depiction_functions = self.get_depiction_functions(smiles) # If nothing is returned, try different function + # FIXME: depictions_functions could be an empty list for _ in range(3): if len(depiction_functions) != 0: # Pick random depiction function and call it @@ -1400,12 +1487,15 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: Returns: List[Callable]: List of depiction functions """ - depiction_functions = [ - self.depict_and_resize_rdkit, - self.depict_and_resize_indigo, - self.depict_and_resize_cdk, - self.depict_and_resize_pikachu, - ] + + depiction_functions_registry = { + 'rdkit': self.depict_and_resize_rdkit, + 'indigo': self.depict_and_resize_indigo, + 'cdk': self.depict_and_resize_cdk, + 'pikachu': self.depict_and_resize_pikachu, + } + depiction_functions = [depiction_functions_registry[k] for k in self._config.styles] + # Remove PIKAChU if there is an isotope if re.search("(\[\d\d\d?[A-Z])|(\[2H\])|(\[3H\])|(D)|(T)", smiles): depiction_functions.remove(self.depict_and_resize_pikachu) @@ -1418,11 +1508,13 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: depiction_functions.remove(self.depict_and_resize_pikachu) # "R", "X", "Z" are not depicted by RDKit # The same is valid for X,Y,Z and a number - if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): - depiction_functions.remove(self.depict_and_resize_rdkit) + if self.depict_and_resize_rdkit in depiction_functions: + if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): + depiction_functions.remove(self.depict_and_resize_rdkit) # "X", "R0" and indices above 32 are not depicted by Indigo - if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles): - depiction_functions.remove(self.depict_and_resize_indigo) + if self.depict_and_resize_indigo in depiction_functions: + if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles): + depiction_functions.remove(self.depict_and_resize_indigo) return depiction_functions def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: @@ -3030,7 +3122,7 @@ def distribute_elements_evenly( class RandomMarkushStructureCreator: - def __init__(self, *, variables_list=None, max_index=21): + def __init__(self, *, variables_list=None, max_index=20): """ RandomMarkushStructureCreator objects are instantiated with the desired inserted R group variables. Otherwise, "R", "X" and "Z" are used. @@ -3043,7 +3135,7 @@ def __init__(self, *, variables_list=None, max_index=21): else: self.r_group_variables = variables_list - self.potential_indices = range(max_index + 1) + self.potential_indices = range(1, max_index + 1) def generate_markush_structure_dataset(self, smiles_list: List[str]) -> List[str]: """ diff --git a/Tests/test_functions.py b/Tests/test_functions.py index 62fe128..64e5497 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -1,9 +1,12 @@ from RanDepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator from rdkit import DataStructs import numpy as np +from omegaconf import OmegaConf import pytest +from RanDepict import RandomDepictorConfig + class TestDepictionFeatureRanges: DFR = DepictionFeatureRanges() @@ -272,6 +275,76 @@ def test_correct_amount_of_FP_to_pick(self): assert list(picked_fingerprints) == ["A"] * 20 assert corrected_n == 3 +class TestRandomDepictorConstruction: + + def test_default_depicter(self): + """RandomDepictor default values should match that of the default RandomDepictorConfig""" + depicter = RandomDepictor() + config = RandomDepictorConfig() + assert depicter.seed == config.seed + assert depicter.hand_drawn == config.hand_drawn + + def test_init_param_override(self): + """Values passed to init should override defaults""" + config = RandomDepictorConfig() + assert config.seed != 21 + assert not config.hand_drawn + depicter = RandomDepictor(seed=21, hand_drawn=True) + assert depicter.seed == 21 + assert depicter.hand_drawn + + def test_config_override(self): + """Config passed to init should override defaults""" + config = RandomDepictorConfig(seed=21, hand_drawn=True) + depicter = RandomDepictor(config=config) + assert depicter.seed == 21 + assert depicter.hand_drawn + + @pytest.mark.xfail(raises=ValueError) + def test_invalid_style(self): + """Invalid style passed to config should raise exception""" + _ = RandomDepictorConfig(styles=["pio", "cdk"]) + + def test_empty_style_list(self): + """Empty style list passed to config should raise exception""" + with pytest.raises(ValueError) as excinfo: + _ = RandomDepictorConfig(styles=[]) + assert 'Empty list' in str(excinfo.value) + + def test_omega_config_rdc(self): + """Can create RandomDepictorConfig from yaml""" + s = """ + # RandomDepictorConfig: + seed: 21 + augment: False + styles: + - cdk + """ + dict_config = OmegaConf.create(s) + rdc = RandomDepictorConfig.from_config(dict_config) + assert rdc.seed == 21 + assert not rdc.hand_drawn + assert not rdc.augment + assert len(rdc.styles) == 1 + assert 'cdk' in rdc.styles + + def test_omega_config_rd(self, tmp_path): + """Can create RandomDepictor from yaml""" + s = """ + RandomDepictorConfig: + seed: 21 + augment: False + styles: + - cdk + - indigo + """ + temp_config_file = tmp_path / "omg.yaml" + temp_config_file.write_text(s) + rd = RandomDepictor.from_config(config_file=temp_config_file) + assert rd.seed == 21 + assert not rd.hand_drawn + + class TestRandomDepictor: depictor = RandomDepictor() @@ -321,7 +394,9 @@ def test_get_depiction_functions_normal(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + # symmetric_difference + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_isotopes(self): # PIKAChU can't handle isotopes @@ -331,7 +406,8 @@ def test_get_depiction_functions_isotopes(self): self.depictor.depict_and_resize_indigo, self.depictor.depict_and_resize_cdk, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_R(self): # RDKit depicts "R" without indices as '*' (which is not desired) @@ -341,7 +417,8 @@ def test_get_depiction_functions_R(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_X(self): # RDKit and Indigo don't depict "X" @@ -350,7 +427,8 @@ def test_get_depiction_functions_X(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_smiles_to_mol_str(self): # Compare generated mol file str with reference string diff --git a/docs/requirements.txt b/docs/requirements.txt index 46e1d28..53257e4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -11,3 +11,4 @@ rdkit-pypi imagecorruptions pillow>=8 pikachu-chem>=1.0.7 +omegaconf==2.2.3 diff --git a/setup.py b/setup.py index edd916b..33ef5cd 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,8 @@ "imagecorruptions", "pillow>=8.2.0", "pikachu-chem>=1.0.7", + "omegaconf", + "typing-extensions;python_version<'3.8'" ], extras_require={ "dev": ["tox", "pytest"], @@ -36,8 +38,6 @@ package_data={"RanDepict": ["assets/*.*", "assets/*/*.*", "assets/*/*/*.*"]}, classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -46,5 +46,5 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires=">=3.5", + python_requires=">=3.7", ) diff --git a/tox.ini b/tox.ini index 98bf6e6..fc9be28 100644 --- a/tox.ini +++ b/tox.ini @@ -6,11 +6,8 @@ requires = tox-conda setenv = CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 whitelist_externals = python - -[testenv:py37] conda_deps = pytest - rdkit conda_channels = conda-forge commands = pytest --basetemp="{envtmpdir}" {posargs}