From 3a1778601ffd8da842f95a65b79c691ec5bddf18 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:35:45 +0200 Subject: [PATCH] GraphDescriptors (#6) * add graph_descriptors * update doc --- Test/SynGraph/__init__.py | 0 Test/SynGraph/test_graph_descriptors.py | 176 ++++++++++++++ doc/getting_started.rst | 67 +++++- synutility/SynGraph/__init__.py | 0 synutility/SynGraph/graph_descriptors.py | 282 +++++++++++++++++++++++ synutility/SynIO/data_type.py | 45 +++- synutility/misc.py | 43 ++++ 7 files changed, 610 insertions(+), 3 deletions(-) create mode 100644 Test/SynGraph/__init__.py create mode 100644 Test/SynGraph/test_graph_descriptors.py create mode 100644 synutility/SynGraph/__init__.py create mode 100644 synutility/SynGraph/graph_descriptors.py create mode 100644 synutility/misc.py diff --git a/Test/SynGraph/__init__.py b/Test/SynGraph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynGraph/test_graph_descriptors.py b/Test/SynGraph/test_graph_descriptors.py new file mode 100644 index 0000000..0b350e3 --- /dev/null +++ b/Test/SynGraph/test_graph_descriptors.py @@ -0,0 +1,176 @@ +import unittest +import networkx as nx +from synutility.SynGraph.graph_descriptors import GraphDescriptor + + +class TestGraphDescriptor(unittest.TestCase): + + def setUp(self): + # Creating different types of graphs + self.acyclic_graph = nx.balanced_tree( + r=2, h=3 + ) # Creates a balanced binary tree, which is acyclic + self.single_cyclic_graph = nx.cycle_graph(5) # Creates a cycle with 5 nodes + self.complex_cyclic_graph = ( + nx.house_x_graph() + ) # Known small graph with multiple cycles + self.empty_graph = nx.Graph() # Empty graph for testing + + # Set up the graph + self.graph = nx.Graph() + self.graph.add_node( + 11, + element="N", + charge=0, + hcount=1, + aromatic=False, + atom_map=11, + isomer="N", + partial_charge=-0.313, + hybridization="SP3", + in_ring=True, + explicit_valence=3, + implicit_hcount=0, + neighbors=["C", "C"], + ) + self.graph.add_node( + 35, + element="H", + charge=0, + hcount=0, + aromatic=False, + atom_map=11, + isomer="N", + partial_charge=0, + hybridization="0", + in_ring=False, + explicit_valence=0, + implicit_hcount=0, + ) + self.graph.add_node( + 28, + element="C", + charge=0, + hcount=0, + aromatic=True, + atom_map=28, + isomer="N", + partial_charge=0.063, + hybridization="SP2", + in_ring=True, + explicit_valence=4, + implicit_hcount=0, + neighbors=["Br", "C", "C"], + ) + self.graph.add_node( + 29, + element="Br", + charge=0, + hcount=0, + aromatic=False, + atom_map=29, + isomer="N", + partial_charge=-0.047, + hybridization="SP3", + in_ring=False, + explicit_valence=1, + implicit_hcount=0, + neighbors=["C"], + ) + + # Adding edges with their attributes + self.graph.add_edge(11, 35, order=(1.0, 0), standard_order=1.0) + self.graph.add_edge(11, 28, order=(0, 1.0), standard_order=-1.0) + self.graph.add_edge(35, 29, order=(0, 1.0), standard_order=-1.0) + self.graph.add_edge(28, 29, order=(1.0, 0), standard_order=1.0) + # Prepare the data dictionary + self.data = [{"RC": self.graph}] + + def test_is_acyclic_graph(self): + self.assertTrue(GraphDescriptor.is_acyclic_graph(self.acyclic_graph)) + self.assertFalse(GraphDescriptor.is_acyclic_graph(self.single_cyclic_graph)) + self.assertFalse(GraphDescriptor.is_acyclic_graph(self.complex_cyclic_graph)) + self.assertFalse(GraphDescriptor.is_acyclic_graph(self.empty_graph)) + + def test_is_single_cyclic_graph(self): + self.assertFalse(GraphDescriptor.is_single_cyclic_graph(self.acyclic_graph)) + self.assertTrue( + GraphDescriptor.is_single_cyclic_graph(self.single_cyclic_graph) + ) + self.assertFalse( + GraphDescriptor.is_single_cyclic_graph(self.complex_cyclic_graph) + ) + self.assertFalse(GraphDescriptor.is_single_cyclic_graph(self.empty_graph)) + + def test_is_complex_cyclic_graph(self): + self.assertFalse(GraphDescriptor.is_complex_cyclic_graph(self.acyclic_graph)) + self.assertFalse( + GraphDescriptor.is_complex_cyclic_graph(self.single_cyclic_graph) + ) + self.assertTrue( + GraphDescriptor.is_complex_cyclic_graph(self.complex_cyclic_graph) + ) + self.assertFalse(GraphDescriptor.is_complex_cyclic_graph(self.empty_graph)) + + def test_check_graph_type(self): + self.assertEqual( + GraphDescriptor.check_graph_type(self.acyclic_graph), "Acyclic" + ) + self.assertEqual( + GraphDescriptor.check_graph_type(self.single_cyclic_graph), "Single Cyclic" + ) + self.assertEqual( + GraphDescriptor.check_graph_type(self.complex_cyclic_graph), + "Combinatorial Cyclic", + ) + self.assertEqual( + GraphDescriptor.check_graph_type(self.empty_graph), "Empty Graph" + ) + + def test_get_cycle_member_rings(self): + self.assertEqual(GraphDescriptor.get_cycle_member_rings(self.acyclic_graph), []) + self.assertEqual( + GraphDescriptor.get_cycle_member_rings(self.single_cyclic_graph), [5] + ) + self.assertEqual( + GraphDescriptor.get_cycle_member_rings(self.complex_cyclic_graph), + [3, 3, 3, 3], + ) + self.assertEqual(GraphDescriptor.get_cycle_member_rings(self.empty_graph), []) + + def test_get_element_count(self): + # Expected results + expected_element_count = {"N": 1, "H": 1, "C": 1, "Br": 1} + + # Test get_element_count + self.assertEqual( + GraphDescriptor.get_element_count(self.graph), expected_element_count + ) + + def test_get_descriptors(self): + # Expected output after processing + expected_output = [ + { + "RC": self.graph, + "topo": "Single Cyclic", # Adjust based on expected graph type analysis + "cycle": [ + 4 + ], # Expected cycle results, to be filled after actual function implementation + "atom_count": {"N": 1, "H": 1, "C": 1, "Br": 1}, + "rtype": "Elementary", # Expected reaction type + "rstep": 1, # This should be based on the actual cycles count + } + ] + + # Run the descriptor function + GraphDescriptor.get_descriptors(self.data, "RC") + + # Validate that the data has been enhanced correctly + for obtained, expected in zip(self.data, expected_output): + print("Hi", obtained) + print("hiii", expected) + self.assertDictEqual(obtained, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/doc/getting_started.rst b/doc/getting_started.rst index d6a3787..c25ee3d 100644 --- a/doc/getting_started.rst +++ b/doc/getting_started.rst @@ -1,13 +1,76 @@ +.. _getting-started-synutility: + =============== Getting Started =============== +Welcome to the **synutility** documentation. This guide will assist you in setting up **synutility**, a versatile toolkit designed to streamline and enhance your workflows. + +Introduction +------------ + +**synutility** is designed to provide tools and functions that simplify complex tasks and improve productivity. It is essential for professionals looking to optimize their processes efficiently. + Installation ============ +Installing **synutility** is straightforward using pip, the Python package installer. For a more controlled and conflict-free environment, we recommend setting up **synutility** within a virtual environment. + +Requirements +------------ + +Before installing **synutility**, ensure you have Python 3.11 or later installed on your machine. The use of a virtual environment is recommended to avoid conflicts with other Python packages. + +Creating a Virtual Environment (Optional) +----------------------------------------- + +Setting up a virtual environment for **synutility** is an optional but recommended step. It isolates the installation and dependencies of **synutility** from other Python projects, which helps prevent conflicts and maintain a clean workspace. + +1. **Create a virtual environment**: + + Use the following command to create a virtual environment named `synutility_env`. This step requires the `conda` package manager, which can be installed via the Anaconda distribution. + + .. code-block:: bash + + conda create -n synutility_env python=3.11 + +2. **Activate the virtual environment**: + + Once the environment is created, activate it using: + + .. code-block:: bash + + conda activate synutility_env + +Installing **synutility** via Pip +--------------------------------- + +After setting up and activating your virtual environment, install **synutility** using pip by running the following command: + +.. code-block:: bash + + pip install synutility + +This command will download and install the latest version of **synutility** along with its required dependencies. + +Getting Started with **synutility** +=================================== + +After installation, you can begin using **synutility** by importing it into your Python scripts or interactive sessions. Here’s a quick example to get you started: + +.. code-block:: python + + import synutility + result = synutility.example_function('Hello, World!') + print(result) + +Further Resources +================= -Spliting Functions -================== +For more detailed documentation, usage examples, and API guides, visit the [official **synutility** documentation](https://tieulongphan.github.io/SynUtils). +Support +------- +If you encounter any issues or require assistance, please refer to the community support forums or file an issue on the [**synutility** GitHub page](https://github.com/TieuLongPhan/SynUtils/issues). diff --git a/synutility/SynGraph/__init__.py b/synutility/SynGraph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/graph_descriptors.py b/synutility/SynGraph/graph_descriptors.py new file mode 100644 index 0000000..5e0a871 --- /dev/null +++ b/synutility/SynGraph/graph_descriptors.py @@ -0,0 +1,282 @@ +import networkx as nx +from typing import List, Dict, Any, Union +from collections import Counter, OrderedDict +from synutility.SynIO.debug import setup_logging + +logger = setup_logging() + + +class GraphDescriptor: + def __init__(self, graph: nx.Graph): + self.graph = graph + + @staticmethod + def is_graph_empty(graph: Union[nx.Graph, dict, list, Any]) -> bool: + """ + Determine if a graph representation is empty. + + This function checks for emptiness in various types of graph representations, + including NetworkX graphs, dictionaries (potentially adjacency lists), + lists (adjacency matrices), and custom graph classes with an 'is_empty' method. + + Parameters: + - graph (Union[nx.Graph, dict, list, Any]): A graph representation which can be + a NetworkX graph, a dictionary, a list, or an object with a 'is_empty' method. + + Returns: + - bool: Returns True if the graph is empty (no nodes/vertices), otherwise False. + + Raises: + - TypeError: If the graph representation is not supported. + """ + if isinstance(graph, nx.Graph): + return graph.number_of_nodes() == 0 + elif isinstance(graph, dict): + return len(graph) == 0 + elif isinstance(graph, list): + return all(len(row) == 0 for row in graph) + elif hasattr(graph, "is_empty"): + return graph.is_empty() + else: + raise TypeError("Unsupported graph representation") + + @staticmethod + def is_acyclic_graph(G: nx.Graph) -> bool: + """ + Determines if the given graph is acyclic. + + Parameters: + - G (nx.Graph): The graph to be checked. + + Returns: + - bool: True if the graph is acyclic, False otherwise. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + if GraphDescriptor.is_graph_empty(G): + return False + + return nx.is_tree(G) + + @staticmethod + def is_single_cyclic_graph(G: nx.Graph) -> bool: + """ + Determines if the given graph is a single cyclic graph, + which means the graph has exactly one cycle + and all nodes in the graph are part of that cycle. + + Parameters: + - G (nx.Graph): The graph to be checked. + + Returns: + - bool: True if the graph is single cyclic, False otherwise. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + + if GraphDescriptor.is_graph_empty(G): + return False + + if not nx.is_connected(G): + return False + + cycles = nx.cycle_basis(G) + + if cycles: + nodes_in_cycles = set(node for cycle in cycles for node in cycle) + if ( + nodes_in_cycles == set(G.nodes()) + and G.number_of_edges() == G.number_of_nodes() + ): + return True + + return False + + @staticmethod + def is_complex_cyclic_graph(G: nx.Graph) -> bool: + """ + Determines if the given graph is a complex cyclic graph, + which means all nodes are part of cycles, + there are multiple cycles, and there are no acyclic parts. + + Parameters: + - G (nx.Graph): The graph to be checked. + + Returns: + - bool: True if the graph is complex cyclic, False otherwise. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + + if GraphDescriptor.is_graph_empty(G): + return False + + # Check if the graph is connected and has at least one cycle + if not nx.is_connected(G) or not any(nx.minimum_cycle_basis(G)): + return False + + # Get a list of cycles that form a cycle basis for G + cycles = nx.minimum_cycle_basis(G) + + # If there's only one cycle in the basis, it might not be a complex cyclic graph + if len(cycles) <= 1: + return False + + # Decompose cycles into a list of nodes, allowing for node overlap between cycles + nodes_in_cycles = set(node for cycle in cycles for node in cycle) + + # Check if all nodes in G are covered by the nodes in cycles + return nodes_in_cycles == set(G.nodes()) + + @staticmethod + def check_graph_type(G: nx.Graph) -> str: + """ + Determines if the given graph is acyclic, single cyclic, or complex cyclic. + + Parameters: + - G (nx.Graph): The graph to be checked. + + Returns: + - str: A string indicating if the graph is "Acyclic", + "Single Cyclic", or "Complex Cyclic". + + Raises: + - TypeError: If the input G is not a networkx Graph. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + + if GraphDescriptor.is_graph_empty(G): + return "Empty Graph" + if GraphDescriptor.is_acyclic_graph(G): + return "Acyclic" + elif GraphDescriptor.is_single_cyclic_graph(G): + return "Single Cyclic" + elif GraphDescriptor.is_complex_cyclic_graph(G): + return "Combinatorial Cyclic" + else: + return "Complex Cyclic" + + @staticmethod + def get_cycle_member_rings(G: nx.Graph) -> List[int]: + """ + Identifies all cycles in the given graph using cycle bases to ensure no overlap + and returns a list of the sizes of these cycles (member rings), + sorted in ascending order. + + Parameters: + - G (nx.Graph): The NetworkX graph to be analyzed. + + Returns: + - List[int]: A sorted list of cycle sizes (member rings) found in the graph. + """ + if not isinstance(G, nx.Graph): + raise TypeError("Input must be a networkx Graph object.") + + # Find cycle basis for the graph which gives non-overlapping cycles + cycles = nx.minimum_cycle_basis(G) + # Determine the size of each cycle (member ring) + member_rings = [len(cycle) for cycle in cycles] + + # Sort the sizes in ascending order + member_rings.sort() + + return member_rings + + @staticmethod + def get_element_count(graph: nx.Graph) -> Dict[str, int]: + """ + Counts the occurrences of each chemical element in the graph nodes and returns + a dictionary with these counts sorted alphabetically for easy comparison. + + Parameters: + - graph (nx.Graph): A NetworkX graph object where each node has attributes including 'element'. + + Returns: + - Dict[str, int]: An ordered dictionary with element symbols as keys and their counts as values, + sorted alphabetically by element symbol. + """ + # Use Counter to count occurrences of each element + element_counts = Counter( + [data["element"] for _, data in graph.nodes(data=True)] + ) + + # Create an ordered dictionary sorted alphabetically by element + ordered_counts = OrderedDict(sorted(element_counts.items())) + + return ordered_counts + + @staticmethod + def get_descriptors(data: List[Dict], reaction_centers: str = "RC") -> List[Dict]: + """ + Enhance data with topology type and reaction type descriptors. + + Parameters: + - data (List[Dict]): List of dictionaries containing reaction data. + - reaction_centers (str): Key for accessing the reaction centers in the data dictionaries. + + Returns: + - List[Dict]: Enhanced list of dictionaries with added descriptors. + """ + for entry in data: + rc_data = entry.get(reaction_centers) + if isinstance(rc_data, list): + try: + graph = rc_data[2] + except IndexError: + logger.error( + f"No graph data available at index 2 for entry {entry}" + ) + continue + elif isinstance(rc_data, nx.Graph): + graph = rc_data + else: + logger.error( + f"Unsupported data type for reaction centers in entry {entry}" + ) + continue + + # Enhance the dictionary with additional descriptors + entry["topo"] = GraphDescriptor.check_graph_type(graph) + entry["cycle"] = GraphDescriptor.get_cycle_member_rings(graph) + entry["atom_count"] = GraphDescriptor.get_element_count(graph) + + # Determine the reaction type based on the topology type + if entry["topo"] in ["Single Cyclic", "Acyclic"]: + entry["rtype"] = "Elementary" + else: + entry["rtype"] = "Complicated" + + # Adjust "Rings" and "Reaction Step" based on the topology type + if entry["topo"] == "Acyclic": + entry["cycle"] = [0] # No rings in acyclic graphs + elif entry["topo"] == "Complex Cyclic": + entry["cycle"] = [0] + entry[ + "cycle" + ] # Prepending zero might represent a base cycle count + + entry["rstep"] = len(entry["cycle"]) # Steps are based on cycle counts + + return data + + +def check_graph_connectivity(graph: nx.Graph) -> str: + """ + Check the connectivity of a NetworkX graph. + + This function assesses whether all nodes in the graph are connected by some path, + applicable to undirected graphs. + + Parameters: + - graph (nx.Graph): A NetworkX graph object. + + Returns: + - str: Returns 'Connected' if the graph is connected, otherwise 'Disconnected'. + + Raises: + - NetworkXNotImplemented: If graph is directed and does not support is_connected. + """ + if nx.is_connected(graph): + return "Connected" + else: + return "Disconnected." diff --git a/synutility/SynIO/data_type.py b/synutility/SynIO/data_type.py index e2e7d4a..ea648a5 100644 --- a/synutility/SynIO/data_type.py +++ b/synutility/SynIO/data_type.py @@ -1,9 +1,10 @@ +import os import json import pickle import numpy as np from numpy import ndarray from joblib import dump, load -from typing import List, Dict, Any +from typing import List, Dict, Any, Generator from synutility.SynIO.debug import setup_logging logger = setup_logging() @@ -211,3 +212,45 @@ def load_dict_from_json(file_path: str) -> dict: except Exception as e: logger.error(e) return None + + +def load_from_pickle_generator(file_path: str) -> Generator[Any, None, None]: + """ + A generator that yields items from a pickle file where each pickle load returns a list + of dictionaries. + + Paremeters: + - file_path (str): The path to the pickle file to load. + + - Yields: + Any: Yields a single item from the list of dictionaries stored in the pickle file. + """ + with open(file_path, "rb") as file: + while True: + try: + batch_items = pickle.load(file) + for item in batch_items: + yield item + except EOFError: + break + + +def collect_data(num_batches: int, temp_dir: str, file_template: str) -> List[Any]: + """ + Collects and aggregates data from multiple pickle files into a single list. + + Paremeters: + - num_batches (int): The number of batch files to process. + - temp_dir (str): The directory where the batch files are stored. + - file_template (str): The template string for batch file names, expecting an integer + formatter. + + Returns: + List[Any]: A list of aggregated data items from all batch files. + """ + collected_data: List[Any] = [] + for i in range(num_batches): + file_path = os.path.join(temp_dir, file_template.format(i)) + for item in load_from_pickle_generator(file_path): + collected_data.append(item) + return collected_data diff --git a/synutility/misc.py b/synutility/misc.py new file mode 100644 index 0000000..802d421 --- /dev/null +++ b/synutility/misc.py @@ -0,0 +1,43 @@ +import random +from typing import Dict, List + + +def stratified_random_sample( + data: List[Dict], property_key: str, samples_per_class: int, seed: int = None +) -> List[Dict]: + """ + Stratifies and samples data from a list of dictionaries based on a specified property. + + Parameters: + - data (List[Dict]): The data to sample from, a list of dictionaries. + - property_key (str): The key in the dictionaries to stratify by. + - samples_per_class (int): The number of samples to take from each class. + - seed (int): The seed for the random number generator for reproducibility. + + Returns: + - List[Dict]: A list of sampled dictionaries. + """ + + if seed is not None: + random.seed(seed) + + # Group data by the specified property + stratified_data = {} + for item in data: + key = item.get(property_key) + if key in stratified_data: + stratified_data[key].append(item) + else: + stratified_data[key] = [item] + + # Sample data from each group + sampled_data = [] + for key, items in stratified_data.items(): + if len(items) >= samples_per_class: + sampled_data.extend(random.sample(items, samples_per_class)) + else: + raise ValueError( + f"Not enough data to sample {samples_per_class} items for class {key}" + ) + + return sampled_data