From 1ee33bcfcd84203881dbf0558516bc81d0a161bd Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Mon, 3 Oct 2022 13:16:44 -0700 Subject: [PATCH 1/7] run tests for all py envs --- tox.ini | 2 -- 1 file changed, 2 deletions(-) diff --git a/tox.ini b/tox.ini index 98bf6e6..52d85f3 100644 --- a/tox.ini +++ b/tox.ini @@ -6,8 +6,6 @@ requires = tox-conda setenv = CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 whitelist_externals = python - -[testenv:py37] conda_deps = pytest rdkit From 0c85eabcc08857ad6a1dce73458f2b7c2febdb53 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Mon, 3 Oct 2022 13:17:37 -0700 Subject: [PATCH 2/7] no zero index variable --- RanDepict/randepict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index b911eaa..3acd85a 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -3030,7 +3030,7 @@ def distribute_elements_evenly( class RandomMarkushStructureCreator: - def __init__(self, *, variables_list=None, max_index=21): + def __init__(self, *, variables_list=None, max_index=20): """ RandomMarkushStructureCreator objects are instantiated with the desired inserted R group variables. Otherwise, "R", "X" and "Z" are used. @@ -3043,7 +3043,7 @@ def __init__(self, *, variables_list=None, max_index=21): else: self.r_group_variables = variables_list - self.potential_indices = range(max_index + 1) + self.potential_indices = range(1, max_index + 1) def generate_markush_structure_dataset(self, smiles_list: List[str]) -> List[str]: """ From 13698142ff57af6c4a6bfc6a533616af61b01487 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Mon, 3 Oct 2022 15:41:53 -0700 Subject: [PATCH 3/7] add configuration via yaml file --- RanDepict/randepict.py | 56 ++++++++++++++++++++++++++++++++++++------ docs/requirements.txt | 1 + setup.py | 5 ++-- tox.ini | 1 - 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 3acd85a..be01df0 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import os -import pathlib +from pathlib import Path import numpy as np import io from skimage import io as sk_io @@ -10,7 +12,7 @@ import imgaug.augmenters as iaa import random from copy import deepcopy -from typing import Tuple, List, Dict, Any, Callable +from typing import Optional, Tuple, List, Dict, Any, Callable import re from rdkit import Chem @@ -22,6 +24,9 @@ from rdkit.SimDivFilters.rdSimDivPickers import MaxMinPicker from itertools import product +from omegaconf import OmegaConf, DictConfig # configuration package +from dataclasses import dataclass + from indigo import Indigo from indigo.renderer import IndigoRenderer from jpype import startJVM, getDefaultJVMPath @@ -34,6 +39,15 @@ from scipy.ndimage import gaussian_filter from scipy.ndimage import map_coordinates +@dataclass +class RandomDepictorConfig: + seed: int = 42 + hand_drawn: bool = False + augment: bool = True + + @classmethod + def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig': + return OmegaConf.structured(cls(dict_config)) class RandomDepictor: """ @@ -43,12 +57,38 @@ class RandomDepictor: the RGB image with the given chemical structure. """ - def __init__(self, seed: int = 42, hand_drawn: bool = False): + def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: Path = None): """ Load the JVM only once, load superatom list (OSRA), set context for multiprocessing + + Parameters + ---------- + seed : int + seed for random number generator + hand_drawn : bool + Whether to augment with hand drawn features + config : Path object to configuration file in yaml format. + RandomDepictor section is expected. + + Returns + ------- + """ - self.HERE = pathlib.Path(__file__).resolve().parent.joinpath("assets") + # TODO removing seed and hand_drawn args might break existing code + self.seed = seed + self.hand_drawn = hand_drawn + + self._config = RandomDepictorConfig() + if config: + try: + # TODO Needs documentation + self._config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config)[self.__class__.__name__]) + except Exception as e: + print(f"Error loading from {config}. Make sure it has {self.__class__.__name__} section. {e}") + print("Using default config.") + + self.HERE = Path(__file__).resolve().parent.joinpath("assets") # Start the JVM to access Java classes try: @@ -67,8 +107,6 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False): self.jar_path = self.HERE.joinpath("jar_files/cdk-2.8.jar") startJVM(self.jvmPath, "-ea", "-Djava.class.path=" + str(self.jar_path)) - self.seed = seed - self.hand_drawn = hand_drawn random.seed(self.seed) # Load list of superatoms for label generation @@ -105,14 +143,16 @@ def __call__( hand_drawn: bool = False, ): # Depict structure with random parameters + # TODO hand_drawn to this call is ignored. Decide which one to keep hand_drawn = self.hand_drawn if hand_drawn: depiction = self.random_depiction(smiles, shape) - + # TODO is call to hand_drawn_augment missing? else: depiction = self.random_depiction(smiles, shape) # Add augmentations - depiction = self.add_augmentations(depiction) + if self._config.augment: + depiction = self.add_augmentations(depiction) if grayscale: return self.to_grayscale_float_img(depiction) diff --git a/docs/requirements.txt b/docs/requirements.txt index 46e1d28..53257e4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -11,3 +11,4 @@ rdkit-pypi imagecorruptions pillow>=8 pikachu-chem>=1.0.7 +omegaconf==2.2.3 diff --git a/setup.py b/setup.py index edd916b..fed5928 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "imagecorruptions", "pillow>=8.2.0", "pikachu-chem>=1.0.7", + 'omegaconf', ], extras_require={ "dev": ["tox", "pytest"], @@ -36,8 +37,6 @@ package_data={"RanDepict": ["assets/*.*", "assets/*/*.*", "assets/*/*/*.*"]}, classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -46,5 +45,5 @@ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires=">=3.5", + python_requires=">=3.7", ) diff --git a/tox.ini b/tox.ini index 52d85f3..fc9be28 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ setenv = whitelist_externals = python conda_deps = pytest - rdkit conda_channels = conda-forge commands = pytest --basetemp="{envtmpdir}" {posargs} From 688602914789253c802ac3e7a3441807f2faf9d6 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Mon, 3 Oct 2022 16:18:03 -0700 Subject: [PATCH 4/7] init to accept config class, add from_config --- RanDepict/randepict.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index be01df0..f485cfe 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -57,7 +57,7 @@ class RandomDepictor: the RGB image with the given chemical structure. """ - def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: Path = None): + def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: RandomDepictorConfig = None): """ Load the JVM only once, load superatom list (OSRA), set context for multiprocessing @@ -75,18 +75,14 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: Path = N ------- """ - # TODO removing seed and hand_drawn args might break existing code + # TODO remove seed and hand_drawn args but watch out b/c might break existing code self.seed = seed self.hand_drawn = hand_drawn - self._config = RandomDepictorConfig() - if config: - try: - # TODO Needs documentation - self._config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config)[self.__class__.__name__]) - except Exception as e: - print(f"Error loading from {config}. Make sure it has {self.__class__.__name__} section. {e}") - print("Using default config.") + if config is None: + self._config = RandomDepictorConfig() + else: + self._config = config self.HERE = Path(__file__).resolve().parent.joinpath("assets") @@ -134,6 +130,18 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: Path = N set_start_method("spawn") except RuntimeError: pass + + @classmethod + def from_config(cls, config_file: Path) -> 'RandomDepictor': + try: + # TODO Needs documentation + config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[cls.__name__]) + except Exception as e: + print(f"Error loading from {config}. Make sure it has {cls.__name__} section. {e}") + print("Using default config.") + config = RandomDepictorConfig() + return RandomDepictor(config=config) + def __call__( self, From f0457d93d6bada59c69467b71695c8cd86a49ac5 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Mon, 3 Oct 2022 22:43:19 -0700 Subject: [PATCH 5/7] fix from_config, enable style selection --- README.md | 1 + RanDepict/__init__.py | 14 ++++----- RanDepict/randepict.py | 66 +++++++++++++++++++++++++++++++---------- Tests/test_functions.py | 13 +++++--- setup.py | 3 +- 5 files changed, 69 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index d0b3ffd..c5f8956 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ $ bash Miniconda3-latest-Linux-x86_64.sh $ echo -e "channels:\n - conda-forge\n - nodefaults" > ~/.condarc $ conda update conda $ conda install conda-libmamba-solver +$ conda config --set experimental_solver libmamba $ conda create --name RanDepict python=3.8 $ conda activate RanDepict # pypi has rdkit so not necessary to install it using conda diff --git a/RanDepict/__init__.py b/RanDepict/__init__.py index 291f5a8..c7c7a5e 100644 --- a/RanDepict/__init__.py +++ b/RanDepict/__init__.py @@ -7,12 +7,12 @@ chemical structure depictions (random depiction styles and image augmentations). -Typical usage example: - -from RanDepict import RandomDepictor -smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" -with RandomDepictor() as depictor: - image = depictor(smiles) +Example: +-------- +>>> from RanDepict import RandomDepictor +>>> smiles = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" +>>> with RandomDepictor() as depictor: +>>> image = depictor(smiles) Have a look in the RanDepictNotebook.ipynb for more examples. @@ -28,4 +28,4 @@ ] -from .randepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator +from .randepict import RandomDepictor, RandomDepictorConfig, DepictionFeatureRanges, RandomMarkushStructureCreator diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index f485cfe..c54d224 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -25,7 +25,7 @@ from itertools import product from omegaconf import OmegaConf, DictConfig # configuration package -from dataclasses import dataclass +from dataclasses import dataclass, field from indigo import Indigo from indigo.renderer import IndigoRenderer @@ -41,13 +41,34 @@ @dataclass class RandomDepictorConfig: + """ + Examples + -------- + >>> c1 = RandomDepictorConfig(seed=24, styles=["cdk", "indigo"]) + >>> c1 + RandomDepictorConfig(seed=24, hand_drawn=False, augment=True, styles=['cdk', 'indigo']) + >>> c2 = RandomDepictorConfig(styles=["cdk", "indigo", "pikachu", "rdkit"]) + >>> c2 + RandomDepictorConfig(seed=42, hand_drawn=False, augment=True, styles=['cdk', 'indigo', 'pikachu', 'rdkit']) + """ seed: int = 42 hand_drawn: bool = False augment: bool = True + # unions of containers are not supported yet + # https://github.com/omry/omegaconf/issues/144 + # styles: Union[str, List[str]] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) + styles: List[str] = field(default_factory=lambda: ["cdk", "indigo", "pikachu", "rdkit"]) @classmethod def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepictorConfig': - return OmegaConf.structured(cls(dict_config)) + return OmegaConf.structured(cls(**dict_config)) + + def __post_init__(self): + """Ensure styles are always List[str] even if "cdk, indigo" is passed""" + # TODO make sure styles are valid... + if isinstance(self.styles, str): + self.styles = [v.strip() for v in self.styles.split(",")] + class RandomDepictor: """ @@ -130,18 +151,25 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: RandomDe set_start_method("spawn") except RuntimeError: pass - + @classmethod def from_config(cls, config_file: Path) -> 'RandomDepictor': try: - # TODO Needs documentation + # TODO Needs documentation of config_file yaml format... + """ + # randepict.yaml + RandomDepictor: + seed: 42 + augment: False + styles: + - cdk + """ config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[cls.__name__]) except Exception as e: - print(f"Error loading from {config}. Make sure it has {cls.__name__} section. {e}") + print(f"Error loading from {config_file}. Make sure it has {cls.__name__} section. {e}") print("Using default config.") config = RandomDepictorConfig() return RandomDepictor(config=config) - def __call__( self, @@ -1394,6 +1422,7 @@ def random_depiction( """ depiction_functions = self.get_depiction_functions(smiles) # If nothing is returned, try different function + # FIXME: depictions_functions could be an empty list for _ in range(3): if len(depiction_functions) != 0: # Pick random depiction function and call it @@ -1448,12 +1477,15 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: Returns: List[Callable]: List of depiction functions """ - depiction_functions = [ - self.depict_and_resize_rdkit, - self.depict_and_resize_indigo, - self.depict_and_resize_cdk, - self.depict_and_resize_pikachu, - ] + + depiction_functions_registry = { + 'rdkit': self.depict_and_resize_rdkit, + 'indigo': self.depict_and_resize_indigo, + 'cdk': self.depict_and_resize_cdk, + 'pikachu': self.depict_and_resize_pikachu, + } + depiction_functions = [depiction_functions_registry[k] for k in self._config.styles] + # Remove PIKAChU if there is an isotope if re.search("(\[\d\d\d?[A-Z])|(\[2H\])|(\[3H\])|(D)|(T)", smiles): depiction_functions.remove(self.depict_and_resize_pikachu) @@ -1466,11 +1498,13 @@ def get_depiction_functions(self, smiles: str) -> List[Callable]: depiction_functions.remove(self.depict_and_resize_pikachu) # "R", "X", "Z" are not depicted by RDKit # The same is valid for X,Y,Z and a number - if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): - depiction_functions.remove(self.depict_and_resize_rdkit) + if self.depict_and_resize_rdkit in depiction_functions: + if re.search("\[[RXZ]\]|\[[XYZ]\d+", smiles): + depiction_functions.remove(self.depict_and_resize_rdkit) # "X", "R0" and indices above 32 are not depicted by Indigo - if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles): - depiction_functions.remove(self.depict_and_resize_indigo) + if self.depict_and_resize_indigo in depiction_functions: + if re.search("\[R0\]|\[X\]|[4-9][0-9]+|3[3-9]", smiles): + depiction_functions.remove(self.depict_and_resize_indigo) return depiction_functions def resize(self, image: np.array, shape: Tuple[int], HQ: bool = False) -> np.array: diff --git a/Tests/test_functions.py b/Tests/test_functions.py index 62fe128..ac41398 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -321,7 +321,9 @@ def test_get_depiction_functions_normal(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + # symmetric_difference + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_isotopes(self): # PIKAChU can't handle isotopes @@ -331,7 +333,8 @@ def test_get_depiction_functions_isotopes(self): self.depictor.depict_and_resize_indigo, self.depictor.depict_and_resize_cdk, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_R(self): # RDKit depicts "R" without indices as '*' (which is not desired) @@ -341,7 +344,8 @@ def test_get_depiction_functions_R(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_get_depiction_functions_X(self): # RDKit and Indigo don't depict "X" @@ -350,7 +354,8 @@ def test_get_depiction_functions_X(self): self.depictor.depict_and_resize_cdk, self.depictor.depict_and_resize_pikachu, ] - assert observed == expected + difference = set(observed) ^ set(expected) + assert not difference def test_smiles_to_mol_str(self): # Compare generated mol file str with reference string diff --git a/setup.py b/setup.py index fed5928..33ef5cd 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ "imagecorruptions", "pillow>=8.2.0", "pikachu-chem>=1.0.7", - 'omegaconf', + "omegaconf", + "typing-extensions;python_version<'3.8'" ], extras_require={ "dev": ["tox", "pytest"], From 5803be80a63761020cddf9f904600feaad0e7e29 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Thu, 6 Oct 2022 23:47:22 -0700 Subject: [PATCH 6/7] add tests for RandomDepictorConfig --- RanDepict/randepict.py | 24 +++++++++++++++++------- Tests/test_functions.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index c54d224..9059ea0 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -1,4 +1,5 @@ from __future__ import annotations +import copy import os from pathlib import Path @@ -64,10 +65,15 @@ def from_config(cls, dict_config: Optional[DictConfig] = None) -> 'RandomDepicto return OmegaConf.structured(cls(**dict_config)) def __post_init__(self): - """Ensure styles are always List[str] even if "cdk, indigo" is passed""" - # TODO make sure styles are valid... + # Ensure styles are always List[str] when "cdk, indigo" is passed if isinstance(self.styles, str): self.styles = [v.strip() for v in self.styles.split(",")] + if len(self.styles) == 0: + raise ValueError("Empty list of styles was supplied.") + # Not sure if this is the best way in order to not repeat the list of styles + ss = set(self.__dataclass_fields__['styles'].default_factory()) + if any([s not in ss for s in self.styles]): + raise ValueError(f"Use only {', '.join(ss)}") class RandomDepictor: @@ -78,7 +84,7 @@ class RandomDepictor: the RGB image with the given chemical structure. """ - def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: RandomDepictorConfig = None): + def __init__(self, seed: Optional[int] = None, hand_drawn: Optional[bool] = None, *, config: RandomDepictorConfig = None): """ Load the JVM only once, load superatom list (OSRA), set context for multiprocessing @@ -96,14 +102,18 @@ def __init__(self, seed: int = 42, hand_drawn: bool = False, *, config: RandomDe ------- """ - # TODO remove seed and hand_drawn args but watch out b/c might break existing code - self.seed = seed - self.hand_drawn = hand_drawn if config is None: self._config = RandomDepictorConfig() else: - self._config = config + self._config = copy.deepcopy(config) + if seed is not None: + self._config.seed = seed + if hand_drawn is not None: + self._config.hand_drawn = hand_drawn + + self.seed = self._config.seed + self.hand_drawn = self._config.hand_drawn self.HERE = Path(__file__).resolve().parent.joinpath("assets") diff --git a/Tests/test_functions.py b/Tests/test_functions.py index ac41398..b4ce656 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -4,6 +4,8 @@ import pytest +from RanDepict import RandomDepictorConfig + class TestDepictionFeatureRanges: DFR = DepictionFeatureRanges() @@ -272,6 +274,43 @@ def test_correct_amount_of_FP_to_pick(self): assert list(picked_fingerprints) == ["A"] * 20 assert corrected_n == 3 +class TestRandomDepictorConstruction: + + def test_default_depicter(self): + """RandomDepictor default values should match that of the default RandomDepictorConfig""" + depicter = RandomDepictor() + config = RandomDepictorConfig() + assert depicter.seed == config.seed + assert depicter.hand_drawn == config.hand_drawn + + def test_init_param_override(self): + """Values passed to init should override defaults""" + config = RandomDepictorConfig() + assert config.seed != 21 + assert not config.hand_drawn + depicter = RandomDepictor(seed=21, hand_drawn=True) + assert depicter.seed == 21 + assert depicter.hand_drawn + + def test_config_override(self): + """Config passed to init should override defaults""" + config = RandomDepictorConfig(seed=21, hand_drawn=True) + depicter = RandomDepictor(config=config) + assert depicter.seed == 21 + assert depicter.hand_drawn + + @pytest.mark.xfail(raises=ValueError) + def test_invalid_style(self): + """Invalid style passed to config should raise exception""" + _ = RandomDepictorConfig(styles=["pio", "cdk"]) + + def test_empty_style_list(self): + """Empty style list passed to config should raise exception""" + with pytest.raises(ValueError) as excinfo: + _ = RandomDepictorConfig(styles=[]) + assert 'Empty list' in str(excinfo.value) + + class TestRandomDepictor: depictor = RandomDepictor() From ae387567320e22ae890d7d060028f57ad2bfb3d7 Mon Sep 17 00:00:00 2001 From: Tulay Muezzinoglu Date: Fri, 7 Oct 2022 00:58:04 -0700 Subject: [PATCH 7/7] change yaml section, add file config test --- RanDepict/randepict.py | 4 ++-- Tests/test_functions.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/RanDepict/randepict.py b/RanDepict/randepict.py index 9059ea0..7419aac 100644 --- a/RanDepict/randepict.py +++ b/RanDepict/randepict.py @@ -168,13 +168,13 @@ def from_config(cls, config_file: Path) -> 'RandomDepictor': # TODO Needs documentation of config_file yaml format... """ # randepict.yaml - RandomDepictor: + RandomDepictorConfig: seed: 42 augment: False styles: - cdk """ - config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[cls.__name__]) + config: RandomDepictorConfig = RandomDepictorConfig.from_config(OmegaConf.load(config_file)[RandomDepictorConfig.__name__]) except Exception as e: print(f"Error loading from {config_file}. Make sure it has {cls.__name__} section. {e}") print("Using default config.") diff --git a/Tests/test_functions.py b/Tests/test_functions.py index b4ce656..64e5497 100644 --- a/Tests/test_functions.py +++ b/Tests/test_functions.py @@ -1,6 +1,7 @@ from RanDepict import RandomDepictor, DepictionFeatureRanges, RandomMarkushStructureCreator from rdkit import DataStructs import numpy as np +from omegaconf import OmegaConf import pytest @@ -310,6 +311,39 @@ def test_empty_style_list(self): _ = RandomDepictorConfig(styles=[]) assert 'Empty list' in str(excinfo.value) + def test_omega_config_rdc(self): + """Can create RandomDepictorConfig from yaml""" + s = """ + # RandomDepictorConfig: + seed: 21 + augment: False + styles: + - cdk + """ + dict_config = OmegaConf.create(s) + rdc = RandomDepictorConfig.from_config(dict_config) + assert rdc.seed == 21 + assert not rdc.hand_drawn + assert not rdc.augment + assert len(rdc.styles) == 1 + assert 'cdk' in rdc.styles + + def test_omega_config_rd(self, tmp_path): + """Can create RandomDepictor from yaml""" + s = """ + RandomDepictorConfig: + seed: 21 + augment: False + styles: + - cdk + - indigo + """ + temp_config_file = tmp_path / "omg.yaml" + temp_config_file.write_text(s) + rd = RandomDepictor.from_config(config_file=temp_config_file) + assert rd.seed == 21 + assert not rd.hand_drawn + class TestRandomDepictor: depictor = RandomDepictor()