From 4aa9601a112f6ec671ef1b60b210279aa93ae909 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:50:24 +0100 Subject: [PATCH] prepare release (#8) --- pyproject.toml | 10 +-- requirements.txt | 4 +- syntemp/SynITS/its_decomposer.py | 108 +++++++++++++++++++++++++++++++ syntemp/SynITS/its_hadjuster.py | 2 +- syntemp/pipeline.py | 2 +- 5 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 syntemp/SynITS/its_decomposer.py diff --git a/pyproject.toml b/pyproject.toml index d901bdc..74107ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "syntemp" -version = "0.0.3" +version = "0.0.4" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] @@ -27,10 +27,10 @@ dependencies = [ "chytorch==1.60", "chytorch-rxnmap==1.4", "torchdata==0.7.1", - "rdkit==2023.9.5", - "networkx==3.3", - "seaborn==0.13.2", - "joblib==1.3.2", + "rdkit>=2023.9.5", + "networkx>=3.3", + "seaborn>=0.13.2", + "joblib>=1.3.2", "synrbl>=0.0.24", ] diff --git a/requirements.txt b/requirements.txt index 1a8bea2..6cb1475 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,8 @@ chython==1.75 chytorch==1.60 chytorch-rxnmap==1.4 torchdata==0.7.1 -rdkit==2023.9.5 -networkx==3.3 +rdkit>=2023.9.5 +networkx>=3.3 seaborn==0.13.2 joblib==1.3.2 synrbl>=0.0.24 \ No newline at end of file diff --git a/syntemp/SynITS/its_decomposer.py b/syntemp/SynITS/its_decomposer.py new file mode 100644 index 0000000..0e54aa9 --- /dev/null +++ b/syntemp/SynITS/its_decomposer.py @@ -0,0 +1,108 @@ +import networkx as nx + + +def its_decompose(its_graph: nx.Graph, nodes_share="typesGH", edges_share="order"): + """ + Decompose an ITS graph into two separate graphs G and H based on shared + node and edge attributes. + + Parameters: + - its_graph (nx.Graph): The integrated transition state (ITS) graph. + - nodes_share (str): Node attribute key that stores tuples with node attributes + or G and H. + - edges_share (str): Edge attribute key that stores tuples with edge attributes + for G and H. + + Returns: + - Tuple[nx.Graph, nx.Graph]: A tuple containing the two graphs G and H. + """ + G = nx.Graph() + H = nx.Graph() + + # Decompose nodes + for node, data in its_graph.nodes(data=True): + if nodes_share in data: + node_attr_g, node_attr_h = data[nodes_share] + # Unpack node attributes for G + G.add_node( + node, + element=node_attr_g[0], + aromatic=node_attr_g[1], + hcount=node_attr_g[2], + charge=node_attr_g[3], + neighbors=node_attr_g[4], + ) + # Unpack node attributes for H + H.add_node( + node, + element=node_attr_h[0], + aromatic=node_attr_h[1], + hcount=node_attr_h[2], + charge=node_attr_h[3], + neighbors=node_attr_h[4], + ) + + # Decompose edges + for u, v, data in its_graph.edges(data=True): + if edges_share in data: + order_g, order_h = data[edges_share] + if order_g > 0: # Assuming 0 means no edge in G + G.add_edge(u, v, order=order_g) + if order_h > 0: # Assuming 0 means no edge in H + H.add_edge(u, v, order=order_h) + + return G, H + + +def compare_graphs( + graph1: nx.Graph, + graph2: nx.Graph, + node_attrs: list = ["element", "aromatic", "hcount", "charge", "neighbors"], + edge_attrs: list = ["order"], +) -> bool: + """ + Compare two graphs based on specified node and edge attributes. + + Parameters: + - graph1 (nx.Graph): The first graph to compare. + - graph2 (nx.Graph): The second graph to compare. + - node_attrs (list): A list of node attribute names to include in the comparison. + - edge_attrs (list): A list of edge attribute names to include in the comparison. + + Returns: + - bool: True if both graphs are identical with respect to the specified attributes, + otherwise False. + """ + # Compare node sets + if set(graph1.nodes()) != set(graph2.nodes()): + return False + + # Compare nodes based on attributes + for node in graph1.nodes(): + if node not in graph2: + return False + node_data1 = {attr: graph1.nodes[node].get(attr, None) for attr in node_attrs} + node_data2 = {attr: graph2.nodes[node].get(attr, None) for attr in node_attrs} + if node_data1 != node_data2: + return False + + # Compare edge sets with sorted tuples + if set(tuple(sorted(edge)) for edge in graph1.edges()) != set( + tuple(sorted(edge)) for edge in graph2.edges() + ): + return False + + # Compare edges based on attributes + for edge in graph1.edges(): + # Sort the edge for consistent comparison + sorted_edge = tuple(sorted(edge)) + if sorted_edge not in graph2.edges(): + return False + edge_data1 = {attr: graph1.edges[edge].get(attr, None) for attr in edge_attrs} + edge_data2 = { + attr: graph2.edges[sorted_edge].get(attr, None) for attr in edge_attrs + } + if edge_data1 != edge_data2: + return False + + return True diff --git a/syntemp/SynITS/its_hadjuster.py b/syntemp/SynITS/its_hadjuster.py index f34de12..a634a51 100644 --- a/syntemp/SynITS/its_hadjuster.py +++ b/syntemp/SynITS/its_hadjuster.py @@ -313,7 +313,7 @@ def process_graph_data_parallel( balance_its: bool = True, get_random_results: bool = False, fast_process: bool = False, - job_timeout: int = 1, + job_timeout: int = 5, ) -> List[Dict]: """ Processes a list of dictionaries containing graph information in parallel. diff --git a/syntemp/pipeline.py b/syntemp/pipeline.py index d4ffb47..24815cb 100644 --- a/syntemp/pipeline.py +++ b/syntemp/pipeline.py @@ -161,7 +161,7 @@ def extract_its( symbol: str = ">>", get_random_results: bool = False, fast_process: bool = False, - job_timeout: int = 1, + job_timeout: int = 5, ) -> List[dict]: """ Executes the extraction of ITS graphs from reaction data in batches,