diff --git a/.gitignore b/.gitignore index d92525e..12531b3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ Data/Temp/Benchmark/Complete/* Data/Temp/Benchmark/Hier/* Data/Temp/Benchmark/Raw/* *.ipynb +*backup +bug.py diff --git a/Test/SynITS/test_hydrogen_utils.py b/Test/SynITS/test_hydrogen_utils.py new file mode 100644 index 0000000..32fc7df --- /dev/null +++ b/Test/SynITS/test_hydrogen_utils.py @@ -0,0 +1,84 @@ +import unittest +import networkx as nx +from synutility.SynIO.data_type import load_from_pickle +from syntemp.SynITS.hydrogen_utils import ( + check_explicit_hydrogen, + check_hcount_change, + get_cycle_member_rings, + get_priority, +) + + +class TestGraphFunctions(unittest.TestCase): + + def setUp(self): + # Create a test graph for the tests + self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz") + + def test_check_explicit_hydrogen(self): + # Test the check_explicit_hydrogen function + # Note, usually only appear in reactants (+H2 reactions) + count_r, hydrogen_nodes_r = check_explicit_hydrogen( + self.data[20]["ITSGraph"][0] + ) + self.assertEqual(count_r, 2) + self.assertEqual(hydrogen_nodes_r, [45, 46]) + + def test_check_hcount_change(self): + # Test the check_hcount_change function + max_change = check_hcount_change( + self.data[20]["ITSGraph"][0], self.data[20]["ITSGraph"][0] + ) + self.assertEqual(max_change, 2) + + def test_get_cycle_member_rings_minimal(self): + # Test get_cycle_member_rings with 'minimal' cycles + member_rings = get_cycle_member_rings(self.data[1]["GraphRules"][2], "minimal") + self.assertEqual(member_rings, [4]) # Cycles of size 4 and 3 + + def test_get_priority(self): + # Create a test graph for the tests + self.graph = nx.Graph() + self.graph.add_nodes_from( + [ + (1, {"element": "H", "hcount": 2}), + (2, {"element": "C", "hcount": 1}), + (3, {"element": "H", "hcount": 1}), + ] + ) + self.graph.add_edges_from([(1, 2), (2, 3)]) + + # Create another graph for `check_hcount_change` tests + self.prod_graph = nx.Graph() + self.prod_graph.add_nodes_from( + [ + (1, {"element": "H", "hcount": 1}), + (2, {"element": "C", "hcount": 1}), + (3, {"element": "H", "hcount": 2}), + ] + ) + self.prod_graph.add_edges_from([(1, 2), (2, 3)]) + + # Create a more complex graph for cycle tests + self.complex_graph = nx.Graph() + self.complex_graph.add_edges_from( + [ + (1, 2), + (2, 3), + (3, 4), + (4, 1), # A simple square cycle + (3, 5), + (5, 6), + (6, 3), # Another cycle + ] + ) + reaction_centers = [self.graph, self.prod_graph, self.complex_graph] + + # Get priority indices + priority_indices = get_priority(reaction_centers) + + self.assertEqual(priority_indices, [0, 1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynITS/test_its_extraction.py b/Test/SynITS/test_its_extraction.py index 8e83e8d..d625956 100644 --- a/Test/SynITS/test_its_extraction.py +++ b/Test/SynITS/test_its_extraction.py @@ -2,6 +2,8 @@ from syntemp.SynITS.its_extraction import ITSExtraction from syntemp.SynITS.its_construction import ITSConstruction +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph + class TestITSExtraction(unittest.TestCase): @@ -29,31 +31,15 @@ def setUp(self): ] self.mapper_names = ["local_mapper", "rxn_mapper", "graphormer"] - def test_graph_from_smiles(self): - graph = ITSExtraction.graph_from_smiles(self.smiles1) - self.assertEqual(len(graph.nodes()), 4) - self.assertEqual(len(graph.edges()), 3) - def test_check_equivariant_graph(self): - react_local_mapper, prod_local_mapper = self.mapped_smiles_list[0][ - "local_mapper" - ].split(">>") - G_local = ITSExtraction.graph_from_smiles(react_local_mapper) - H_local = ITSExtraction.graph_from_smiles(prod_local_mapper) + G_local, H_local = rsmi_to_graph(self.mapped_smiles_list[0]["local_mapper"]) ITS_local = ITSConstruction.ITSGraph(G_local, H_local) - - react_rxn_mapper, prod_rxn_mapper = self.mapped_smiles_list[0][ - "rxn_mapper" - ].split(">>") - G_rxn = ITSExtraction.graph_from_smiles(react_rxn_mapper) - H_rxn = ITSExtraction.graph_from_smiles(prod_rxn_mapper) + G_rxn, H_rxn = rsmi_to_graph(self.mapped_smiles_list[0]["rxn_mapper"]) ITS_rxn = ITSConstruction.ITSGraph(G_rxn, H_rxn) - react_graphormer, prod_graphormer = self.mapped_smiles_list[0][ - "graphormer" - ].split(">>") - G_graphormer = ITSExtraction.graph_from_smiles(react_graphormer) - H_graphormer = ITSExtraction.graph_from_smiles(prod_graphormer) + G_graphormer, H_graphormer = rsmi_to_graph( + self.mapped_smiles_list[0]["graphormer"] + ) ITS_graphormer = ITSConstruction.ITSGraph(G_graphormer, H_graphormer) classified, equivariant = ITSExtraction.check_equivariant_graph( @@ -82,7 +68,7 @@ def test_parallel_process_smiles(self): self.assertIsNotNone(results[0]["GraphRules"]) # Inequivalent AAM - self.assertEqual(results_wrong[0]["equivariant"], 0) + self.assertEqual(results_wrong[0]["equivariant"], -1) # -1 mean exit early def test_unsanitize_smiles(self): test_2 = { diff --git a/Test/SynITS/test_its_hadjuster.py b/Test/SynITS/test_its_hadjuster.py index 1f87044..87c95cf 100644 --- a/Test/SynITS/test_its_hadjuster.py +++ b/Test/SynITS/test_its_hadjuster.py @@ -1,65 +1,96 @@ -# import unittest -# import networkx as nx -# from SynTemp.SynITS.its_hadjuster import ITSHAdjuster - - -# class TestITSHAdjuster(unittest.TestCase): - -# def create_mock_graph(self, hcounts: dict) -> nx.Graph: -# """Utility function to create a mock graph with specified -# hydrogen counts for nodes.""" -# graph = nx.Graph() -# for node_id, hcount in hcounts.items(): -# graph.add_node(node_id, hcount=hcount) -# return graph - -# def test_check_hcount_change(self): -# # Mock reactant and product graphs with specified hydrogen counts -# react_graph = self.create_mock_graph({1: 1, 2: 2}) -# prod_graph = self.create_mock_graph({1: 0, 2: 3}) - -# # Expected: one hydrogen formation (node 1) and one hydrogen break (node 2) -# max_hydrogen_change = ITSHAdjuster.check_hcount_change(react_graph, prod_graph) -# self.assertEqual(max_hydrogen_change, 1) - -# def test_add_hydrogen_nodes(self): -# # Mock reactant and product graphs with specified hydrogen counts -# react_graph = self.create_mock_graph({1: 1}) -# prod_graph = self.create_mock_graph({1: 0}) - -# # Add hydrogen nodes to reactant and product graphs -# updated_react_graph, _ = ITSHAdjuster.add_hydrogen_nodes( -# react_graph, prod_graph -# ) - -# # Verify that hydrogen nodes have been added correctly -# self.assertIn( -# max(updated_react_graph.nodes), updated_react_graph.nodes -# ) # Hydrogen node added to reactant graph -# self.assertEqual( -# updated_react_graph.nodes[max(updated_react_graph.nodes)]["element"], "H" -# ) # Check element of added node - -# def test_add_hydrogen_nodes_multiple(self): -# # Mock reactant and product graphs with specified hydrogen counts -# react_graph = self.create_mock_graph({1: 2, 2: 1}) -# prod_graph = self.create_mock_graph({1: 0, 2: 2}) - -# # Generate updated graph pairs with multiple hydrogen nodes added -# updated_graph_pairs = ITSHAdjuster.add_hydrogen_nodes_multiple( -# react_graph, prod_graph -# ) - -# # Verify that multiple updated graph pairs are generated -# self.assertTrue(len(updated_graph_pairs) > 1) # Multiple permutations generated -# for react_graph, prod_graph in updated_graph_pairs: -# self.assertIn( -# max(react_graph.nodes), react_graph.nodes -# ) # Hydrogen node added to reactant graph -# self.assertIn( -# max(prod_graph.nodes), prod_graph.nodes -# ) # Hydrogen node added to product graph - - -# if __name__ == "__main__": -# unittest.main() +import unittest +import networkx as nx +from copy import deepcopy +from synutility.SynIO.data_type import load_from_pickle +from syntemp.SynITS.its_hadjuster import ITSHAdjuster + + +class TestITSHAdjuster(unittest.TestCase): + + def setUp(self): + """Setup before each test.""" + # Create sample graphs + self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz") + + def test_process_single_graph_data_success(self): + """Test the process_single_graph_data method.""" + processed_data = ITSHAdjuster.process_single_graph_data( + self.data[0], "ITSGraph" + ) + for value in processed_data["ITSGraph"]: + self.assertTrue(isinstance(value, nx.Graph)) + for value in processed_data["GraphRules"]: + self.assertTrue(isinstance(value, nx.Graph)) + + def test_process_single_graph_data_fail(self): + """Test the process_single_graph_data method.""" + processed_data = ITSHAdjuster.process_single_graph_data( + self.data[16], "ITSGraph" + ) + self.assertIsNone(processed_data["ITSGraph"]) + self.assertIsNone(processed_data["GraphRules"]) + + def test_process_single_graph_data_empty_graph(self): + """Test that an empty graph results in empty ITSGraph and GraphRules.""" + empty_graph_data = { + "ITSGraph": [None, None, None], + "GraphRules": [None, None, None], + } + + processed_data = ITSHAdjuster.process_single_graph_data( + empty_graph_data, "ITSGraph" + ) + + # Ensure the result is None or empty as expected for an empty graph + self.assertIsNone(processed_data["ITSGraph"]) + self.assertIsNone(processed_data["GraphRules"]) + + def test_process_single_graph_data_safe(self): + """Test the process_single_graph_data method.""" + processed_data = ITSHAdjuster.process_single_graph_data_safe( + self.data[0], "ITSGraph", job_timeout=0.0001 + ) + self.assertIsNone(processed_data["ITSGraph"]) + self.assertIsNone(processed_data["GraphRules"]) + + def test_process_graph_data_parallel(self): + """Test the process_graph_data_parallel method.""" + result = ITSHAdjuster().process_graph_data_parallel( + self.data, "ITSGraph", n_jobs=1, verbose=0, get_priority_graph=True + ) + result = [value for value in result if value["ITSGraph"]] + # Check if the result matches the input data structure + self.assertEqual(len(result), 48) + + def test_process_graph_data_parallel_safe(self): + """Test the process_graph_data_parallel method.""" + result = ITSHAdjuster().process_graph_data_parallel( + self.data, + "ITSGraph", + n_jobs=1, + verbose=0, + get_priority_graph=True, + safe=True, + job_timeout=0.0001, # lower timeout will fail all process + ) + result = [value for value in result if value["ITSGraph"]] + # Check if the result matches the input data structure + self.assertEqual(len(result), 0) + + def test_process_multiple_hydrogens(self): + """Test the process_multiple_hydrogens method.""" + graphs = deepcopy(self.data[0]) + react_graph, prod_graph, _ = graphs["ITSGraph"] + + result = ITSHAdjuster.process_multiple_hydrogens( + graphs, react_graph, prod_graph, ignore_aromaticity=False, balance_its=True + ) + + for value in result["ITSGraph"]: + self.assertTrue(isinstance(value, nx.Graph)) + for value in result["GraphRules"]: + self.assertTrue(isinstance(value, nx.Graph)) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynRule/test_rc_cluster.py b/Test/SynRule/test_rc_cluster.py index 351b0e6..7095f0d 100644 --- a/Test/SynRule/test_rc_cluster.py +++ b/Test/SynRule/test_rc_cluster.py @@ -70,7 +70,7 @@ def test_auto_cluster_wrong_isomorphism(self): clusters, _ = self.clusterer.auto_cluster( rc, signature, nodeMatch=None, edgeMatch=None ) - self.assertEqual(len(clusters), 36) # wrong value but almost correct + self.assertEqual(len(clusters), 37) # correct by some magic. No proof for this def test_fit(self): """Test the fit method to ensure it correctly updates data entries with cluster indices.""" diff --git a/Test/test_auto_template.py b/Test/test_auto_template.py index b6de1bc..fce9350 100644 --- a/Test/test_auto_template.py +++ b/Test/test_auto_template.py @@ -28,7 +28,7 @@ def setUp(self) -> None: def test_temp_extract(self): (rules, _, _, _, _, _) = self.auto.temp_extract(self.data, lib_path=None) self.assertIn("ruleID", rules[0][0]) - self.assertEqual(len(rules[0]), 10) + self.assertEqual(len(rules[0]), 9) def test_temp_extract_lib(self): print(f"{root_dir}/Data/Testcase/Compose/SingleRule") @@ -36,7 +36,7 @@ def test_temp_extract_lib(self): self.data, lib_path=f"{root_dir}/Data/Testcase/Compose/SingleRule" ) # 1 rules exist self.assertIn("ruleID", rules[0][0]) - self.assertEqual(len(rules[0]), 8) + self.assertEqual(len(rules[0]), 7) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index c986b91..c2e2630 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "seaborn>=0.13.2", "joblib>=1.3.2", "synrbl>=0.0.25", - "synutility>=0.0.12" + "synutility>=0.0.13" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 4a52be1..9c2079d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ networkx>=3.3 seaborn>=0.13.2 joblib>=1.3.2 synrbl>=0.0.25 -synutility>=0.0.12 \ No newline at end of file +synutility>=0.0.13 \ No newline at end of file diff --git a/syntemp/SynITS/hydrogen_utils.py b/syntemp/SynITS/hydrogen_utils.py new file mode 100644 index 0000000..312f57a --- /dev/null +++ b/syntemp/SynITS/hydrogen_utils.py @@ -0,0 +1,133 @@ +import networkx as nx +from typing import List, Any +from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor + + +def check_explicit_hydrogen(graph: nx.Graph) -> tuple: + """ + Counts the explicit hydrogen nodes in the given graph and collects their IDs. + + Parameters: + - graph (nx.Graph): The graph to inspect. + + Returns: + tuple: A tuple containing the number of hydrogen nodes and a list of their node IDs. + """ + hydrogen_nodes = [ + node_id + for node_id, attr in graph.nodes(data=True) + if attr.get("element") == "H" + ] + return len(hydrogen_nodes), hydrogen_nodes + + +def check_hcount_change(react_graph: nx.Graph, prod_graph: nx.Graph) -> int: + """ + Computes the maximum change in hydrogen count ('hcount') between corresponding nodes + in the reactant and product graphs. It considers both hydrogen formation and breakage. + + Parameters: + - react_graph (nx.Graph): The graph representing reactants. + - prod_graph (nx.Graph): The graph representing products. + + Returns: + int: The maximum hydrogen change observed across all nodes. + """ + # max_hydrogen_change = 0 + hcount_break, _ = check_explicit_hydrogen(react_graph) + hcount_form, _ = check_explicit_hydrogen(prod_graph) + + for node_id, attrs in react_graph.nodes(data=True): + react_hcount = attrs.get("hcount", 0) + if node_id in prod_graph: + prod_hcount = prod_graph.nodes[node_id].get("hcount", 0) + else: + prod_hcount = 0 + + if react_hcount >= prod_hcount: + hcount_break += react_hcount - prod_hcount + else: + hcount_form += prod_hcount - react_hcount + + max_hydrogen_change = max(hcount_break, hcount_form) + + return max_hydrogen_change + + +def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: + """ + Identifies all cycles in the given graph using cycle bases to ensure no overlap + and returns a list of the sizes of these cycles (member rings), + sorted in ascending order. + + Parameters: + - G (nx.Graph): The NetworkX graph to be analyzed. + + Returns: + - List[int]: A sorted list of cycle sizes (member rings) found in the graph. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + + if type == "minimal": + cycles = nx.minimum_cycle_basis(G) + else: + cycles = nx.cycle_basis(G) + member_rings = [len(cycle) for cycle in cycles] + + member_rings.sort() + + return member_rings + + +def get_priority(reaction_centers: List[Any]) -> List[int]: + """ + Evaluate reaction centers for specific graph characteristics, selecting indices based + on the shortest reaction paths and maximum ring sizes, and adjusting for certain + graph types by modifying the ring information. + + Parameters: + - reaction_centers: List[Any], a list of reaction centers where each center should be + capable of being analyzed for graph type and ring sizes. + + Returns: + - List[int]: A list of indices from the original list of reaction centers that meet + the criteria of having the shortest reaction steps and/or the largest ring sizes. + Returns indices with minimum reaction steps if no indices meet both criteria. + """ + # Extract topology types and ring sizes from reaction centers + topo_type = [ + GraphDescriptor.check_graph_type(center) for center in reaction_centers + ] + cyclic = [ + get_cycle_member_rings(center, "fundamental") for center in reaction_centers + ] + + # Adjust ring information based on the graph type + for index, graph_type in enumerate(topo_type): + if graph_type in ["Acyclic", "Complex Cyclic"]: + cyclic[index] = [0] + cyclic[index] + + # Determine minimum reaction steps + reaction_steps = [len(rings) for rings in cyclic] + min_reaction_step = min(reaction_steps) + + # Filter indices with the minimum reaction steps + indices_shortest = [ + i for i, steps in enumerate(reaction_steps) if steps == min_reaction_step + ] + + # Filter indices with the maximum ring size + max_size = max( + max(rings) for rings in cyclic if rings + ) # Safeguard against empty sublists + prior_indices = [i for i, rings in enumerate(cyclic) if max(rings) == max_size] + + # Combine criteria for final indices + final_indices = [index for index in prior_indices if index in indices_shortest] + + # Fallback to shortest indices if no indices meet both criteria + if not final_indices: + return indices_shortest + + return final_indices diff --git a/syntemp/SynITS/its_arbitrary.py b/syntemp/SynITS/its_arbitrary.py new file mode 100644 index 0000000..a2bae55 --- /dev/null +++ b/syntemp/SynITS/its_arbitrary.py @@ -0,0 +1,308 @@ +import networkx as nx +from syntemp.SynITS.its_extraction import ITSExtraction +from syntemp.SynITS.its_hadjuster import ITSHAdjuster +from syntemp.SynRule.rules_extraction import RuleExtraction +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.debug import setup_logging +from typing import Dict, List, Tuple +from syntemp.SynRule.rc_cluster import RCCluster +from synutility.SynGraph.Descriptor.graph_signature import GraphSignature +from syntemp.SynITS.hydrogen_utils import ( + check_hcount_change, +) +from joblib import Parallel, delayed + +logger = setup_logging() + + +class ITSArbitrary: + def __init__(self): + pass + + @staticmethod + def process_equivalent_map( + react_graph: nx.Graph, + prod_graph: nx.Graph, + ignore_aromaticity: bool, + balance_its: bool, + ) -> Tuple[List[nx.Graph], List[nx.Graph]]: + """ + Process equivalent maps by adding hydrogen nodes and constructing ITS graphs. + + Parameters: + - react_graph (nx.Graph): The reactant graph. + - prod_graph (nx.Graph): The product graph. + - ignore_aromaticity (bool): Whether to ignore aromaticity in graph construction. + - balance_its (bool): Whether to balance the ITS graph. + + Returns: + - Tuple of (List[nx.Graph], List[nx.Graph]): Lists of reaction graphs and + ITS graphs. + """ + hcount_change = check_hcount_change(react_graph, prod_graph) + if hcount_change == 0: + its_list = [ITSConstruction().ITSGraph(react_graph, prod_graph)] + rc_list = [ + RuleExtraction.extract_reaction_rules( + react_graph, prod_graph, i, False, 1 + )[2] + for i in its_list + ] + return list(rc_list), list(its_list) + + combinations_solution = ITSHAdjuster.add_hydrogen_nodes_multiple( + react_graph, prod_graph + ) + + # Create ITS graphs for each combination solution + its_list = [ + ITSConstruction.ITSGraph( + i[0], i[1], ignore_aromaticity, balance_its=balance_its + ) + for i in combinations_solution + ] + + # Extract reaction rules for each ITS graph + rc_list = [ + RuleExtraction.extract_reaction_rules(react_graph, prod_graph, i, False, 1)[ + 2 + ] + for i in its_list + ] + + # Filter valid reaction graphs and ITS graphs + valid_rc_its = [ + (rc, its) + for rc, its in zip(rc_list, its_list) + if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0 + ] + + # Unzip valid results + rc_list, its_list = zip(*valid_rc_its) if valid_rc_its else ([], []) + return list(rc_list), list(its_list) + + @staticmethod + def process_non_equivalent_map( + data: Dict[str, str], + mapped_key: List[str], + sanitize: bool, + ignore_aromaticity: bool, + balance_its: bool, + ) -> Tuple[List[nx.Graph], List[nx.Graph]]: + """ + Process non-equivalent maps and construct their corresponding reaction and + ITS graphs. + + Parameters: + - data (Dict[str, str]): Dictionary of mapped SMILES strings. + - mapped_key (List[str]): List of mapper names to process. + - sanitize (bool): Whether to sanitize the molecule(s). + - ignore_aromaticity (bool): Whether to ignore aromaticity in graph construction. + - balance_its (bool): Whether to balance the ITS graph. + + Returns: + - Tuple of (List[nx.Graph], List[nx.Graph]): Lists of reaction graphs and + ITS graphs. + """ + rc_list, its_list = [], [] + for mapper in mapped_key: + try: + # Convert SMILES to graphs + G, H = rsmi_to_graph( + data[mapper], + drop_non_aam=True, + light_weight=True, + sanitize=sanitize, + ) + # Process equivalent maps + rc, its = ITSArbitrary.process_equivalent_map( + G, H, ignore_aromaticity, balance_its + ) + # print(rc) + rc_list.extend(rc) + its_list.extend(its) + except Exception as e: + logger.warning(f"Error processing {mapper}: {e}") + return rc_list, its_list + + @staticmethod + def get_unique_graphs_for_clusters( + graphs: List[nx.Graph], cluster_indices: List[int] + ) -> List[nx.Graph]: + """ + Get a unique graph for each cluster from a list of graphs. + + Parameters: + - graphs (List[nx.Graph]): List of networkx graphs. + - cluster_indices (List[int]): List of indices that represent cluster assignments + for each graph. + + Returns: + - List[nx.Graph]: List of unique graphs, one per cluster. + """ + # Create a dictionary to store graphs by cluster index + cluster_graphs = {} + + for idx, cluster_id in enumerate(cluster_indices): + # Add graph to the appropriate cluster + if cluster_id not in cluster_graphs: + cluster_graphs[cluster_id] = [] + cluster_graphs[cluster_id].append(graphs[idx]) + + # Now, select one unique graph per cluster (e.g., the first graph) + unique_graphs = [] + for cluster_id, graphs_in_cluster in cluster_graphs.items(): + unique_graphs.append(graphs_in_cluster[0]) + + return unique_graphs + + @staticmethod + def its_expand( + data: Dict[str, str], + mapped_key: List[str], + check_method: str = "RC", + id_column: str = "R-id", + ignore_aromaticity: bool = False, + confident_mapper: str = "graphormer", + sanitize: bool = True, + balance_its: bool = True, + ) -> Tuple[List[nx.Graph], List[nx.Graph]]: + """ + Expand ITS graphs by checking equivalence and processing accordingly. + + Parameters: + - data (Dict[str, str]): Dictionary of mapped SMILES strings. + - mapped_key (List[str]): List of mapper names to process. + - check_method (str): Method to check for isomorphism, "RC" or "ITS". + - id_column (str): Column name for reaction ID. + - ignore_aromaticity (bool): Whether to ignore aromaticity. + - confident_mapper (str): Mapper to use when confident. + - symbol (str): Reaction symbol separator. + - sanitize (bool): Whether to sanitize molecules. + - balance_its (bool): Whether to balance ITS graphs. + + Returns: + - Tuple of (List[nx.Graph], List[nx.Graph]): Lists of reaction graphs + and ITS graphs. + """ + try: + # Process the mapped SMILES strings and check equivalence + good, _ = ITSExtraction.process_mapped_smiles( + data, + mapped_key, + check_method, + id_column, + ignore_aromaticity, + confident_mapper, + sanitize, + ) + + # Ensure equivalence check is valid + if "equivariant" not in good: + raise ValueError( + "Equivalence check result 'equivariant' not found in the response." + ) + # print(good) + # Process based on equivalence check + if good["equivariant"] != (len(mapped_key) - 1): + # print(1) + rc_list, its_list = ITSArbitrary.process_non_equivalent_map( + data, mapped_key, sanitize, ignore_aromaticity, balance_its + ) + else: + r, p = good[confident_mapper][0], good[confident_mapper][1] + rc_list, its_list = ITSArbitrary.process_equivalent_map( + r, p, ignore_aromaticity, balance_its + ) + + sig = [GraphSignature(i).create_graph_signature() for i in rc_list] + cluster_indices = RCCluster().fit_graphs(rc_list, sig) + + new_rc = ITSArbitrary.get_unique_graphs_for_clusters( + rc_list, cluster_indices + ) + new_its = ITSArbitrary.get_unique_graphs_for_clusters( + its_list, cluster_indices + ) + return new_rc, new_its + + except Exception as e: + # Log error and re-raise exception + logger.error(f"Error in ITSArbitrary.its_expand: {str(e)}") + return [], [] + + def parallel_its_expand( + self, + data: List[Dict], + mapped_key: List[str], + check_method: str = "RC", + id_column: str = "R-id", + ignore_aromaticity: bool = False, + confident_mapper: str = "graphormer", + sanitize: bool = True, + balance_its: bool = True, + n_jobs: int = 1, + verbose: int = 0, + ) -> Tuple[List[Dict], List[Dict]]: + """ + Expands ITS graphs in parallel for a list of reaction data. + + Parameters: + - data (List[Dict]): List of dictionaries containing mapped reaction data. + - mapped_key (List[str]): List of keys for processing each reaction data. + - check_method (str): Method for checking graph equivalence ("RC" or "ITS"). + Default is "RC". + - id_column (str): The column in the dictionary representing reaction ID. + - ignore_aromaticity (bool): Whether to ignore aromaticity when + constructing chemical graphs. + - confident_mapper (str): Mapper to use when confident. + Default: "graphormer". + - sanitize (bool): Whether to sanitize the molecules during graph construction. + - balance_its (bool): Whether to balance ITS graphs. + - n_jobs (int): Number of parallel jobs (default: 1). + - verbose (int): Verbosity level for parallel processing. + + Returns: + - Tuple of (List[Dict], List[Dict]): Updated list of dictionaries with expanded + RC and ITS graphs for each reaction. + """ + + logger.info(f"Starting parallel ITS graph expansion with {n_jobs} jobs.") + + try: + results = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(self.its_expand)( + graph_data, + mapped_key, + check_method, + id_column, + ignore_aromaticity, + confident_mapper, + sanitize, + balance_its, + ) + for graph_data in data + ) + except Exception as e: + logger.error(f"Error occurred during parallel processing: {e}") + return [], [] # In case of failure, return empty lists for RC and ITS + + # Process and store the results into the data dictionary + for key, result in enumerate(results): + try: + rc, its = ( + result # Deconstruct the tuple (RC graph list, ITS graph list) + ) + data[key]["RC"] = rc + data[key]["ITS"] = its + logger.info( + f"Processed reaction {data[key].get(id_column)} successfully." + ) + except Exception as e: + logger.warning(f"Error processing reaction at index {key}: {e}") + data[key]["RC"] = [] + data[key]["ITS"] = [] + + # Return the updated data with RC and ITS graphs + return data diff --git a/syntemp/SynITS/its_decomposer.py b/syntemp/SynITS/its_decomposer.py deleted file mode 100644 index 0e54aa9..0000000 --- a/syntemp/SynITS/its_decomposer.py +++ /dev/null @@ -1,108 +0,0 @@ -import networkx as nx - - -def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): - """ - Decompose an ITS graph into two separate graphs G and H based on shared - node and edge attributes. - - Parameters: - - its_graph (nx.Graph): The integrated transition state (ITS) graph. - - nodes_share (str): Node attribute key that stores tuples with node attributes - or G and H. - - edges_share (str): Edge attribute key that stores tuples with edge attributes - for G and H. - - Returns: - - Tuple[nx.Graph, nx.Graph]: A tuple containing the two graphs G and H. - """ - G = nx.Graph() - H = nx.Graph() - - # Decompose nodes - for node, data in its_graph.nodes(data=True): - if nodes_share in data: - node_attr_g, node_attr_h = data[nodes_share] - # Unpack node attributes for G - G.add_node( - node, - element=node_attr_g[0], - aromatic=node_attr_g[1], - hcount=node_attr_g[2], - charge=node_attr_g[3], - neighbors=node_attr_g[4], - ) - # Unpack node attributes for H - H.add_node( - node, - element=node_attr_h[0], - aromatic=node_attr_h[1], - hcount=node_attr_h[2], - charge=node_attr_h[3], - neighbors=node_attr_h[4], - ) - - # Decompose edges - for u, v, data in its_graph.edges(data=True): - if edges_share in data: - order_g, order_h = data[edges_share] - if order_g > 0: # Assuming 0 means no edge in G - G.add_edge(u, v, order=order_g) - if order_h > 0: # Assuming 0 means no edge in H - H.add_edge(u, v, order=order_h) - - return G, H - - -def compare_graphs( - graph1: nx.Graph, - graph2: nx.Graph, - node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], - edge_attrs: list = ["order"], -) -> bool: - """ - Compare two graphs based on specified node and edge attributes. - - Parameters: - - graph1 (nx.Graph): The first graph to compare. - - graph2 (nx.Graph): The second graph to compare. - - node_attrs (list): A list of node attribute names to include in the comparison. - - edge_attrs (list): A list of edge attribute names to include in the comparison. - - Returns: - - bool: True if both graphs are identical with respect to the specified attributes, - otherwise False. - """ - # Compare node sets - if set(graph1.nodes()) != set(graph2.nodes()): - return False - - # Compare nodes based on attributes - for node in graph1.nodes(): - if node not in graph2: - return False - node_data1 = {attr: graph1.nodes[node].get(attr, None) for attr in node_attrs} - node_data2 = {attr: graph2.nodes[node].get(attr, None) for attr in node_attrs} - if node_data1 != node_data2: - return False - - # Compare edge sets with sorted tuples - if set(tuple(sorted(edge)) for edge in graph1.edges()) != set( - tuple(sorted(edge)) for edge in graph2.edges() - ): - return False - - # Compare edges based on attributes - for edge in graph1.edges(): - # Sort the edge for consistent comparison - sorted_edge = tuple(sorted(edge)) - if sorted_edge not in graph2.edges(): - return False - edge_data1 = {attr: graph1.edges[edge].get(attr, None) for attr in edge_attrs} - edge_data2 = { - attr: graph2.edges[sorted_edge].get(attr, None) for attr in edge_attrs - } - if edge_data1 != edge_data2: - return False - - return True diff --git a/syntemp/SynITS/its_extraction.py b/syntemp/SynITS/its_extraction.py index 4e84622..82b4a62 100644 --- a/syntemp/SynITS/its_extraction.py +++ b/syntemp/SynITS/its_extraction.py @@ -1,5 +1,4 @@ import networkx as nx -from rdkit import Chem from operator import eq from copy import deepcopy from joblib import Parallel, delayed @@ -7,12 +6,13 @@ from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match from synutility.SynIO.debug import setup_logging -from synutility.SynIO.Format.mol_to_graph import MolToGraph from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynGraph.Descriptor.graph_signature import GraphSignature -from syntemp.SynRule.rules_extraction import RuleExtraction +from syntemp.SynRule.rules_extraction import RuleExtraction logger = setup_logging() @@ -21,23 +21,6 @@ class ITSExtraction: def __init__(self): pass - @staticmethod - def graph_from_smiles(smiles: str, sanitize: bool = True) -> nx.Graph: - """ - Constructs a graph representation from a SMILES string. - - Parameters: - - smiles (str): A SMILES string representing a molecule or a set of molecules. - - sanitize (bool): Whether to sanitize the molecule(s). - - Returns: - - nx.Graph: A graph representation of the molecule(s). - """ - - mol = Chem.MolFromSmiles(smiles, sanitize=sanitize) - graph = MolToGraph().mol_to_graph(mol, drop_non_aam=True) - return graph - @staticmethod def check_equivariant_graph( its_graphs: List[nx.Graph], @@ -52,6 +35,11 @@ def check_equivariant_graph( - List[Tuple[int, int]]: A list of tuples representing pairs of indices of isomorphic graphs. """ + # If there's only one graph, no comparison is possible + if len(its_graphs) == 1: + return [], 0 + + # Define node and edge matchers for graph isomorphism check nodeLabelNames = ["typesGH"] nodeLabelDefault = [()] nodeLabelOperator = [eq] @@ -60,14 +48,18 @@ def check_equivariant_graph( ) edgeMatch = generic_edge_match("order", 1, eq) + # List to store classified isomorphic pairs classified = [] + # Compare each graph to the first one for i in range(1, len(its_graphs)): - # Compare the first graph with each subsequent graph - if nx.is_isomorphic( + if not nx.is_isomorphic( its_graphs[0], its_graphs[i], node_match=nodeMatch, edge_match=edgeMatch ): - classified.append((0, i)) + return [], -1 # Early exit if no isomorphism is found + + classified.append((0, i)) + return classified, len(classified) @staticmethod @@ -78,7 +70,6 @@ def process_mapped_smiles( id_column: str = "R-id", ignore_aromaticity: bool = False, confident_mapper: str = "graphormer", - symbol: str = ">>", sanitize: bool = True, ) -> Dict[str, any]: """ @@ -114,56 +105,56 @@ def process_mapped_smiles( """ threshold = len(mapper_names) - 1 graphs_by_map = {id_column: mapped_smiles.get(id_column, "N/A")} + for mapper in mapper_names: + graphs_by_map[mapper] = (None, None, None) rules_by_map = {id_column: mapped_smiles.get(id_column, "N/A")} its_graphs = [] rules_graphs = [] - + sig = [] for mapper in mapper_names: try: - reactants_side, products_side = mapped_smiles[mapper].split(symbol) - - # Get reactants graph G - G = ITSExtraction.graph_from_smiles(reactants_side, sanitize) - - # Get products graph H - H = ITSExtraction.graph_from_smiles(products_side, sanitize) + # reactants_side, products_side = mapped_smiles[mapper].split(symbol) + G, H = rsmi_to_graph( + mapped_smiles[mapper], + drop_non_aam=True, + light_weight=True, + sanitize=sanitize, + ) # Construct the ITS graph ITS = ITSConstruction.ITSGraph(G, H, ignore_aromaticity) its_graphs.append(ITS) + graph_rules = RuleExtraction.extract_reaction_rules( + G, H, ITS, extend=False + ) + rc = graph_rules[2] + sig_rc = GraphSignature(rc).create_graph_signature() + if len(sig) > 0: + if sig[-1] != sig_rc: + rules_graphs = [] + break # Store graphs and ITS graphs_by_map[mapper] = (G, H, ITS) # Extract reaction rules - rules_by_map[mapper] = RuleExtraction.extract_reaction_rules( - G, H, ITS, extend=False - ) + rules_by_map[mapper] = graph_rules _, _, rules = rules_by_map[mapper] rules_graphs.append(rules) except Exception as e: logger.info(f"Error processing {mapper}: {e}") + rules_graphs = [] + break - # Fallback: Create a one-node graph for ITS and Rules - one_node_graph = nx.Graph() - one_node_graph.add_node(0) # Create a graph with a single node - - # Use the one-node graph for ITS and Rules - its_graphs.append(one_node_graph) - graphs_by_map[mapper] = (one_node_graph, one_node_graph, one_node_graph) - rules_by_map[mapper] = (one_node_graph, one_node_graph, one_node_graph) - rules_graphs.append(one_node_graph) - if len(rules_graphs) > 1: + if len(rules_graphs) == len(mapper_names): if check_method == "RC": _, equivariant = ITSExtraction.check_equivariant_graph(rules_graphs) elif check_method == "ITS": _, equivariant = ITSExtraction.check_equivariant_graph(its_graphs) else: - equivariant = 0 - # graphs_by_map['check_equivariant'] = classified + equivariant = -1 graphs_by_map["equivariant"] = equivariant - graphs_by_map_correct = deepcopy(graphs_by_map) graphs_by_map_incorrect = deepcopy(graphs_by_map) is_empty_graph_present = any( @@ -198,14 +189,13 @@ def process_mapped_smiles( def parallel_process_smiles( mapped_smiles_list: List[Dict[str, str]], mapper_names: List[str], - n_jobs: int = -1, - verbose: int = 10, + n_jobs: int = 1, + verbose: int = 0, id_column: str = "R-id", check_method="RC", export_full=False, ignore_aromaticity: bool = False, confident_mapper: str = "graphormer", - symbol: str = ">>", sanitize: bool = True, ) -> List[Dict[str, any],]: """ @@ -241,7 +231,6 @@ def parallel_process_smiles( id_column, ignore_aromaticity, confident_mapper, - symbol, sanitize, ) for mapped_smiles in mapped_smiles_list diff --git a/syntemp/SynITS/its_hadjuster.py b/syntemp/SynITS/its_hadjuster.py index adb7d5e..0699260 100644 --- a/syntemp/SynITS/its_hadjuster.py +++ b/syntemp/SynITS/its_hadjuster.py @@ -1,16 +1,18 @@ import itertools import networkx as nx +from operator import eq from copy import deepcopy from multiprocessing import Pool from joblib import Parallel, delayed from typing import Dict, List, Tuple, Iterable +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match from synutility.SynIO.debug import setup_logging from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynGraph.Descriptor.graph_signature import GraphSignature -from syntemp.SynITS.its_extraction import ITSExtraction from syntemp.SynRule.rules_extraction import RuleExtraction -from syntemp.SynUtils.graph_utils import ( +from syntemp.SynITS.hydrogen_utils import ( check_hcount_change, check_explicit_hydrogen, get_priority, @@ -21,39 +23,92 @@ class ITSHAdjuster: + """ + A class for infering hydrogen to complete reaction center or ITS graph. + """ + + @staticmethod + def check_equivariant_graph( + its_graphs: List[nx.Graph], + ) -> Tuple[List[Tuple[int, int]], int]: + """ + Checks for isomorphism among a list of ITS graphs. + + Parameters: + - its_graphs (List[nx.Graph]): A list of ITS graphs. + + Returns: + - List[Tuple[int, int]]: A list of tuples representing pairs of indices of + isomorphic graphs. + """ + nodeLabelNames = ["typesGH"] + nodeLabelDefault = [()] + nodeLabelOperator = [eq] + nodeMatch = generic_node_match( + nodeLabelNames, nodeLabelDefault, nodeLabelOperator + ) + edgeMatch = generic_edge_match("order", 1, eq) + + classified = [] + + for i in range(1, len(its_graphs)): + # Compare the first graph with each subsequent graph + if nx.is_isomorphic( + its_graphs[0], its_graphs[i], node_match=nodeMatch, edge_match=edgeMatch + ): + classified.append((0, i)) + return classified, len(classified) + + @staticmethod + def update_graph_data( + graph_data: Dict, + react_graph: nx.Graph, + prod_graph: nx.Graph, + its: nx.Graph, + ) -> Dict: + """ + Updates the graph data dictionary with new ITS and GraphRules based on the provided graphs. + + Parameters: + - graph_data (Dict): Existing graph data dictionary. + - react_graph (nx.Graph): Reactant graph. + - prod_graph (nx.Graph): Product graph. + - its (nx.Graph): Imaginary Transition State graph + + Returns: + - Dict: Updated graph data dictionary with new ITS and GraphRules. + """ + graph_data["ITSGraph"] = (react_graph, prod_graph, its) + graph_data["GraphRules"] = RuleExtraction.extract_reaction_rules( + react_graph, prod_graph, its, extend=False, n_knn=1 + ) + return graph_data @staticmethod def process_single_graph_data( graph_data: Dict, column: str, - return_all: bool = False, ignore_aromaticity: bool = False, balance_its: bool = True, - get_random_results=False, - fast_process: bool = False, + get_priority_graph: bool = False, ) -> Dict: """ - Processes a single dictionary containing graph information by applying - modifications based on hcount changes. - Optionally handles aromaticity and provides different return behaviors based on - the `return_all` flag. + Processes a single graph data dictionary, applying modifications based on + hydrogen count changes. Optionally handles aromaticity and + returns modified graph data. Parameters: - - graph_data (Dict): A dictionary containing essential graph information. This - includes nodes, edges, and other graph-specific data. - - column (str): The key in the dictionary where the graph tuple is stored, - typically pointing to the specific data structure to be modified. - - return_all (bool): A flag that determines the nature of the output. If True, the - function returns all modified data, otherwise it returns only the most relevant - changes. The default value is False. - - ignore_aromaticity (bool): A flag to indicate whether aromaticity should be - ignored during the graph processing. Ignoring aromaticity may affect the ITS - construction. The default value is False. + - graph_data (Dict): Dictionary containing graph. + - column (str): The key in the dictionary where the graph tuple is stored. + - ignore_aromaticity (bool): Flag to indicate if aromaticity should be ignored. + Default is False. + - balance_its (bool): Flag to balance the ITS. Default is True. + - get_priority_graph (bool): Flag to determine if priority graphs should be + considered. Default is False. Returns: - - Dict: An updated dictionary that includes the new internal topology structure - (ITS) and any applicable GraphRules, reflecting the modifications made based on - hydrogen counts and aromaticity considerations. + - Dict: Updated graph data dictionary, reflecting changes based on + hydrogen counts and aromaticity. """ graphs = deepcopy(graph_data) react_graph, prod_graph, its = graphs[column] @@ -63,7 +118,6 @@ def process_single_graph_data( ) if is_empty_graph_present: - # Update graph data if any graph is empty graphs["ITSGraph"], graphs["GraphRules"] = None, None return graphs @@ -72,32 +126,15 @@ def process_single_graph_data( graph_data = ITSHAdjuster.update_graph_data( graphs, react_graph, prod_graph, its ) - elif hcount_change < 5: + else: graph_data = ITSHAdjuster.process_multiple_hydrogens( graphs, react_graph, prod_graph, - its, ignore_aromaticity, - return_all, balance_its, - get_random_results, + get_priority_graph, ) - else: - if fast_process: - graphs["ITSGraph"], graphs["GraphRules"] = None, None - return graphs - else: - graph_data = ITSHAdjuster.process_high_hcount_change( - graphs, - react_graph, - prod_graph, - its, - ignore_aromaticity, - return_all, - balance_its, - get_random_results, - ) if graph_data["GraphRules"] is not None: is_empty_rc_present = any( (not isinstance(graph, nx.Graph) or graph.number_of_nodes() == 0) @@ -108,172 +145,33 @@ def process_single_graph_data( graph_data["GraphRules"] = None return graph_data - @staticmethod - def update_graph_data(graph_data, react_graph, prod_graph, its, ignore=False): - """ - Update graph data dictionary with new ITS and GraphRules based on the graphs - provided. - - Parameters: - - graph_data (Dict): Existing graph data. - - react_graph (nx.Graph), prod_graph (nx.Graph), its: Graphs and ITS to use. - - Returns: - Dict: Updated graph data dictionary. - """ - graph_data["ITSGraph"] = (react_graph, prod_graph, its) - graph_data["GraphRules"] = RuleExtraction.extract_reaction_rules( - react_graph, prod_graph, its, extend=False, n_knn=1 - ) - if ignore: - graph_data["ITSGraph"], graph_data["GraphRules"] = None, None - return graph_data - - @staticmethod - def process_multiple_hydrogens( - graph_data, - react_graph, - prod_graph, - its, - ignore_aromaticity, - return_all, - balance_its, - get_random_results=False, - ): - """ - Handles cases with hydrogen count changes between 2 and 4, inclusive. - Manages the creation of multiple hydrogen node scenarios and evaluates their - equivalence. - - Parameters: - - graph_data, react_graph, prod_graph, its, ignore_aromaticity, - return_all as described. - - Returns: - - Dict: Updated graph data. - """ - combinations_solution = ITSHAdjuster.add_hydrogen_nodes_multiple( - react_graph, prod_graph - ) - its_list = [ - ITSConstruction.ITSGraph( - i[0], i[1], ignore_aromaticity, balance_its=balance_its - ) - for i in combinations_solution - ] - _, equivariant = ITSExtraction.check_equivariant_graph(its_list) - pairwise_combinations = len(its_list) - 1 - if equivariant == pairwise_combinations: - graph_data = ITSHAdjuster.update_graph_data( - graph_data, *combinations_solution[0], its_list[0] - ) - else: - graph_data = ITSHAdjuster.process_high_hcount_change( - graph_data, - react_graph, - prod_graph, - its, - ignore_aromaticity, - return_all, - balance_its, - get_random_results, - ) - return graph_data - - @staticmethod - def process_high_hcount_change( - graph_data, - react_graph, - prod_graph, - its, - ignore_aromaticity, - return_all, - balance_its: bool = True, - get_random_results=False, - ): - """ - Handles cases with hydrogen count changes of 5 or more. - Similar to `process_multiple_hydrogens` but tailored for higher counts. - - Parameters: - - Same as `process_multiple_hydrogens`. - - Returns: - - Dict: Updated graph data. - """ - combinations_solution = ITSHAdjuster.add_hydrogen_nodes_multiple( - react_graph, prod_graph - ) - - its_list = [ - ITSConstruction.ITSGraph( - i[0], i[1], ignore_aromaticity, balance_its=balance_its - ) - for i in combinations_solution - ] - reaction_centers = [ - RuleExtraction.extract_reaction_rules(react_graph, prod_graph, i)[2] - for i in its_list - ] - - filtered_reaction_centers = [ - rc - for rc in reaction_centers - if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0 - ] - - filtered_combinations_solution = [ - comb - for rc, comb in zip(reaction_centers, combinations_solution) - if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0 - ] - - # Update the original lists with the filtered results - reaction_centers, combinations_solution = ( - filtered_reaction_centers, - filtered_combinations_solution, - ) - - priority_indices = get_priority(reaction_centers) - rc_list = [reaction_centers[i] for i in priority_indices] - its_list = [its_list[i] for i in priority_indices] - combinations_solution = [combinations_solution[i] for i in priority_indices] - _, equivariant = ITSExtraction.check_equivariant_graph(rc_list) - pairwise_combinations = len(its_list) - 1 - - if equivariant == pairwise_combinations: - - graph_data = ITSHAdjuster.update_graph_data( - graph_data, *combinations_solution[0], its_list[0] - ) - - else: - if get_random_results is True: - graph_data = ITSHAdjuster.update_graph_data( - graph_data, *combinations_solution[0], its_list[0] - ) - - else: - if return_all: - graph_data = ITSHAdjuster.update_graph_data( - graph_data, react_graph, prod_graph, its - ) - else: - graph_data["ITSGraph"], graph_data["GraphRules"] = None, None - return graph_data - @staticmethod def process_single_graph_data_safe( graph_data: Dict, column: str, - return_all: bool = False, ignore_aromaticity: bool = False, balance_its: bool = True, - get_random_results=False, - fast_process: bool = False, job_timeout: int = 1, + get_priority_graph: bool = False, ) -> Dict: - # pool = multiprocessing.pool.ThreadPool(1) + """ + Processes a single graph data dictionary asynchronously, handling potential + timeouts during processing. + + Parameters: + - graph_data (Dict): Dictionary containing graph data. + - column (str): Key to access the graph data. + - ignore_aromaticity (bool): Flag to ignore aromaticity during processing. + Default is False. + - balance_its (bool): Flag to balance the ITS. Default is True. + - job_timeout (int): Timeout in seconds for the asynchronous task. + Default is 1 second. + - get_priority_graph (bool): Flag to include priority graph processing. + Default is False. + + Returns: + - Dict: Processed graph data dictionary. + """ pool = Pool(processes=1) try: async_result = pool.apply_async( @@ -281,11 +179,9 @@ def process_single_graph_data_safe( ( graph_data, column, - return_all, ignore_aromaticity, balance_its, - get_random_results, - fast_process, + get_priority_graph, ), ) graph_data = async_result.get(job_timeout) @@ -304,99 +200,153 @@ def process_single_graph_data_safe( pool.join() return graph_data - @staticmethod def process_graph_data_parallel( + self, graph_data_list: List[Dict], column: str, n_jobs: int, verbose: int, - return_all: bool = False, ignore_aromaticity: bool = False, balance_its: bool = True, - get_random_results: bool = False, - fast_process: bool = False, job_timeout: int = 5, + safe: bool = False, + get_priority_graph: bool = False, ) -> List[Dict]: """ - Processes a list of dictionaries containing graph information in parallel. + Processes a list of graph data dictionaries in parallel, utilizing multiple jobs + for faster processing. Parameters: - - graph_data_list (List[Dict]): A list of dictionaries containing graph - information. - - column (str): The key in the dictionary where the graph tuple is stored. - - n_jobs (int): The number of concurrent jobs. - - verbose (int): The verbosity level. + - graph_data_list (List[Dict]): List of dictionaries containing graph data. + - column (str): Key where the graph data is stored. + - n_jobs (int): Number of parallel jobs to run. + - verbose (int): Verbosity level for the parallel process. + - ignore_aromaticity (bool): Flag to ignore aromaticity during processing. + Default is False. + - balance_its (bool): Flag to balance ITS. + Default is True. + - job_timeout (int): Timeout for job processing in seconds. + Default is 5 seconds. + - safe (bool): Flag to use safe parallel processing (timeout). + Default is False. + - get_priority_graph (bool): Flag to prioritize graphs. + Default is False. Returns: - - List[Dict]: A list of dictionaries with the updated graph data. + - List[Dict]: A list of processed graph data dictionaries. """ - processed_data = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(ITSHAdjuster.process_single_graph_data_safe)( - graph_data, - column, - return_all, - ignore_aromaticity, - balance_its, - get_random_results, - fast_process, - job_timeout, + if safe: + processed_data = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(self.process_single_graph_data_safe)( + graph_data, + column, + ignore_aromaticity, + balance_its, + job_timeout, + get_priority_graph, + ) + for graph_data in graph_data_list + ) + else: + processed_data = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(self.process_single_graph_data)( + graph_data, + column, + ignore_aromaticity, + balance_its, + get_priority_graph, + ) + for graph_data in graph_data_list ) - for graph_data in graph_data_list - ) return processed_data @staticmethod - def add_hydrogen_nodes_multiple_utils( - graph: nx.Graph, - node_id_pairs: Iterable[Tuple[int, int]], - atom_map_update: bool = True, - ) -> nx.Graph: + def process_multiple_hydrogens( + graph_data: Dict, + react_graph: nx.Graph, + prod_graph: nx.Graph, + ignore_aromaticity: bool, + balance_its: bool, + get_priority_graph: bool = False, + ) -> Dict: """ - Creates and returns a new graph with added hydrogen nodes based on the input graph - and node ID pairs. + Handles cases where hydrogen counts change significantly between the reactant + and product graphs. Adjusts hydrogen nodes accordingly and evaluates equivalence. Parameters: - - graph (nx.Graph): The base graph to which the nodes will be added. - - node_id_pairs (Iterable[Tuple[int, int]]): Pairs of node IDs (original node, new - hydrogen node) to link with hydrogen. - - atom_map_update (bool): If True, update the 'atom_map' attribute with the new - hydrogen node ID; otherwise, retain the original node's 'atom_map'. + - graph_data (Dict): Dictionary of graph data. + - react_graph (nx.Graph): Reactant graph. + - prod_graph (nx.Graph): Product graph. + - ignore_aromaticity (bool): Flag to ignore aromaticity. Default is False. + - balance_its (bool): Flag to balance ITS. Default is True. + - get_priority_graph (bool): Flag to prioritize graph processing. Default is False. Returns: - - nx.Graph: A new graph instance with the added hydrogen nodes. + - Dict: Updated graph data after handling hydrogen node adjustments. """ - new_graph = deepcopy(graph) - for node_id, new_hydrogen_node_id in node_id_pairs: - atom_map_val = ( - new_hydrogen_node_id - if atom_map_update - else new_graph.nodes[node_id].get("atom_map", 0) + combinations_solution = ITSHAdjuster.add_hydrogen_nodes_multiple( + react_graph, prod_graph + ) + # print("length original:", len(combinations_solution)) + its_list = [ + ITSConstruction.ITSGraph( + i[0], i[1], ignore_aromaticity, balance_its=balance_its ) - new_graph.add_node( - new_hydrogen_node_id, - charge=0, - hcount=0, - aromatic=False, - element="H", - atom_map=atom_map_val, - isomer="N", - partial_charge=0, - hybridization=0, - in_ring=False, - explicit_valence=0, - implicit_hcount=0, + for i in combinations_solution + ] + # rc_list = [get_rc(i) for i in its_list] + rc_list = [ + RuleExtraction.extract_reaction_rules(react_graph, prod_graph, i, False, 1)[ + 2 + ] + for i in its_list + ] + rc_list = [ + rc + for rc in rc_list + if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0 + ] + rc_sig = [GraphSignature(i).create_graph_signature() for i in rc_list] + combinations_solution = [ + comb + for rc, comb in zip(rc_list, combinations_solution) + if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0 + ] + its_list = [ + ITSConstruction.ITSGraph( + i[0], i[1], ignore_aromaticity, balance_its=balance_its ) - new_graph.add_edge( - node_id, - new_hydrogen_node_id, - order=1.0, - ez_isomer="N", - bond_type="SINGLE", - conjugated=False, - in_ring=False, + for i in combinations_solution + ] + if len(rc_sig) != len(set(rc_sig)): + _, equivariant = ITSHAdjuster.check_equivariant_graph(rc_list) + else: + equivariant = 0 + + pairwise_combinations = len(rc_list) - 1 + if equivariant == pairwise_combinations: + graph_data = ITSHAdjuster.update_graph_data( + graph_data, *combinations_solution[0], its_list[0] ) - return new_graph + else: + graph_data["ITSGraph"], graph_data["GraphRules"] = None, None + if get_priority_graph: + priority_indices = get_priority(rc_list) + rc_list = [rc_list[i] for i in priority_indices] + rc_sig = [rc_sig[i] for i in priority_indices] + its_list = [its_list[i] for i in priority_indices] + combinations_solution = [ + combinations_solution[i] for i in priority_indices + ] + if len(rc_sig) != len(set(rc_sig)): + _, equivariant = ITSHAdjuster.check_equivariant_graph(rc_list) + pairwise_combinations = len(rc_list) - 1 + if equivariant == pairwise_combinations: + graph_data = ITSHAdjuster.update_graph_data( + graph_data, *combinations_solution[0], its_list[0] + ) + return graph_data @staticmethod def add_hydrogen_nodes_multiple( @@ -475,3 +425,55 @@ def add_hydrogen_nodes_multiple( ) updated_graphs.append((current_react_graph, current_prod_graph)) return updated_graphs + + @staticmethod + def add_hydrogen_nodes_multiple_utils( + graph: nx.Graph, + node_id_pairs: Iterable[Tuple[int, int]], + atom_map_update: bool = True, + ) -> nx.Graph: + """ + Creates and returns a new graph with added hydrogen nodes based on the input graph + and node ID pairs. + + Parameters: + - graph (nx.Graph): The base graph to which the nodes will be added. + - node_id_pairs (Iterable[Tuple[int, int]]): Pairs of node IDs (original node, new + hydrogen node) to link with hydrogen. + - atom_map_update (bool): If True, update the 'atom_map' attribute with the new + hydrogen node ID; otherwise, retain the original node's 'atom_map'. + + Returns: + - nx.Graph: A new graph instance with the added hydrogen nodes. + """ + new_graph = deepcopy(graph) + for node_id, new_hydrogen_node_id in node_id_pairs: + atom_map_val = ( + new_hydrogen_node_id + if atom_map_update + else new_graph.nodes[node_id].get("atom_map", 0) + ) + new_graph.add_node( + new_hydrogen_node_id, + charge=0, + hcount=0, + aromatic=False, + element="H", + atom_map=atom_map_val, + isomer="N", + partial_charge=0, + hybridization=0, + in_ring=False, + explicit_valence=0, + implicit_hcount=0, + ) + new_graph.add_edge( + node_id, + new_hydrogen_node_id, + order=1.0, + ez_isomer="N", + bond_type="SINGLE", + conjugated=False, + in_ring=False, + ) + return new_graph diff --git a/syntemp/SynRule/rc_cluster.py b/syntemp/SynRule/rc_cluster.py index eb1a2e5..47ec0a4 100644 --- a/syntemp/SynRule/rc_cluster.py +++ b/syntemp/SynRule/rc_cluster.py @@ -147,3 +147,37 @@ def fit( value["class"] = cluster_indices[key] return new_data + + def fit_graphs( + self, graphs_data: List[nx.Graph], attribute: List[str] + ) -> Tuple[List[int], List[Dict[str, Any]]]: + """ + Automatically clusters the input graphs based on provided attributes and + determines the cluster indices for each graph. The method may use existing + templates for clustering or generate new ones. + + Parameters: + - graphs_data (List[nx.Graph]): A list of NetworkX graph objects to cluster. + - attribute (List[str]): A list of attributes associated with each graph, + used for clustering comparisons. + + Returns: + - Tuple[List[int], List[Dict[str, Any]]]: + - A list of cluster indices (List[int]) corresponding to each graph in + the input list, indicating which cluster each graph belongs to. + - A list of updated or newly generated templates (List[Dict[str, Any]]), + representing the cluster information or other relevant data for each + group of graphs. + """ + + # Perform clustering without predefined templates + _, graph_to_cluster_dict = self.auto_cluster( + graphs_data, attribute, self.nodeMatch, self.edgeMatch + ) + + # Generate the cluster indices based on the clustering result + cluster_indices = [ + graph_to_cluster_dict.get(i, None) for i in range(len(graphs_data)) + ] + + return cluster_indices diff --git a/syntemp/SynUtils/graph_utils.py b/syntemp/SynUtils/graph_utils.py index d95c392..bc57003 100644 --- a/syntemp/SynUtils/graph_utils.py +++ b/syntemp/SynUtils/graph_utils.py @@ -128,7 +128,7 @@ def check_graph_type(G: nx.Graph) -> str: return "Complex Cyclic" -def get_cycle_member_rings(G: nx.Graph) -> List[int]: +def get_cycle_member_rings(G: nx.Graph, type="minimal") -> List[int]: """ Identifies all cycles in the given graph using cycle bases to ensure no overlap and returns a list of the sizes of these cycles (member rings), @@ -144,7 +144,10 @@ def get_cycle_member_rings(G: nx.Graph) -> List[int]: raise TypeError("Input must be a networkx Graph object.") # Find cycle basis for the graph which gives non-overlapping cycles - cycles = nx.minimum_cycle_basis(G) + if type == "minimal": + cycles = nx.minimum_cycle_basis(G) + else: + cycles = nx.cycle_basis(G) # Determine the size of each cycle (member ring) member_rings = [len(cycle) for cycle in cycles] @@ -347,7 +350,9 @@ def get_priority(reaction_centers: List[Any]) -> List[int]: """ # Extract topology types and ring sizes from reaction centers topo_type = [check_graph_type(center) for center in reaction_centers] - cyclic = [get_cycle_member_rings(center) for center in reaction_centers] + cyclic = [ + get_cycle_member_rings(center, "fundamental") for center in reaction_centers + ] # Adjust ring information based on the graph type for index, graph_type in enumerate(topo_type): diff --git a/syntemp/__main__.py b/syntemp/__main__.py index 63a832d..79ba9b9 100644 --- a/syntemp/__main__.py +++ b/syntemp/__main__.py @@ -24,7 +24,7 @@ def parse_arguments(): "--rsmi", type=str, default="reactions", help="Reaction SMILES column" ) parser.add_argument( - "--n_jobs", type=int, default=1, help="Number of jobs to run in parallel" + "--n_jobs", type=int, default=4, help="Number of jobs to run in parallel" ) parser.add_argument("--verbose", type=int, default=2, help="Verbosity level") parser.add_argument( @@ -37,9 +37,6 @@ def parse_arguments(): parser.add_argument( "--fix_hydrogen", action="store_true", help="Enable fixing hydrogen" ) - parser.add_argument( - "--refinement_its", action="store_true", help="Refine non-equivalent ITS" - ) parser.add_argument( "--fast_process", action="store_true", @@ -97,7 +94,6 @@ def main(): safe_mode=args.safe_mode, save_dir=args.save_dir, fix_hydrogen=args.fix_hydrogen, - refinement_its=args.refinement_its, rerun_aam=args.rerun_aam, log_file=args.log_file, log_level=args.log_level, diff --git a/syntemp/auto_template.py b/syntemp/auto_template.py index c2df82d..b4a1d09 100644 --- a/syntemp/auto_template.py +++ b/syntemp/auto_template.py @@ -153,15 +153,14 @@ def temp_extract( # Step 3: Extract ITS graphs and categorize them self.logger.info("Extract ITS graphs and categorize them.") its_correct, its_incorrect, uncertain_hydrogen = extract_its( - aam_data, - self.mapper_types, - self.batch_size, - self.verbose, - self.n_jobs, - self.fix_hydrogen, - self.save_dir, - get_random_results=self.get_random_hydrogen, - fast_process=self.fast_process, + data=aam_data, + mapper_types=self.mapper_types, + batch_size=self.batch_size, + verbose=self.verbose, + n_jobs=self.n_jobs, + fix_hydrogen=self.fix_hydrogen, + save_dir=self.save_dir, + ) # Step 4: Extract rules from the correct ITS graphs diff --git a/syntemp/pipeline.py b/syntemp/pipeline.py index 4d1b997..0693013 100644 --- a/syntemp/pipeline.py +++ b/syntemp/pipeline.py @@ -159,10 +159,7 @@ def extract_its( fix_hydrogen: bool = True, save_dir: Optional[str] = None, data_name: str = "", - symbol: str = ">>", - get_random_results: bool = False, - fast_process: bool = False, - job_timeout: int = 5, + job_timeout: int = 60, ) -> List[dict]: """ Executes the extraction of ITS graphs from reaction data in batches, @@ -210,19 +207,16 @@ def extract_its( verbose=verbose, export_full=False, check_method="RC", - symbol=symbol, ) if fix_hydrogen: if i == 1 or (i % 10 == 0 and i >= 10): logger.info(f"Fixing hydrogen for batch {i + 1}/{num_batches}.") - batch_processed = ITSHAdjuster.process_graph_data_parallel( + batch_processed = ITSHAdjuster().process_graph_data_parallel( batch_correct, "ITSGraph", - n_jobs=n_jobs, + n_jobs=1, verbose=verbose, - get_random_results=get_random_results, - fast_process=fast_process, job_timeout=job_timeout, )