From b847438b9cd417f3eb224c8f86611d96ff944c06 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Wed, 23 Oct 2024 09:27:39 +0200 Subject: [PATCH 01/10] update --- synutility/SynIO/data_type.py | 42 ++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/synutility/SynIO/data_type.py b/synutility/SynIO/data_type.py index e2e7d4a..e78bfa0 100644 --- a/synutility/SynIO/data_type.py +++ b/synutility/SynIO/data_type.py @@ -3,7 +3,7 @@ import numpy as np from numpy import ndarray from joblib import dump, load -from typing import List, Dict, Any +from typing import List, Dict, Any, Generator from synutility.SynIO.debug import setup_logging logger = setup_logging() @@ -211,3 +211,43 @@ def load_dict_from_json(file_path: str) -> dict: except Exception as e: logger.error(e) return None + + +def load_from_pickle_generator(file_path: str) -> Generator[Any, None, None]: + """ + A generator that yields items from a pickle file where each pickle load returns a list + of dictionaries. + + Paremeters: + - file_path (str): The path to the pickle file to load. + + - Yields: + Any: Yields a single item from the list of dictionaries stored in the pickle file. + """ + with open(file_path, "rb") as file: + while True: + try: + batch_items = pickle.load(file) + for item in batch_items: + yield item + except EOFError: + break +def collect_data(num_batches: int, temp_dir: str, file_template: str) -> List[Any]: + """ + Collects and aggregates data from multiple pickle files into a single list. + + Paremeters: + - num_batches (int): The number of batch files to process. + - temp_dir (str): The directory where the batch files are stored. + - file_template (str): The template string for batch file names, expecting an integer + formatter. + + Returns: + List[Any]: A list of aggregated data items from all batch files. + """ + collected_data: List[Any] = [] + for i in range(num_batches): + file_path = os.path.join(temp_dir, file_template.format(i)) + for item in load_from_pickle_generator(file_path): + collected_data.append(item) + return collected_data \ No newline at end of file From 4d858d94e0cdd04e26adb6b5eb8452e9bf6fd424 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 09:33:28 +0100 Subject: [PATCH 02/10] test mod compatible --- .github/workflows/test-and-lint.yml | 3 +- Test/SynGraph/Transform/__init__.py | 0 Test/SynGraph/Transform/test_core_engine.py | 114 ++++++++++++++ lint.sh | 2 +- requirements.txt | 4 +- synutility/SynChem/Reaction/standardize.py | 31 +++- synutility/SynGraph/Transform/__init__.py | 0 synutility/SynGraph/Transform/core_engine.py | 156 +++++++++++++++++++ synutility/SynIO/data_type.py | 48 ++++++ 9 files changed, 353 insertions(+), 5 deletions(-) create mode 100644 Test/SynGraph/Transform/__init__.py create mode 100644 Test/SynGraph/Transform/test_core_engine.py create mode 100644 synutility/SynGraph/Transform/__init__.py create mode 100644 synutility/SynGraph/Transform/core_engine.py diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 3fb87fb..bb09300 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -7,7 +7,7 @@ name: Test & Lint on: push: - branches: [ "dev" ] + branches: [ "dev", "maintain" ] pull_request: branches: [ "main" ] @@ -34,6 +34,7 @@ jobs: run: | conda create --name synutils-env python=3.11 -y conda activate synutils-env + conda install -c jakobandersen -c conda-forge mod pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi shell: bash -l {0} diff --git a/Test/SynGraph/Transform/__init__.py b/Test/SynGraph/Transform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynGraph/Transform/test_core_engine.py b/Test/SynGraph/Transform/test_core_engine.py new file mode 100644 index 0000000..0a8cf95 --- /dev/null +++ b/Test/SynGraph/Transform/test_core_engine.py @@ -0,0 +1,114 @@ +import os +import pytest +import unittest +import tempfile +from synutility.SynGraph.Transform.core_engine import CoreEngine + + +@pytest.mark.skip(reason="Temporarily disabled for demonstration purposes") +class TestCoreEngine(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.temp_dir = tempfile.TemporaryDirectory() + + # Path for the rule file + self.rule_file_path = os.path.join(self.temp_dir.name, "test_rule.gml") + + # Define rule content + self.rule_content = """ + rule [ + ruleID "1" + left [ + edge [ source 1 target 2 label "=" ] + edge [ source 3 target 4 label "-" ] + ] + context [ + node [ id 1 label "C" ] + node [ id 2 label "C" ] + node [ id 3 label "H" ] + node [ id 4 label "H" ] + ] + right [ + edge [ source 1 target 2 label "-" ] + edge [ source 1 target 3 label "-" ] + edge [ source 2 target 4 label "-" ] + ] + ] + """ + + # Write rule content to the temporary file + with open(self.rule_file_path, "w") as rule_file: + rule_file.write(self.rule_content) + + # Initialize SMILES strings for testing + self.initial_smiles_fw = ["CC=CC", "[HH]"] + self.initial_smiles_bw = ["CCCC"] + + def tearDown(self): + # Clean up temporary directory + self.temp_dir.cleanup() + + def test_perform_reaction_forward(self): + # Test the perform_reaction method with forward reaction type + result = CoreEngine.perform_reaction( + rule_file_path=self.rule_file_path, + initial_smiles=self.initial_smiles_fw, + prediction_type="forward", + print_results=False, + verbosity=0, + ) + print(result) + # Check if result is a list of strings and has content + self.assertIsInstance( + result, list, "Expected a list of reaction SMILES strings." + ) + self.assertTrue( + len(result) > 0, "Result should contain reaction SMILES strings." + ) + + self.assertEqual(result[0], "CC=CC.[HH]>>CCCC") + + # Check if the result SMILES format matches expected output format + for reaction_smiles in result: + self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") + parts = reaction_smiles.split(">>") + self.assertEqual( + parts[0], + ".".join(self.initial_smiles_fw), + "Base SMILES are not correctly formatted.", + ) + self.assertTrue(len(parts[1]) > 0, "Product SMILES should be non-empty.") + + def test_perform_reaction_backward(self): + # Test the perform_reaction method with backward reaction type + result = CoreEngine.perform_reaction( + rule_file_path=self.rule_file_path, + initial_smiles=self.initial_smiles_bw, + prediction_type="backward", + print_results=False, + verbosity=0, + ) + # Check if result is a list of strings and has content + self.assertIsInstance( + result, list, "Expected a list of reaction SMILES strings." + ) + self.assertTrue( + len(result) > 0, "Result should contain reaction SMILES strings." + ) + self.assertEqual(result[0], "C=CCC.[H][H]>>CCCC") + self.assertEqual(result[1], "[H][H].C(C)=CC>>CCCC") + + # Check if the result SMILES format matches expected output format + for reaction_smiles in result: + self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") + parts = reaction_smiles.split(">>") + self.assertTrue(len(parts[0]) > 0, "Product SMILES should be non-empty.") + self.assertEqual( + parts[1], + ".".join(self.initial_smiles_bw), + "Base SMILES are not correctly formatted.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/lint.sh b/lint.sh index 034fdf8..e2178df 100755 --- a/lint.sh +++ b/lint.sh @@ -2,5 +2,5 @@ flake8 . --count --max-complexity=13 --max-line-length=120 \ --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501" \ - --exclude venv \ + --exclude venv,core_engine.py \ --statistics diff --git a/requirements.txt b/requirements.txt index f4e805f..35d3b27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ scikit-learn>=1.4.0 xgboost==2.1.1 -pandas==1.5.3 seaborn==0.13.2 drfp==0.3.6 fgutils>=0.1.3 rxn-chem-utils==1.5.0 rxn-utils==2.0.0 rxnmapper==0.3.0 -rdkit >= 2024.3.3 \ No newline at end of file +rdkit >= 2024.3.3 +pandas>=2.2.0 \ No newline at end of file diff --git a/synutility/SynChem/Reaction/standardize.py b/synutility/SynChem/Reaction/standardize.py index 2a48090..08753ca 100644 --- a/synutility/SynChem/Reaction/standardize.py +++ b/synutility/SynChem/Reaction/standardize.py @@ -1,5 +1,5 @@ from rdkit import Chem -from typing import List, Optional +from typing import List, Optional, Tuple class Standardize: @@ -126,3 +126,32 @@ def fit( rsmi = self.standardize_rsmi(rsmi, not ignore_stereo) rsmi = rsmi.replace("[HH]", "[H][H]") return rsmi + + @staticmethod + def categorize_reactions( + reactions: List[str], target_reaction: str + ) -> Tuple[List[str], List[str]]: + """ + Sorts a list of reaction SMILES strings into two groups based on + their match with a specified target reaction. The categorization process + distinguishes between reactions that align with the target reaction + and those that do not. + + Parameters: + - reactions (List[str]): The array of reaction SMILES strings to be categorized. + - target_reaction (str): The SMILES string of the target reaction + used as the benchmark for categorization. + + Returns: + - Tuple[List[str], List[str]]: A pair of lists, where the first contains + reactions matching the target and the second + comprises non-matching reactions. + """ + match, not_match = [], [] + target_reaction = Standardize.standardize_rsmi(target_reaction, stereo=False) + for reaction_smiles in reactions: + if reaction_smiles == target_reaction: + match.append(reaction_smiles) + else: + not_match.append(reaction_smiles) + return match, not_match diff --git a/synutility/SynGraph/Transform/__init__.py b/synutility/SynGraph/Transform/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/Transform/core_engine.py b/synutility/SynGraph/Transform/core_engine.py new file mode 100644 index 0000000..9b1c776 --- /dev/null +++ b/synutility/SynGraph/Transform/core_engine.py @@ -0,0 +1,156 @@ +from typing import List +from synutility.SynIO.data_type import load_gml_as_text +from rdkit import Chem +from copy import deepcopy +import torch +from mod import * + + +class CoreEngine: + """ + The MØDModeling class encapsulates functionalities for reaction modeling using the MØD + toolkit. It provides methods for forward and backward prediction based on templates + library. + """ + + @staticmethod + def generate_reaction_smiles( + temp_results: List[str], base_smiles: str, is_forward: bool = True + ) -> List[str]: + """ + Constructs reaction SMILES strings from intermediate results using a base SMILES + string, indicating whether the process is a forward or backward reaction. This + function iterates over a list of intermediate SMILES strings, combines them with + the base SMILES, and formats them into complete reaction SMILES strings. + + Parameters: + - temp_results (List[str]): Intermediate SMILES strings resulting from partial + reactions or combinations. + - base_smiles (str): The SMILES string representing the starting point of the + reaction, either as reactants or products, depending on the reaction direction. + - is_forward (bool, optional): Flag to determine the direction of the reaction; + 'True' for forward reactions where 'base_smiles' are reactants, and 'False' for + backward reactions where 'base_smiles' are products. Defaults to True. + + Returns: + - List[str]: A list of complete reaction SMILES strings, formatted according to + the specified reaction direction. + """ + results = [] + for comb in temp_results: + joined_smiles = ".".join(comb) + reaction_smiles = ( + f"{base_smiles}>>{joined_smiles}" + if is_forward + else f"{joined_smiles}>>{base_smiles}" + ) + results.append(reaction_smiles) + return results + + @staticmethod + def perform_reaction( + rule_file_path: str, + initial_smiles: List[str], + prediction_type: str = "forward", + print_results: bool = False, + verbosity: int = 0, + ) -> List[str]: + """ + Applies a specified reaction rule, loaded from a GML file, to a set of initial + molecules represented by SMILES strings. The reaction can be simulated in forward + or backward direction and repeated multiple times. + + Parameters: + - rule_file_path (str): Path to the GML file containing the reaction rule. + - initial_smiles (List[str]): Initial molecules represented as SMILES strings. + - type (str, optional): Direction of the reaction ('forward' for forward, + 'backward' for backward). Defaults to 'forward'. + - print_results (bool): Print results in latex or not. Defaults to False. + + Returns: + - List[str]: SMILES strings of the resulting molecules or reactions. + """ + + # Determine the rule inversion based on reaction type + invert_rule = prediction_type == "backward" + # Convert SMILES strings to molecule objects, avoiding duplicate conversions + initial_molecules = [smiles(smile, add=False) for smile in (initial_smiles)] + + def deduplicateGraphs(initial): + res = [] + for cand in initial: + for a in res: + if cand.isomorphism(a) != 0: + res.append(a) # the one we had already + break + else: + # didn't find any isomorphic, use the new one + res.append(cand) + return res + + initial_molecules = deduplicateGraphs(initial_molecules) + + initial_molecules = sorted( + initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False + ) + # Load the reaction rule from the GML file + gml_content = load_gml_as_text(rule_file_path) + reaction_rule = ruleGMLString(gml_content, invert=invert_rule, add=False) + # Initialize the derivation graph and execute the strategy + dg = DG(graphDatabase=initial_molecules) + config.dg.doRuleIsomorphismDuringBinding = False + dg.build().apply(initial_molecules, reaction_rule, verbosity=verbosity) + if print_results: + dg.print() + + temp_results = [] + for e in dg.edges: + productSmiles = [v.graph.smiles for v in e.targets] + temp_results.append(productSmiles) + + if len(temp_results) == 0: + dg = DG(graphDatabase=initial_molecules) + # dg.build().execute(strategy, verbosity=8) + config.dg.doRuleIsomorphismDuringBinding = False + dg.build().apply( + initial_molecules, reaction_rule, verbosity=verbosity, onlyProper=False + ) + temp_results, small_educt = [], [] + for edge in dg.edges: + temp_results.append([vertex.graph.smiles for vertex in edge.targets]) + small_educt.extend([vertex.graph.smiles for vertex in edge.sources]) + + small_educt_set = [ + Chem.CanonSmiles(smile) for smile in small_educt if smile is not None + ] + + reagent = deepcopy(initial_smiles) + for value in small_educt_set: + if value in reagent: + reagent.remove(value) + + # Update solutions with reagents and normalize SMILES + for solution in temp_results: + solution.extend(reagent) + for i, smile in enumerate(solution): + try: + mol = Chem.MolFromSmiles(smile) + if mol: # Only convert if mol creation was successful + solution[i] = Chem.MolToSmiles(mol) + except Exception as e: + print(f"Error processing SMILES {smile}: {str(e)}") + + reaction_processing_map = { + "forward": lambda smiles: CoreEngine.generate_reaction_smiles( + temp_results, ".".join(initial_smiles), is_forward=True + ), + "backward": lambda smiles: CoreEngine.generate_reaction_smiles( + temp_results, ".".join(initial_smiles), is_forward=False + ), + } + + # Use the reaction type to select the appropriate processing function and apply it + if prediction_type in reaction_processing_map: + return reaction_processing_map[prediction_type](initial_smiles) + else: + return "" diff --git a/synutility/SynIO/data_type.py b/synutility/SynIO/data_type.py index ea648a5..accc219 100644 --- a/synutility/SynIO/data_type.py +++ b/synutility/SynIO/data_type.py @@ -254,3 +254,51 @@ def collect_data(num_batches: int, temp_dir: str, file_template: str) -> List[An for item in load_from_pickle_generator(file_path): collected_data.append(item) return collected_data + + +def merge_dicts( + list1: List[Dict[str, Any]], + list2: List[Dict[str, Any]], + key: str, + intersection: bool = True, +) -> List[Dict[str, Any]]: + """ + Merges two lists of dictionaries based on a specified key, with an option to + either merge only dictionaries with matching key values (intersection) or + all dictionaries (union). + + Parameters: + - list1 (List[Dict[str, Any]]): The first list of dictionaries. + - list2 (List[Dict[str, Any]]): The second list of dictionaries. + - key (str): The key used to match and merge dictionaries from both lists. + - intersection (bool): If True, only merge dictionaries with matching key values; + if False, merge all dictionaries, combining those with matching key values. + + Returns: + - List[Dict[str, Any]]: A list of dictionaries with merged contents from both + input lists according to the specified merging strategy. + """ + dict1 = {item[key]: item for item in list1} + dict2 = {item[key]: item for item in list2} + + if intersection: + # Intersection of keys: only keys present in both dictionaries are merged + merged_list = [] + for item1 in list1: + r_id = item1.get(key) + if r_id in dict2: + merged_item = {**item1, **dict2[r_id]} + merged_list.append(merged_item) + return merged_list + else: + # Union of keys: all keys from both dictionaries are merged + merged_dict = {} + all_keys = set(dict1) | set(dict2) + for k in all_keys: + if k in dict1 and k in dict2: + merged_dict[k] = {**dict1[k], **dict2[k]} + elif k in dict1: + merged_dict[k] = dict1[k] + else: + merged_dict[k] = dict2[k] + return list(merged_dict.values()) From d64875e459f9ed9bbe9fce416f302eed626bbe32 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 12:12:03 +0100 Subject: [PATCH 03/10] add graph visualizer --- .coverage | Bin 0 -> 77824 bytes .gitignore | 1 + README.md | 1 + synutility/SynIO/Format/graph_to_mol.py | 17 +- synutility/SynIO/Format/mol_to_graph.py | 157 +++++++++------ synutility/SynIO/Format/smi_to_graph.py | 68 +++++++ synutility/SynVis/graph_visualizer.py | 245 ++++++++++++++++++++++++ synutility/SynVis/pdf_writer.py | 137 +++++++++++++ synutility/SynVis/rsmi_to_fig.py | 76 ++++++++ 9 files changed, 640 insertions(+), 62 deletions(-) create mode 100644 .coverage create mode 100644 synutility/SynIO/Format/smi_to_graph.py create mode 100644 synutility/SynVis/graph_visualizer.py create mode 100644 synutility/SynVis/pdf_writer.py create mode 100644 synutility/SynVis/rsmi_to_fig.py diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..deb3c516f54c700bf12b88255b0c49224df7de27 GIT binary patch literal 77824 zcmeI532+qGna5wxw0bm0cO!)EYXBi35E{gN>Hq}DfP|1h7#4=ns3i@WLrl+bcx}+& zI8G&Y>?A&7uPxhEDIe=~?9KYf;!BC0wUgaoJ5G$_#Kztv#t9)nVva>d?Dt-4q;4JH zRQFWM_)USn>0{o%zu)`b*;ix3T3-QkPZH)!vFZk0$}79NOdT zlETh5pI>tNf~`KUC*ljXJ0l%ZeDA0%wc>^j)qU2Y%cJ z2Nw**u@3zZOT&-JF$bf87C^o;8tB5A1HVv&BcZW=Io59@p#f($CC z)kF^Y<}5-#0F^%eZ`K?#UNdyaYovWMc6D9Ez8EsBodTqmC2IPK5(6DhjxAQ*y9Hp%C9|ps!;q;o&G5 zu@AvN`k@3U0ZM=ppaduZN`Mle1SkPYfD)htD1obwfSEBd7Cinpi60^HB>bTtN`Mle z1SkPYfD)htC;>`<5}*Vq0ZM=p_+k=pn9O-%{1T6uB4f_U7`O;ve%ZXT`S>D#llTb| zKM~J-F&(5wO$kr}lmI0_2~Yx*03|>PPy&0)1u)snIo5$mfF#(# zB%Z~)jTifrFY(_{O9@Z{lmI0_2~Yx*03|>PPy&C^KnYL+lmI0_2~Yx*03|>PPy&>|mx+KFU}M`DS4Suy$*vY3JdW5C z>~uwZQq;8~PPy&zT>#jvB%Nwa64)p)s9*A z`|WqyZ?#A50ehpp+&;l>v7NVlYkHPW;ehFf5}*Vq0ZM=ppadv^tCc{?W|pxQZthvydo=0mrYh8n7_J#s3C)aS z6WHuccezW`T@@y7BVN0a5u3m|o67nYJn)w{j=u41vDk=LGzw9K_U{ipzulCw0ag@l z==sy1{`!a4PA@SdRAp~~RhbQUow)wG-xQ5o4-1plA4a(^GasJ&LBpR*)9PVOX?=wm zT|9LJo$9yknENkL*}~Vs>fAcmvG+vMDXV=QPGnshvYaxdXRL*VnQJTVX3+0Nl!^8y zJtUz1&-xzV*1!f~4U={a*wSn2@=*8jyWe)b_D5IFnY(XgxmvurmKnbqY|-kTv;F-S zAE`v$6{wrdSq1ABtlFQobBANcd{n@7&qddp(Tis;<(+)Mu@W|Buk88F@#D`gWgInl zsiyAq3r8>ZvEx?2(xerKdv{n+;kl3J&CTbR!0+d1vcU*1ilDGne%oK;19A zmQxK23#uP>d{`%T|8kp%&c1Q!*XHh*E77~>8D|x2Tvm1X078Y|{=uEuXs?LQ%lYW_ zf~u3p8Pt(8F|!i(m{r-sr1xIx|FIncb4S;VwBG*33|9eLgbF5oDcG{56&EdN{kvy- zp7QO&pO?+t61;s0!!5>kF%uiVD2cHiEL?Q`C&wN`T}R)4F{O6`D(rUVl4v)@T24`zdJs)&adPF~ zN=Z%4YI$`9SN`e>C%+I5E_dO6PK1N_%*-vo$GU)-Fdyu+`4tve+l^m-j$eXq;^yJ4 z^B7?+*!;PNE6_{m<#T3}Z4NA_&*?c5e(?ceHr_V-a6WAN;Mo2u+bq0oR?m?`AJnIp z!}qfC!{*-OXV9sCwe%iHM{cIe;+P5RvuE}^apakI`mEz!uvFwaY~IV95+83q5@k-A zxiVNMlrj7au(>npZaDP);}6c>SAU(`PZ?z09QRVWpJAO=+d8o2G6!An&LooDi=z>XDDdfxlP zrIVNXE?vB|_solD-Z}B{HtE7QpR7E~a)mg-LWY}+?PMlr64(WkdOYa$Lmw5GdkEit#y?-=x;N|{%jx9O1xbK@KIk2fIr$0+XZ~pDXOB{N% z5KcH}!hmyg+wT6pb2;dpdHlBXQmuxobqS3fv*@r6Cs^ek8^%leq>e(ioAdYS3I z?ZmzadUq+G3F~t+5A)rf=-q#ISTkU8LdL@vuQz3X4aB!8@cdDdWHvjuX&D<#XnUg%j~6k#PuMvxTxV*WLB2Cm$Hc!_p*v`_64g zu6v^S$>Wcw<~d-^632edyw}`;7NMCIG`pMYMc1KI%cdPOiFVlRwD(M$+1)3i>zBUd zu;KJ8=* zEnAFUKla9pFYPgLX1u}7a3*X`OcKLdtpbD3|8M5+MeyvuPsJ1B$Kp|#19(gPnfQwM zg7}R1een_TLGd2(+u}FGuZTB@yTy>$A$r6uVxzcLtPv~3`Ql8mM4Tk%h!aF!ugd(9p$P`ip8_X(v#-HK;!XJaVg`<61Zvz-~k=8c{4H1&BSclM9jvG#56S#)7VJNh7H6tG!V0XJu&t5 z#MIRhvu+(RYu6I9W(_gdTtiH4EitQC6SHa+F)LRRQ&U6CiWS5xUrx-jWyDlh6H`@1 zOl2i86&1uRT}sT7CB!UVOw6K1#4KD$%z_2P%%4xpym`dTolDG|ImFDKP0XxW#FUp4 zGjk>}E*CLnWyH*wK}=~WG1I3LGi@3%B_+g6ok~n`F)>9vn&Okp80lP421X%aCL zClcdy5>rq>OnyEwd3nU-<`R>WLriuyF`<5}*Vq0ZQQNC4kTW)A9e+o4?dEN`Mle1SkPY zfD)htC;>`<5}*Vq0ZIT9!1w>t@jr$@Ka>C^KnYL+lmI0_2~Yx*03|>PPy&>|)k^@M z|Iagjir^mqsltB<75saAkheMRbCbO@w+svPuZ#L(ferR%=%*=yK9r_4`5E}CphF330>KP^Q1tHME#u^jyg3)&T zZc}$8e4%LMm=%2GEQt2SE2?e+zI5w{!CU$T{rEi@*5<=wBU zv?uF-yKMYjuS(iYVNXzQgLl3vZ+!KJ!jfAGw)=t-crxF({1~pL@@d|fmH+2{W3JXG z^}BM(du&xN%$*ZtV{RieXM|WF6mItf2j0&+E8)I%c;xuYfM*&V93y;F%B%9~;Ti30 ztRYb(RpC(@W@GI}>a37tZ`jur35AuISDcc4?Lki@8pd%yJ;KI16EE31cp&&y&Qsdh zSbc&-RX7W`u(A2Z$Sm)W0xnh3rcO3?t$|5zkUZqI(@Hi)JkbceR6Fdu4rkHCglfs~ z((YDIxK+_|8W-sVnbxncg8Mu?4h5~kGR@q2<^$?ffsyq!`Q^d@-OSZyNoR3$3a;3!`g z@=M;RUs6hxHp<5t?9ZT<${*3Vij{=jhz1HGTB>kg|W`El`n z^R40{bGGTeuCpMN2UF3`%c?M+wC@&<| zHzn!0q*dY~)Rla#`I#@sFzCyCn8uBnHu+?w`jG-HQfsSHd2hOwjmL$FR)#3x1nzF8nulE%%$DYc_2>DtX~Y_`52 zh|gKo)~lPV#l}s!LBcCxV$U*m8ZG zv^4=~JX9IM86BfUU}%J{nRJ+xSTXH*EK7eLje>5_v!Q8+0{w%`%L!@b9rE-~2V71| zHSv(6e=_BAO6r-1Orx7SZ47Fr9-vI6CzheAaHSzM5`y_85Fq&U;kf(T9EjJxIs7$&*h)U z_wtSK*1m&|BKuSJMYf;V)?5ETKbBl$`IV)U`;7ZGSDbVxsewJlUeAs*KWeTt zy=`h|E-*dJeDq^TO!L`Z#G+WWuAYFlhkMy5QkdzwtNt3&?R&kfI~Vv^K`?zoI{%4gD)4fniJ)dMyewR zq&Di3x}008j!JtrDCvuc%Vd;Fs51+M5?6k;aqi31omH3Pr|y1rWa!SS%K%g>1=2yu zov@%7-qGtw)9u*h7^?e=yL2b2sM6@vm-63pjVp!J(W+L<3Ypla3AbXYQw!F zB@tOQD5i32=QvO@uwfk3h~S1%>s5M=UTTTr6aCw)kanek93gF0Kz1cV z@{)ncj|j`*XC68&plEnar49_0du|TImK*J!N}xFj1ZEp0ph`HzLc(5M32PcDH6T*? z4bKb-}8qy&3$1tJJ z8W7S~t;aKEn2;|byS<^H90@BI1VE$Mu_9solG7@(pBw;oQZ)$atLMiOQir&|3S_qFDoA6=j1mx1Wx||-bBKdCQ?yoqjK11oET`0U z&{EwhV-zK|FWZ-ZioS{Gkt!-LnioU*`g%sH^i>J3S_BEFjv843xGe>$idN5TmlWJB zyHv3Kz7|&?)T-|L1{Oj>Mow@j6XZ?tA(%^hf62}TAg6CEtMaFDUvRjbijjXlXz43` zarnbkB6ygo^{jOssOc*fW2vcqnycHV(fFx5#;RLqI_H3zG3}W`P_084s>Fh`K}_F_ z)RkP*Tp$&g1$z1-Uz=egoYTnp%M+e+s_U^jXC^x5RAK>_Zjx8Zxps2fpmH0MzbxS- z)g^9*ZnZ0fiQ2ceQcyGI9L~rRr}ApobkJ!Zwf$HT@Ue9 z$64bI;mbKU!j+w}OY4>DVm4B+6{pm#g&zu=_|yFDe4gWJ#}@kq`@OJ% zekcJ-fD)htC;>`<5}*Vq0ZITN_h#st8+qat-{`2mhjM!}2)DAR;-ATE9lkW86 zqr_3~0o}Y2(x0nqNT^Oba}!N zAASGIm`<5}*XWhy?WCyQghh zhxgq#>wmt=NC8z3xY@(TY8b|fQdW@D+N>;8S$J}lCIOTQ`$ zRP2+CzOi-89$=mRhX<%!XxtZ6l!o;H*Xe&y*k~%63mVrZCSFy!vPS>YA+%n=)n)N3 z)NMwWW}G&@3wPiS)}rlfta9|2|GejGM;yBfjaC1n#6~1P@FWWL^>Ecjmy4n6`-WYK zxytAV#0)8oa^+o>(GPA=scEiTT$VV1%E|M=Yd%*bE+W(y;8hx3e+`jSg?EL~b<2oz aoU59MnAw~xSR^!g8v7;*}vle literal 0 HcmV?d00001 diff --git a/.gitignore b/.gitignore index 1706557..8e6144a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ *.ipynb *.json test_mod.py +test_format.py diff --git a/README.md b/README.md index 24c34aa..e4f8676 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ git pull ## Contributing - [Tieu-Long Phan](https://tieulongphan.github.io/) +- [Klaus Weinbauer](https://github.com/klausweinbauer) - [Phuoc-Chung Nguyen Van](https://github.com/phuocchung123) ## Deployment timeline diff --git a/synutility/SynIO/Format/graph_to_mol.py b/synutility/SynIO/Format/graph_to_mol.py index 0cbbddb..243ce52 100644 --- a/synutility/SynIO/Format/graph_to_mol.py +++ b/synutility/SynIO/Format/graph_to_mol.py @@ -12,7 +12,13 @@ class GraphToMol: """ def __init__( - self, node_attributes: Dict[str, str], edge_attributes: Dict[str, str] + self, + node_attributes: Dict[str, str] = { + "element": "element", + "charge": "charge", + "atom_map": "atom_map", + }, + edge_attributes: Dict[str, str] = {"order": "order"}, ): """ Initializes the GraphToMol converter with mappings for node and edge attributes @@ -22,7 +28,7 @@ def __init__( self.edge_attributes = edge_attributes def graph_to_mol( - self, graph: nx.Graph, ignore_bond_order: bool = False + self, graph: nx.Graph, ignore_bond_order: bool = False, sanitize: bool = True ) -> Chem.Mol: """ Converts a NetworkX graph into an RDKit molecule object by interpreting node and @@ -52,8 +58,8 @@ def graph_to_mol( charge = data.get(self.node_attributes["charge"], 0) atom = Chem.Atom(element) atom.SetFormalCharge(charge) - if "atom_class" in data: # Set atom map number if available - atom.SetAtomMapNum(data["atom_class"]) + if "atom_map" in data: # Set atom map number if available + atom.SetAtomMapNum(data["atom_map"]) idx = mol.AddAtom(atom) node_to_idx[node] = idx @@ -69,7 +75,8 @@ def graph_to_mol( mol.AddBond(node_to_idx[u], node_to_idx[v], bond_type) # Attempt to sanitize the molecule to ensure its chemical validity - Chem.SanitizeMol(mol) + if sanitize: + Chem.SanitizeMol(mol) return mol diff --git a/synutility/SynIO/Format/mol_to_graph.py b/synutility/SynIO/Format/mol_to_graph.py index 556cbab..af5cd51 100644 --- a/synutility/SynIO/Format/mol_to_graph.py +++ b/synutility/SynIO/Format/mol_to_graph.py @@ -1,7 +1,7 @@ from rdkit import Chem from rdkit.Chem import AllChem import networkx as nx -from typing import Dict +from typing import Any, Dict, Optional import random from synutility.SynIO.debug import setup_logging @@ -11,11 +11,15 @@ class MolToGraph: """ A class for converting molecules from SMILES strings to graph representations using - RDKit and NetworkX. + RDKit and NetworkX. It supports creating both lightweight and detailed + graph representations with customizable atom and bond attributes, + allowing for exclusion of atoms without atom mapping numbers. """ - def __init__(self): - """Initialize the MolToGraphConverter class.""" + def __init__(self) -> None: + """ + Initialize the MolToGraph class. + """ pass @staticmethod @@ -43,11 +47,11 @@ def get_stereochemistry(atom: Chem.Atom) -> str: - str: The stereochemistry ('R', 'S', or 'N' for non-chiral). """ chiral_tag = atom.GetChiralTag() - if chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CCW: - return "S" - elif chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CW: - return "R" - return "N" + return ( + "S" + if chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CCW + else "R" if chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CW else "N" + ) @staticmethod def get_bond_stereochemistry(bond: Chem.Bond) -> str: @@ -75,28 +79,19 @@ def has_atom_mapping(mol: Chem.Mol) -> bool: """ Check if the given molecule has any atom mapping numbers. - Atom mapping numbers are used in chemical reactions to track the correspondence - between atoms in reactants and products. - Parameters: - mol (Chem.Mol): An RDKit molecule object. Returns: - bool: True if any atom in the molecule has a mapping number, False otherwise. """ - for atom in mol.GetAtoms(): - if atom.HasProp("molAtomMapNumber"): - return True - return False + return any(atom.HasProp("molAtomMapNumber") for atom in mol.GetAtoms()) @staticmethod def random_atom_mapping(mol: Chem.Mol) -> Chem.Mol: """ Assigns a random atom mapping number to each atom in the given molecule. - This method iterates over all atoms in the molecule and assigns a random - mapping number between 1 and the total number of atoms to each atom. - Parameters: - mol (Chem.Mol): An RDKit molecule object. @@ -110,66 +105,114 @@ def random_atom_mapping(mol: Chem.Mol) -> Chem.Mol: return mol @classmethod - def mol_to_graph(cls, mol: Chem.Mol, drop_non_aam: bool = False) -> nx.Graph: + def mol_to_graph( + cls, + mol: Chem.Mol, + drop_non_aam: Optional[bool] = False, + light_weight: Optional[bool] = False, + ) -> nx.Graph: """ Converts an RDKit molecule object to a NetworkX graph with specified atom and bond - attributes. - Optionally excludes atoms without atom mapping numbers if drop_non_aam is True. + attributes. Optionally excludes atoms without atom mapping numbers + if drop_non_aam is True. Parameters: - mol (Chem.Mol): An RDKit molecule object. - drop_non_aam (bool, optional): If True, nodes without atom mapping numbers will - be dropped. + be dropped. + - light_weight (bool, optional): If True, creates a graph with minimal attributes. Returns: - nx.Graph: A NetworkX graph representing the molecule. """ + if light_weight: + return cls._create_light_weight_graph(mol, drop_non_aam) + else: + return cls._create_detailed_graph(mol, drop_non_aam) - cls.add_partial_charges(mol) - + @classmethod + def _create_light_weight_graph(cls, mol: Chem.Mol, drop_non_aam: bool) -> nx.Graph: graph = nx.Graph() - index_to_class: Dict[int, int] = {} - if cls.has_atom_mapping(mol) is False: - mol = cls.random_atom_mapping(mol) - for atom in mol.GetAtoms(): atom_map = atom.GetAtomMapNum() - if drop_non_aam and atom_map == 0: continue - gasteiger_charge = round(float(atom.GetProp("_GasteigerCharge")), 3) - index_to_class[atom.GetIdx()] = atom_map graph.add_node( atom_map, - charge=atom.GetFormalCharge(), - hcount=atom.GetTotalNumHs(), - aromatic=atom.GetIsAromatic(), element=atom.GetSymbol(), + aromatic=atom.GetIsAromatic(), + hcount=atom.GetTotalNumHs(), + charge=atom.GetFormalCharge(), + neighbors=[neighbor.GetSymbol() for neighbor in atom.GetNeighbors()], atom_map=atom_map, - isomer=cls.get_stereochemistry(atom), - partial_charge=gasteiger_charge, - hybridization=str(atom.GetHybridization()), - in_ring=atom.IsInRing(), - # explicit_valence=atom.GetExplicitValence(), - implicit_hcount=atom.GetNumImplicitHs(), - neighbors=sorted( - [neighbor.GetSymbol() for neighbor in atom.GetNeighbors()] - ), ) + for bond in atom.GetBonds(): + neighbor = bond.GetOtherAtom(atom) + neighbor_map = neighbor.GetAtomMapNum() + if not drop_non_aam or neighbor_map != 0: + graph.add_edge( + atom_map, neighbor_map, order=bond.GetBondTypeAsDouble() + ) + return graph - for bond in mol.GetBonds(): - begin_atom_class = index_to_class.get(bond.GetBeginAtomIdx()) - end_atom_class = index_to_class.get(bond.GetEndAtomIdx()) - if begin_atom_class is None or end_atom_class is None: + @classmethod + def _create_detailed_graph(cls, mol: Chem.Mol, drop_non_aam: bool) -> nx.Graph: + cls.add_partial_charges(mol) # Compute charges if not already present + graph = nx.Graph() + index_to_class = {} + if not cls.has_atom_mapping(mol): + mol = cls.random_atom_mapping(mol) + + for atom in mol.GetAtoms(): + atom_map = atom.GetAtomMapNum() + if drop_non_aam and atom_map == 0: continue - graph.add_edge( - begin_atom_class, - end_atom_class, - order=bond.GetBondTypeAsDouble(), - ez_isomer=cls.get_bond_stereochemistry(bond), - bond_type=str(bond.GetBondType()), - conjugated=bond.GetIsConjugated(), - in_ring=bond.IsInRing(), - ) + props = cls._gather_atom_properties(atom) + index_to_class[atom.GetIdx()] = atom_map + graph.add_node(atom_map, **props) + + for bond in mol.GetBonds(): + begin_atom_map = index_to_class.get(bond.GetBeginAtomIdx()) + end_atom_map = index_to_class.get(bond.GetEndAtomIdx()) + if begin_atom_map and end_atom_map: + graph.add_edge( + begin_atom_map, end_atom_map, **cls._gather_bond_properties(bond) + ) return graph + + @staticmethod + def _gather_atom_properties(atom: Chem.Atom) -> Dict[str, Any]: + """Collect all relevant properties from an atom to use + as graph node attributes.""" + gasteiger_charge = ( + round(float(atom.GetProp("_GasteigerCharge")), 3) + if atom.HasProp("_GasteigerCharge") + else 0.0 + ) + return { + "charge": atom.GetFormalCharge(), + "hcount": atom.GetTotalNumHs(), + "aromatic": atom.GetIsAromatic(), + "element": atom.GetSymbol(), + "atom_map": atom.GetAtomMapNum(), + "isomer": MolToGraph.get_stereochemistry(atom), + "partial_charge": gasteiger_charge, + "hybridization": str(atom.GetHybridization()), + "in_ring": atom.IsInRing(), + "implicit_hcount": atom.GetNumImplicitHs(), + "neighbors": sorted( + neighbor.GetSymbol() for neighbor in atom.GetNeighbors() + ), + } + + @staticmethod + def _gather_bond_properties(bond: Chem.Bond) -> Dict[str, Any]: + """Collect all relevant properties from a bond to use as graph edge attributes.""" + return { + "order": bond.GetBondTypeAsDouble(), + "ez_isomer": MolToGraph.get_bond_stereochemistry(bond), + "bond_type": str(bond.GetBondType()), + "conjugated": bond.GetIsConjugated(), + "in_ring": bond.IsInRing(), + } diff --git a/synutility/SynIO/Format/smi_to_graph.py b/synutility/SynIO/Format/smi_to_graph.py new file mode 100644 index 0000000..8cd4386 --- /dev/null +++ b/synutility/SynIO/Format/smi_to_graph.py @@ -0,0 +1,68 @@ +import networkx as nx +from rdkit import Chem +from typing import Optional, Tuple + +from synutility.SynIO.debug import setup_logging +from synutility.SynIO.Format.mol_to_graph import MolToGraph + + +logger = setup_logging + + +def smiles_to_graph( + smiles: str, drop_non_aam: bool, light_weight: bool, sanitize: bool +) -> Optional[nx.Graph]: + """ + Helper function to convert SMILES string to a graph using MolToGraph class. + + Parameters: + - smiles (str): SMILES representation of the molecule. + - drop_non_aam (bool): Whether to drop nodes without atom mapping. + - light_weight (bool): Whether to create a light-weight graph. + - sanitize (bool): Whether to sanitize the molecule during conversion. + + Returns: + - nx.Graph or None: The networkx graph representation of the molecule, or None if conversion fails. + """ + try: + mol = Chem.MolFromSmiles(smiles, sanitize) + if mol: + return MolToGraph().mol_to_graph(mol, drop_non_aam, light_weight) + else: + logger.warning(f"Failed to parse SMILES: {smiles}") + except Exception as e: + logger.error(f"Error converting SMILES to graph: {smiles}, Error: {str(e)}") + return None + + +def rsmi_to_graph( + rsmi: str, + drop_non_aam: bool = True, + light_weight: bool = True, + sanitize: bool = True, +) -> Tuple[Optional[nx.Graph], Optional[nx.Graph]]: + """ + Converts reactant and product SMILES strings from a reaction SMILES (RSMI) format + to graph representations. + + Parameters: + - rsmi (str): Reaction SMILES string in "reactants>>products" format. + - drop_non_aam (bool, optional): If True, nodes without atom mapping numbers + will be dropped. + - light_weight (bool, optional): If True, creates a light-weight graph. + - sanitize (bool, optional): If True, sanitizes molecules during conversion. + + Returns: + - Tuple[Optional[nx.Graph], Optional[nx.Graph]]: A tuple containing t + he graph representations of the reactants and products. + """ + try: + reactants_smiles, products_smiles = rsmi.split(">>") + r_graph = smiles_to_graph( + reactants_smiles, drop_non_aam, light_weight, sanitize + ) + p_graph = smiles_to_graph(products_smiles, drop_non_aam, light_weight, sanitize) + return (r_graph, p_graph) + except ValueError: + logger.error(f"Invalid RSMI format: {rsmi}") + return (None, None) diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py new file mode 100644 index 0000000..0ab106b --- /dev/null +++ b/synutility/SynVis/graph_visualizer.py @@ -0,0 +1,245 @@ +from rdkit import Chem +from rdkit.Chem import rdDepictor +from typing import Dict, Optional +import networkx as nx +from synutility.SynIO.Format.graph_to_mol import GraphToMol +import matplotlib.pyplot as plt + + +class GraphVisualizer: + def __init__( + self, + node_attributes: Dict[str, str] = { + "element": "element", + "charge": "charge", + "atom_map": "atom_map", + }, + edge_attributes: Dict[str, str] = {"order": "order"}, + ): + self.node_attributes = node_attributes + self.edge_attributes = edge_attributes + + def _get_its_as_mol(self, its: nx.Graph) -> Optional[Chem.Mol]: + """ + Convert a graph representation of an intermediate transition state into an RDKit molecule. + + Parameters: + - its (nx.Graph): The graph to convert. + + Returns: + - Chem.Mol or None: The RDKit molecule if conversion is successful, None otherwise. + """ + _its = its.copy() + for n in _its.nodes(): + _its.nodes[n]["atom_map"] = n # + for u, v in _its.edges(): + _its[u][v]["order"] = 1 + return GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol( + _its, False, False + ) # Ensure this function is defined correctly elsewhere + + def plot_its( + self, + its: nx.Graph, + ax: plt.Axes, + use_mol_coords: bool = True, + title: Optional[str] = None, + node_color: str = "#FFFFFF", + node_size: int = 500, + edge_color: str = "#000000", + edge_weight: float = 2.0, + show_atom_map: bool = False, + use_edge_color: bool = False, # + symbol_key: str = "element", + bond_key: str = "order", + aam_key: str = "atom_map", + standard_order_key: str = "standard_order", + font_size: int = 12, + ): + """ + Plot an intermediate transition state (ITS) graph on a given Matplotlib axes with various customizations. + + Parameters: + - its (nx.Graph): The graph representing the intermediate transition state. + - ax (plt.Axes): The matplotlib axes to draw the graph on. + - use_mol_coords (bool): Use molecular coordinates for node positions if True, else use a spring layout. + - title (Optional[str]): Title for the graph. If None, no title is set. + - node_color (str): Color code for the graph nodes. + - node_size (int): Size of the graph nodes. + - edge_color (str): Default color code for the graph edges if not using conditional coloring. + - edge_weight (float): Thickness of the graph edges. + - show_aam (bool): If True, displays atom mapping numbers alongside symbols. + - use_edge_color (bool): If True, colors edges based on their 'standard_order' attribute. + - symbol_key (str): Key to access the symbol attribute in the node's data. + - bond_key (str): Key to access the bond type attribute in the edge's data. + - aam_key (str): Key to access the atom mapping number in the node's data. + - standard_order_key (str): Key to determine the edge color conditionally. + - font_size (int): Font size for labels and edge labels. + + Returns: + - None + """ + bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡"} + + positions = self._calculate_positions(its, use_mol_coords) + + ax.axis("equal") + ax.axis("off") + if title: + ax.set_title(title) + + # Conditional edge coloring based on 'standard_order' + if use_edge_color: + edge_colors = [ + ( + "green" + if data.get(standard_order_key, 0) > 0 + else "red" if data.get(standard_order_key, 0) < 0 else "black" + ) + for _, _, data in its.edges(data=True) + ] + else: + edge_colors = edge_color + + nx.draw_networkx_edges( + its, positions, edge_color=edge_colors, width=edge_weight, ax=ax + ) + nx.draw_networkx_nodes( + its, positions, node_color=node_color, node_size=node_size, ax=ax + ) + + # Adjust labels to optionally show atom mapping numbers + labels = { + n: ( + f"{d[symbol_key]} ({d.get(aam_key, '')})" + if show_atom_map + else f"{d[symbol_key]}" + ) + for n, d in its.nodes(data=True) + } + edge_labels = self._determine_edge_labels(its, bond_char, bond_key) + + nx.draw_networkx_labels( + its, positions, labels=labels, font_size=font_size, ax=ax + ) + nx.draw_networkx_edge_labels( + its, positions, edge_labels=edge_labels, font_size=font_size, ax=ax + ) + + def _calculate_positions(self, its: nx.Graph, use_mol_coords: bool) -> dict: + if use_mol_coords: + mol = self._get_its_as_mol(its) + positions = {} + rdDepictor.Compute2DCoords(mol) + for i, atom in enumerate(mol.GetAtoms()): + aam = atom.GetAtomMapNum() + apos = mol.GetConformer().GetAtomPosition(i) + positions[aam] = [apos.x, apos.y] + else: + positions = nx.spring_layout(its) + return positions + + def _determine_edge_labels( + self, its: nx.Graph, bond_char: dict, bond_key: str + ) -> dict: + edge_labels = {} + for u, v, data in its.edges(data=True): + bond_codes = data.get(bond_key, (0, 0)) + bc1, bc2 = bond_char.get(bond_codes[0], "∅"), bond_char.get( + bond_codes[1], "∅" + ) + if bc1 != bc2: + edge_labels[(u, v)] = f"({bc1},{bc2})" + return edge_labels + + def plot_as_mol( + self, + g: nx.Graph, + ax: plt.Axes, + use_mol_coords: bool = True, + node_color: str = "#FFFFFF", + node_size: int = 500, + edge_color: str = "#000000", + edge_width: float = 2.0, + label_color: str = "#000000", + font_size: int = 12, + show_atom_map: bool = False, + bond_char: Dict[Optional[int], str] = None, + symbol_key: str = "element", + bond_key: str = "order", + aam_key: str = "atom_map", + ) -> None: + """ + Plots a molecular graph on a given Matplotlib axes using either molecular coordinates + or a networkx layout. + + Parameters: + - g (nx.Graph): The molecular graph to be plotted. + - ax (plt.Axes): Matplotlib axes where the graph will be plotted. + - use_mol_coords (bool, optional): Use molecular coordinates if True, else use networkx layout. + - node_color (str, optional): Color code for the nodes. + - node_size (int, optional): Size of the nodes. + - edge_color (str, optional): Color code for the edges. + - label_color (str, optional): Color for node labels. + - font_size (int, optional): Font size for labels. + - bond_char (Dict[Optional[int], str], optional): Dictionary mapping bond types to characters. + - symbol_key (str, optional): Node attribute key for element symbols. + - bond_key (str, optional): Edge attribute key for bond types. + + Returns: + - None + """ + + # Set default bond characters if not provided + if bond_char is None: + bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"} + + # Determine positions based on use_mol_coords flag + if use_mol_coords: + mol = GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol( + g, False + ) # This function needs to be defined or imported + positions = {} + rdDepictor.Compute2DCoords(mol) + for atom in mol.GetAtoms(): + aidx = atom.GetIdx() + atom_map = atom.GetAtomMapNum() + apos = mol.GetConformer().GetAtomPosition(aidx) + positions[atom_map] = [apos.x, apos.y] + else: + positions = nx.spring_layout(g) # Optionally provide a layout configuration + + ax.axis("equal") + ax.axis("off") + + # Drawing elements on the plot + nx.draw_networkx_edges( + g, positions, edge_color=edge_color, width=edge_width, ax=ax + ) + nx.draw_networkx_nodes( + g, positions, node_color=node_color, node_size=node_size, ax=ax + ) + + # Preparing labels + labels = {} + for n, d in g.nodes(data=True): + label = f"{d.get(symbol_key, '')}" + if show_atom_map: + label += f" ({d.get(aam_key, '')})" + labels[n] = label + edge_labels = { + (u, v): bond_char.get(d[bond_key], "∅") for u, v, d in g.edges(data=True) + } + + # Drawing labels + nx.draw_networkx_labels( + g, + positions, + labels=labels, + font_color=label_color, + font_size=font_size, + ax=ax, + ) + nx.draw_networkx_edge_labels( + g, positions, edge_labels=edge_labels, font_color=label_color, ax=ax + ) diff --git a/synutility/SynVis/pdf_writer.py b/synutility/SynVis/pdf_writer.py new file mode 100644 index 0000000..d07851e --- /dev/null +++ b/synutility/SynVis/pdf_writer.py @@ -0,0 +1,137 @@ +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +from typing import List, Callable, Union, Tuple, Optional +import tqdm + + +class PdfWriter: + """ + A utility class to create PDF reports with plots from a list of figures or dynamically generated plots. + + Parameters: + - file (str): The file name of the output PDF. + - plot_fn (Optional[Callable], optional): Function to create a plot for a single data entry or row. + Expected interface: `plot_fn(data_entry, axis, **kwargs)`. Default is None. + - plot_per_row (bool, optional): If True, calls `plot_fn` for an entire row instead of individual subplots. + Default is False. + - max_pages (int, optional): Maximum number of pages to create. Default is 999. + - rows (int, optional): Number of plot rows per page. Default is 7. + - cols (int, optional): Number of plot columns per page. Default is 2. + - pagesize (Tuple[float, float], optional): Size of a single page (in inches). Default is (21, 29.7). + - width_ratios (Optional[List[float]], optional): Column width ratios. Default is None. + - show_progress (bool, optional): If True, displays a progress bar using `tqdm`. Default is True. + """ + + def __init__( + self, + file: str, + plot_fn: Optional[Callable] = None, + plot_per_row: bool = False, + max_pages: int = 999, + rows: int = 7, + cols: int = 2, + pagesize: Tuple[float, float] = (21, 29.7), + width_ratios: Optional[List[float]] = None, + show_progress: bool = True, + ): + self.pdf_pages = PdfPages(file) + self.plot_fn = plot_fn + self.plot_per_row = plot_per_row + self.max_pages = max_pages + self.rows = rows + self.cols = cols + self.pagesize = pagesize + self.width_ratios = width_ratios + self.show_progress = show_progress + + def plot(self, data: Union[List[plt.Figure], List], **kwargs): + """ + Generate plots from data or save pre-generated figures to the PDF. + + Parameters: + - data (Union[List[matplotlib.figure.Figure], List]): Input data or list of figures. + If a list of figures, they are saved directly. Otherwise, the `plot_fn` is called for each data entry. + - **kwargs: Additional keyword arguments passed to `plot_fn`. + + Returns: + - None + """ + # Case 1: Pre-generated figures + if all(isinstance(item, plt.Figure) for item in data): + for fig in tqdm.tqdm( + data, disable=not self.show_progress, desc="Saving Figures" + ): + self.save_figure(fig) + return + + # Case 2: Generate plots dynamically using `plot_fn` + if self.plot_fn is None: + raise ValueError( + "plot_fn must be provided when input is not a list of figures." + ) + + if not isinstance(data, list): + raise ValueError( + "Data must be a list or a list of matplotlib.figure.Figure." + ) + + plots_per_page = self.rows if self.plot_per_row else self.rows * self.cols + max_plots = self.max_pages * plots_per_page + step = max(len(data) / max_plots, 1) + pages = int((len(data) / step + plots_per_page - 1) // plots_per_page) + + for p in tqdm.tqdm( + range(pages), disable=not self.show_progress, desc="Generating Pages" + ): + fig, ax = plt.subplots( + self.rows, + self.cols, + figsize=self.pagesize, + squeeze=False, + gridspec_kw=( + {"width_ratios": self.width_ratios} if self.width_ratios else None + ), + ) + done = False + for r in range(self.rows): + if self.plot_per_row: + _idx = int((p * self.rows + r) * step) + if _idx >= len(data): + done = True + break + self.plot_fn(data[_idx], ax[r, :], index=_idx, **kwargs) + else: + for c in range(self.cols): + _idx = int((p * plots_per_page + r * self.cols + c) * step) + if _idx >= len(data): + done = True + break + self.plot_fn(data[_idx], ax[r, c], index=_idx, **kwargs) + plt.tight_layout() + self.pdf_pages.savefig(fig, bbox_inches="tight", pad_inches=1) + plt.close(fig) + if done: + break + + def save_figure(self, figure: plt.Figure): + """ + Save a pre-generated matplotlib figure directly to the PDF. + + Parameters: + - figure (matplotlib.figure.Figure): The figure to save. + + Returns: + - None + """ + if not isinstance(figure, plt.Figure): + raise ValueError("Input must be a matplotlib.figure.Figure.") + self.pdf_pages.savefig(figure, bbox_inches="tight", pad_inches=1) + + def close(self): + """ + Close the PDF file, ensuring all pages are written. + + Returns: + - None + """ + self.pdf_pages.close() diff --git a/synutility/SynVis/rsmi_to_fig.py b/synutility/SynVis/rsmi_to_fig.py new file mode 100644 index 0000000..13a52fa --- /dev/null +++ b/synutility/SynVis/rsmi_to_fig.py @@ -0,0 +1,76 @@ +import networkx as nx +import matplotlib.pyplot as plt +from typing import Union, Tuple + + +from synutility.SynVis.graph_visualizer import GraphVisualizer +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.Format.its_construction import ITSConstruction + +vis_graph = GraphVisualizer() + + +def three_graph_vis( + input: Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]], + sanitize: bool = False, + figsize: Tuple[int, int] = (18, 5), + orientation: str = "horizontal", + show_titles: bool = True, + show_atom_map: bool = False, +) -> plt.Figure: + """ + Visualize three related graphs (reactants, intermediate transition state, and products) + side by side or vertically in a single figure. + + Parameters: + - input (Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): Either a reaction SMILES string + or a tuple of three NetworkX graphs (reactants, ITS, products). + - sanitize (bool, optional): If True, sanitizes the input molecule. Default is False. + - figsize (Tuple[int, int], optional): The size of the Matplotlib figure. Default is (18, 5). + - orientation (str, optional): Layout of the subplots; 'horizontal' or 'vertical'. Default is 'horizontal'. + - show_titles (bool, optional): If True, adds titles to each subplot. Default is True. + + Returns: + - plt.Figure: The Matplotlib figure containing the three subplots. + """ + try: + # Parse input to determine graphs + if isinstance(input, str): + r, p = rsmi_to_graph(input, light_weight=True, sanitize=sanitize) + its = ITSConstruction().ITSGraph(r, p) + elif isinstance(input, tuple) and len(input) == 3: + r, p, its = input + else: + raise ValueError( + "Input must be a reaction SMILES string or a tuple of three graphs (r, p, its)." + ) + + # Set up subplots + if orientation == "horizontal": + fig, ax = plt.subplots(1, 3, figsize=figsize) + elif orientation == "vertical": + fig, ax = plt.subplots(3, 1, figsize=figsize) + else: + raise ValueError("Orientation must be 'horizontal' or 'vertical'.") + + # Plot the graphs + vis_graph.plot_as_mol( + r, ax[0], show_atom_map=show_atom_map, font_size=12, node_size=800, edge_width=2.0 + ) + if show_titles: + ax[0].set_title("Reactants") + + vis_graph.plot_its(its, ax[1], use_edge_color=True, show_atom_map=show_atom_map) + if show_titles: + ax[1].set_title("Imaginary Transition State") + + vis_graph.plot_as_mol( + p, ax[2], show_atom_map=show_atom_map, font_size=12, node_size=800, edge_width=2.0 + ) + if show_titles: + ax[2].set_title("Products") + + return fig + + except Exception as e: + raise RuntimeError(f"An error occurred during visualization: {str(e)}") From d4da32936d93a1fd404d8a0d81b96bdded8d5dd9 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 12:19:05 +0100 Subject: [PATCH 04/10] fix lint --- synutility/SynVis/graph_visualizer.py | 4 ++-- synutility/SynVis/rsmi_to_fig.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py index 0ab106b..453f078 100644 --- a/synutility/SynVis/graph_visualizer.py +++ b/synutility/SynVis/graph_visualizer.py @@ -47,13 +47,13 @@ def plot_its( node_color: str = "#FFFFFF", node_size: int = 500, edge_color: str = "#000000", - edge_weight: float = 2.0, + edge_weight: float = 2.0, show_atom_map: bool = False, use_edge_color: bool = False, # symbol_key: str = "element", bond_key: str = "order", aam_key: str = "atom_map", - standard_order_key: str = "standard_order", + standard_order_key: str = "standard_order", font_size: int = 12, ): """ diff --git a/synutility/SynVis/rsmi_to_fig.py b/synutility/SynVis/rsmi_to_fig.py index 13a52fa..690d326 100644 --- a/synutility/SynVis/rsmi_to_fig.py +++ b/synutility/SynVis/rsmi_to_fig.py @@ -55,7 +55,12 @@ def three_graph_vis( # Plot the graphs vis_graph.plot_as_mol( - r, ax[0], show_atom_map=show_atom_map, font_size=12, node_size=800, edge_width=2.0 + r, + ax[0], + show_atom_map=show_atom_map, + font_size=12, + node_size=800, + edge_width=2.0, ) if show_titles: ax[0].set_title("Reactants") @@ -65,7 +70,12 @@ def three_graph_vis( ax[1].set_title("Imaginary Transition State") vis_graph.plot_as_mol( - p, ax[2], show_atom_map=show_atom_map, font_size=12, node_size=800, edge_width=2.0 + p, + ax[2], + show_atom_map=show_atom_map, + font_size=12, + node_size=800, + edge_width=2.0, ) if show_titles: ax[2].set_title("Products") From 1a628b3620b24e24fdcdf2643346dbeb743bdfef Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 12:42:34 +0100 Subject: [PATCH 05/10] add copy right for FGUtils --- synutility/SynVis/graph_visualizer.py | 6 ++++++ synutility/SynVis/pdf_writer.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py index 453f078..acf7426 100644 --- a/synutility/SynVis/graph_visualizer.py +++ b/synutility/SynVis/graph_visualizer.py @@ -1,3 +1,9 @@ +""" +This module comprises several functions adapted from the work of Klaus Weinbauer. +The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils. +Adaptations were made to enhance functionality and integrate with other system components. +""" + from rdkit import Chem from rdkit.Chem import rdDepictor from typing import Dict, Optional diff --git a/synutility/SynVis/pdf_writer.py b/synutility/SynVis/pdf_writer.py index d07851e..6dfc4bc 100644 --- a/synutility/SynVis/pdf_writer.py +++ b/synutility/SynVis/pdf_writer.py @@ -1,3 +1,9 @@ +""" +This module comprises several functions adapted from the work of Klaus Weinbauer. +The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils. +Adaptations were made to enhance functionality and integrate with other system components. +""" + import matplotlib.pyplot as plt from matplotlib.backends.backend_pdf import PdfPages from typing import List, Callable, Union, Tuple, Optional From 439cf41913ad8e22983b4e94cd5d583b928c4277 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 12:52:49 +0100 Subject: [PATCH 06/10] prepare release --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 557f751..b15c061 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synutility" -version = "0.0.10" +version = "0.0.11" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] From 6d06033d33b4599b97ea12198574eb67e1480f1e Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 14:46:50 +0100 Subject: [PATCH 07/10] add partial map expansion --- .github/workflows/test-and-lint.yml | 2 +- Test/SynAAM/__init__.py | 0 Test/SynAAM/test_normalize_aam.py | 63 ++++++++ Test/SynAAM/test_partial_expand.py | 23 +++ Test/SynGraph/Transform/test_core_engine.py | 2 - lint.sh | 2 +- pytest.sh | 3 + synutility/SynAAM/__init__.py | 0 synutility/SynAAM/misc.py | 142 ++++++++++++++++++ synutility/SynAAM/normalize_aam.py | 134 +++++++++++++++++ synutility/SynAAM/partial_expand.py | 81 ++++++++++ .../Transform}/rule_apply.py | 51 ++++++- synutility/SynIO/data_type.py | 21 +++ 13 files changed, 519 insertions(+), 5 deletions(-) create mode 100644 Test/SynAAM/__init__.py create mode 100644 Test/SynAAM/test_normalize_aam.py create mode 100644 Test/SynAAM/test_partial_expand.py create mode 100755 pytest.sh create mode 100644 synutility/SynAAM/__init__.py create mode 100644 synutility/SynAAM/misc.py create mode 100644 synutility/SynAAM/normalize_aam.py create mode 100644 synutility/SynAAM/partial_expand.py rename synutility/{SynMOD => SynGraph/Transform}/rule_apply.py (50%) diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index bb09300..af56d18 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -51,5 +51,5 @@ jobs: - name: Test with pytest run: | conda activate synutils-env - pytest Test + ./pytest.sh shell: bash -l {0} diff --git a/Test/SynAAM/__init__.py b/Test/SynAAM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynAAM/test_normalize_aam.py b/Test/SynAAM/test_normalize_aam.py new file mode 100644 index 0000000..5ccd071 --- /dev/null +++ b/Test/SynAAM/test_normalize_aam.py @@ -0,0 +1,63 @@ +import unittest +import networkx as nx +from synutility.SynAAM.normalize_aam import NormalizeAAM + + +class TestNormalizeAAM(unittest.TestCase): + def setUp(self): + """Set up for testing.""" + self.normalizer = NormalizeAAM() + + def test_fix_atom_mapping(self): + """Test that atom mappings are incremented correctly.""" + input_smiles = "[C:0]([H:1])([H:2])[H:3]" + expected_smiles = "[C:1]([H:2])([H:3])[H:4]" + self.assertEqual( + self.normalizer.fix_atom_mapping(input_smiles), expected_smiles + ) + + def test_fix_rsmi(self): + """Test that RSMI atom mappings are incremented correctly + for both reactants and products.""" + input_rsmi = "[C:0]>>[C:1]" + expected_rsmi = "[C:1]>>[C:2]" + self.assertEqual(self.normalizer.fix_rsmi(input_rsmi), expected_rsmi) + + def test_extract_subgraph(self): + """Test extraction of a subgraph based on specified indices.""" + g = nx.complete_graph(5) + indices = [0, 1, 2] + subgraph = self.normalizer.extract_subgraph(g, indices) + self.assertEqual(len(subgraph.nodes()), 3) + self.assertTrue(all(node in subgraph for node in indices)) + + def test_reset_indices_and_atom_map(self): + """Test resetting of indices and atom map in a subgraph.""" + g = nx.path_graph(5) + for i in range(5): + g.nodes[i]["atom_map"] = i + 1 + reset_graph = self.normalizer.reset_indices_and_atom_map(g) + expected_atom_maps = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5} + for node in reset_graph: + self.assertEqual( + reset_graph.nodes[node]["atom_map"], expected_atom_maps[node] + ) + + def test_reaction_smiles_processing(self): + """Test that the reaction SMILES string is processed to meet expected output.""" + input_rsmi = ( + "[C:2]([C:3]([H:9])([H:10])[H:11])([H:8])=[C:1]([C:0]([H:6])([H:5])" + + "[H:4])[H:7].[H:12][H:13]>>[C:3]([C:2]([C:1]([C:0]([H:6])([H:5])" + + "[H:4])([H:12])[H:7])([H:8])[H:13])([H:9])([H:10])[H:11]" + ) + expected_output = ( + "[CH3:1][CH:2]=[CH:3][CH3:4].[H:5][H:6]>>[CH3:1][CH:2]([CH:3]" + + "([CH3:4])[H:6])[H:5]" + ) + result = self.normalizer.fit(input_rsmi) + self.assertEqual(result, expected_output) + + +# Run the unittest +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynAAM/test_partial_expand.py b/Test/SynAAM/test_partial_expand.py new file mode 100644 index 0000000..ebe7594 --- /dev/null +++ b/Test/SynAAM/test_partial_expand.py @@ -0,0 +1,23 @@ +import unittest +from synutility.SynAAM.partial_expand import PartialExpand + + +class TestPartialExpand(unittest.TestCase): + def test_expand(self): + """ + Test the expand function of the PartialExpand class with a given RSMI. + """ + # Input RSMI + input_rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" + # Expected output + expected_rsmi = ( + "[CH2:1]=[CH:2][CH3:3].[H:4][H:5]>>[CH2:1]([CH:2]([CH3:3])[H:5])[H:4]" + ) + # Perform the expansion + output_rsmi = PartialExpand.expand(input_rsmi) + # Assert the result matches the expected output + self.assertEqual(output_rsmi, expected_rsmi) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Transform/test_core_engine.py b/Test/SynGraph/Transform/test_core_engine.py index 0a8cf95..059dd98 100644 --- a/Test/SynGraph/Transform/test_core_engine.py +++ b/Test/SynGraph/Transform/test_core_engine.py @@ -1,11 +1,9 @@ import os -import pytest import unittest import tempfile from synutility.SynGraph.Transform.core_engine import CoreEngine -@pytest.mark.skip(reason="Temporarily disabled for demonstration purposes") class TestCoreEngine(unittest.TestCase): def setUp(self): # Create a temporary directory diff --git a/lint.sh b/lint.sh index e2178df..c088240 100755 --- a/lint.sh +++ b/lint.sh @@ -2,5 +2,5 @@ flake8 . --count --max-complexity=13 --max-line-length=120 \ --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501" \ - --exclude venv,core_engine.py \ + --exclude venv,core_engine.py,rule_apply.py \ --statistics diff --git a/pytest.sh b/pytest.sh new file mode 100755 index 0000000..fd9a7e4 --- /dev/null +++ b/pytest.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pytest Test/SynChem Test/SynAAM Test/SynGraph Test/SynIO Test/SynSplit Test/SynSplit \ No newline at end of file diff --git a/synutility/SynAAM/__init__.py b/synutility/SynAAM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynAAM/misc.py b/synutility/SynAAM/misc.py new file mode 100644 index 0000000..b08fd07 --- /dev/null +++ b/synutility/SynAAM/misc.py @@ -0,0 +1,142 @@ +import networkx as nx + + +def get_rc( + ITS: nx.Graph, + element_key: str = "element", + bond_key: str = "order", + standard_key: str = "standard_order", +) -> nx.Graph: + """ + Extracts the reaction center (RC) graph from a given ITS graph by identifying edges + where the bond order changes, indicating a reaction event. + + Parameters: + ITS (nx.Graph): The ITS graph to extract the RC from. + element_key (str): Node attribute key for atom symbols. Defaults to 'element'. + bond_key (str): Edge attribute key for bond order. Defaults to 'order'. + standard_key (str): Edge attribute key for standard order information. Defaults to 'standard_order'. + + Returns: + nx.Graph: A new graph representing the reaction center of the ITS. + """ + rc = nx.Graph() + for n1, n2, data in ITS.edges(data=True): + if data[bond_key][0] != data[bond_key][1]: + rc.add_node(n1, **{element_key: ITS.nodes[n1][element_key]}) + rc.add_node(n2, **{element_key: ITS.nodes[n2][element_key]}) + rc.add_edge( + n1, + n2, + **{bond_key: data[bond_key], standard_key: data[standard_key]}, + ) + return rc + + +def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): + """ + Decompose an ITS graph into two separate graphs G and H based on shared + node and edge attributes. + + Parameters: + - its_graph (nx.Graph): The integrated transition state (ITS) graph. + - nodes_share (str): Node attribute key that stores tuples with node attributes + or G and H. + - edges_share (str): Edge attribute key that stores tuples with edge attributes + for G and H. + + Returns: + - Tuple[nx.Graph, nx.Graph]: A tuple containing the two graphs G and H. + """ + G = nx.Graph() + H = nx.Graph() + + # Decompose nodes + for node, data in its_graph.nodes(data=True): + if nodes_share in data: + node_attr_g, node_attr_h = data[nodes_share] + # Unpack node attributes for G + G.add_node( + node, + element=node_attr_g[0], + aromatic=node_attr_g[1], + hcount=node_attr_g[2], + charge=node_attr_g[3], + neighbors=node_attr_g[4], + atom_map=node, + ) + # Unpack node attributes for H + H.add_node( + node, + element=node_attr_h[0], + aromatic=node_attr_h[1], + hcount=node_attr_h[2], + charge=node_attr_h[3], + neighbors=node_attr_h[4], + atom_map=node, + ) + + # Decompose edges + for u, v, data in its_graph.edges(data=True): + if edges_share in data: + order_g, order_h = data[edges_share] + if order_g > 0: # Assuming 0 means no edge in G + G.add_edge(u, v, order=order_g) + if order_h > 0: # Assuming 0 means no edge in H + H.add_edge(u, v, order=order_h) + + return G, H + + +def compare_graphs( + graph1: nx.Graph, + graph2: nx.Graph, + node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], + edge_attrs: list = ["order"], +) -> bool: + """ + Compare two graphs based on specified node and edge attributes. + + Parameters: + - graph1 (nx.Graph): The first graph to compare. + - graph2 (nx.Graph): The second graph to compare. + - node_attrs (list): A list of node attribute names to include in the comparison. + - edge_attrs (list): A list of edge attribute names to include in the comparison. + + Returns: + - bool: True if both graphs are identical with respect to the specified attributes, + otherwise False. + """ + # Compare node sets + if set(graph1.nodes()) != set(graph2.nodes()): + return False + + # Compare nodes based on attributes + for node in graph1.nodes(): + if node not in graph2: + return False + node_data1 = {attr: graph1.nodes[node].get(attr, None) for attr in node_attrs} + node_data2 = {attr: graph2.nodes[node].get(attr, None) for attr in node_attrs} + if node_data1 != node_data2: + return False + + # Compare edge sets with sorted tuples + if set(tuple(sorted(edge)) for edge in graph1.edges()) != set( + tuple(sorted(edge)) for edge in graph2.edges() + ): + return False + + # Compare edges based on attributes + for edge in graph1.edges(): + # Sort the edge for consistent comparison + sorted_edge = tuple(sorted(edge)) + if sorted_edge not in graph2.edges(): + return False + edge_data1 = {attr: graph1.edges[edge].get(attr, None) for attr in edge_attrs} + edge_data2 = { + attr: graph2.edges[sorted_edge].get(attr, None) for attr in edge_attrs + } + if edge_data1 != edge_data2: + return False + + return True diff --git a/synutility/SynAAM/normalize_aam.py b/synutility/SynAAM/normalize_aam.py new file mode 100644 index 0000000..6e8c2d8 --- /dev/null +++ b/synutility/SynAAM/normalize_aam.py @@ -0,0 +1,134 @@ +import re +import networkx as nx +from rdkit import Chem +from typing import List + +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.Format.graph_to_mol import GraphToMol +from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.misc import its_decompose, get_rc + + +class NormalizeAAM: + """ + Provides functionalities to normalize atom mappings in SMILES representations, + extract and process reaction centers from ITS graphs, and convert between + graph representations and molecular models. + """ + + def __init__(self) -> None: + """ + Initializes the NormalizeAAM class. + """ + pass + + @staticmethod + def increment(match: re.Match) -> str: + """ + Helper function to increment a matched atom mapping number by 1. + + Parameters: + match (re.Match): A regex match object containing the atom mapping number. + + Returns: + str: The incremented atom mapping number as a string. + """ + return str(int(match.group()) + 1) + + @staticmethod + def fix_atom_mapping(smiles: str) -> str: + """ + Increments each atom mapping number in a SMILES string by 1. + + Parameters: + smiles (str): The SMILES string with atom mapping numbers. + + Returns: + str: The SMILES string with updated atom mapping numbers. + """ + pattern = re.compile(r"(?<=:)\d+") + return pattern.sub(NormalizeAAM.increment, smiles) + + @staticmethod + def fix_rsmi(rsmi: str) -> str: + """ + Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI). + + Parameters: + rsmi (str): The reaction SMILES string. + + Returns: + str: The RSMI with updated atom mappings for both reactants and products. + """ + r, p = rsmi.split(">>") + return f"{NormalizeAAM.fix_atom_mapping(r)}>>{NormalizeAAM.fix_atom_mapping(p)}" + + @staticmethod + def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph: + """ + Extracts a subgraph from a given graph based on a list of node indices. + + Parameters: + graph (nx.Graph): The original graph from which to extract the subgraph. + indices (List[int]): A list of node indices that define the subgraph. + + Returns: + nx.Graph: The extracted subgraph. + """ + return graph.subgraph(indices).copy() + + def reset_indices_and_atom_map( + self, subgraph: nx.Graph, aam_key: str = "atom_map" + ) -> nx.Graph: + """ + Resets the node indices and the atom_map of the subgraph to be continuous from 1 onwards. + + Parameters: + subgraph (nx.Graph): The subgraph with possibly non-continuous indices. + aam_key (str): The attribute key for atom mapping. Defaults to 'atom_map'. + + Returns: + nx.Graph: A new subgraph with continuous indices and adjusted atom_map. + """ + new_graph = nx.Graph() + node_id_mapping = { + old_id: new_id for new_id, old_id in enumerate(subgraph.nodes(), 1) + } + for old_id, new_id in node_id_mapping.items(): + node_data = subgraph.nodes[old_id].copy() + node_data[aam_key] = new_id + new_graph.add_node(new_id, **node_data) + for u, v, data in subgraph.edges(data=True): + new_graph.add_edge(node_id_mapping[u], node_id_mapping[v], **data) + return new_graph + + def fit(self, rsmi: str, fix_aam_indice: bool = True) -> str: + """ + Processes a reaction SMILES (RSMI) to adjust atom mappings, extract reaction centers, + decompose into separate reactant and product graphs, and generate the corresponding SMILES. + + Parameters: + rsmi (str): The reaction SMILES string to be processed. + fix_aam_indice (bool): Whether to fix the atom mapping numbers. Defaults to True. + + Returns: + str: The resulting reaction SMILES string with updated atom mappings. + """ + if fix_aam_indice: + rsmi = self.fix_rsmi(rsmi) + r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) + its = ITSConstruction().ITSGraph(r_graph, p_graph) + rc = get_rc(its) + keep_indice = [ + indice + for indice, data in its.nodes(data=True) + if indice not in rc.nodes() and data["element"] != "H" + ] + keep_indice.extend(rc.nodes()) + subgraph = self.extract_subgraph(its, keep_indice) + subgraph = self.reset_indices_and_atom_map(subgraph) + r_graph, p_graph = its_decompose(subgraph) + r_mol, p_mol = GraphToMol().graph_to_mol( + r_graph, sanitize=False + ), GraphToMol().graph_to_mol(p_graph, sanitize=False) + return f"{Chem.MolToSmiles(r_mol)}>>{Chem.MolToSmiles(p_mol)}" diff --git a/synutility/SynAAM/partial_expand.py b/synutility/SynAAM/partial_expand.py new file mode 100644 index 0000000..0269e7b --- /dev/null +++ b/synutility/SynAAM/partial_expand.py @@ -0,0 +1,81 @@ +from synutility.SynIO.Format.nx_to_gml import NXToGML +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.Format.its_construction import ITSConstruction + +from synutility.SynAAM.misc import its_decompose, get_rc +from synutility.SynAAM.normalize_aam import NormalizeAAM + +from synutility.SynChem.Reaction.standardize import Standardize + +from synutility.SynGraph.Transform.rule_apply import rule_apply, getReactionSmiles + + +class PartialExpand: + """ + A class for partially expanding reaction SMILES (RSMI) by applying transformation + rules based on the reaction center (RC) graph. + + Methods: + expand(rsmi: str) -> str: + Expands a reaction SMILES string and returns the transformed RSMI. + """ + + def __init__(self) -> None: + """ + Initializes the PartialExpand class. + """ + pass + + @staticmethod + def expand(rsmi: str) -> str: + """ + Expands a reaction SMILES string by identifying the reaction center (RC), + applying transformation rules, and standardizing the atom mappings. + + Parameters: + - rsmi (str): The input reaction SMILES string. + + Returns: + - str: The transformed reaction SMILES string. + """ + try: + # Convert RSMI to reactant and product graphs + r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) + + # Construct ITS (Intermediate Transition State) graph + its = ITSConstruction().ITSGraph(r_graph, p_graph) + + # Extract the reaction center (RC) graph + rc = get_rc(its) + + # Decompose the RC into reactant and product graphs + r_graph, p_graph = its_decompose(rc) + + # Transform the graph to a GML rule + rule = NXToGML().transform((r_graph, p_graph, rc)) + + # Standardize the input reaction SMILES + original_rsmi = Standardize().fit(rsmi) + + # Extract reactants from the standardized RSMI + reactants = original_rsmi.split(">>")[0].split(".") + + # Apply the transformation rule to the reactants + transformed_graph = rule_apply(reactants, rule) + + # Extract the transformed reaction SMILES + transformed_rsmi = list(getReactionSmiles(transformed_graph).values())[0][0] + + # Normalize atom mappings in the transformed RSMI + normalized_rsmi = NormalizeAAM().fit(transformed_rsmi) + + return normalized_rsmi + + except Exception as e: + print(f"An error occurred during RSMI expansion: {e}") + return rsmi + + +if __name__ == "__main__": + rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" + print(PartialExpand.expand(rsmi)) diff --git a/synutility/SynMOD/rule_apply.py b/synutility/SynGraph/Transform/rule_apply.py similarity index 50% rename from synutility/SynMOD/rule_apply.py rename to synutility/SynGraph/Transform/rule_apply.py index 233faaa..0dee2d3 100644 --- a/synutility/SynMOD/rule_apply.py +++ b/synutility/SynGraph/Transform/rule_apply.py @@ -1,6 +1,8 @@ import os +import regex from synutility.SynIO.debug import setup_logging -from mod import smiles, ruleGMLString, DG, config +import torch +from mod import smiles, ruleGMLString, DG, config, DGVertexMapper logger = setup_logging() @@ -56,3 +58,50 @@ def rule_apply(smiles_list, rule, print_output=False): return dg except Exception as e: logger.error(f"An error occurred: {e}") + + +def getReactionSmiles(dg): + origSmiles = {} + for v in dg.vertices: + s = v.graph.smilesWithIds + s = regex.sub(":([0-9]+)]", ":o\\1]", s) + origSmiles[v.graph] = s + + res = {} + for e in dg.edges: + vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) + # vms = DGVertexMapper(e) + eductSmiles = [origSmiles[g] for g in vms.left] + + for ev in vms.left.vertices: + s = eductSmiles[ev.graphIndex] + s = s.replace(f":o{ev.vertex.id}]", f":{ev.id}]") + eductSmiles[ev.graphIndex] = s + + strs = set() + for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): + # for vm in DGVertexMapper(e): + productSmiles = [origSmiles[g] for g in vms.right] + for ev in vms.left.vertices: + pv = vm.map[ev] + if not pv: + continue + s = productSmiles[pv.graphIndex] + s = s.replace(f":o{pv.vertex.id}]", f":{ev.id}]") + productSmiles[pv.graphIndex] = s + count = vms.left.numVertices + for pv in vms.right.vertices: + ev = vm.map.inverse(pv) + if ev: + continue + s = productSmiles[pv.graphIndex] + s = s.replace(f":o{pv.vertex.id}]", f":{count}]") + count += 1 + productSmiles[pv.graphIndex] = s + left = ".".join(eductSmiles) + right = ".".join(productSmiles) + s = f"{left}>>{right}" + assert ":o" not in s + strs.add(s) + res[e] = list(sorted(strs)) + return res diff --git a/synutility/SynIO/data_type.py b/synutility/SynIO/data_type.py index accc219..1a50e44 100644 --- a/synutility/SynIO/data_type.py +++ b/synutility/SynIO/data_type.py @@ -102,6 +102,27 @@ def load_gml_as_text(gml_file_path): return None +def save_text_as_gml(gml_text, file_path): + """ + Save a GML text string to a file. + + Parameters: + - gml_text (str): The GML content as a text string. + - file_path (str): The file path where the GML text will be saved. + + Returns: + - bool: True if saving was successful, False otherwise. + """ + try: + with open(file_path, "w") as file: + file.write(gml_text) + print(f"GML text successfully saved to {file_path}") + return True + except Exception as e: + print(f"An error occurred while saving the GML text: {e}") + return False + + def save_compressed(array: ndarray, filename: str) -> None: """ Saves a NumPy array in a compressed format using .npz extension. From 9db18e81befacc58d93d65bed0b9f29e7db69c52 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 18 Nov 2024 14:56:58 +0100 Subject: [PATCH 08/10] add testcase for partial expansion ver1 --- Test/SynAAM/test_partial_expand.py | 9 +++++++++ synutility/SynAAM/partial_expand.py | 1 + 2 files changed, 10 insertions(+) diff --git a/Test/SynAAM/test_partial_expand.py b/Test/SynAAM/test_partial_expand.py index ebe7594..0675bb9 100644 --- a/Test/SynAAM/test_partial_expand.py +++ b/Test/SynAAM/test_partial_expand.py @@ -18,6 +18,15 @@ def test_expand(self): # Assert the result matches the expected output self.assertEqual(output_rsmi, expected_rsmi) + def test_expand_2(self): + input_rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" + output_rsmi = PartialExpand.expand(input_rsmi) + expected_rsmi = ( + "[CH3:1][CH2:2][CH2:3][Cl:4].[NH2:5][H:6]" + + ">>[CH3:1][CH2:2][CH2:3][NH2:5].[Cl:4][H:6]" + ) + self.assertEqual(output_rsmi, expected_rsmi) + if __name__ == "__main__": unittest.main() diff --git a/synutility/SynAAM/partial_expand.py b/synutility/SynAAM/partial_expand.py index 0269e7b..e26d249 100644 --- a/synutility/SynAAM/partial_expand.py +++ b/synutility/SynAAM/partial_expand.py @@ -78,4 +78,5 @@ def expand(rsmi: str) -> str: if __name__ == "__main__": rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" + rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" print(PartialExpand.expand(rsmi)) From d78aa0bf763a4cf0879d90644fcecec402f1c033 Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Mon, 25 Nov 2024 14:59:38 +0100 Subject: [PATCH 09/10] update new features --- Test/SynAAM/test_aam_validator.py | 97 +++++++ .../test_its_construction.py | 2 +- Test/SynIO/Format/test_chemcal_conversion.py | 94 +++++++ synutility/SynAAM/aam_validator.py | 254 ++++++++++++++++++ .../Format => SynAAM}/its_construction.py | 0 synutility/SynAAM/misc.py | 125 ++++++++- synutility/SynAAM/normalize_aam.py | 4 +- synutility/SynAAM/partial_expand.py | 4 +- synutility/SynGraph/Morphism/__init__.py | 0 synutility/SynGraph/Morphism/misc.py | 29 ++ ...smi_to_graph.py => chemical_conversion.py} | 66 ++++- synutility/SynIO/Format/gml_to_nx.py | 4 +- synutility/SynIO/Format/mol_to_graph.py | 4 +- synutility/SynVis/graph_visualizer.py | 25 +- synutility/SynVis/rsmi_to_fig.py | 4 +- 15 files changed, 680 insertions(+), 32 deletions(-) create mode 100644 Test/SynAAM/test_aam_validator.py rename Test/{SynIO/Format => SynAAM}/test_its_construction.py (96%) create mode 100644 Test/SynIO/Format/test_chemcal_conversion.py create mode 100644 synutility/SynAAM/aam_validator.py rename synutility/{SynIO/Format => SynAAM}/its_construction.py (100%) create mode 100644 synutility/SynGraph/Morphism/__init__.py create mode 100644 synutility/SynGraph/Morphism/misc.py rename synutility/SynIO/Format/{smi_to_graph.py => chemical_conversion.py} (55%) diff --git a/Test/SynAAM/test_aam_validator.py b/Test/SynAAM/test_aam_validator.py new file mode 100644 index 0000000..94a7099 --- /dev/null +++ b/Test/SynAAM/test_aam_validator.py @@ -0,0 +1,97 @@ +import unittest +from synutility.SynAAM.aam_validator import AAMValidator + + +class TestAMMValidator(unittest.TestCase): + + def setUp(self): + self.true_pair = ( + ( + "[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]" + + "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]([N:2]" + + "([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)([C:6]2=" + + "[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]" + ), + ( + "[OH2:17].[cH:12]1[cH:13][cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])" + + "[CH2:3][C:4]#[C:5][c:6]1[cH:10][s:9][cH:8][cH:7]1>>[cH:12]1[cH:13]" + + "[cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])[C:5](=[CH:4][CH:3]=[O:17])" + + "[c:6]1[cH:10][s:9][cH:8][cH:7]1" + ), + ) + self.false_pair = ( + ( + "[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]" + + "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]" + + "([N:2]([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)" + + "([C:6]2=[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]" + ), + ( + "[CH3:1][N:2]([CH2:3][C:4]#[C:5][c:7]1[cH:8][cH:9][s:10][cH:11]1)" + + "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1.[OH2:6]>>[CH3:1][N:2]" + + "([C:3](=[CH:4][CH:5]=[O:6])[c:7]1[cH:8][cH:9][s:10][cH:11]1)" + + "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1" + ), + ) + self.tautomer = ( + "[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>[CH3:1][C:2](=[O:3])" + + "[O:7][CH2:6][CH3:5].[OH2:4]", + "[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>" + + "[CH3:1][C:2](=[O:4])[O:7][CH2:6][CH3:5].[OH2:3]", + ) + + self.data_dict_1 = {"ref": self.true_pair[0], "map": self.true_pair[1]} + self.data_dict_2 = {"ref": self.false_pair[0], "map": self.false_pair[1]} + self.data_dict_3 = {"ref": self.tautomer[0], "map": self.tautomer[1]} + self.data = [self.data_dict_1, self.data_dict_2, self.data_dict_3] + + def test_smiles_check(self): + + self.assertTrue( + AAMValidator.smiles_check( + *self.true_pair, check_method="RC", ignore_aromaticity=False + ) + ) + self.assertFalse( + AAMValidator.smiles_check( + *self.false_pair, check_method="RC", ignore_aromaticity=False + ) + ) + + def test_smiles_check_tautomer(self): + self.assertFalse( + AAMValidator.smiles_check( + self.tautomer[0], + self.tautomer[1], + check_method="RC", + ignore_aromaticity=False, + ) + ) + + self.assertTrue( + AAMValidator.smiles_check_tautomer( + self.tautomer[0], + self.tautomer[1], + check_method="RC", + ignore_aromaticity=True, + ) + ) + + def test_validate_smiles_dataframe(self): + + results = AAMValidator.validate_smiles( + data=self.data, + ground_truth_col="ref", + mapped_cols=["map"], + check_method="RC", + ignore_aromaticity=False, + n_jobs=2, + verbose=0, + ignore_tautomers=False, + ) + self.assertEqual(results[0]["accuracy"], 66.67) + self.assertEqual(results[0]["success_rate"], 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_its_construction.py b/Test/SynAAM/test_its_construction.py similarity index 96% rename from Test/SynIO/Format/test_its_construction.py rename to Test/SynAAM/test_its_construction.py index 0b34ebb..6d21afe 100644 --- a/Test/SynIO/Format/test_its_construction.py +++ b/Test/SynAAM/test_its_construction.py @@ -1,6 +1,6 @@ import unittest import networkx as nx -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction class TestITSConstruction(unittest.TestCase): diff --git a/Test/SynIO/Format/test_chemcal_conversion.py b/Test/SynIO/Format/test_chemcal_conversion.py new file mode 100644 index 0000000..89df407 --- /dev/null +++ b/Test/SynIO/Format/test_chemcal_conversion.py @@ -0,0 +1,94 @@ +import unittest +import networkx as nx + +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynIO.Format.chemical_conversion import ( + smiles_to_graph, + rsmi_to_graph, + graph_to_rsmi, + smart_to_gml, + gml_to_smart, +) + +from synutility.SynGraph.Morphism.misc import rule_isomorphism + + +class TestChemicalConversions(unittest.TestCase): + + def setUp(self) -> None: + self.rsmi = "[CH2:1]([H:4])[CH2:2][OH:3]>>[CH2:1]=[CH2:2].[H:4][OH:3]" + self.gml = ( + "rule [\n" + ' ruleID "rule"\n' + " left [\n" + ' edge [ source 1 target 4 label "-" ]\n' + ' edge [ source 1 target 2 label "-" ]\n' + ' edge [ source 2 target 3 label "-" ]\n' + " ]\n" + " context [\n" + ' node [ id 1 label "C" ]\n' + ' node [ id 4 label "H" ]\n' + ' node [ id 2 label "C" ]\n' + ' node [ id 3 label "O" ]\n' + " ]\n" + " right [\n" + ' edge [ source 1 target 2 label "=" ]\n' + ' edge [ source 4 target 3 label "-" ]\n' + " ]\n" + "]" + ) + + self.std = Standardize() + + def test_smiles_to_graph_valid(self): + # Test converting a valid SMILES to a graph + result = smiles_to_graph("[CH3:1][CH2:2][OH:3]", False, True, True) + self.assertIsInstance(result, nx.Graph) + self.assertEqual(result.number_of_nodes(), 3) + + def test_smiles_to_graph_invalid(self): + # Test converting an invalid SMILES string to a graph + result = smiles_to_graph("invalid_smiles", True, False, False) + self.assertIsNone(result) + + def test_rsmi_to_graph_valid(self): + # Test converting valid reaction SMILES to graphs for reactants and products + reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=True) + self.assertIsInstance(reactants_graph, nx.Graph) + self.assertEqual(reactants_graph.number_of_nodes(), 3) + self.assertIsInstance(products_graph, nx.Graph) + self.assertEqual(products_graph.number_of_nodes(), 3) + + reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=False) + self.assertIsInstance(reactants_graph, nx.Graph) + self.assertEqual(reactants_graph.number_of_nodes(), 4) + self.assertIsInstance(products_graph, nx.Graph) + self.assertEqual(products_graph.number_of_nodes(), 4) + + def test_rsmi_to_graph_invalid(self): + # Test handling of invalid RSMI format + result = rsmi_to_graph("invalid_format") + self.assertEqual((None, None), result) + + def test_graph_to_rsmi(self): + r, p = rsmi_to_graph(self.rsmi, sanitize=False) + rsmi = graph_to_rsmi(r, p) + self.assertIsInstance(rsmi, str) + self.assertEqual(self.std.fit(rsmi, False), self.std.fit(self.rsmi, False)) + + def test_smart_to_gml(self): + result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=False) + self.assertIsInstance(result, str) + self.assertEqual(result, self.gml) + + result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=True) + self.assertTrue(rule_isomorphism(result, self.gml)) + + def test_gml_to_smart(self): + smarts, _ = gml_to_smart(self.gml) + self.assertIsInstance(smarts, str) + self.assertEqual(self.std.fit(smarts, False), self.std.fit(self.rsmi, False)) + + +if __name__ == "__main__": + unittest.main() diff --git a/synutility/SynAAM/aam_validator.py b/synutility/SynAAM/aam_validator.py new file mode 100644 index 0000000..53e3b4c --- /dev/null +++ b/synutility/SynAAM/aam_validator.py @@ -0,0 +1,254 @@ +import pandas as pd +import networkx as nx +from operator import eq +from itertools import combinations +from joblib import Parallel, delayed +from typing import Dict, List, Tuple, Union, Optional +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match + +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph +from synutility.SynAAM.misc import get_rc, enumerate_tautomers, mapping_success_rate + + +class AAMValidator: + def __init__(self): + """Initializes the AAMValidator class.""" + pass + + @staticmethod + def check_equivariant_graph( + its_graphs: List[nx.Graph], + ) -> Tuple[List[Tuple[int, int]], int]: + """ + Checks for isomorphism among a list of ITS graphs and + identifies all pairs of isomorphic graphs. + + Parameters: + - its_graphs (List[nx.Graph]): A list of ITS graphs. + + Returns: + - List[Tuple[int, int]]: A list of tuples representing + pairs of indices of isomorphic graphs. + - int: The count of unique isomorphic graph pairs found. + """ + nodeLabelNames = ["typesGH"] + nodeLabelDefault = ["*", False, 0, 0, ()] + nodeLabelOperator = [eq, eq, eq, eq, eq] + nodeMatch = generic_node_match( + nodeLabelNames, nodeLabelDefault, nodeLabelOperator + ) + edgeMatch = generic_edge_match("order", 1, eq) + + classified = [] + for i, j in combinations(range(len(its_graphs)), 2): + if nx.is_isomorphic( + its_graphs[i], its_graphs[j], node_match=nodeMatch, edge_match=edgeMatch + ): + classified.append((i, j)) + + return classified, len(classified) + + @staticmethod + def smiles_check( + mapped_smile: str, + ground_truth: str, + check_method: str = "RC", # or 'ITS' + ignore_aromaticity: bool = False, + ) -> bool: + """ + Checks the equivalence of mapped SMILES against ground truth + using reaction center (RC) or ITS graph method. + + Parameters: + - mapped_smile (str): The mapped SMILES string. + - ground_truth (str): The ground truth SMILES string. + - check_method (str): The method used for validation ('RC' or 'ITS'). + - ignore_aromaticity (bool): Flag to ignore aromaticity in ITS graph construction. + + Returns: + - bool: True if the mapped SMILES is equivalent to the ground truth, + False otherwise. + """ + its_graphs = [] + rc_graphs = [] + try: + for rsmi in [mapped_smile, ground_truth]: + G, H = rsmi_to_graph( + rsmi=rsmi, sanitize=True, drop_non_aam=True, light_weight=True + ) + + ITS = ITSConstruction.ITSGraph(G, H, ignore_aromaticity) + its_graphs.append(ITS) + rc = get_rc(ITS) + rc_graphs.append(rc) + + _, equivariant = AAMValidator.check_equivariant_graph( + rc_graphs if check_method == "RC" else its_graphs + ) + return equivariant == 1 + + except Exception as e: + print("An error occurred:", str(e)) + return False + + @staticmethod + def smiles_check_tautomer( + mapped_smile: str, + ground_truth: str, + check_method: str = "RC", # or 'ITS' + ignore_aromaticity: bool = False, + ) -> Optional[bool]: + """ + Determines if a given mapped SMILE string is equivalent to any tautomer of + a ground truth SMILES string using a specified comparison method. + + Parameters: + - mapped_smile (str): The mapped SMILES string to check against the tautomers of + the ground truth. + - ground_truth (str): The reference SMILES string for generating possible + tautomers. + - check_method (str): The method used for checking equivalence. Default is 'RC'. + Possible values are 'RC' for reaction center or 'ITS'. + - ignore_aromaticity (bool): Flag to ignore differences in aromaticity between + the mapped SMILE and the tautomers.Default is False. + + Returns: + - Optional[bool]: True if the mapped SMILE matches any of the enumerated tautomers + of the ground truth according to the specified check method. + Returns False if no match is found. + Returns None if an error occurs during processing. + + Raises: + - Exception: If an error occurs during the tautomer enumeration + or the comparison process. + """ + try: + ground_truth_tautomers = enumerate_tautomers(ground_truth) + return any( + AAMValidator.smiles_check( + mapped_smile, t, check_method, ignore_aromaticity + ) + for t in ground_truth_tautomers + ) + except Exception as e: + print(f"An error occurred: {e}") + return None + + @staticmethod + def check_pair( + mapping: Dict[str, str], + mapped_col: str, + ground_truth_col: str, + check_method: str = "RC", + ignore_aromaticity: bool = False, + ignore_tautomers: bool = True, + ) -> bool: + """ + Checks the equivalence between the mapped and ground truth + values within a given mapping dictionary, using a specified check method. + The check can optionally ignore aromaticity. + + Parameters: + - mapping (Dict[str, str]): A dictionary containing the data entries to check. + - mapped_col (str): The key in the mapping dictionary corresponding + to the mapped value. + - ground_truth_col (str): The key in the mapping dictionary corresponding + to the ground truth value. + - check_method (str, optional): The method used for checking the equivalence. + Defaults to 'RC'. + - ignore_aromaticity (bool, optional): Flag to indicate whether aromaticity + should be ignored during the check. Defaults to False. + - ignore_tautomers (bool, optional): Flag to indicate whether tautomers + should be ignored during the check. Defaults to False. + + Returns: + - bool: The result of the check, indicating whether the mapped value is + equivalent to the ground truth according to the specified method + and considerations regarding aromaticity. + """ + if ignore_tautomers: + return AAMValidator.smiles_check( + mapping[mapped_col], + mapping[ground_truth_col], + check_method, + ignore_aromaticity, + ) + else: + return AAMValidator.smiles_check_tautomer( + mapping[mapped_col], + mapping[ground_truth_col], + check_method, + ignore_aromaticity, + ) + + @staticmethod + def validate_smiles( + data: Union[pd.DataFrame, List[Dict[str, str]]], + ground_truth_col: str = "ground_truth", + mapped_cols: List[str] = ["rxn_mapper", "graphormer", "local_mapper"], + check_method: str = "RC", + ignore_aromaticity: bool = False, + n_jobs: int = 1, + verbose: int = 0, + ignore_tautomers=True, + ) -> List[Dict[str, Union[str, float, List[bool]]]]: + """ + Validates collections of mapped SMILES against their ground truths for + multiple mappers and calculates the accuracy. + + Parameters: + - data (Union[pd.DataFrame, List[Dict[str, str]]]): + The input data containing mapped and ground truth SMILES. + - id_col (str): The name of the column or key containing the reaction ID. + - ground_truth_col (str): The name of the column or key containing + the ground truth SMILES. + - mapped_cols (List[str]): The list of columns or keys containing + the mapped SMILES for different mappers. + - check_method (str): The method used for validation ('RC' or 'ITS'). + - ignore_aromaticity (bool): Flag to ignore aromaticity in ITS graph construction. + - n_jobs (int): The number of parallel jobs to run. + - verbose (int): The verbosity level for joblib's parallel execution. + + Returns: + - List[Dict[str, Union[str, float, List[bool]]]]: A list of dictionaries, each + containing the mapper name, accuracy, and individual results for each SMILES pair. + """ + + validation_results = [] + + for mapped_col in mapped_cols: + + if isinstance(data, pd.DataFrame): + mappings = data.to_dict("records") + elif isinstance(data, list): + mappings = data + else: + raise ValueError( + "Data must be either a pandas DataFrame or a list of dictionaries." + ) + + results = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(AAMValidator.check_pair)( + mapping, + mapped_col, + ground_truth_col, + check_method, + ignore_aromaticity, + ignore_tautomers, + ) + for mapping in mappings + ) + accuracy = sum(results) / len(mappings) if mappings else 0 + mapped_data = [value[mapped_col] for value in mappings] + + validation_results.append( + { + "mapper": mapped_col, + "accuracy": round(100 * accuracy, 2), + "results": results, + "success_rate": mapping_success_rate(mapped_data), + } + ) + + return validation_results diff --git a/synutility/SynIO/Format/its_construction.py b/synutility/SynAAM/its_construction.py similarity index 100% rename from synutility/SynIO/Format/its_construction.py rename to synutility/SynAAM/its_construction.py diff --git a/synutility/SynAAM/misc.py b/synutility/SynAAM/misc.py index b08fd07..3efdcfa 100644 --- a/synutility/SynAAM/misc.py +++ b/synutility/SynAAM/misc.py @@ -1,9 +1,14 @@ +import re import networkx as nx +from rdkit import Chem +from rdkit.Chem.MolStandardize import rdMolStandardize + +from typing import Optional, List def get_rc( ITS: nx.Graph, - element_key: str = "element", + element_key: list = ["element", "charge", "typesGH"], bond_key: str = "order", standard_key: str = "standard_order", ) -> nx.Graph: @@ -12,23 +17,27 @@ def get_rc( where the bond order changes, indicating a reaction event. Parameters: - ITS (nx.Graph): The ITS graph to extract the RC from. - element_key (str): Node attribute key for atom symbols. Defaults to 'element'. - bond_key (str): Edge attribute key for bond order. Defaults to 'order'. - standard_key (str): Edge attribute key for standard order information. Defaults to 'standard_order'. + - ITS (nx.Graph): The ITS graph to extract the RC from. + - element_key (list): List of node attribute keys for atom properties. + Defaults to ['element', 'charge', 'typesGH']. + - bond_key (str): Edge attribute key for bond order. Defaults to 'order'. + - standard_key (str): Edge attribute key for standard order information. + Defaults to 'standard_order'. Returns: - nx.Graph: A new graph representing the reaction center of the ITS. + - nx.Graph: A new graph representing the reaction center of the ITS. """ rc = nx.Graph() for n1, n2, data in ITS.edges(data=True): - if data[bond_key][0] != data[bond_key][1]: - rc.add_node(n1, **{element_key: ITS.nodes[n1][element_key]}) - rc.add_node(n2, **{element_key: ITS.nodes[n2][element_key]}) + if data.get(bond_key, [None, None])[0] != data.get(bond_key, [None, None])[1]: + rc.add_node( + n1, **{k: ITS.nodes[n1][k] for k in element_key if k in ITS.nodes[n1]} + ) + rc.add_node( + n2, **{k: ITS.nodes[n2][k] for k in element_key if k in ITS.nodes[n2]} + ) rc.add_edge( - n1, - n2, - **{bond_key: data[bond_key], standard_key: data[standard_key]}, + n1, n2, **{bond_key: data[bond_key], standard_key: data[standard_key]} ) return rc @@ -140,3 +149,95 @@ def compare_graphs( return False return True + + +def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: + """ + Enumerates possible tautomers for reactants while canonicalizing the products in a + reaction SMILES string. This function first splits the reaction SMILES string into + reactants and products. It then generates all possible tautomers for the reactants and + canonicalizes the product molecule. The function returns a list of reaction SMILES + strings for each tautomer of the reactants combined with the canonical product. + + Parameters: + - reaction_smiles (str): A SMILES string of the reaction formatted as + 'reactants>>products'. + + Returns: + - List[str] | None: A list of SMILES strings for the reaction, with each string + representing a different + - tautomer of the reactants combined with the canonicalized products. Returns None if + an error occurs or if invalid SMILES strings are provided. + + Raises: + - ValueError: If the provided SMILES strings cannot be converted to molecule objects, + indicating invalid input. + """ + try: + # Split the input reaction SMILES string into reactants and products + reactants_smiles, products_smiles = reaction_smiles.split(">>") + + # Convert SMILES strings to molecule objects + reactants_mol = Chem.MolFromSmiles(reactants_smiles) + products_mol = Chem.MolFromSmiles(products_smiles) + + if reactants_mol is None or products_mol is None: + raise ValueError( + "Invalid SMILES string provided for reactants or products." + ) + + # Initialize tautomer enumerator + + enumerator = rdMolStandardize.TautomerEnumerator() + + # Enumerate tautomers for the reactants and canonicalize the products + try: + reactants_can = enumerator.Enumerate(reactants_mol) + except Exception as e: + print(f"An error occurred: {e}") + reactants_can = [reactants_mol] + products_can = products_mol + + # Convert molecule objects back to SMILES strings + reactants_can_smiles = [Chem.MolToSmiles(i) for i in reactants_can] + products_can_smiles = Chem.MolToSmiles(products_can) + + # Combine each reactant tautomer with the canonical product in SMILES format + rsmi_list = [i + ">>" + products_can_smiles for i in reactants_can_smiles] + if len(rsmi_list) == 0: + return [reaction_smiles] + else: + # rsmi_list.remove(reaction_smiles) + rsmi_list.insert(0, reaction_smiles) + return rsmi_list + + except Exception as e: + print(f"An error occurred: {e}") + return [reaction_smiles] + + +def mapping_success_rate(list_mapping_data): + """ + Calculate the success rate of entries containing atom mappings in a list of data + strings. + + Parameters: + - list_mapping_in_data (list of str): List containing strings to be searched for atom + mappings. + + Returns: + - float: The success rate of finding atom mappings in the list as a percentage. + + Raises: + - ValueError: If the input list is empty. + """ + atom_map_pattern = re.compile(r":\d+") + if not list_mapping_data: + raise ValueError("The input list is empty, cannot calculate success rate.") + + success = sum( + 1 for entry in list_mapping_data if re.search(atom_map_pattern, entry) + ) + rate = 100 * (success / len(list_mapping_data)) + + return round(rate, 2) diff --git a/synutility/SynAAM/normalize_aam.py b/synutility/SynAAM/normalize_aam.py index 6e8c2d8..950548b 100644 --- a/synutility/SynAAM/normalize_aam.py +++ b/synutility/SynAAM/normalize_aam.py @@ -3,9 +3,9 @@ from rdkit import Chem from typing import List -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph from synutility.SynIO.Format.graph_to_mol import GraphToMol -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynAAM.misc import its_decompose, get_rc diff --git a/synutility/SynAAM/partial_expand.py b/synutility/SynAAM/partial_expand.py index e26d249..ea8aee7 100644 --- a/synutility/SynAAM/partial_expand.py +++ b/synutility/SynAAM/partial_expand.py @@ -1,6 +1,6 @@ from synutility.SynIO.Format.nx_to_gml import NXToGML -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph +from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynAAM.misc import its_decompose, get_rc from synutility.SynAAM.normalize_aam import NormalizeAAM diff --git a/synutility/SynGraph/Morphism/__init__.py b/synutility/SynGraph/Morphism/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/Morphism/misc.py b/synutility/SynGraph/Morphism/misc.py new file mode 100644 index 0000000..2a1bb04 --- /dev/null +++ b/synutility/SynGraph/Morphism/misc.py @@ -0,0 +1,29 @@ +from mod import ruleGMLString + + +def rule_isomorphism(rule_1: str, rule_2: str) -> bool: + """ + Determines if two rule representations, given in GML format, are isomorphic. + + This function converts two GML strings into `ruleGMLString` objects and checks + if these two objects are isomorphic. Isomorphism here is determined by the method + `isomorphism` of the `ruleGMLString` class, which should return `1` for isomorphic + structures and `0` otherwise. + + Parameters: + - rule_1 (str): The GML string representation of the first rule. + - rule_2 (str): The GML string representation of the second rule. + + Returns: + - bool: `True` if the two rules are isomorphic; `False` otherwise. + + Raises: + - Any exceptions thrown by the `ruleGMLString` initialization or methods should + be documented here, if there are any known potential issues. + """ + # Create ruleGMLString objects from the GML strings + rule_obj_1 = ruleGMLString(rule_1) + rule_obj_2 = ruleGMLString(rule_2) + + # Check for isomorphism and return the result + return rule_obj_1.isomorphism(rule_obj_2) == 1 diff --git a/synutility/SynIO/Format/smi_to_graph.py b/synutility/SynIO/Format/chemical_conversion.py similarity index 55% rename from synutility/SynIO/Format/smi_to_graph.py rename to synutility/SynIO/Format/chemical_conversion.py index 8cd4386..6755eb7 100644 --- a/synutility/SynIO/Format/smi_to_graph.py +++ b/synutility/SynIO/Format/chemical_conversion.py @@ -4,9 +4,14 @@ from synutility.SynIO.debug import setup_logging from synutility.SynIO.Format.mol_to_graph import MolToGraph +from synutility.SynIO.Format.graph_to_mol import GraphToMol +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.nx_to_gml import NXToGML +from synutility.SynIO.Format.gml_to_nx import GMLToNX +from synutility.SynAAM.misc import get_rc, its_decompose -logger = setup_logging +logger = setup_logging() def smiles_to_graph( @@ -22,7 +27,8 @@ def smiles_to_graph( - sanitize (bool): Whether to sanitize the molecule during conversion. Returns: - - nx.Graph or None: The networkx graph representation of the molecule, or None if conversion fails. + - nx.Graph or None: The networkx graph representation of the molecule, + or None if conversion fails. """ try: mol = Chem.MolFromSmiles(smiles, sanitize) @@ -66,3 +72,59 @@ def rsmi_to_graph( except ValueError: logger.error(f"Invalid RSMI format: {rsmi}") return (None, None) + + +def graph_to_rsmi(r: nx.Graph, p: nx.Graph) -> str: + """ + Converts graph representations of reactants and products to a reaction SMILES string. + + Parameters: + - r (nx.Graph): Graph of the reactants. + - p (nx.Graph): Graph of the products. + + Returns: + - str: Reaction SMILES string. + """ + r = GraphToMol().graph_to_mol(r) + p = GraphToMol().graph_to_mol(p) + return f"{Chem.MolToSmiles(r)}>>{Chem.MolToSmiles(p)}" + + +def smart_to_gml( + smart: str, + core: bool = True, + sanitize: bool = False, + rule_name: str = "rule", + reindex: bool = False, +) -> str: + """ + Converts a SMARTS string to GML format, optionally focusing on the reaction core. + + Parameters: + - smart (str): The SMARTS string representing the reaction. + - core (bool): Whether to extract and focus on the reaction core. Defaults to True. + + Returns: + - str: The GML representation of the reaction. + """ + r, p = rsmi_to_graph(smart, sanitize=sanitize) + its = ITSConstruction.ITSGraph(r, p) + if core: + its = get_rc(its) + r, p = its_decompose(its) + gml = NXToGML().transform((r, p, its), reindex=reindex, rule_name=rule_name) + return gml + + +def gml_to_smart(gml: str) -> str: + """ + Converts a GML string back to a SMARTS string by interpreting the graph structures. + + Parameters: + - gml (str): The GML string to convert. + + Returns: + - str: The corresponding SMARTS string. + """ + r, p, rc = GMLToNX(gml).transform() + return graph_to_rsmi(r, p), rc diff --git a/synutility/SynIO/Format/gml_to_nx.py b/synutility/SynIO/Format/gml_to_nx.py index 7e6d4f9..47a70a3 100644 --- a/synutility/SynIO/Format/gml_to_nx.py +++ b/synutility/SynIO/Format/gml_to_nx.py @@ -1,7 +1,7 @@ import networkx as nx import re from typing import Tuple -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction class GMLToNX: @@ -90,7 +90,7 @@ def _extract_element_and_charge(self, label: str) -> Tuple[str, int]: is conservative and primarily for error handling. """ # Regex to separate the element symbols from the optional charge and sign - match = re.match(r"([A-Za-z]+)(\d+)?([+-])?$", label) + match = re.match(r"([A-Za-z*]+)(\d+)?([+-])?$", label) if not match: return ( "X", diff --git a/synutility/SynIO/Format/mol_to_graph.py b/synutility/SynIO/Format/mol_to_graph.py index af5cd51..3c44a28 100644 --- a/synutility/SynIO/Format/mol_to_graph.py +++ b/synutility/SynIO/Format/mol_to_graph.py @@ -143,7 +143,9 @@ def _create_light_weight_graph(cls, mol: Chem.Mol, drop_non_aam: bool) -> nx.Gra aromatic=atom.GetIsAromatic(), hcount=atom.GetTotalNumHs(), charge=atom.GetFormalCharge(), - neighbors=[neighbor.GetSymbol() for neighbor in atom.GetNeighbors()], + neighbors=sorted( + neighbor.GetSymbol() for neighbor in atom.GetNeighbors() + ), atom_map=atom_map, ) for bond in atom.GetBonds(): diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py index acf7426..84179e7 100644 --- a/synutility/SynVis/graph_visualizer.py +++ b/synutility/SynVis/graph_visualizer.py @@ -4,12 +4,14 @@ Adaptations were made to enhance functionality and integrate with other system components. """ +import networkx as nx from rdkit import Chem from rdkit.Chem import rdDepictor + +import matplotlib.pyplot as plt from typing import Dict, Optional -import networkx as nx + from synutility.SynIO.Format.graph_to_mol import GraphToMol -import matplotlib.pyplot as plt class GraphVisualizer: @@ -42,7 +44,7 @@ def _get_its_as_mol(self, its: nx.Graph) -> Optional[Chem.Mol]: _its[u][v]["order"] = 1 return GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol( _its, False, False - ) # Ensure this function is defined correctly elsewhere + ) def plot_its( self, @@ -85,7 +87,7 @@ def plot_its( Returns: - None """ - bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡"} + bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"} positions = self._calculate_positions(its, use_mol_coords) @@ -98,9 +100,9 @@ def plot_its( if use_edge_color: edge_colors = [ ( - "green" + "red" if data.get(standard_order_key, 0) > 0 - else "red" if data.get(standard_order_key, 0) < 0 else "black" + else "green" if data.get(standard_order_key, 0) < 0 else "black" ) for _, _, data in its.edges(data=True) ] @@ -198,7 +200,7 @@ def plot_as_mol( # Set default bond characters if not provided if bond_char is None: - bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"} + bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"} # Determine positions based on use_mol_coords flag if use_mol_coords: @@ -229,7 +231,14 @@ def plot_as_mol( # Preparing labels labels = {} for n, d in g.nodes(data=True): - label = f"{d.get(symbol_key, '')}" + charge = d.get("charge", 0) + if charge == 0: + charge = "" + elif charge > 0: + charge = f"{charge}+" if charge > 1 else "+" + else: + charge = f"{-charge}-" if charge < -1 else "-" + label = f"{d.get(symbol_key, '')}{charge}" if show_atom_map: label += f" ({d.get(aam_key, '')})" labels[n] = label diff --git a/synutility/SynVis/rsmi_to_fig.py b/synutility/SynVis/rsmi_to_fig.py index 690d326..a08520d 100644 --- a/synutility/SynVis/rsmi_to_fig.py +++ b/synutility/SynVis/rsmi_to_fig.py @@ -4,8 +4,8 @@ from synutility.SynVis.graph_visualizer import GraphVisualizer -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph +from synutility.SynAAM.its_construction import ITSConstruction vis_graph = GraphVisualizer() From 4b10781437adafa5181944e7f9a62caf624f008f Mon Sep 17 00:00:00 2001 From: TieuLongPhan Date: Wed, 27 Nov 2024 14:46:29 +0100 Subject: [PATCH 10/10] prepare release --- .gitignore | 1 + Test/SynAAM/test_inference.py | 25 +++ Test/SynAAM/test_normalize_aam.py | 2 +- Test/SynAAM/test_partial_expand.py | 23 +- Test/SynGraph/Transform/test_multi_step.py | 55 +++++ lint.sh | 2 +- pyproject.toml | 2 +- synutility/SynAAM/inference.py | 73 ++++++ synutility/SynAAM/normalize_aam.py | 53 ++++- synutility/SynAAM/partial_expand.py | 102 ++++++--- synutility/SynGraph/Transform/core_engine.py | 58 ++--- synutility/SynGraph/Transform/multi_step.py | 223 +++++++++++++++++++ synutility/SynGraph/Transform/rule_apply.py | 115 ++++------ synutility/SynIO/Format/dg_to_gml.py | 2 + synutility/SynMOD/__init__.py | 0 15 files changed, 599 insertions(+), 137 deletions(-) create mode 100644 Test/SynAAM/test_inference.py create mode 100644 Test/SynGraph/Transform/test_multi_step.py create mode 100644 synutility/SynAAM/inference.py create mode 100644 synutility/SynGraph/Transform/multi_step.py delete mode 100644 synutility/SynMOD/__init__.py diff --git a/.gitignore b/.gitignore index 8e6144a..288988f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.json test_mod.py test_format.py +*dev_zone diff --git a/Test/SynAAM/test_inference.py b/Test/SynAAM/test_inference.py new file mode 100644 index 0000000..d83fe14 --- /dev/null +++ b/Test/SynAAM/test_inference.py @@ -0,0 +1,25 @@ +import unittest +from synutility.SynIO.Format.chemical_conversion import smart_to_gml +from synutility.SynAAM.inference import aam_infer + + +class TestAAMInference(unittest.TestCase): + + def setUp(self): + + self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" + self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]") + self.expect = ( + "[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]" + + "=[CH:5]1.[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>" + + "[Br:1][H:15].[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])" + + "[CH:9]=[CH:5]1)[O:14][CH2:13][CH2:12][O:11][CH3:10]" + ) + + def test_aam_infer(self): + result = aam_infer(self.rsmi, self.gml) + self.assertEqual(result[0], self.expect) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynAAM/test_normalize_aam.py b/Test/SynAAM/test_normalize_aam.py index 5ccd071..c3c022e 100644 --- a/Test/SynAAM/test_normalize_aam.py +++ b/Test/SynAAM/test_normalize_aam.py @@ -21,7 +21,7 @@ def test_fix_rsmi(self): for both reactants and products.""" input_rsmi = "[C:0]>>[C:1]" expected_rsmi = "[C:1]>>[C:2]" - self.assertEqual(self.normalizer.fix_rsmi(input_rsmi), expected_rsmi) + self.assertEqual(self.normalizer.fix_aam_rsmi(input_rsmi), expected_rsmi) def test_extract_subgraph(self): """Test extraction of a subgraph based on specified indices.""" diff --git a/Test/SynAAM/test_partial_expand.py b/Test/SynAAM/test_partial_expand.py index 0675bb9..fedc468 100644 --- a/Test/SynAAM/test_partial_expand.py +++ b/Test/SynAAM/test_partial_expand.py @@ -1,4 +1,8 @@ import unittest +from synutility.SynAAM.aam_validator import AAMValidator +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph + from synutility.SynAAM.partial_expand import PartialExpand @@ -16,7 +20,7 @@ def test_expand(self): # Perform the expansion output_rsmi = PartialExpand.expand(input_rsmi) # Assert the result matches the expected output - self.assertEqual(output_rsmi, expected_rsmi) + self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS")) def test_expand_2(self): input_rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" @@ -25,7 +29,22 @@ def test_expand_2(self): "[CH3:1][CH2:2][CH2:3][Cl:4].[NH2:5][H:6]" + ">>[CH3:1][CH2:2][CH2:3][NH2:5].[Cl:4][H:6]" ) - self.assertEqual(output_rsmi, expected_rsmi) + self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS")) + + def test_graph_expand(self): + input_rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" + expect = ( + "[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1." + + "[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>[Br:1][H:15]" + + ".[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1)" + + "[O:14][CH2:13][CH2:12][O:11][CH3:10]" + ) + r, p = rsmi_to_graph( + "[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]", sanitize=False + ) + its = ITSConstruction().ITSGraph(r, p) + output = PartialExpand.graph_expand(its, input_rsmi) + self.assertTrue(AAMValidator.smiles_check(output, expect)) if __name__ == "__main__": diff --git a/Test/SynGraph/Transform/test_multi_step.py b/Test/SynGraph/Transform/test_multi_step.py new file mode 100644 index 0000000..cad6bf3 --- /dev/null +++ b/Test/SynGraph/Transform/test_multi_step.py @@ -0,0 +1,55 @@ +import unittest +from synutility.SynIO.Format.chemical_conversion import smart_to_gml +from synutility.SynGraph.Transform.multi_step import ( + perform_multi_step_reaction, + remove_reagent_from_smiles, + calculate_max_depth, + find_all_paths, +) + + +class TestMultiStep(unittest.TestCase): + def setUp(self) -> None: + smarts = [ + "[CH2:4]([CH:5]=[O:6])[H:7]>>[CH2:4]=[CH:5][O:6][H:7]", + ( + "[CH2:2]=[O:3].[CH2:4]=[CH:5][O:6][H:7]>>[CH2:2]([O:3][H:7])[CH2:4]" + + "[CH:5]=[O:6]" + ), + "[CH2:4]([CH:5]=[O:6])[H:8]>>[CH2:4]=[CH:5][O:6][H:8]", + ( + "[CH2:2]([OH:3])[CH:4]=[CH:5][O:6][H:8]>>[CH2:2]=[CH:4][CH:5]=[O:6]" + + ".[OH:3][H:8]" + ), + ] + self.gml = [smart_to_gml(value) for value in smarts] + self.order = [0, 1, 0, -1] + self.rsmi = "CC=O.CC=O.CCC=O>>CC=O.CC=C(C)C=O.O" + + def test_remove_reagent_from_smiles(self): + rsmi = remove_reagent_from_smiles(self.rsmi) + self.assertEqual(rsmi, "CC=O.CCC=O>>CC=C(C)C=O.O") + + def test_perform_multi_step_reaction(self): + results, _ = perform_multi_step_reaction(self.gml, self.order, self.rsmi) + self.assertEqual(len(results), 4) + + def test_calculate_max_depth(self): + _, reaction_tree = perform_multi_step_reaction(self.gml, self.order, self.rsmi) + max_depth = calculate_max_depth(reaction_tree) + self.assertEqual(max_depth, 4) + + def test_find_all_paths(self): + results, reaction_tree = perform_multi_step_reaction( + self.gml, self.order, self.rsmi + ) + target_products = sorted(self.rsmi.split(">>")[1].split(".")) + max_depth = len(results) + all_paths = find_all_paths(reaction_tree, target_products, self.rsmi, max_depth) + self.assertEqual(len(all_paths), 1) + real_path = all_paths[0][1:] # remove the original reaction + self.assertEqual(len(real_path), 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/lint.sh b/lint.sh index c088240..b17d8cb 100755 --- a/lint.sh +++ b/lint.sh @@ -1,6 +1,6 @@ #!/bin/bash flake8 . --count --max-complexity=13 --max-line-length=120 \ - --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501" \ + --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \ --exclude venv,core_engine.py,rule_apply.py \ --statistics diff --git a/pyproject.toml b/pyproject.toml index b15c061..24607e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synutility" -version = "0.0.11" +version = "0.0.12" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] diff --git a/synutility/SynAAM/inference.py b/synutility/SynAAM/inference.py new file mode 100644 index 0000000..24ad997 --- /dev/null +++ b/synutility/SynAAM/inference.py @@ -0,0 +1,73 @@ +import torch +from typing import List, Any +from synutility.SynIO.Format.dg_to_gml import DGToGML +from synutility.SynAAM.normalize_aam import NormalizeAAM +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynGraph.Transform.rule_apply import rule_apply + +std = Standardize() + + +def aam_infer(rsmi: str, gml: Any) -> List[str]: + """ + Infers a set of normalized SMILES from a reaction SMILES string and a graph model (GML). + + This function takes a reaction SMILES string (rsmi) and a graph model (gml), applies the + reaction transformation using the graph model, normalizes and standardizes the resulting + SMILES, and returns a list of SMILES that match the original reaction's structure after + normalization and standardization. + + Steps: + 1. The reactants in the reaction SMILES string are separated. + 2. The transformation is applied to the reactants using the provided graph model (gml). + 3. The resulting SMILES are transformed to a canonical form. + 4. The resulting SMILES are normalized and standardized. + 5. The function returns the normalized SMILES that match the original reaction SMILES. + + Parameters: + - rsmi (str): The reaction SMILES string in the form "reactants >> products". + - gml (Any): A graph model or data structure used for applying the reaction transformation. + + Returns: + - List[str]: A list of valid, normalized, and standardized SMILES strings that match the original reaction SMILES. + """ + # Split the input reaction SMILES into reactants and products + smiles = rsmi.split(">>")[0].split(".") + + # Apply the reaction transformation based on the graph model (GML) + dg = rule_apply(smiles, gml) + + # Get the transformed reaction SMILES from the graph + transformed_rsmi = list(DGToGML.getReactionSmiles(dg).values()) + transformed_rsmi = [value[0] for value in transformed_rsmi] + + # Normalize the transformed SMILES + normalized_rsmi = [] + for value in transformed_rsmi: + try: + value = NormalizeAAM().fit(value) + normalized_rsmi.append(value) + except Exception as e: + print(e) + continue + + # Standardize the normalized SMILES + curated_smiles = [] + for value in normalized_rsmi: + try: + curated_smiles.append(std.fit(value)) + except Exception as e: + print(e) + curated_smiles.append(None) + continue + + # Standardize the original SMILES for comparison + org_smiles = std.fit(rsmi) + + # Filter out the SMILES that match the original reaction SMILES + final = [] + for key, value in enumerate(curated_smiles): + if value == org_smiles: + final.append(normalized_rsmi[key]) + + return final diff --git a/synutility/SynAAM/normalize_aam.py b/synutility/SynAAM/normalize_aam.py index 950548b..f837493 100644 --- a/synutility/SynAAM/normalize_aam.py +++ b/synutility/SynAAM/normalize_aam.py @@ -50,7 +50,7 @@ def fix_atom_mapping(smiles: str) -> str: return pattern.sub(NormalizeAAM.increment, smiles) @staticmethod - def fix_rsmi(rsmi: str) -> str: + def fix_aam_rsmi(rsmi: str) -> str: """ Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI). @@ -63,6 +63,54 @@ def fix_rsmi(rsmi: str) -> str: r, p = rsmi.split(">>") return f"{NormalizeAAM.fix_atom_mapping(r)}>>{NormalizeAAM.fix_atom_mapping(p)}" + @staticmethod + def fix_rsmi_kekulize(rsmi: str) -> str: + """ + Filters the reactants and products of a reaction SMILES string. + + Parameters: + - rsmi (str): A string representing the reaction SMILES in the form of "reactants >> products". + + Returns: + - str: A filtered reaction SMILES string where invalid reactants/products are removed. + """ + # Split the reaction into reactants and products + reactants, products = rsmi.split(">>") + + # Filter valid reactants and products + filtered_reactants = NormalizeAAM.fix_kekulize(reactants) + filtered_products = NormalizeAAM.fix_kekulize(products) + + # Return the filtered reaction SMILES + return f"{filtered_reactants}>>{filtered_products}" + + @staticmethod + def fix_kekulize(smiles: str) -> str: + """ + Filters and returns valid SMILES strings from a string of SMILES, joined by '.'. + + This function processes a string of SMILES separated by periods (e.g., "CCO.CC=O"), + filters out invalid SMILES, and returns a string of valid SMILES joined by periods. + + Parameters: + - smiles (str): A string containing SMILES strings separated by periods ('.'). + + Returns: + - str: A string of valid SMILES, joined by periods ('.'). + """ + smiles_list = smiles.split(".") # Split SMILES by period + valid_smiles = [] # List to store valid SMILES strings + + for smile in smiles_list: + mol = Chem.MolFromSmiles(smile, sanitize=False) + if mol: # Check if molecule is valid + valid_smiles.append( + Chem.MolToSmiles( + mol, canonical=True, kekuleSmiles=True, allHsExplicit=True + ) + ) + return ".".join(valid_smiles) # Return valid SMILES joined by '.' + @staticmethod def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph: """ @@ -114,8 +162,9 @@ def fit(self, rsmi: str, fix_aam_indice: bool = True) -> str: Returns: str: The resulting reaction SMILES string with updated atom mappings. """ + rsmi = self.fix_rsmi_kekulize(rsmi) if fix_aam_indice: - rsmi = self.fix_rsmi(rsmi) + rsmi = self.fix_aam_rsmi(rsmi) r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) its = ITSConstruction().ITSGraph(r_graph, p_graph) rc = get_rc(its) diff --git a/synutility/SynAAM/partial_expand.py b/synutility/SynAAM/partial_expand.py index ea8aee7..71c5964 100644 --- a/synutility/SynAAM/partial_expand.py +++ b/synutility/SynAAM/partial_expand.py @@ -1,13 +1,13 @@ +import networkx as nx from synutility.SynIO.Format.nx_to_gml import NXToGML from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph -from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynAAM.misc import its_decompose, get_rc -from synutility.SynAAM.normalize_aam import NormalizeAAM - +from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynAAM.inference import aam_infer -from synutility.SynGraph.Transform.rule_apply import rule_apply, getReactionSmiles +std = Standardize() class PartialExpand: @@ -15,68 +15,102 @@ class PartialExpand: A class for partially expanding reaction SMILES (RSMI) by applying transformation rules based on the reaction center (RC) graph. + This class provides methods for expanding a given RSMI by identifying the + reaction center (RC), applying transformation rules, and standardizing atom mappings + to generate a full AAM RSMI. + Methods: - expand(rsmi: str) -> str: - Expands a reaction SMILES string and returns the transformed RSMI. + - expand(rsmi: str) -> str: + Expands a reaction SMILES string by identifying the reaction center (RC), + applying transformation rules, and standardizing atom mappings. + + - graph_expand(partial_its: nx.Graph, rsmi: str) -> str: + Expands a reaction SMILES string using an Imaginary Transition State + (ITS) graph and applies the transformation rule based on the reaction center (RC). """ def __init__(self) -> None: """ Initializes the PartialExpand class. + + This constructor currently does not initialize any instance-specific attributes. """ pass @staticmethod - def expand(rsmi: str) -> str: + def graph_expand(partial_its: nx.Graph, rsmi: str) -> str: """ - Expands a reaction SMILES string by identifying the reaction center (RC), - applying transformation rules, and standardizing the atom mappings. + Expands a reaction SMILES string by applying transformation rules using an + ITS graph based on the reaction center (RC) graph. + + This method extracts the reaction center (RC) from the ITS graph, decomposes it + into reactant and product graphs, generates a GML rule for transformation, + and applies the rule to the RSMI string. Parameters: - - rsmi (str): The input reaction SMILES string. + - partial_its (nx.Graph): The Intermediate Transition State (ITS) graph. + - rsmi (str): The input reaction SMILES string to be expanded. Returns: - - str: The transformed reaction SMILES string. + - str: The transformed reaction SMILES string after applying the + transformation rules. """ - try: - # Convert RSMI to reactant and product graphs - r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) + # Extract the reaction center (RC) graph from the ITS graph + rc = get_rc(partial_its) - # Construct ITS (Intermediate Transition State) graph - its = ITSConstruction().ITSGraph(r_graph, p_graph) + # Decompose the RC into reactant and product graphs + r_graph, p_graph = its_decompose(rc) + + # Transform the graph into a GML rule + rule = NXToGML().transform((r_graph, p_graph, rc)) - # Extract the reaction center (RC) graph - rc = get_rc(its) + # Apply the transformation rule to the RSMI + transformed_rsmi = aam_infer(rsmi, rule)[0] - # Decompose the RC into reactant and product graphs - r_graph, p_graph = its_decompose(rc) + return transformed_rsmi - # Transform the graph to a GML rule - rule = NXToGML().transform((r_graph, p_graph, rc)) + @staticmethod + def expand(rsmi: str) -> str: + """ + Expands a reaction SMILES string by identifying the reaction center (RC), + applying transformation rules, and standardizing the atom mappings. - # Standardize the input reaction SMILES - original_rsmi = Standardize().fit(rsmi) + This method constructs the Intermediate Transition State (ITS) graph from the + input RSMI, applies the reaction transformation rules using `graph_expand`, + and returns the transformed reaction SMILES string. - # Extract reactants from the standardized RSMI - reactants = original_rsmi.split(">>")[0].split(".") + Parameters: + - rsmi (str): The input reaction SMILES string to be expanded. - # Apply the transformation rule to the reactants - transformed_graph = rule_apply(reactants, rule) + Returns: + - str: The transformed reaction SMILES string after applying the + transformation rules. - # Extract the transformed reaction SMILES - transformed_rsmi = list(getReactionSmiles(transformed_graph).values())[0][0] + Raises: + - Exception: If an error occurs during the expansion process, the original RSMI + is returned. + """ + try: + # Convert RSMI to reactant and product graphs + r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) - # Normalize atom mappings in the transformed RSMI - normalized_rsmi = NormalizeAAM().fit(transformed_rsmi) + # Construct the ITS graph from the reactant and product graphs + its = ITSConstruction().ITSGraph(r_graph, p_graph) - return normalized_rsmi + # Standardize smiles + rsmi = std.fit(rsmi) + # Apply graph expansion and return the result + return PartialExpand.graph_expand(its, rsmi) except Exception as e: + # Log the error and return the original RSMI if something goes wrong print(f"An error occurred during RSMI expansion: {e}") - return rsmi + return None if __name__ == "__main__": rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" print(PartialExpand.expand(rsmi)) +# self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" +# self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]") diff --git a/synutility/SynGraph/Transform/core_engine.py b/synutility/SynGraph/Transform/core_engine.py index 9b1c776..7f237a5 100644 --- a/synutility/SynGraph/Transform/core_engine.py +++ b/synutility/SynGraph/Transform/core_engine.py @@ -1,7 +1,9 @@ -from typing import List -from synutility.SynIO.data_type import load_gml_as_text from rdkit import Chem -from copy import deepcopy +from pathlib import Path +from typing import List, Union +from collections import Counter +from synutility.SynIO.data_type import load_gml_as_text + import torch from mod import * @@ -49,7 +51,7 @@ def generate_reaction_smiles( @staticmethod def perform_reaction( - rule_file_path: str, + rule_file_path: Union[str, str], initial_smiles: List[str], prediction_type: str = "forward", print_results: bool = False, @@ -94,7 +96,16 @@ def deduplicateGraphs(initial): initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False ) # Load the reaction rule from the GML file - gml_content = load_gml_as_text(rule_file_path) + rule_path = Path(rule_file_path) + + try: + if rule_path.is_file(): + gml_content = load_gml_as_text(rule_file_path) + else: + gml_content = rule_file_path + except Exception as e: + # print(f"An error occurred while loading the GML file: {e}") + gml_content = rule_file_path reaction_rule = ruleGMLString(gml_content, invert=invert_rule, add=False) # Initialize the derivation graph and execute the strategy dg = DG(graphDatabase=initial_molecules) @@ -107,8 +118,10 @@ def deduplicateGraphs(initial): for e in dg.edges: productSmiles = [v.graph.smiles for v in e.targets] temp_results.append(productSmiles) + # print(productSmiles) if len(temp_results) == 0: + # print(1) dg = DG(graphDatabase=initial_molecules) # dg.build().execute(strategy, verbosity=8) config.dg.doRuleIsomorphismDuringBinding = False @@ -118,27 +131,22 @@ def deduplicateGraphs(initial): temp_results, small_educt = [], [] for edge in dg.edges: temp_results.append([vertex.graph.smiles for vertex in edge.targets]) - small_educt.extend([vertex.graph.smiles for vertex in edge.sources]) - - small_educt_set = [ - Chem.CanonSmiles(smile) for smile in small_educt if smile is not None - ] - - reagent = deepcopy(initial_smiles) - for value in small_educt_set: - if value in reagent: - reagent.remove(value) - - # Update solutions with reagents and normalize SMILES - for solution in temp_results: + small_educt.append([vertex.graph.smiles for vertex in edge.sources]) + + for key, solution in enumerate(temp_results): + educt = small_educt[key] + small_educt_counts = Counter( + Chem.CanonSmiles(smile) for smile in educt if smile is not None + ) + reagent_counts = Counter([Chem.CanonSmiles(s) for s in initial_smiles]) + reagent_counts.subtract(small_educt_counts) + reagent = [ + smile + for smile, count in reagent_counts.items() + for _ in range(count) + if count > 0 + ] solution.extend(reagent) - for i, smile in enumerate(solution): - try: - mol = Chem.MolFromSmiles(smile) - if mol: # Only convert if mol creation was successful - solution[i] = Chem.MolToSmiles(mol) - except Exception as e: - print(f"Error processing SMILES {smile}: {str(e)}") reaction_processing_map = { "forward": lambda smiles: CoreEngine.generate_reaction_smiles( diff --git a/synutility/SynGraph/Transform/multi_step.py b/synutility/SynGraph/Transform/multi_step.py new file mode 100644 index 0000000..b25f2a4 --- /dev/null +++ b/synutility/SynGraph/Transform/multi_step.py @@ -0,0 +1,223 @@ +from collections import Counter +from typing import List, Dict, Tuple +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynGraph.Transform.core_engine import CoreEngine + +std = Standardize() + + +def remove_reagent_from_smiles(rsmi: str) -> str: + """ + Removes common molecules from the reactants and products in a SMILES reaction string. + + This function identifies the molecules that appear on both sides of the reaction + (reactants and products) and removes one occurrence of each common molecule from + both sides. + + Parameters: + - rsmi (str): A SMILES string representing a chemical reaction in the form: + 'reactant1.reactant2...>>product1.product2...' + + Returns: + - str: A new SMILES string with the common molecules removed, in the form: + 'reactant1.reactant2...>>product1.product2...' + + Example: + >>> remove_reagent_from_smiles('CC=O.CC=O.CCC=O>>CC=CO.CC=O.CC=O') + 'CCC=O>>CC=CO' + """ + + # Split the input SMILES string into reactants and products + reactants, products = rsmi.split(">>") + + # Split the reactants and products by '.' to separate molecules + reactant_molecules = reactants.split(".") + product_molecules = products.split(".") + + # Count the occurrences of each molecule in reactants and products + reactant_count = Counter(reactant_molecules) + product_count = Counter(product_molecules) + + # Find common molecules between reactants and products + common_molecules = set(reactant_count) & set(product_count) + + # Remove common molecules by the minimum occurrences in both reactants and products + for molecule in common_molecules: + common_occurrences = min(reactant_count[molecule], product_count[molecule]) + + # Decrease the count by the common occurrences + reactant_count[molecule] -= common_occurrences + product_count[molecule] -= common_occurrences + + # Rebuild the lists of reactant and product molecules after removal + filtered_reactant_molecules = [ + molecule for molecule, count in reactant_count.items() for _ in range(count) + ] + filtered_product_molecules = [ + molecule for molecule, count in product_count.items() for _ in range(count) + ] + + # Join the remaining molecules back into SMILES strings + new_reactants = ".".join(filtered_reactant_molecules) + new_products = ".".join(filtered_product_molecules) + + # Return the updated reaction string + return f"{new_reactants}>>{new_products}" + + +def perform_multi_step_reaction( + gml_list: List[str], order: List[int], rsmi: str +) -> Tuple[List[List[str]], Dict[str, List[str]]]: + """ + Applies a sequence of multi-step reactions to a starting SMILES string. The function + processes each reaction step in a specified order, and returns both the intermediate + and final products, as well as a mapping of reactant SMILES to their + corresponding products. + + Parameters: + - gml_list (List[str]): A list of reaction rules (in GML format) to be applied. + Each element corresponds to a reaction step. + - order (List[int]): A list of integers that defines the order in which the + reaction steps should be applied. Each integer is an index referring to the position + of a reaction rule in the `gml_list`. + - rsmi (str): The starting reaction SMILES string, representing the reactants for the + first reaction. + + Returns: + - Tuple[List[List[str]], Dict[str, List[str]]]: + - A list of lists of SMILES strings, where each inner list contains the + RSMI generated at each reaction step. + - A dictionary mapping each RSMI string to the resulting products after applying + the reaction rules. The keys are the input RSMIs, and the values are the + resulting product SMILES strings. + """ + + # Initialize CoreEngine for reaction processing + core = CoreEngine() + # Initialize a dictionary to hold reaction results + reaction_results = {} + + # List to store the results of each reaction step + all_steps: List[List[str]] = [] + result: List[str] = [rsmi] # Initial result is the input SMILES string + + # Loop over the reaction steps in the specified order + for i, j in enumerate(order): + # Get the reaction SMILES (RSMI) for the current step + current_step_gml = gml_list[j] + new_result: List[str] = [] # List to hold products for this step + + # Apply the reaction for each current reactant SMILES + for current_rsmi in result: + smi_lst = ( + current_rsmi.split(">>")[0].split( + "." + ) # Split reactants at the first step + if i == 0 + else current_rsmi.split(">>")[1].split( + "." + ) # Split products for subsequent steps + ) + + # Perform the reaction using the CoreEngine + o = core.perform_reaction(current_step_gml, smi_lst) + + # Apply standardization on the products + o = [std.fit(i) for i in o] + + # Collect the new results (products) from this reaction step + new_result.extend(o) + + # Record the reaction results in the dictionary, mapping input RSMI to output products + if len(o) > 0: + reaction_results[current_rsmi] = o + + # Update the result list for the next step + result = new_result + + # Append the results of this step to the overall steps list + all_steps.append(result) + + # Return the results: a list of all steps and a dictionary of reaction results + return all_steps, reaction_results + + +def calculate_max_depth(reaction_tree, current_node=None, depth=0): + """ + Calculate the maximum depth of a reaction tree. + + Parameters: + - reaction_tree (dict): A dictionary where keys are reaction SMILES (RSMI) + and values are lists of product reactions. + - current_node (str): The current node in the tree being explored (reaction SMILES). + - depth (int): The current depth of the tree. + + Returns: + - int: The maximum depth of the tree. + """ + # If current_node is None, start from the root node (first key in the reaction tree) + if current_node is None: + current_node = list(reaction_tree.keys())[0] + + # Get the products of the current node (reaction) + products = reaction_tree.get(current_node, []) + + # If no products, we are at a leaf node, return the current depth + if not products: + return depth + + # Recursively calculate the depth for each product and return the maximum + max_subtree_depth = max( + calculate_max_depth(reaction_tree, product, depth + 1) for product in products + ) + return max_subtree_depth + + +def find_all_paths( + reaction_tree, + target_products, + current_node, + target_depth, + current_depth=0, + path=None, +): + """ + Recursively find all paths from the root to the maximum depth in the reaction tree. + + Parameters: + - reaction_tree (dict): A dictionary of reaction SMILES with products. + - current_node (str): The current node (reaction SMILES). + - target_depth (int): The depth at which the product matches the root's product. + - current_depth (int): The current depth of the search. + - path (list): The current path in the tree. + + Returns: + - List of all paths to the max depth. + """ + if path is None: + path = [] + + # Add the current node (reaction SMILES) to the path + path.append(current_node) + + # If we have reached the target depth, check the product + if current_depth == target_depth: + # Extract products of the current node + products = sorted(current_node.split(">>")[1].split(".")) + return [path] if products == target_products else [] + + # If we haven't reached the target depth, recurse on the products + paths = [] + for product in reaction_tree.get(current_node, []): + paths.extend( + find_all_paths( + reaction_tree, + target_products, + product, + target_depth, + current_depth + 1, + path.copy(), + ) + ) + + return paths diff --git a/synutility/SynGraph/Transform/rule_apply.py b/synutility/SynGraph/Transform/rule_apply.py index 0dee2d3..9836e0a 100644 --- a/synutility/SynGraph/Transform/rule_apply.py +++ b/synutility/SynGraph/Transform/rule_apply.py @@ -1,107 +1,80 @@ import os -import regex +from typing import List from synutility.SynIO.debug import setup_logging import torch -from mod import smiles, ruleGMLString, DG, config, DGVertexMapper +from mod import smiles, ruleGMLString, DG, config logger = setup_logging() def deduplicateGraphs(initial): - """ - Removes duplicate graphs from a list based on graph isomorphism. + res = [] + for cand in initial: + for a in res: + if cand.isomorphism(a) != 0: + res.append(a) # the one we had already + break + else: + # didn't find any isomorphic, use the new one + res.append(cand) + return res - Parameters: - - initial (list): List of graph objects. - Returns: - - List of unique graph objects. +def rule_apply( + smiles_list: List[str], rule: str, verbose: int = 0, print_output: bool = False +) -> DG: """ - unique_graphs = [] - for candidate in initial: - # Check if candidate is isomorphic to any graph already in unique_graphs - if not any(candidate.isomorphism(existing) != 0 for existing in unique_graphs): - unique_graphs.append(candidate) - return unique_graphs + Applies a reaction rule to a list of SMILES strings and optionally prints + the derivation graph. - -def rule_apply(smiles_list, rule, print_output=False): - """ - Applies a reaction rule to a list of SMILES and optionally prints the output. + This function first converts the SMILES strings into molecular graphs, + deduplicates them, sorts them based on the number of vertices, and + then applies the provided reaction rule in the GML string format. + The resulting derivation graph (DG) is returned. Parameters: - - smiles_list (list): List of SMILES strings. - - rule (str): Reaction rule in GML string format. - - print_output (bool): If True, output will be printed to a directory. + - smiles_list (List[str]): A list of SMILES strings representing the molecules + to which the reaction rule will be applied. + - rule (str): The reaction rule in GML string format. This rule will be applied to the + molecules represented by the SMILES strings. + - verbose (int, optional): The verbosity level for logging or debugging. + Default is 0 (no verbosity). + - print_output (bool, optional): If True, the derivation graph will be printed + to the "out" directory. Default is False. Returns: - - dg (DG): The derivation graph after applying the rule. + - DG: The derivation graph (DG) after applying the reaction rule to the + initial molecules. + + Raises: + - Exception: If an error occurs during the process of applying the rule, + an exception is raised. """ try: + # Convert SMILES strings to molecular graphs and deduplicate initial_molecules = [smiles(smile, add=False) for smile in smiles_list] initial_molecules = deduplicateGraphs(initial_molecules) + + # Sort molecules based on the number of vertices initial_molecules = sorted( initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False ) + # Convert the reaction rule from GML string format to a reaction rule object reaction_rule = ruleGMLString(rule) + # Create the derivation graph and apply the reaction rule dg = DG(graphDatabase=initial_molecules) config.dg.doRuleIsomorphismDuringBinding = False - dg.build().apply(initial_molecules, reaction_rule, verbosity=8) + dg.build().apply(initial_molecules, reaction_rule, verbosity=verbose) - # Optionally print the output + # Optionally print the output to a directory if print_output: os.makedirs("out", exist_ok=True) dg.print() return dg + except Exception as e: logger.error(f"An error occurred: {e}") - - -def getReactionSmiles(dg): - origSmiles = {} - for v in dg.vertices: - s = v.graph.smilesWithIds - s = regex.sub(":([0-9]+)]", ":o\\1]", s) - origSmiles[v.graph] = s - - res = {} - for e in dg.edges: - vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) - # vms = DGVertexMapper(e) - eductSmiles = [origSmiles[g] for g in vms.left] - - for ev in vms.left.vertices: - s = eductSmiles[ev.graphIndex] - s = s.replace(f":o{ev.vertex.id}]", f":{ev.id}]") - eductSmiles[ev.graphIndex] = s - - strs = set() - for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): - # for vm in DGVertexMapper(e): - productSmiles = [origSmiles[g] for g in vms.right] - for ev in vms.left.vertices: - pv = vm.map[ev] - if not pv: - continue - s = productSmiles[pv.graphIndex] - s = s.replace(f":o{pv.vertex.id}]", f":{ev.id}]") - productSmiles[pv.graphIndex] = s - count = vms.left.numVertices - for pv in vms.right.vertices: - ev = vm.map.inverse(pv) - if ev: - continue - s = productSmiles[pv.graphIndex] - s = s.replace(f":o{pv.vertex.id}]", f":{count}]") - count += 1 - productSmiles[pv.graphIndex] = s - left = ".".join(eductSmiles) - right = ".".join(productSmiles) - s = f"{left}>>{right}" - assert ":o" not in s - strs.add(s) - res[e] = list(sorted(strs)) - return res + raise diff --git a/synutility/SynIO/Format/dg_to_gml.py b/synutility/SynIO/Format/dg_to_gml.py index b2d079d..321a74b 100644 --- a/synutility/SynIO/Format/dg_to_gml.py +++ b/synutility/SynIO/Format/dg_to_gml.py @@ -22,6 +22,7 @@ def getReactionSmiles(dg): res = {} for e in dg.edges: vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) + # vms = DGVertexMapper(e) eductSmiles = [origSmiles[g] for g in vms.left] for ev in vms.left.vertices: @@ -31,6 +32,7 @@ def getReactionSmiles(dg): strs = set() for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): + # for vm in DGVertexMapper(e): productSmiles = [origSmiles[g] for g in vms.right] for ev in vms.left.vertices: pv = vm.map[ev] diff --git a/synutility/SynMOD/__init__.py b/synutility/SynMOD/__init__.py deleted file mode 100644 index e69de29..0000000