From 7981294c8237b77767373b794169ac770ce4ef41 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Wed, 16 Oct 2024 14:40:13 +0200 Subject: [PATCH 1/2] Prepare new release (#3) * update rc_split * format package --- Test/{ => SynChem}/Fingerprint/__init__.py | 0 .../Fingerprint/test_fp_calculator.py | 2 +- .../Fingerprint/test_smiles_featurizer.py | 2 +- .../Fingerprint/test_transformation_fp.py | 2 +- Test/{Partition => SynChem}/__init__.py | 0 Test/{Visualization => SynSplit}/__init__.py | 0 .../test_data_split.py} | 33 +- .../test_random_split.py} | 6 +- Test/SynSplit/test_rc_split.py | 95 +++++ .../test_stratified_reduction_split.py} | 10 +- .../test_stratified_split.py} | 12 +- .../Fingerprint => Test/SynVis}/__init__.py | 0 .../test_embedding.py | 2 +- lint.sh | 2 +- pyproject.toml | 5 +- .../Fingerprint}/__init__.py | 0 .../Fingerprint/fp_calculator.py | 4 +- .../Fingerprint/smiles_featurizer.py | 0 .../Fingerprint/transformation_fp.py | 4 +- .../{Visualization => SynChem}/__init__.py | 0 synutility/SynIO/__init__.py | 0 synutility/SynIO/data_type.py | 213 ++++++++++ synutility/{utils.py => SynIO/debug.py} | 119 +----- synutility/SynSplit/__init__.py | 0 .../data_split.py} | 82 +++- .../random_split.py} | 4 +- synutility/SynSplit/rc_split.py | 85 ++++ .../stratified_reduction_partition.py | 6 +- .../stratified_split.py} | 4 +- synutility/SynVis/__init__.py | 0 .../SynVis/chemical_graph_visualizer.py | 378 ++++++++++++++++++ .../SynVis/chemical_reaction_visualizer.py | 133 ++++++ .../chemical_space.py | 0 .../{Visualization => SynVis}/embedding.py | 0 34 files changed, 1028 insertions(+), 175 deletions(-) rename Test/{ => SynChem}/Fingerprint/__init__.py (100%) rename Test/{ => SynChem}/Fingerprint/test_fp_calculator.py (97%) rename Test/{ => SynChem}/Fingerprint/test_smiles_featurizer.py (97%) rename Test/{ => SynChem}/Fingerprint/test_transformation_fp.py (96%) rename Test/{Partition => SynChem}/__init__.py (100%) rename Test/{Visualization => SynSplit}/__init__.py (100%) rename Test/{Partition/test_data_partition.py => SynSplit/test_data_split.py} (73%) rename Test/{Partition/test_random_partition.py => SynSplit/test_random_split.py} (90%) create mode 100644 Test/SynSplit/test_rc_split.py rename Test/{Partition/test_stratified_reduction_partition.py => SynSplit/test_stratified_reduction_split.py} (88%) rename Test/{Partition/test_stratified_parition.py => SynSplit/test_stratified_split.py} (90%) rename {synutility/Fingerprint => Test/SynVis}/__init__.py (100%) rename Test/{Visualization => SynVis}/test_embedding.py (96%) rename synutility/{Partition => SynChem/Fingerprint}/__init__.py (100%) rename synutility/{ => SynChem}/Fingerprint/fp_calculator.py (96%) rename synutility/{ => SynChem}/Fingerprint/smiles_featurizer.py (100%) rename synutility/{ => SynChem}/Fingerprint/transformation_fp.py (96%) rename synutility/{Visualization => SynChem}/__init__.py (100%) create mode 100644 synutility/SynIO/__init__.py create mode 100644 synutility/SynIO/data_type.py rename synutility/{utils.py => SynIO/debug.py} (51%) create mode 100644 synutility/SynSplit/__init__.py rename synutility/{Partition/data_partition.py => SynSplit/data_split.py} (58%) rename synutility/{Partition/random_parition.py => SynSplit/random_split.py} (94%) create mode 100644 synutility/SynSplit/rc_split.py rename synutility/{Partition => SynSplit}/stratified_reduction_partition.py (96%) rename synutility/{Partition/stratified_partition.py => SynSplit/stratified_split.py} (95%) create mode 100644 synutility/SynVis/__init__.py create mode 100644 synutility/SynVis/chemical_graph_visualizer.py create mode 100644 synutility/SynVis/chemical_reaction_visualizer.py rename synutility/{Visualization => SynVis}/chemical_space.py (100%) rename synutility/{Visualization => SynVis}/embedding.py (100%) diff --git a/Test/Fingerprint/__init__.py b/Test/SynChem/Fingerprint/__init__.py similarity index 100% rename from Test/Fingerprint/__init__.py rename to Test/SynChem/Fingerprint/__init__.py diff --git a/Test/Fingerprint/test_fp_calculator.py b/Test/SynChem/Fingerprint/test_fp_calculator.py similarity index 97% rename from Test/Fingerprint/test_fp_calculator.py rename to Test/SynChem/Fingerprint/test_fp_calculator.py index ebf812c..ab567a2 100644 --- a/Test/Fingerprint/test_fp_calculator.py +++ b/Test/SynChem/Fingerprint/test_fp_calculator.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np -from synutility.Fingerprint.fp_calculator import FPCalculator +from synutility.SynChem.Fingerprint.fp_calculator import FPCalculator class TestFPCalculator(unittest.TestCase): diff --git a/Test/Fingerprint/test_smiles_featurizer.py b/Test/SynChem/Fingerprint/test_smiles_featurizer.py similarity index 97% rename from Test/Fingerprint/test_smiles_featurizer.py rename to Test/SynChem/Fingerprint/test_smiles_featurizer.py index 305220f..d490e56 100644 --- a/Test/Fingerprint/test_smiles_featurizer.py +++ b/Test/SynChem/Fingerprint/test_smiles_featurizer.py @@ -3,7 +3,7 @@ from rdkit.Chem import MACCSkeys import numpy as np -from synutility.Fingerprint.smiles_featurizer import SmilesFeaturizer +from synutility.SynChem.Fingerprint.smiles_featurizer import SmilesFeaturizer class TestSmilesFeaturizer(unittest.TestCase): diff --git a/Test/Fingerprint/test_transformation_fp.py b/Test/SynChem/Fingerprint/test_transformation_fp.py similarity index 96% rename from Test/Fingerprint/test_transformation_fp.py rename to Test/SynChem/Fingerprint/test_transformation_fp.py index 4a11c6a..9974229 100644 --- a/Test/Fingerprint/test_transformation_fp.py +++ b/Test/SynChem/Fingerprint/test_transformation_fp.py @@ -2,7 +2,7 @@ import numpy as np from rdkit.DataStructs import cDataStructs -from synutility.Fingerprint.transformation_fp import TransformationFP +from synutility.SynChem.Fingerprint.transformation_fp import TransformationFP class TestTransformationFP(unittest.TestCase): diff --git a/Test/Partition/__init__.py b/Test/SynChem/__init__.py similarity index 100% rename from Test/Partition/__init__.py rename to Test/SynChem/__init__.py diff --git a/Test/Visualization/__init__.py b/Test/SynSplit/__init__.py similarity index 100% rename from Test/Visualization/__init__.py rename to Test/SynSplit/__init__.py diff --git a/Test/Partition/test_data_partition.py b/Test/SynSplit/test_data_split.py similarity index 73% rename from Test/Partition/test_data_partition.py rename to Test/SynSplit/test_data_split.py index f7f455d..18bbccc 100644 --- a/Test/Partition/test_data_partition.py +++ b/Test/SynSplit/test_data_split.py @@ -1,9 +1,9 @@ import unittest import pandas as pd -from synutility.Partition.data_partition import DataPartition +from synutility.SynSplit.data_split import DataSplit -class TestDataPartition(unittest.TestCase): +class TestDataSplit(unittest.TestCase): def setUp(self): # Create a mock dataset @@ -34,7 +34,7 @@ def setUp(self): def test_random_partition(self): # Test random partition method with real RandomPartition class - partitioner = DataPartition( + partitioner = DataSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -51,7 +51,7 @@ def test_random_partition(self): def test_stratified_partition(self): # Test stratified partition method with real StratifiedPartition class - partitioner = DataPartition( + partitioner = DataSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -74,7 +74,30 @@ def test_stratified_partition(self): def test_stratified_class_reduction_partition(self): # Test stratified class reduction partition method with real StratifiedReductionPartition class - partitioner = DataPartition( + partitioner = DataSplit( + data=self.data, + test_size=self.test_size, + class_column=self.class_column, + method="stratified_class_reduction", + random_state=self.random_state, + drop_class_ratio=0.4, + ) + train_data, test_data, removed_data = partitioner.fit() + + # Check that the partitioning happened correctly + self.assertEqual( + len(train_data) + len(test_data) + len(removed_data), len(self.data) + ) + self.assertAlmostEqual( + len(test_data) / len(self.data), self.test_size, delta=0.05 + ) + + # Ensure that some data was removed based on the reduction logic + self.assertGreater(len(removed_data), 0) + + def test_reaction_center_split(self): + # Test stratified class reduction partition method with real StratifiedReductionPartition class + partitioner = DataSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, diff --git a/Test/Partition/test_random_partition.py b/Test/SynSplit/test_random_split.py similarity index 90% rename from Test/Partition/test_random_partition.py rename to Test/SynSplit/test_random_split.py index 2c4c572..f304572 100644 --- a/Test/Partition/test_random_partition.py +++ b/Test/SynSplit/test_random_split.py @@ -1,10 +1,10 @@ import unittest import pandas as pd -from synutility.Partition.random_parition import RandomPartition +from synutility.SynSplit.random_split import RandomSplit -class TestRandomPartition(unittest.TestCase): +class TestRandomSplit(unittest.TestCase): def setUp(self): # Sample data setup self.data = pd.DataFrame( @@ -15,7 +15,7 @@ def setUp(self): self.random_state = 42 # Instantiate RandomPartition - self.random_partition = RandomPartition( + self.random_partition = RandomSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, diff --git a/Test/SynSplit/test_rc_split.py b/Test/SynSplit/test_rc_split.py new file mode 100644 index 0000000..ef676d5 --- /dev/null +++ b/Test/SynSplit/test_rc_split.py @@ -0,0 +1,95 @@ +import unittest +import numpy as np +import pandas as pd +from synutility.SynSplit.rc_split import RCSplit + + +class TestRCSplit(unittest.TestCase): + + def setUp(self): + """Set up a simple dataset for testing.""" + np.random.seed(42) + # Create a DataFrame with 100 entries and a skewed distribution of classes + data = { + "class": np.random.choice( + ["A", "B", "C", "D", "E"], p=[0.2, 0.3, 0.3, 0.1, 0.1], size=100 + ), + "feature": np.random.rand(100), + } + self.df = pd.DataFrame(data) + + self.test_size = 0.2 # 20% of data as test set + self.random_state = 42 + self.class_column = "class" + + def test_fit_function(self): + """Test the fit function to ensure it divides the dataset correctly + according to the test_size.""" + splitter = RCSplit( + data=self.df, + test_size=self.test_size, + class_column=self.class_column, + random_state=self.random_state, + ) + train_set, test_set = splitter.fit() + + # Check the proportion of the test set + actual_test_size = len(test_set) / len(self.df) + self.assertAlmostEqual( + actual_test_size, + self.test_size, + delta=0.05, + msg="Test set size is not within the acceptable range.", + ) + + # Verify that test set does not exceed the test_size even slightly + self.assertTrue( + actual_test_size <= self.test_size, + "Test set size exceeds the specified test_size.", + ) + + # Check if there's no overlap in the indices between train and test sets + self.assertEqual( + len(set(train_set.index).intersection(set(test_set.index))), + 0, + "Train and test sets have overlapping indices.", + ) + + self.assertEqual( + len( + set(train_set[self.class_column]).intersection( + set(test_set[self.class_column]) + ) + ), + 0, + "Train and test sets have overlapping indices.", + ) + + # def test_randomness_of_split(self): + # """Test if different random states produce different splits.""" + # splitter1 = RCSplit( + # data=self.df, + # test_size=0.1, + # class_column=self.class_column, + # random_state=42, + # ) + # splitter2 = RCSplit( + # data=self.df, + # test_size=0.1, + # class_column=self.class_column, + # random_state=1, + # ) + + # _, test_set1 = splitter1.fit() + # _, test_set2 = splitter2.fit() + + # # Assert that the two test sets with different random seeds are not the same + # self.assertNotEqual( + # test_set1.index.tolist(), + # test_set2.index.tolist(), + # "Test sets are identical for different random states.", + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/Partition/test_stratified_reduction_partition.py b/Test/SynSplit/test_stratified_reduction_split.py similarity index 88% rename from Test/Partition/test_stratified_reduction_partition.py rename to Test/SynSplit/test_stratified_reduction_split.py index 9e4881e..9969a6a 100644 --- a/Test/Partition/test_stratified_reduction_partition.py +++ b/Test/SynSplit/test_stratified_reduction_split.py @@ -1,11 +1,11 @@ import unittest import pandas as pd -from synutility.Partition.stratified_reduction_partition import ( - StratifiedReductionPartition, +from synutility.SynSplit.stratified_reduction_partition import ( + StratifiedReductionSplit, ) -class TestStratifiedReductionPartition(unittest.TestCase): +class TestStratifiedReductionSplit(unittest.TestCase): def setUp(self): # Sample data for testing @@ -30,7 +30,7 @@ def setUp(self): ] } ) - self.srp = StratifiedReductionPartition( + self.srp = StratifiedReductionSplit( data=data, test_size=0.2, drop_class_ratio=0.1, @@ -47,7 +47,7 @@ def test_initialization(self): def test_random_select(self): test_dict = {"A": 0.5, "B": 0.3, "C": 0.2} - selected = StratifiedReductionPartition.random_select( + selected = StratifiedReductionSplit.random_select( test_dict, 0.2, random_state=42 ) self.assertIn("C", selected) diff --git a/Test/Partition/test_stratified_parition.py b/Test/SynSplit/test_stratified_split.py similarity index 90% rename from Test/Partition/test_stratified_parition.py rename to Test/SynSplit/test_stratified_split.py index 68ae2ff..a0683a2 100644 --- a/Test/Partition/test_stratified_parition.py +++ b/Test/SynSplit/test_stratified_split.py @@ -2,10 +2,10 @@ import pandas as pd import numpy as np -from synutility.Partition.stratified_partition import StratifiedPartition +from synutility.SynSplit.stratified_split import StratifiedSplit -class TestStratifiedPartition(unittest.TestCase): +class TestStratifiedSplit(unittest.TestCase): def setUp(self): # Sample data setup self.data = pd.DataFrame( @@ -21,7 +21,7 @@ def setUp(self): self.random_state = 42 # Instantiate StratifiedPartition - self.strat_partition = StratifiedPartition( + self.strat_partition = StratifiedSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -36,7 +36,7 @@ def test_constructor_and_attribute_assignment(self): def test_constructor_raises_for_invalid_class_column(self): with self.assertRaises(ValueError): - StratifiedPartition( + StratifiedSplit( data=pd.DataFrame({"feature": [1, 2, 3]}), test_size=0.25, class_column="nonexistent", @@ -45,14 +45,14 @@ def test_constructor_raises_for_invalid_class_column(self): def test_constructor_raises_for_invalid_test_size(self): with self.assertRaises(ValueError): - StratifiedPartition( + StratifiedSplit( data=self.data, test_size=-0.1, class_column=self.class_column, random_state=42, ) with self.assertRaises(ValueError): - StratifiedPartition( + StratifiedSplit( data=self.data, test_size=1.5, class_column=self.class_column, diff --git a/synutility/Fingerprint/__init__.py b/Test/SynVis/__init__.py similarity index 100% rename from synutility/Fingerprint/__init__.py rename to Test/SynVis/__init__.py diff --git a/Test/Visualization/test_embedding.py b/Test/SynVis/test_embedding.py similarity index 96% rename from Test/Visualization/test_embedding.py rename to Test/SynVis/test_embedding.py index 44f2658..c9e02a0 100644 --- a/Test/Visualization/test_embedding.py +++ b/Test/SynVis/test_embedding.py @@ -1,7 +1,7 @@ import unittest import numpy as np -from synutility.Visualization.embedding import Embedding +from synutility.SynVis.embedding import Embedding class TestEmbedding(unittest.TestCase): diff --git a/lint.sh b/lint.sh index 922e801..7eb3775 100755 --- a/lint.sh +++ b/lint.sh @@ -1,6 +1,6 @@ #!/bin/bash flake8 . --count --max-complexity=13 --max-line-length=120 \ - --per-file-ignores="__init__.py:F401" \ + --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501" \ --exclude venv \ --statistics diff --git a/pyproject.toml b/pyproject.toml index 0f4e18f..291d4f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,15 @@ classifiers = [ ] dependencies = [ "scikit-learn==1.5.1", - "xgboost==2.1.1", "pandas>=1.5.3", "rdkit>=2024.3.3", "networkx==3.3", "seaborn==0.13.2", - "drfp==0.3.6", ] +[project.optional-dependencies] +all = ["drfp==0.3.6", "xgboost==2.1.1"] + [project.urls] homepage = "https://github.com/TieuLongPhan/SynUtils" source = "https://github.com/TieuLongPhan/SynUtils" diff --git a/synutility/Partition/__init__.py b/synutility/SynChem/Fingerprint/__init__.py similarity index 100% rename from synutility/Partition/__init__.py rename to synutility/SynChem/Fingerprint/__init__.py diff --git a/synutility/Fingerprint/fp_calculator.py b/synutility/SynChem/Fingerprint/fp_calculator.py similarity index 96% rename from synutility/Fingerprint/fp_calculator.py rename to synutility/SynChem/Fingerprint/fp_calculator.py index ab717f6..93f1923 100644 --- a/synutility/Fingerprint/fp_calculator.py +++ b/synutility/SynChem/Fingerprint/fp_calculator.py @@ -3,8 +3,8 @@ from drfp import DrfpEncoder from joblib import Parallel, delayed from typing import Optional -from synutility.utils import setup_logging -from synutility.Fingerprint.transformation_fp import TransformationFP +from synutility.SynIO.debug import setup_logging +from synutility.SynChem.Fingerprint.transformation_fp import TransformationFP class FPCalculator: diff --git a/synutility/Fingerprint/smiles_featurizer.py b/synutility/SynChem/Fingerprint/smiles_featurizer.py similarity index 100% rename from synutility/Fingerprint/smiles_featurizer.py rename to synutility/SynChem/Fingerprint/smiles_featurizer.py diff --git a/synutility/Fingerprint/transformation_fp.py b/synutility/SynChem/Fingerprint/transformation_fp.py similarity index 96% rename from synutility/Fingerprint/transformation_fp.py rename to synutility/SynChem/Fingerprint/transformation_fp.py index 9818762..cb79e1f 100644 --- a/synutility/Fingerprint/transformation_fp.py +++ b/synutility/SynChem/Fingerprint/transformation_fp.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union, Any from rdkit.DataStructs import cDataStructs -from synutility.Fingerprint.smiles_featurizer import SmilesFeaturizer +from synutility.SynChem.Fingerprint.smiles_featurizer import SmilesFeaturizer class TransformationFP: @@ -39,7 +39,7 @@ def fit( fp_type: str, abs: bool, return_array: bool = True, - **kwargs: Any + **kwargs: Any, ) -> Union[np.ndarray, cDataStructs.ExplicitBitVect]: """ Generates a reaction fingerprint for a given reaction represented by a SMILES string. diff --git a/synutility/Visualization/__init__.py b/synutility/SynChem/__init__.py similarity index 100% rename from synutility/Visualization/__init__.py rename to synutility/SynChem/__init__.py diff --git a/synutility/SynIO/__init__.py b/synutility/SynIO/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynIO/data_type.py b/synutility/SynIO/data_type.py new file mode 100644 index 0000000..e2e7d4a --- /dev/null +++ b/synutility/SynIO/data_type.py @@ -0,0 +1,213 @@ +import json +import pickle +import numpy as np +from numpy import ndarray +from joblib import dump, load +from typing import List, Dict, Any +from synutility.SynIO.debug import setup_logging + +logger = setup_logging() + + +def save_database(database: list[dict], pathname: str = "./Data/database.json") -> None: + """ + Save a database (a list of dictionaries) to a JSON file. + + Args: + database: The database to be saved. + pathname: The path where the database will be saved. + Defaults to './Data/database.json'. + + Raises: + TypeError: If the database is not a list of dictionaries. + ValueError: If there is an error writing the file. + """ + if not all(isinstance(item, dict) for item in database): + raise TypeError("Database should be a list of dictionaries.") + + try: + with open(pathname, "w") as f: + json.dump(database, f) + except IOError as e: + raise ValueError(f"Error writing to file {pathname}: {e}") + + +def load_database(pathname: str = "./Data/database.json") -> List[Dict]: + """ + Load a database (a list of dictionaries) from a JSON file. + + Args: + pathname: The path from where the database will be loaded. + Defaults to './Data/database.json'. + + Returns: + The loaded database. + + Raises: + ValueError: If there is an error reading the file. + """ + try: + with open(pathname, "r") as f: + database = json.load(f) # Load the JSON data from the file + return database + except IOError as e: + raise ValueError(f"Error reading to file {pathname}: {e}") + + +def save_to_pickle(data: List[Dict[str, Any]], filename: str) -> None: + """ + Save a list of dictionaries to a pickle file. + + Parameters: + data (List[Dict[str, Any]]): A list of dictionaries to be saved. + filename (str): The name of the file where the data will be saved. + """ + with open(filename, "wb") as file: + pickle.dump(data, file) + + +def load_from_pickle(filename: str) -> List[Any]: + """ + Load data from a pickle file. + + Parameters: + filename (str): The name of the pickle file to load data from. + + Returns: + List[Any]: The data loaded from the pickle file. + """ + with open(filename, "rb") as file: + return pickle.load(file) + + +def load_gml_as_text(gml_file_path): + """ + Load the contents of a GML file as a text string. + + Parameters: + - gml_file_path (str): The file path to the GML file. + + Returns: + - str: The text content of the GML file. + """ + try: + with open(gml_file_path, "r") as file: + return file.read() + except FileNotFoundError: + print(f"File not found: {gml_file_path}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + + +def save_compressed(array: ndarray, filename: str) -> None: + """ + Saves a NumPy array in a compressed format using .npz extension. + + Parameters: + - array (ndarray): The NumPy array to be saved. + - filename (str): The file path or name to save the array to, + with a '.npz' extension. + + Returns: + - None: This function does not return any value. + """ + np.savez_compressed(filename, array=array) + + +def load_compressed(filename: str) -> ndarray: + """ + Loads a NumPy array from a compressed .npz file. + + Parameters: + - filename (str): The path of the .npz file to load. + + Returns: + - ndarray: The loaded NumPy array. + + Raises: + - KeyError: If the .npz file does not contain an array with the key 'array'. + """ + with np.load(filename) as data: + if "array" in data: + return data["array"] + else: + raise KeyError( + "The .npz file does not contain" + " an array with the key 'array'." + ) + + +def save_model(model: Any, filename: str) -> None: + """ + Save a machine learning model to a file using joblib. + + Parameters: + - model (Any): The machine learning model to save. + - filename (str): The path to the file where the model will be saved. + """ + dump(model, filename) + logger.info(f"Model saved successfully to {filename}") + + +def load_model(filename: str) -> Any: + """ + Load a machine learning model from a file using joblib. + + Parameters: + - filename (str): The path to the file from which the model will be loaded. + + Returns: + - Any: The loaded machine learning model. + """ + model = load(filename) + logger.info(f"Model loaded successfully from {filename}") + return model + + +def save_dict_to_json(data: dict, file_path: str) -> None: + """ + Save a dictionary to a JSON file. + + Parameters: + ----------- + data : dict + The dictionary to be saved. + + file_path : str + The path to the file where the dictionary should be saved. + Make sure the file has a .json extension. + + Returns: + -------- + None + """ + with open(file_path, "w") as json_file: + json.dump(data, json_file, indent=4) + + logger.info(f"Dictionary successfully saved to {file_path}") + + +def load_dict_from_json(file_path: str) -> dict: + """ + Load a dictionary from a JSON file. + + Parameters: + ----------- + file_path : str + The path to the JSON file from which to load the dictionary. + Make sure the file has a .json extension. + + Returns: + -------- + dict + The dictionary loaded from the JSON file. + """ + try: + with open(file_path, "r") as json_file: + data = json.load(json_file) + logger.info(f"Dictionary successfully loaded from {file_path}") + return data + except Exception as e: + logger.error(e) + return None diff --git a/synutility/utils.py b/synutility/SynIO/debug.py similarity index 51% rename from synutility/utils.py rename to synutility/SynIO/debug.py index a2398e3..822c40a 100644 --- a/synutility/utils.py +++ b/synutility/SynIO/debug.py @@ -1,12 +1,7 @@ import os -import json import logging import warnings -import numpy as np -from numpy import ndarray from rdkit import rdBase -from joblib import dump, load -from typing import Any def setup_logging(log_level: str = "INFO", log_filename: str = None) -> logging.Logger: @@ -46,7 +41,7 @@ def setup_logging(log_level: str = "INFO", log_filename: str = None) -> logging. if log_filename: os.makedirs(os.path.dirname(log_filename), exist_ok=True) logging.basicConfig( - level=numeric_level, format=log_format, filename=log_filename, filemode="w" + level=numeric_level, format=log_format, filename=log_filename, filemode="a" ) else: logging.basicConfig(level=numeric_level, format=log_format) @@ -93,115 +88,3 @@ def configure_warnings_and_logs( # Enable RDKit error and warning logs if they were previously disabled rdBase.EnableLog("rdApp.error") rdBase.EnableLog("rdApp.warning") - - -def save_compressed(array: ndarray, filename: str) -> None: - """ - Saves a NumPy array in a compressed format using .npz extension. - - Parameters: - - array (ndarray): The NumPy array to be saved. - - filename (str): The file path or name to save the array to, - with a '.npz' extension. - - Returns: - - None: This function does not return any value. - """ - np.savez_compressed(filename, array=array) - - -def load_compressed(filename: str) -> ndarray: - """ - Loads a NumPy array from a compressed .npz file. - - Parameters: - - filename (str): The path of the .npz file to load. - - Returns: - - ndarray: The loaded NumPy array. - - Raises: - - KeyError: If the .npz file does not contain an array with the key 'array'. - """ - with np.load(filename) as data: - if "array" in data: - return data["array"] - else: - raise KeyError( - "The .npz file does not contain" + " an array with the key 'array'." - ) - - -def save_model(model: Any, filename: str) -> None: - """ - Save a machine learning model to a file using joblib. - - Parameters: - - model (Any): The machine learning model to save. - - filename (str): The path to the file where the model will be saved. - """ - dump(model, filename) - logging.info(f"Model saved successfully to {filename}") - - -def load_model(filename: str) -> Any: - """ - Load a machine learning model from a file using joblib. - - Parameters: - - filename (str): The path to the file from which the model will be loaded. - - Returns: - - Any: The loaded machine learning model. - """ - model = load(filename) - logging.info(f"Model loaded successfully from {filename}") - return model - - -def save_dict_to_json(data: dict, file_path: str) -> None: - """ - Save a dictionary to a JSON file. - - Parameters: - ----------- - data : dict - The dictionary to be saved. - - file_path : str - The path to the file where the dictionary should be saved. - Make sure the file has a .json extension. - - Returns: - -------- - None - """ - with open(file_path, "w") as json_file: - json.dump(data, json_file, indent=4) - - logging.info(f"Dictionary successfully saved to {file_path}") - - -def load_dict_from_json(file_path: str) -> dict: - """ - Load a dictionary from a JSON file. - - Parameters: - ----------- - file_path : str - The path to the JSON file from which to load the dictionary. - Make sure the file has a .json extension. - - Returns: - -------- - dict - The dictionary loaded from the JSON file. - """ - try: - with open(file_path, "r") as json_file: - data = json.load(json_file) - logging.info(f"Dictionary successfully loaded from {file_path}") - return data - except Exception as e: - logging.error(e) - return None diff --git a/synutility/SynSplit/__init__.py b/synutility/SynSplit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/Partition/data_partition.py b/synutility/SynSplit/data_split.py similarity index 58% rename from synutility/Partition/data_partition.py rename to synutility/SynSplit/data_split.py index a95df13..67fe29f 100644 --- a/synutility/Partition/data_partition.py +++ b/synutility/SynSplit/data_split.py @@ -1,16 +1,18 @@ from typing import Tuple, Union import pandas as pd -from synutility.Partition.random_parition import RandomPartition -from synutility.Partition.stratified_partition import StratifiedPartition -from synutility.Partition.stratified_reduction_partition import ( - StratifiedReductionPartition, +from synutility.SynSplit.random_split import RandomSplit +from synutility.SynSplit.stratified_split import StratifiedSplit +from synutility.SynSplit.stratified_reduction_partition import ( + StratifiedReductionSplit, ) -from synutility.utils import setup_logging +from synutility.SynSplit.rc_split import RCSplit +from synutility.SynIO.debug import setup_logging +from synutility.SynIO.data_type import load_database logger = setup_logging() -class DataPartition: +class DataSplit: """ Class for partitioning a dataset into training and testing sets using various partitioning methods. @@ -18,18 +20,20 @@ class DataPartition: Attributes: - data (pd.DataFrame): The dataset to be partitioned. - test_size (float): The proportion of the dataset to include in the test split, - between 0.0 and 1.0. + between 0.0 and 1.0. - class_column (str): Column name in the dataset that contains class labels. - method (str): Method used for data partitioning; valid options include - 'random', 'stratified_target', and 'stratified_class_reduction'. + 'random', 'stratified_target', and 'stratified_class_reduction'. - drop_class_ratio (float): The maximum cumulative proportion of classes to remove - from the training data, used with 'stratified_class_reduction' method. Default is 0.1. + from the training data, used with 'stratified_class_reduction' method. + Default is 0.1. - random_state (int): Seed for random number generation to ensure reproducibility. + - keep_data (bool): Whether to return the data removed from the training set or not. """ def __init__( self, - data: pd.DataFrame, + data, test_size: float, class_column: str, method: str, @@ -38,27 +42,56 @@ def __init__( keep_data: bool = True, ) -> None: """ - Initializes the DataPartition instance with the specified parameters. + Initializes the DataSplit instance with the specified parameters. Parameters: - - data (pd.DataFrame): Dataset to be partitioned. + - data: Dataset to be partitioned, can be a pd.DataFrame or a path to + a json.gz file. - test_size (float): Proportion of the dataset to include in the test split. - class_column (str): Column name containing class labels. - method (str): Partitioning method ('random', 'stratified_target', - 'stratified_class_reduction'). + 'stratified_class_reduction'). - random_state (int): Random seed for reproducibility. - drop_class_ratio (float, optional): Ratio of minority classes to remove - (default 0.1). - - keep_data (bool): return also data_remove or not + (default 0.1). + - keep_data (bool): Return also data removed or not. """ - self.data = data + if isinstance(data, pd.DataFrame): + self.data = data + elif isinstance(data, str) and data.endswith(".json.gz"): + self.data = load_database(data) + else: + logger.error( + "Unsupported data format. Please provide " + + "a pandas DataFrame or a .json.gz file path." + ) + raise ValueError( + "Unsupported data format. Please provide " + + "a pandas DataFrame or a .json.gz file path." + ) + self.test_size = test_size self.class_column = class_column - self.method = method + self.method = self._validate_method(method) self.drop_class_ratio = drop_class_ratio self.random_state = random_state self.keep_data = keep_data + def _validate_method(self, method: str) -> str: + """Validate the partitioning method.""" + valid_methods = [ + "random", + "stratified_target", + "stratified_class_reduction", + "rc_split", + ] + if method not in valid_methods: + raise ValueError( + f"Invalid method '{method}'." + + " Valid options are: {', '.join(valid_methods)}" + ) + return method + def fit( self, ) -> Union[ @@ -79,7 +112,7 @@ def fit( """ if self.method == "random": logger.info("Partition data using random approach") - splitter = RandomPartition( + splitter = RandomSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -89,7 +122,7 @@ def fit( return splitter.fit() elif self.method == "stratified_target": logger.info("Partition data using stratify approach") - splitter = StratifiedPartition( + splitter = StratifiedSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -98,7 +131,7 @@ def fit( return splitter.fit() elif self.method == "stratified_class_reduction": logger.info("Partition data using stratify reduction approach") - splitter = StratifiedReductionPartition( + splitter = StratifiedReductionSplit( data=self.data, test_size=self.test_size, class_column=self.class_column, @@ -107,6 +140,15 @@ def fit( keep_data=self.keep_data, ) + return splitter.fit() + elif self.method == "rc_split": + logger.info("Partition data using reaction center approach") + splitter = RCSplit( + data=self.data, + test_size=self.test_size, + class_column=self.class_column, + random_state=self.random_state, + ) return splitter.fit() else: logger.error( diff --git a/synutility/Partition/random_parition.py b/synutility/SynSplit/random_split.py similarity index 94% rename from synutility/Partition/random_parition.py rename to synutility/SynSplit/random_split.py index 088089a..b454813 100644 --- a/synutility/Partition/random_parition.py +++ b/synutility/SynSplit/random_split.py @@ -2,7 +2,7 @@ from sklearn.model_selection import train_test_split -class RandomPartition: +class RandomSplit: """ Class for partitioning data into training and test sets based on specified class columns. @@ -17,7 +17,7 @@ def __init__( self, data: DataFrame, test_size: float, class_column: str, random_state: int ) -> None: """ - Initializes the RandomPartition with the given data and parameters. + Initializes the RandomSplit with the given data and parameters. Parameters: - data (DataFrame): The dataset to be partitioned. diff --git a/synutility/SynSplit/rc_split.py b/synutility/SynSplit/rc_split.py new file mode 100644 index 0000000..4dfaedc --- /dev/null +++ b/synutility/SynSplit/rc_split.py @@ -0,0 +1,85 @@ +import numpy as np +from pandas import DataFrame +import random + + +class RCSplit: + """ + Class for partitioning data into training and test sets based on a specified class column, + ensuring that the cumulative percentage of data points in the test set does not exceed the specified test_size. + + Attributes: + data (DataFrame): The dataset to partition. + test_size (float): The maximum proportion of the dataset to include in the test split. + class_column (str): The name of the column in `data` that represents the class labels. + random_state (int): Controls the shuffling applied to the data before applying the split. + """ + + def __init__( + self, data: DataFrame, test_size: float, class_column: str, random_state: int + ) -> None: + """ + Initializes the RCSplit with the given data and parameters. + """ + self.data = data + self.test_size = test_size + self.class_column = class_column + self.random_state = random_state + random.seed(random_state) # Set random seed for Python's random module + np.random.seed(random_state) # Set random seed for NumPy's random functions + + @staticmethod + def random_select(data: dict, threshold: float, random_state: int = 42) -> list: + """ + Selects keys from a dictionary randomly without exceeding a specified + cumulative threshold. Continues to search for fitting classes even if one class exceeds the threshold. + + Parameters: + - data (dict): A dictionary where keys are identifiers and values are weights. + - threshold (float): The cumulative weight not to exceed. + - random_state (int): Seed for the random number generator. + + Returns: + list: A list of selected keys. + """ + random.seed(random_state) + keys = list(data.keys()) + random.shuffle(keys) # Shuffle keys to ensure randomness in selection order + selected_keys = [] + cumulative_sum = 0.0 + skipped_keys = [] # Temporarily hold keys that are too big at first glance + + for key in keys: + if cumulative_sum + data[key] <= threshold: + cumulative_sum += data[key] + selected_keys.append(key) + else: + skipped_keys.append(key) + + # After initial pass, try adding any skipped keys that might fit now + random.shuffle( + skipped_keys + ) # Reshuffle the skipped keys before trying to add them + for key in skipped_keys: + if cumulative_sum + data[key] <= threshold: + cumulative_sum += data[key] + selected_keys.append(key) + + return selected_keys + + def fit(self) -> tuple[DataFrame, DataFrame]: + """ + Partitions the dataset into training and testing datasets. + """ + class_counts = ( + self.data[self.class_column].value_counts(normalize=True).to_dict() + ) + print(class_counts) + test_class_list = RCSplit.random_select( + class_counts, self.test_size, self.random_state + ) + + train_set = self.data[~self.data[self.class_column].isin(test_class_list)] + test_set = self.data[self.data[self.class_column].isin(test_class_list)] + + return train_set, test_set diff --git a/synutility/Partition/stratified_reduction_partition.py b/synutility/SynSplit/stratified_reduction_partition.py similarity index 96% rename from synutility/Partition/stratified_reduction_partition.py rename to synutility/SynSplit/stratified_reduction_partition.py index 4aacbeb..ed6ffd2 100644 --- a/synutility/Partition/stratified_reduction_partition.py +++ b/synutility/SynSplit/stratified_reduction_partition.py @@ -2,10 +2,10 @@ import pandas as pd from typing import Tuple from sklearn.preprocessing import LabelEncoder -from synutility.Partition.stratified_partition import StratifiedPartition +from synutility.SynSplit.stratified_split import StratifiedSplit -class StratifiedReductionPartition: +class StratifiedReductionSplit: """ A class for partitioning a dataset into training and test sets with the additional functionality of selectively reducing the presence of certain classes based @@ -82,7 +82,7 @@ def fit(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: selected classes removed, the test data, and the subset of the training data that contains the removed classes. """ - splitter = StratifiedPartition( + splitter = StratifiedSplit( self.data, self.test_size, self.class_column, self.random_state ) data_train, data_test = splitter.fit() diff --git a/synutility/Partition/stratified_partition.py b/synutility/SynSplit/stratified_split.py similarity index 95% rename from synutility/Partition/stratified_partition.py rename to synutility/SynSplit/stratified_split.py index 09925d3..b2ba03d 100644 --- a/synutility/Partition/stratified_partition.py +++ b/synutility/SynSplit/stratified_split.py @@ -3,7 +3,7 @@ from sklearn.model_selection import train_test_split -class StratifiedPartition: +class StratifiedSplit: """ Class for partitioning data into training and test sets using stratified sampling based on the specified class column to maintain the proportion of classes in each subset. @@ -19,7 +19,7 @@ def __init__( self, data: DataFrame, test_size: float, class_column: str, random_state: int ) -> None: """ - Initializes the StratifiedPartition instance with the given data and parameters. + Initializes the StratifiedSplit instance with the given data and parameters. Parameters: data (DataFrame): The dataset to be partitioned. diff --git a/synutility/SynVis/__init__.py b/synutility/SynVis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynVis/chemical_graph_visualizer.py b/synutility/SynVis/chemical_graph_visualizer.py new file mode 100644 index 0000000..f4abbd5 --- /dev/null +++ b/synutility/SynVis/chemical_graph_visualizer.py @@ -0,0 +1,378 @@ +import random +import networkx as nx +import matplotlib.pyplot as plt +import seaborn as sns +from typing import Optional, Dict +import logging + +color_scheme = { + "H": "#FFFFFF", # White + "C": "#909090", # Gray + "N": "#3050F8", # Blue + "O": "#FF0D0D", # Red + "F": "#90E050", # Green + "Cl": "#1FF01F", # Green + "Br": "#A62929", # Dark Red/Brown + "I": "#940094", # Purple + "P": "#FF8000", # Orange + "S": "#FFFF30", # Yellow + "Na": "#E0E0E0", # Light Gray + "K": "#8F40D4", # Light Purple + "Ca": "#3DFF00", # Light Green + "Mg": "#8AFF00", # Light Green + "Fe": "#B7410E", # Rust Red + "Zn": "#7D80B0", # Light Blue + "Cu": "#C88033", # Copper Red/Orange + "Ag": "#C0C0C0", # Light Gray + "Au": "#FFD123", # Gold Yellow + "Hg": "#B8B8D0", # Silver Gray + "Pb": "#575961", # Dark Gray + "Al": "#BFA6A6", # Light Gray + "Si": "#F0C8A0", # Light Brown + "B": "#FFA1A1", # Pink + "As": "#BD80E3", # Light Gray + "Sb": "#9E63B5", # Dark Gray + "Se": "#FFA100", # Light Pink + "Te": "#D47A00", # Gray + "Cd": "#FFD98F", # Light Blue/Gray + "Ti": "#BFC2C7", # Light Gray + "V": "#A6A6AB", # Light Gray/Blue + "Cr": "#8A99C7", # Steel Gray + "Mn": "#9C7AC7", # Gray + "Co": "#FF7A00", # Light Pink + "Ni": "#4DFF4D", # Light Green +} + + +class ChemicalGraphVisualizer: + def __init__( + self, + seed: Optional[int] = None, + element_colors: Optional[Dict[str, str]] = color_scheme, + ): + """ + Initialize the visualizer with optional seed and color scheme. + + Parameters: + seed (int, optional): Seed for random number generator for reproducibility. + element_colors (dict, optional): Dictionary mapping elements to their color codes. + """ + # Define a popular color scheme in chemistry if not provided + if element_colors is None: + self.element_colors = { + "H": "#FFFFFF", # White + "C": "#909090", # Gray + "N": "#3050F8", # Blue + "O": "#FF0D0D", # Red + "F": "#90E050", # Green + "Cl": "#1FF01F", # Green + # Additional elements can be added here + } + else: + self.element_colors = element_colors + self.seed = seed + + def graph_vis( + self, + G: nx.Graph, + node_size: int = 100, + visualize_edge_weight: bool = False, + edge_font_size: int = 10, + show_node_labels: bool = False, + node_label_font_size: int = 12, + ax: Optional[plt.Axes] = None, + ) -> None: + """ + Visualize a NetworkX graph with standard representation. + + Parameters: + G (nx.Graph): The graph to visualize. + node_size (int): The size of the nodes. + visualize_edge_weight (bool): Whether to display edge weights. + edge_font_size (int): Font size for edge labels. + show_node_labels (bool): Whether to show labels on the nodes. + node_label_font_size (int): Font size for node labels. + """ + # Set random seed for reproducibility + if self.seed is not None: + random.seed(self.seed) + + # Get colors for each node + node_colors = [ + self.element_colors.get(G.nodes[node]["element"], "#000000") + for node in G.nodes() + ] + + # Draw the graph + pos = nx.spring_layout(G, seed=self.seed) # Use spring layout + + if ax is None: + ax = plt.gca() # Get current axes if not provided + + if show_node_labels: + node_labels = {node: G.nodes[node]["element"] for node in G.nodes()} + nx.draw( + G, + pos, + ax=ax, + with_labels=True, + labels=node_labels, + node_color=node_colors, + node_size=node_size, + font_size=node_label_font_size, + # font_weight="semi-bold", + ) + else: + nx.draw( + G, + pos, + ax=ax, + with_labels=False, + node_color=node_colors, + node_size=node_size, + # font_weight="bold", + ) + + # Get edge labels if needed + if visualize_edge_weight: + edge_labels = {(u, v): G.edges[u, v]["order"] for u, v in G.edges()} + nx.draw_networkx_edge_labels( + G, pos, ax=ax, edge_labels=edge_labels, font_size=edge_font_size + ) + + def its_vis( + self, + G: nx.Graph, + node_size: int = 100, + show_node_labels: bool = False, + node_label_font_size: int = 12, + ax: Optional[plt.Axes] = None, + ) -> None: + """ + Visualize a NetworkX graph with edge colors indicating bond changes. + + Parameters: + G (nx.Graph): The graph to visualize. + node_size (int): The size of the nodes. + show_node_labels (bool): Whether to show labels on the nodes. + node_label_font_size (int): Font size for node labels. + """ + # Set random seed for reproducibility + if self.seed is not None: + random.seed(self.seed) + + # Draw the graph + pos = nx.spring_layout(G, seed=self.seed) # Use spring layout + # Get colors for each node + node_colors = [ + self.element_colors.get(G.nodes[node]["element"], "#000000") + for node in G.nodes() + ] + + # Determine edge colors based on 'order' + edge_colors = [] + for u, v, data in G.edges(data=True): + order = data.get("standard_order", 0) + # order = tuple(0 if isinstance(x, str) else float(x) for x in order) + if order == 0: + edge_colors.append("black") # Normal bond + elif order < 0: + edge_colors.append("blue") # Increasing bond + else: + edge_colors.append("red") # Breaking bond + + if ax is None: + ax = plt.gca() # Get current axes if not provided + + if show_node_labels: + node_labels = {node: G.nodes[node]["element"] for node in G.nodes()} + nx.draw( + G, + pos, + ax=ax, + with_labels=True, + labels=node_labels, + node_color=node_colors, + node_size=node_size, + font_size=node_label_font_size, + # font_weight="bold", + edge_color=edge_colors, + ) + else: + nx.draw( + G, + pos, + ax=ax, + with_labels=False, + node_color=node_colors, + node_size=node_size, + font_weight="bold", + edge_color=edge_colors, + ) + + def vis_three_graph( + self, + graph_tuple, + figsize=(15, 5), + left_graph_title="Reactants", + k_graph_title="ITS Graph", + right_graph_title="Products", + show_node_labels=True, + title_fontsize=24, + title_weight="bold", + save_path=None, + display_inline=False, + log=False, + ): + """ + Visualize reactants, ITS graph, and products in one figure. + + Parameters: + graph_tuple (tuple): Tuple of NetworkX graphs (reactants, ITS graph, products). + figsize (tuple): Figure size in inches (width, height). + left_graph_title (str): Title for the left subplot. + k_graph_title (str): Title for the middle subplot. + right_graph_title (str): Title for the right subplot. + show_node_labels (bool): If True, show node labels on the graphs. + title_fontsize (int): Font size for subplot titles. + title_weight (str): Font weight for subplot titles. + save_path (str, optional): Path to save the figure to file. + display_inline (bool): If True, display the figure inline in the notebook. + log (bool): If True, enable logging of function progress. + """ + if log: + logging.basicConfig(level=logging.INFO) + + try: + # Unpack the tuple + reactants_graph, products_graph, its_graph = graph_tuple + + # Create a figure with subplots + fig, axs = plt.subplots(1, 3, figsize=figsize) + + # Visualize each graph on its respective subplot + self.graph_vis( + reactants_graph, ax=axs[0], show_node_labels=show_node_labels + ) + self.its_vis(its_graph, ax=axs[1], show_node_labels=show_node_labels) + self.graph_vis(products_graph, ax=axs[2], show_node_labels=show_node_labels) + + # Set titles for subplots + axs[0].set_title( + left_graph_title, fontsize=title_fontsize, weight=title_weight + ) + axs[1].set_title( + k_graph_title, fontsize=title_fontsize, weight=title_weight + ) + axs[2].set_title( + right_graph_title, fontsize=title_fontsize, weight=title_weight + ) + + plt.tight_layout() + + if save_path is not None: + plt.savefig(save_path, dpi=600) + if log: + logging.info(f"Figure saved to {save_path}") + + if display_inline: + plt.show() + else: + plt.close(fig) + + return fig + except Exception as e: + if log: + logging.error("Failed to visualize graphs: ", exc_info=True) + raise RuntimeError("Error in graph visualization: ") from e + + def visualize_all( + self, + graph_tuple_row1, + graph_tuple_row2, + figsize=(15, 10), + titles_row1=("A. Reactant Graph", "B. ITS Graph", "C Products"), + titles_row2=("D. L Graph", "E. K Graph", "D. R Graph"), + show_node_labels=True, + show_grid=True, + grid_style="--", + title_fontsize=24, + title_weight="bold", + save_path=None, + display_inline=False, + log=False, + ): + """ + Visualize two rows of graphs, each with three graphs, optionally displaying a + grid. + + Parameters: + graph_tuple_row1 (tuple): Tuple of NetworkX graphs for the first row + (reactants, ITS graph, products). + graph_tuple_row2 (tuple): Tuple of NetworkX graphs for the second row (L, K, + R). + figsize (tuple): Figure size in inches (width, height). + titles_row1 (tuple): Titles for the first row subplots. + titles_row2 (tuple): Titles for the second row subplots. + show_node_labels (bool): If True, show node labels on the graphs. + show_grid (bool): If True, display grid lines on the plots. + grid_style (str): Style of the grid lines. + title_fontsize (int): Font size for subplot titles. + title_weight (str): Font weight for subplot titles. + save_path (str, optional): Path to save the figure to file. + display_inline (bool): If True, display the figure inline in the notebook. + log (bool): If True, enable logging of function progress. + """ + if log: + logging.basicConfig(level=logging.INFO) + + try: + sns.set_theme(style="darkgrid") # Set the Seaborn style + reactants_graph, products_graph, its_graph = graph_tuple_row1 + l_graph, r_graph, k_graph = graph_tuple_row2 + + # Create a figure with subplots + fig, axs = plt.subplots(2, 3, figsize=figsize) + + # Visualize each graph on its respective subplot (first row) + self.graph_vis( + reactants_graph, ax=axs[0, 0], show_node_labels=show_node_labels + ) + self.its_vis(its_graph, ax=axs[0, 1], show_node_labels=show_node_labels) + self.graph_vis( + products_graph, ax=axs[0, 2], show_node_labels=show_node_labels + ) + + # Visualize each graph on its respective subplot (second row) + self.graph_vis(l_graph, ax=axs[1, 0], show_node_labels=show_node_labels) + self.its_vis(k_graph, ax=axs[1, 1], show_node_labels=show_node_labels) + self.graph_vis(r_graph, ax=axs[1, 2], show_node_labels=show_node_labels) + + # Set titles and enable grid for subplots + for ax, title in zip(axs[0], titles_row1): + ax.set_title(title, fontsize=title_fontsize, weight=title_weight) + if show_grid: + ax.grid(True, linestyle=grid_style, which="both") + + for ax, title in zip(axs[1], titles_row2): + ax.set_title(title, fontsize=title_fontsize, weight=title_weight) + if show_grid: + ax.grid(True, linestyle=grid_style, which="both") + + plt.tight_layout() + + if save_path is not None: + plt.savefig(save_path, dpi=600) + if log: + logging.info(f"Figure saved to {save_path}") + + if display_inline: + plt.show() + else: + plt.close(fig) + + return fig + except Exception as e: + if log: + logging.error("Failed to visualize graphs: ", exc_info=True) + raise RuntimeError("Error in graph visualization: ") from e diff --git a/synutility/SynVis/chemical_reaction_visualizer.py b/synutility/SynVis/chemical_reaction_visualizer.py new file mode 100644 index 0000000..a26ed13 --- /dev/null +++ b/synutility/SynVis/chemical_reaction_visualizer.py @@ -0,0 +1,133 @@ +from typing import List, Dict +from IPython.display import display, HTML, SVG +from rdkit.Chem.Draw import rdMolDraw2D +from rdkit.Chem import rdChemReactions + + +class ChemicalReactionVisualizer: + @staticmethod + def create_html_table_with_svgs( + svg_list: List[str], + titles: List[str], + num_cols: int = 2, + orientation: str = "vertical", + title_size: int = 16, + ) -> HTML: + """ + Creates an HTML table to display SVG images with titles in + a structured 'subplot-like' layout. + + Parameters: + - svg_list (List[str]): List of SVG content strings. + - titles (List[str]): Corresponding titles for each SVG image. + - num_cols (int): Defines the number of columns for the + 'vertical' layout or rows for 'horizontal' layout. + - orientation (str): Layout orientation of images ('vertical' or 'horizontal'). + - title_size (int): Font size of the titles displayed above each image. + + Returns: + - HTML: HTML object to be displayed within an IPython notebook environment. + """ + html = "
{titles[i+j]} {svg_list[i+j]} | "
+ html += "
{titles[i]} {svg_list[i]} | "
+ html += "