diff --git a/Test/SynGraph/Descriptor/test_graph_signature.py b/Test/SynGraph/Descriptor/test_graph_signature.py index 4f51349..37e096a 100644 --- a/Test/SynGraph/Descriptor/test_graph_signature.py +++ b/Test/SynGraph/Descriptor/test_graph_signature.py @@ -10,43 +10,92 @@ def setUp(self): data = load_from_pickle("Data/test.pkl.gz") self.rc = data[0]["GraphRules"][2] self.its = data[0]["ITSGraph"][2] + self.graph_signature = GraphSignature(self.its) - def test_create_topology_signature(self): - signature = GraphSignature(self.rc) - self.assertEqual( - signature.create_topology_signature( - topo="Single Cyclic", cycle=[4], rstep=1 - ), - "114", + def test_validate_graph(self): + """Test the validation of graph structure""" + # Test should pass if graph is valid (no exceptions) + try: + self.graph_signature._validate_graph() + except ValueError as e: + self.fail(f"Graph validation failed: {str(e)}") + + def test_edge_signature(self): + """Test edge signature creation""" + edge_signature = self.graph_signature.create_edge_signature( + include_neighbors=False, max_hop=1 + ) + + # Check that the edge signature is a non-empty string + self.assertIsInstance(edge_signature, str) + self.assertGreater(len(edge_signature), 0) + self.assertIn("Br0", edge_signature) + self.assertIn("{0.0,1.0}", edge_signature) + self.assertIn("H0", edge_signature) + + def test_edge_signature_with_neighbors(self): + """Test edge signature creation including neighbors""" + edge_signature_with_neighbors = self.graph_signature.create_edge_signature( + include_neighbors=True, max_hop=1 ) - def test_create_node_signature(self): - signature = GraphSignature(self.rc) - self.assertEqual(signature.create_node_signature(), "BrCHN") + # Check that the edge signature with neighbors includes node degrees and neighbors + self.assertIsInstance(edge_signature_with_neighbors, str) + self.assertGreater(len(edge_signature_with_neighbors), 0) + self.assertIn("d1", edge_signature_with_neighbors) # node degree for neighbor - def test_create_node_signature_condensed(self): - signature = GraphSignature(self.its) - self.assertEqual(signature.create_node_signature(), "BrC{23}ClHN{3}O{5}S") + def test_wl_hash(self): + """Test the Weisfeiler-Lehman hash generation""" + wl_hash = self.graph_signature.create_wl_hash(iterations=3) - def test_create_edge_signature(self): - signature = GraphSignature(self.rc) - self.assertEqual( - signature.create_edge_signature(), "Br[-1]H/Br[1]C/C[-1]N/H[1]N" + # Check that the WL hash is a valid hexadecimal string + self.assertIsInstance(wl_hash, str) + self.assertRegex(wl_hash, r"^[a-f0-9]{64}$") # SHA-256 hash format + + def test_graph_signature(self): + """Test the complete graph signature creation""" + complete_graph_signature = self.graph_signature.create_graph_signature( + include_wl_hash=True, include_neighbors=True, max_hop=1 ) - def test_create_graph_signature(self): - # Ensure the graph signature combines the results correctly - signature = GraphSignature(self.rc) - node_signature = "BrCHN" - edge_signature = "Br[-1]H/Br[1]C/C[-1]N/H[1]N" - topo_signature = "114" - expected = f"{topo_signature}.{node_signature}.{edge_signature}" - self.assertEqual( - signature.create_graph_signature(topo="Single Cyclic", cycle=[4], rstep=1), - expected, + # Check that the graph signature is a non-empty string + self.assertIsInstance(complete_graph_signature, str) + self.assertGreater(len(complete_graph_signature), 0) + + def test_invalid_node_attributes(self): + """Test for missing node attributes""" + self.rc.add_node(4) # Missing 'element' and 'charge' + + with self.assertRaises(ValueError) as context: + invalid_graph_signature = GraphSignature(self.rc) + invalid_graph_signature._validate_graph() + + self.assertIn( + "Node 4 is missing the 'element' attribute", str(context.exception) ) + def test_invalid_edge_order(self): + """Test for invalid edge 'order' attribute""" + self.its.add_edge( + 3, 4, order="invalid_order", state="steady" + ) # Invalid 'order' type + + with self.assertRaises(ValueError) as context: + invalid_graph_signature = GraphSignature(self.its) + invalid_graph_signature._validate_graph() + + self.assertIn("Edge (3, 4) has an invalid 'order'", str(context.exception)) + + def test_invalid_edge_state(self): + """Test for invalid edge 'state' attribute""" + self.its.add_edge(2, 4, order=1.0, state="invalid_state") # Invalid 'state' + + with self.assertRaises(ValueError) as context: + invalid_graph_signature = GraphSignature(self.its) + invalid_graph_signature._validate_graph() + + self.assertIn("Edge (2, 4) has an invalid 'state'", str(context.exception)) + -# Running the tests if __name__ == "__main__": unittest.main() diff --git a/pyproject.toml b/pyproject.toml index 24607e9..7ed9bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synutility" -version = "0.0.12" +version = "0.0.13" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] @@ -22,6 +22,7 @@ dependencies = [ "rdkit>=2024.3.3", "networkx>=3.3", "seaborn>=0.13.2", + "requests>=3.4.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 35d3b27..5369dc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ rxn-chem-utils==1.5.0 rxn-utils==2.0.0 rxnmapper==0.3.0 rdkit >= 2024.3.3 -pandas>=2.2.0 \ No newline at end of file +pandas>=2.2.0 +requests>=3.4.0 \ No newline at end of file diff --git a/synutility/SynGraph/Descriptor/graph_signature.py b/synutility/SynGraph/Descriptor/graph_signature.py index 5f926a5..0f558c2 100644 --- a/synutility/SynGraph/Descriptor/graph_signature.py +++ b/synutility/SynGraph/Descriptor/graph_signature.py @@ -1,14 +1,12 @@ +import hashlib import networkx as nx -from collections import Counter class GraphSignature: """ - Provides methods to generate canonical signatures for graph nodes, edges, and complete graphs, - useful for comparisons or identification in graph-based data structures. - - Attributes: - graph (nx.Graph): The graph for which signatures will be generated. + Provides methods to generate canonical signatures for graph edges (with flexible 'order' and 'state' attributes, + and node degrees/neighbor information), various spectral invariants, adjacency matrix, and complete graphs. + Aims for high uniqueness without relying solely on isomorphism checks. """ def __init__(self, graph: nx.Graph): @@ -19,104 +17,220 @@ def __init__(self, graph: nx.Graph): - graph (nx.Graph): A NetworkX graph instance. """ self.graph = graph + self._validate_graph() - def create_node_signature(self, condensed: bool = True) -> str: + def _validate_graph(self): """ - Generates a canonical node signature. If `condensed` is True, it condenses - consecutive occurrences of elements, formatting like 'Br{1}C{10}'. - - Parameters: - - condensed (bool): If True, condenses elements with counts. If False, keeps the original format. + Validates that all nodes have the required attributes ('element' and 'charge'), + and all edges have the required 'order' attribute as int, float, or tuple of two floats, + and optionally the 'state' attribute. - Returns: - - str: A concatenated string of sorted node elements, optionally with counts. + Raises: + - ValueError: If any node is missing the 'element' or 'charge' attribute, + or if any edge is missing the 'order' attribute or has an invalid type. """ - # Sort elements - elements = sorted(data["element"] for _, data in self.graph.nodes(data=True)) - - if condensed: - # Count occurrences and format with counts - element_counts = Counter(elements) - signature_parts = [] - for element, count in element_counts.items(): - if count > 1: - signature_parts.append(f"{element}{{{count}}}") - else: - signature_parts.append(element) - return "".join(signature_parts) - else: - # Return the original, uncompressed format - return "".join(elements) - - def create_edge_signature(self) -> str: + for node, data in self.graph.nodes(data=True): + if "element" not in data: + raise ValueError(f"Node {node} is missing the 'element' attribute.") + if "charge" not in data: + raise ValueError(f"Node {node} is missing the 'charge' attribute.") + + for u, v, data in self.graph.edges(data=True): + if "order" not in data: + raise ValueError(f"Edge ({u}, {v}) is missing the 'order' attribute.") + order = data["order"] + if isinstance(order, tuple): + if len(order) != 2 or not all( + isinstance(o, (int, float)) for o in order + ): + raise ValueError( + f"Edge ({u}, {v}) has an invalid 'order'. It must be a tuple of two ints/floats." + ) + elif not isinstance(order, (int, float)): + raise ValueError( + f"Edge ({u}, {v}) has an invalid 'order'. It must be an int, float, or a tuple of two ints/floats." + ) + + # Optional: Validate 'state' attribute if present + state = data.get("state", "steady") # Default to 'steady' if missing + if state not in {"break", "form", "steady"}: + raise ValueError( + f"Edge ({u}, {v}) has an invalid 'state'. It must be 'break', 'form', or 'steady'." + ) + + def create_edge_signature( + self, include_neighbors: bool = False, max_hop: int = 2 + ) -> str: """ - Generates a canonical edge signature by formatting each edge with sorted node elements and a bond order, - separated by '/', with each edge represented as 'node1[standard_order]node2'. + Generates a canonical edge signature by formatting each edge with sorted node elements (including charge), + node degrees, bond order, bond state, and optionally including neighbor information and topological context. + + Parameters: + - include_neighbors (bool): Whether to include neighbors' details in the edge signature. + - max_hop (int): Maximum number of hops to include for neighbor-level structural information. Returns: - str: A concatenated and sorted string of edge representations. """ edge_signature_parts = [] + for u, v, data in self.graph.edges(data=True): - standard_order = int( - data.get("standard_order", 1) - ) # Default to 1 if missing - node1, node2 = sorted( - [self.graph.nodes[u]["element"], self.graph.nodes[v]["element"]] - ) - part = f"{node1}[{standard_order}]{node2}" - edge_signature_parts.append(part) + # Retrieve bond order (default to (1.0, 1.0) if missing) + order = data.get("order", (1.0, 1.0)) + + # Format order as a tuple (default or actual value) + if isinstance(order, tuple): + order_str = f"{{{order[0]:.1f},{order[1]:.1f}}}" + else: + order_str = f"{float(order):.1f}" + + # Get node elements and charges for both nodes + node1_element = self.graph.nodes[u].get( + "element", "X" + ) # Default to 'X' if missing + node1_charge = self.graph.nodes[u].get( + "charge", 0 + ) # Default to 0 if missing + node2_element = self.graph.nodes[v].get("element", "X") + node2_charge = self.graph.nodes[v].get("charge", 0) + + # Construct node representation with element and charge + node1 = f"{node1_element}{node1_charge}" + node2 = f"{node2_element}{node2_charge}" + + # Optionally include neighbors in the signature + if include_neighbors: + neighbors_u = sorted( + [ + f"{self.graph.nodes[neighbor].get('element', 'X')}{self.graph.nodes[neighbor].get('charge', 0)}" + + f"d{self.graph.degree(neighbor)}" + for neighbor in self.graph.neighbors(u) + ] + ) + neighbors_v = sorted( + [ + f"{self.graph.nodes[neighbor].get('element', 'X')}{self.graph.nodes[neighbor].get('charge', 0)}" + + f"d{self.graph.degree(neighbor)}" + for neighbor in self.graph.neighbors(v) + ] + ) + + # Represent neighbors within square brackets + node1_neighbors = "".join(neighbors_u) + node2_neighbors = "".join(neighbors_v) + node1 = f"{node1}[{node1_neighbors}]" + node2 = f"{node2}[{node2_neighbors}]" + + # Include k-hop neighborhood information + if max_hop > 1: + node1_neighbors_khop = self._get_khop_neighbors(u, max_hop) + node2_neighbors_khop = self._get_khop_neighbors(v, max_hop) + node1 += f"[{node1_neighbors_khop}]" + node2 += f"[{node2_neighbors_khop}]" + + # Sort nodes to ensure consistency in edge signature (avoid direction dependency) + node1, node2 = sorted([node1, node2]) + + # Format the edge signature and append it + edge_part = f"{node1}{order_str}{node2}" + edge_signature_parts.append(edge_part) + + # Sort all edge signatures to ensure consistency in the final representation return "/".join(sorted(edge_signature_parts)) - def create_topology_signature(self, topo, cycle, rstep) -> str: + def _get_khop_neighbors(self, node, max_hop): """ - Generates a topology signature for the graph based on its cyclic properties and structure. - The topology is classified and quantified by identifying cycles and other structural features. + Retrieves the k-hop neighborhood information for a given node. + + Parameters: + - node (int): The node for which to get neighborhood information. + - max_hop (int): Maximum number of hops for neighborhood exploration. Returns: - - str: A string representing the numerical and qualitative topology signature of the graph. + - str: A concatenated string representing the k-hop neighborhood information. """ + k_hop_neighbors = [] + current_hop_neighbors = [node] + for _ in range(max_hop): + next_hop_neighbors = [] + for n in current_hop_neighbors: + next_hop_neighbors.extend(list(self.graph.neighbors(n))) + # Filter out already seen nodes to avoid loops + next_hop_neighbors = set(next_hop_neighbors) - set(k_hop_neighbors) + k_hop_neighbors.extend(next_hop_neighbors) + current_hop_neighbors = next_hop_neighbors + + # Return sorted k-hop neighborhood info + return "".join( + sorted( + [ + f"{self.graph.nodes[neighbor].get('element', 'X')}{self.graph.nodes[neighbor].get('charge', 0)}" + for neighbor in k_hop_neighbors + ] + ) + ) - topo_mapping = { - "Acyclic": 0, - "Single Cyclic": 1, - "Combinatorial Cyclic": 2, - "Complex Cyclic": 3, - } + def create_wl_hash(self, iterations: int = 3) -> str: + """ + Generates a Weisfeiler-Lehman (WL) hash for the graph to capture its structural features. - topo_code = topo_mapping.get(topo, 4) + Parameters: + - iterations (int): Number of WL iterations to perform. - rstep = len(cycle) - cycle_str = "".join(map(str, cycle)) - return f"{rstep}{topo_code}{cycle_str}" + Returns: + - str: A hexadecimal hash representing the WL feature. + """ + # Initialize labels with both 'element' and 'charge' + labels = { + node: f"{data['element']}{data.get('charge', 0)}" + for node, data in self.graph.nodes(data=True) + } + for _ in range(iterations): + new_labels = {} + for node in self.graph.nodes(): + # Gather sorted labels of neighbors + neighbor_labels = sorted( + labels[neighbor] for neighbor in self.graph.neighbors(node) + ) + # Concatenate current label with neighbor labels + concatenated = labels[node] + "".join(neighbor_labels) + # Hash the concatenated string to obtain a new label + new_label = hashlib.sha256(concatenated.encode()).hexdigest() + new_labels[node] = new_label + labels = new_labels + # Aggregate all node labels into a sorted string and hash it + sorted_labels = sorted(labels.values()) + aggregated = "".join(sorted_labels) + graph_hash = hashlib.sha256(aggregated.encode()).hexdigest() + return graph_hash def create_graph_signature( self, - condensed: bool = True, - topology: bool = True, - nodes: bool = True, - edges: bool = True, - topo: str = None, - cycle: list = None, - rstep: int = None, + include_wl_hash: bool = True, + include_neighbors: bool = True, + max_hop: int = 1, ) -> str: """ - Combines node, edge, and topology signatures into a single comprehensive graph signature. + Combines edge, various spectral invariants, and WL hash into a single comprehensive graph signature. + + Parameters: + - include_wl_hash (bool): Whether to include the Weisfeiler-Lehman hash. + - include_spectral (bool): Whether to include spectral invariants. + - include_combined_hash (bool): Whether to include the combined hash. + - include_neighbors (bool): Whether to include neighbor information in edge signatures. Returns: - - str: A concatenated string representing the complete graph signature formatted as - 'topology_signature.node_signature.edge_signature'. + - str: A concatenated string representing the complete graph signature. """ - if topology: - topo_signature = self.create_topology_signature(topo, cycle, rstep) - else: - topo_signature = "" - if nodes: - node_signature = self.create_node_signature(condensed) - else: - node_signature = "" - if edges: - edge_signature = self.create_edge_signature() - else: - edge_signature = "" - return f"{topo_signature}.{node_signature}.{edge_signature}" + signatures = [] + + if include_wl_hash: + wl_signature = self.create_wl_hash() + signatures.append(f"{wl_signature}") + + edge_signature = self.create_edge_signature( + include_neighbors=include_neighbors, max_hop=max_hop + ) + signatures.append(f"{edge_signature}") + + return "|".join(signatures) diff --git a/synutility/SynIO/Format/smi_to_id.py b/synutility/SynIO/Format/smi_to_id.py new file mode 100644 index 0000000..bd7bb71 --- /dev/null +++ b/synutility/SynIO/Format/smi_to_id.py @@ -0,0 +1,119 @@ +import time +import requests +import urllib.parse +from typing import List +from joblib import Parallel, delayed + + +def smiles_to_iupac(smiles_string: str, timeout: int = 1): + """ + Converts a SMILES string to its corresponding IUPAC name(s) using the PubChem PUG REST API. + + Parameters: + - smiles_string (str): The SMILES string of the compound (e.g., "C=O" for formaldehyde). + - timeout (int, optional): The timeout in seconds for the request. Default is 1 second. + + Returns: + - list: A list of IUPAC names associated with the SMILES string. Returns an empty list if none found. + """ + # URL encode the SMILES string to handle special characters + encoded_smiles = urllib.parse.quote(smiles_string) + + # PubChem PUG REST API endpoint to retrieve properties + url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{encoded_smiles}/property/IUPACName/JSON" + + retries = 3 # Number of retries in case of failure + delay = 2 # Delay between retries (in seconds) + + for attempt in range(retries): + try: + response = requests.get(url, timeout=timeout) # Adjust timeout for speed + response.raise_for_status() # Raise an HTTPError for bad responses + + data = response.json() + + # Extract the IUPAC name(s) from the response + properties = data.get("PropertyTable", {}).get("Properties", []) + + if not properties: + print(f"No properties found for SMILES: {smiles_string}") + return [] + + iupac_names = [ + prop.get("IUPACName") for prop in properties if prop.get("IUPACName") + ] + + if iupac_names: + return iupac_names + else: + print(f"No IUPAC name found for SMILES: {smiles_string}") + return [] + + except (requests.exceptions.RequestException, ValueError, KeyError) as e: + # If an error occurs, retry a few times + print( + f"Attempt {attempt + 1} failed for SMILES: {smiles_string}, Error: {e}" + ) + if attempt < retries - 1: + time.sleep(delay) # Wait before retrying + else: + print(f"Final failure for SMILES: {smiles_string}") + return [] + + return [] + + +def batch_process_smiles(smiles_batch: List[str], timeout=1): + """ + Processes a batch of SMILES strings to get IUPAC names. + + Parameters: + - smiles_batch (list): A list of SMILES strings to process. + - timeout (int): Timeout for requests (in seconds). + + Returns: + - list: A list of IUPAC name results for each SMILES in the batch. + """ + return [smiles_to_iupac(smiles, timeout) for smiles in smiles_batch] + + +def get_iupac_for_smiles_list( + smiles_list: List[str], batch_size=10, n_jobs=4, timeout=1 +): + """ + Convert a list of SMILES strings to their corresponding IUPAC names using the PubChem API with batch processing. + + Parameters: + smiles_list (list): A list of SMILES strings to be converted to IUPAC names. + batch_size (int): Number of SMILES strings to process in each batch. + n_jobs (int): Number of parallel jobs to run for batch processing. + timeout (int): Timeout for requests (in seconds). + + Returns: + dict: A dictionary with SMILES as keys and lists of IUPAC names as values. + """ + # Split the list into smaller batches + # fmt: off + batches = [ + smiles_list[i: i + batch_size] for i in range(0, len(smiles_list), batch_size) + ] + # fmt: on + + # Use joblib's Parallel and delayed to process batches in parallel + batch_results = Parallel(n_jobs=n_jobs)( + delayed(batch_process_smiles)(batch, timeout) for batch in batches + ) + + # Flatten the list of results and map to SMILES + flattened_results = [item for sublist in batch_results for item in sublist] + iupac_dict = dict(zip(smiles_list, flattened_results)) + + return iupac_dict + + +# Example of usage +smiles_list = ["CCO", "C=O", "CC(=O)O", "C1=CC=CC=C1", "C2H6O", "C4H10", "C5H12"] +iupac_results = get_iupac_for_smiles_list(smiles_list, batch_size=3, n_jobs=2) + +for smiles, iupac_names in iupac_results.items(): + print(f"SMILES: {smiles} => IUPAC Names: {iupac_names}") diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py index 84179e7..1973e77 100644 --- a/synutility/SynVis/graph_visualizer.py +++ b/synutility/SynVis/graph_visualizer.py @@ -57,12 +57,13 @@ def plot_its( edge_color: str = "#000000", edge_weight: float = 2.0, show_atom_map: bool = False, - use_edge_color: bool = False, # + use_edge_color: bool = False, symbol_key: str = "element", bond_key: str = "order", aam_key: str = "atom_map", standard_order_key: str = "standard_order", font_size: int = 12, + rule: bool = False, # New option to remove edges with specific colors ): """ Plot an intermediate transition state (ITS) graph on a given Matplotlib axes with various customizations. @@ -76,13 +77,14 @@ def plot_its( - node_size (int): Size of the graph nodes. - edge_color (str): Default color code for the graph edges if not using conditional coloring. - edge_weight (float): Thickness of the graph edges. - - show_aam (bool): If True, displays atom mapping numbers alongside symbols. + - show_atom_map (bool): If True, displays atom mapping numbers alongside symbols. - use_edge_color (bool): If True, colors edges based on their 'standard_order' attribute. - symbol_key (str): Key to access the symbol attribute in the node's data. - bond_key (str): Key to access the bond type attribute in the edge's data. - aam_key (str): Key to access the atom mapping number in the node's data. - standard_order_key (str): Key to determine the edge color conditionally. - font_size (int): Font size for labels and edge labels. + - rule (bool): If True, removes edges with a specific color before plotting. Returns: - None @@ -109,6 +111,30 @@ def plot_its( else: edge_colors = edge_color + # If rule=True, remove edges with specific colors (red/green/black) + if rule: + # Get the edges that have the colors red, green, or black + edges_to_remove = [ + edge + for edge, color in zip(its.edges(), edge_colors) + if color in ["red", "green", "black"] + ] + its.remove_edges_from(edges_to_remove) + + # Recalculate edge_colors after removal of edges + if use_edge_color: + edge_colors = [ + ( + "red" + if data.get(standard_order_key, 0) > 0 + else "green" if data.get(standard_order_key, 0) < 0 else "black" + ) + for _, _, data in its.edges(data=True) + ] + else: + edge_colors = edge_color + + # Plot the remaining graph nx.draw_networkx_edges( its, positions, edge_color=edge_colors, width=edge_weight, ax=ax ) diff --git a/synutility/SynVis/rsmi_to_fig.py b/synutility/SynVis/rsmi_to_fig.py index a08520d..8b15e93 100644 --- a/synutility/SynVis/rsmi_to_fig.py +++ b/synutility/SynVis/rsmi_to_fig.py @@ -6,6 +6,7 @@ from synutility.SynVis.graph_visualizer import GraphVisualizer from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.gml_to_nx import GMLToNX vis_graph = GraphVisualizer() @@ -17,18 +18,34 @@ def three_graph_vis( orientation: str = "horizontal", show_titles: bool = True, show_atom_map: bool = False, + titles: Tuple[str, str, str] = ( + "Reactants", + "Imaginary Transition State", + "Products", + ), + add_gridbox: bool = False, + rule: bool = False, ) -> plt.Figure: """ - Visualize three related graphs (reactants, intermediate transition state, and products) + Visualize three related graphs (reactants, imaginary transition state, and products) side by side or vertically in a single figure. Parameters: - - input (Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): Either a reaction SMILES string - or a tuple of three NetworkX graphs (reactants, ITS, products). - - sanitize (bool, optional): If True, sanitizes the input molecule. Default is False. - - figsize (Tuple[int, int], optional): The size of the Matplotlib figure. Default is (18, 5). - - orientation (str, optional): Layout of the subplots; 'horizontal' or 'vertical'. Default is 'horizontal'. - - show_titles (bool, optional): If True, adds titles to each subplot. Default is True. + - input (Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): Either + a reaction SMILES stringor a tuple of three NetworkX graphs + (reactants, products, ITS). + - sanitize (bool, optional): If True, sanitizes the input molecule. + Default is False. + - figsize (Tuple[int, int], optional): The size of the Matplotlib figure. + Default is (18, 5). + - orientation (str, optional): Layout of the subplots; 'horizontal' or 'vertical'. + Default is 'horizontal'. + - show_titles (bool, optional): If True, adds titles to each subplot. + Default is True. + - titles (Tuple[str, str, str], optional): Custom titles for each subplot. + Default is ('Reactants', 'Imaginary Transition State', 'Products'). + - add_gridbox (bool, optional): If True, adds a gridbox cover for each subplot + (rectangular frame). Default is False. Returns: - plt.Figure: The Matplotlib figure containing the three subplots. @@ -63,11 +80,13 @@ def three_graph_vis( edge_width=2.0, ) if show_titles: - ax[0].set_title("Reactants") + ax[0].set_title(titles[0]) - vis_graph.plot_its(its, ax[1], use_edge_color=True, show_atom_map=show_atom_map) + vis_graph.plot_its( + its, ax[1], use_edge_color=True, show_atom_map=show_atom_map, rule=rule + ) if show_titles: - ax[1].set_title("Imaginary Transition State") + ax[1].set_title(titles[1]) vis_graph.plot_as_mol( p, @@ -78,9 +97,74 @@ def three_graph_vis( edge_width=2.0, ) if show_titles: - ax[2].set_title("Products") + ax[2].set_title(titles[2]) + + # Add gridbox frame around each subplot if requested + if add_gridbox: + for a in ax: + # Make sure the grid is on top of the plot + a.set_axisbelow(False) + + # Add a rectangular frame (gridbox) with thicker borders + for spine in a.spines.values(): + spine.set_visible(True) + spine.set_linewidth(2) + spine.set_color("black") + + # Make gridlines lighter and under the plot elements + a.grid( + True, + which="both", + axis="both", + linestyle="--", + color="gray", + alpha=0.5, + ) return fig except Exception as e: raise RuntimeError(f"An error occurred during visualization: {str(e)}") + + +def rule_visualize(gml, rule=True, titles=None): + """ + Visualizes a reaction network from GML data with optional edge filtering (rule) + and custom titles. + + Parameters: + - gml (str): GML format string representing the reaction data. + - rule (bool): If True, applies the rule to filter edges + (e.g., removing edges based on color). + - titles (list, optional): List of titles for the subplots. + Defaults to ['L', 'K', 'R']. + + Returns: + - plt.Figure: Matplotlib figure containing the visualized reaction network. + """ + try: + # Transform GML to NetworkX graphs + r, p, its = GMLToNX(gml).transform() + + # If no titles are provided, default to ['L', 'K', 'R'] + if titles is None: + titles = ["L", "K", "R"] + + # Ensure titles match the number of graphs (3) + if len(titles) != 3: + raise ValueError( + "The titles list must contain exactly three titles for the three graphs." + ) + + # Call the `three_graph_vis` function with the transformed graphs and rule filtering + return three_graph_vis( + (r, p, its), + add_gridbox=True, # Add the gridbox around the plot + titles=titles, # Pass the titles for the subplots + rule=rule, # Apply the rule filtering based on the value of `rule` + ) + + except Exception as e: + raise RuntimeError( + f"An error occurred during the visualization process: {str(e)}" + )