diff --git a/Data/Testcase/graph.pkl.gz b/Data/Testcase/graph.pkl.gz new file mode 100644 index 0000000..c3d649a Binary files /dev/null and b/Data/Testcase/graph.pkl.gz differ diff --git a/Test/SynGraph/Cluster/__init__.py b/Test/SynGraph/Cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynGraph/Cluster/test_batch_cluster.py b/Test/SynGraph/Cluster/test_batch_cluster.py new file mode 100644 index 0000000..4c15fd4 --- /dev/null +++ b/Test/SynGraph/Cluster/test_batch_cluster.py @@ -0,0 +1,109 @@ +import time +import unittest +from synutility.SynIO.data_type import load_from_pickle +from synutility.SynGraph.Descriptor.graph_signature import GraphSignature +from synutility.SynGraph.Cluster.batch_cluster import BatchCluster + + +class TestBatchCluster(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz") + cls.templates = None + for value in cls.graphs: + value["rc_sig"] = GraphSignature(value["RC"]).create_graph_signature() + value["its_sig"] = GraphSignature(value["ITS"]).create_graph_signature() + + def test_initialization(self): + """Test initialization and verify if the attributes are set correctly.""" + cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order") + self.assertEqual(cluster.nodeLabelNames, ["element", "charge"]) + self.assertEqual(cluster.nodeLabelDefault, ["*", 0]) + self.assertEqual(cluster.edgeAttribute, "bond_order") + + def test_initialization_failure(self): + """Test initialization failure when lengths of node labels and defaults do not match.""" + with self.assertRaises(ValueError): + BatchCluster(["element"], ["*", 0, 1], "bond_order") + + def test_batch_dicts(self): + """Test the batching function to split data correctly.""" + batch_cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order") + input_list = [{"id": i} for i in range(10)] + batches = batch_cluster.batch_dicts(input_list, 3) + self.assertEqual(len(batches), 4) + self.assertEqual(len(batches[0]), 3) + self.assertEqual(len(batches[-1]), 1) + + def test_lib_check_functionality(self): + """Test the lib_check method using directly comparable results.""" + cluster = BatchCluster() + batch_1 = self.graphs[:50] + batch_2 = self.graphs[50:] + _, templates = cluster.fit(batch_1, None, "RC", "rc_sig") + for entry in batch_2: + _, templates = cluster.lib_check(entry, templates, "RC", "rc_sig") + self.assertEqual(len(templates), 30) + + def test_cluster_integration(self): + """Test the cluster method to ensure it processes data entries correctly.""" + cluster = BatchCluster() + expected_template_count = 30 + _, updated_templates = cluster.cluster(self.graphs, [], "RC", "rc_sig") + + self.assertEqual( + len(updated_templates), + expected_template_count, + f"Failed: expected {expected_template_count} templates, got {len(updated_templates)}", + ) + + def test_fit(self): + cluster = BatchCluster() + batch_sizes = [None, 10] + expected_template_count = 30 + + for batch_size in batch_sizes: + start_time = time.time() + _, updated_templates = cluster.fit( + self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size + ) + elapsed_time = time.time() - start_time + + self.assertEqual( + len(updated_templates), + expected_template_count, + f"Failed for batch_size={batch_size}: expected " + + f"{expected_template_count} templates, got {len(updated_templates)}", + ) + print( + f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds." + ) + + def test_fit_gml(self): + cluster = BatchCluster() + batch_sizes = [None, 10] + expected_template_count = ( + 30 # Assuming this is the expected number of templates after processing + ) + + for batch_size in batch_sizes: + start_time = time.time() + _, updated_templates = cluster.fit( + self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size + ) + elapsed_time = time.time() - start_time + + self.assertEqual( + len(updated_templates), + expected_template_count, + f"Failed for batch_size={batch_size}: expected" + + f" {expected_template_count} templates, got {len(updated_templates)}", + ) + print( + f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds." + ) + + +# To run the tests +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Cluster/test_graph_cluster.py b/Test/SynGraph/Cluster/test_graph_cluster.py new file mode 100644 index 0000000..498adef --- /dev/null +++ b/Test/SynGraph/Cluster/test_graph_cluster.py @@ -0,0 +1,138 @@ +import time +import unittest +from synutility.SynIO.data_type import load_from_pickle +from synutility.SynGraph.Cluster.graph_cluster import GraphCluster +from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor + + +class TestRCCluster(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Load data once for all tests + cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz") + for value in cls.graphs: + # value["RC"] = value["GraphRules"][2] + # value["ITS"] = value["ITSGraph"][2] + value = GraphDescriptor.get_descriptors(value) + cls.clusterer = GraphCluster() + + def test_initialization(self): + """Test the initialization and configuration of the RCCluster.""" + self.assertIsInstance(self.clusterer.nodeLabelNames, list) + self.assertEqual(self.clusterer.edgeAttribute, "order") + self.assertEqual( + len(self.clusterer.nodeLabelNames), len(self.clusterer.nodeLabelDefault) + ) + + def test_auto_cluster(self): + """Test the auto_cluster method functionality.""" + rc = [value["RC"] for value in self.graphs] + cycles = [value["cycle"] for value in self.graphs] + signature = [value["signature_rc"] for value in self.graphs] + atom_count = [value["atom_count"] for value in self.graphs] + for att in [None, cycles, signature, atom_count]: + clusters, graph_to_cluster = self.clusterer.iterative_cluster( + rc, + att, + nodeMatch=self.clusterer.nodeMatch, + edgeMatch=self.clusterer.edgeMatch, + ) + self.assertIsInstance(clusters, list) + self.assertIsInstance(graph_to_cluster, dict) + self.assertEqual(len(clusters), 30) + + def test_auto_cluster_wrong_isomorphism(self): + rc = [value["RC"] for value in self.graphs] + cycles = [value["cycle"] for value in self.graphs] + signature = [value["signature_rc"] for value in self.graphs] + atom_count = [value["atom_count"] for value in self.graphs] + + # cluster all + clusters, _ = self.clusterer.iterative_cluster( + rc, None, nodeMatch=None, edgeMatch=None + ) + self.assertEqual(len(clusters), 8) # wrong value + + # cluster with cycle + clusters, _ = self.clusterer.iterative_cluster( + rc, cycles, nodeMatch=None, edgeMatch=None + ) + self.assertEqual(len(clusters), 8) # wrong value + + # cluster with atom_count + clusters, _ = self.clusterer.iterative_cluster( + rc, atom_count, nodeMatch=None, edgeMatch=None + ) + self.assertEqual(len(clusters), 27) # wrong value but almost correct + + # cluster with signature + clusters, _ = self.clusterer.iterative_cluster( + rc, signature, nodeMatch=None, edgeMatch=None + ) + self.assertEqual(len(clusters), 30) # correct by some magic. No proof for this + + def test_fit(self): + """Test the fit method to ensure it correctly updates data entries with cluster indices.""" + + clustered_data = self.clusterer.fit( + self.graphs, rule_key="RC", attribute_key="atom_count" + ) + max_class = 0 + for item in clustered_data: + print(item["class"]) + max_class = item["class"] if item["class"] >= max_class else max_class + # print(max_class) + self.assertIn("class", item) + self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29 + + def test_fit_gml(self): + """Test the fit method to ensure it correctly updates data entries with cluster indices.""" + + clustered_data = self.clusterer.fit( + self.graphs, rule_key="rc", attribute_key="atom_count" + ) + max_class = 0 + for item in clustered_data: + print(item["class"]) + max_class = item["class"] if item["class"] >= max_class else max_class + # print(max_class) + self.assertIn("class", item) + self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29 + + def test_fit_time_compare(self): + attributes = { + "None": None, + "Cycles": "cycle", + "Signature": "signature_rc", + "Atom_count": "atom_count", + } + + results = {} + for name, attr in attributes.items(): + start_time = time.time() + clustered_data = self.clusterer.fit( + self.graphs, rule_key="RC", attribute_key=attr + ) + elapsed_time = time.time() - start_time + + # Optionally print out class information or verify correctness + max_class = max(item["class"] for item in clustered_data if "class" in item) + + results[name] = elapsed_time + + # Basic verification that 'class' is assigned and max class is as expected + self.assertTrue(all("class" in item for item in clustered_data)) + self.assertEqual( + max_class, 29 + ) # Ensure the maximum class index is as expected + + # Compare results to check which attribute took the least/most time + min_time_attr = min(results, key=results.get) + max_time_attr = max(results, key=results.get) + self.assertIn(min_time_attr, ["atom_count", "Signature"]) + self.assertIn(max_time_attr, ["None", "Cycles"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/lint.sh b/lint.sh index b17d8cb..a08658d 100755 --- a/lint.sh +++ b/lint.sh @@ -1,6 +1,6 @@ #!/bin/bash flake8 . --count --max-complexity=13 --max-line-length=120 \ - --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \ + --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401, morphism.py:F401" \ --exclude venv,core_engine.py,rule_apply.py \ --statistics diff --git a/synutility/SynGraph/Cluster/__init__.py b/synutility/SynGraph/Cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/Cluster/batch_cluster.py b/synutility/SynGraph/Cluster/batch_cluster.py new file mode 100644 index 0000000..67ad3b4 --- /dev/null +++ b/synutility/SynGraph/Cluster/batch_cluster.py @@ -0,0 +1,222 @@ +import networkx as nx +from operator import eq +from typing import List, Dict, Any, Tuple, Optional, Callable +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match +from synutility.misc import stratified_random_sample +from synutility.SynGraph.GML.parse_rule import strip_context +from synutility.SynGraph.Cluster.graph_cluster import GraphCluster +from synutility.SynGraph.GML.morphism import graph_isomorphism, rule_isomorphism + + +class BatchCluster: + def __init__( + self, + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "order", + ): + """ + Initializes an AutoCat instance which uses isomorphism checks for categorizing + new graphs or rules. + + Parameters: + - node_label_names (List[str]): Names of the node attributes to use in + isomorphism checks. + - node_label_default (List[Any]): Default values for node attributes if they are + missing in the graph data. + - edge_attribute (str): The edge attribute to consider when checking isomorphism + between graphs. + + Raises: + - ValueError: If the lengths of `node_label_names` and `node_label_default` + do not match. + """ + if len(node_label_names) != len(node_label_default): + raise ValueError( + "The lengths of `node_label_names` and `node_label_default` must match." + ) + + self.nodeLabelNames = node_label_names + self.nodeLabelDefault = node_label_default + self.edgeAttribute = edge_attribute + self.nodeMatch = generic_node_match( + self.nodeLabelNames, self.nodeLabelDefault, [eq] * len(node_label_names) + ) + self.edgeMatch = generic_edge_match(edge_attribute, 1, eq) + + def lib_check( + self, + data: Dict, + templates: List[Dict], + rule_key: str = "gml", + attribute_key: str = "signature", + nodeMatch: Optional[Callable] = None, + edgeMatch: Optional[Callable] = None, + ) -> Dict: + """ + Checks and classifies a graph or rule based on existing templates using either graph or rule isomorphism. + + Parameters: + - data (Dict): A dictionary representing a graph or rule with its attributes and + classification. + - templates (List[Dict]): Dynamic templates used for categorization. If None, initializes to an empty list. + - rule_key (str): Key to access the graph or rule data within the dictionary. + - attribute_key (str): An attribute used to filter templates before isomorphism check. + - nodeMatch (Optional[Callable]): A function to match nodes, defaults to a predefined generic_node_match. + - edgeMatch (Optional[Callable]): A function to match edges, defaults to a predefined generic_edge_match. + + Returns: + - Dict: The updated dictionary with its classification. + """ + # Ensure that templates are not None + if templates is None: + templates = [] + + att = data.get(attribute_key) + sub_temp = [temp for temp in templates if temp.get(attribute_key) == att] + + for template in sub_temp: + template_data = ( + strip_context(template[rule_key]) + if isinstance(template[rule_key], str) + else template[rule_key] + ) + data_rule = ( + strip_context(data[rule_key]) + if isinstance(data[rule_key], str) + else data[rule_key] + ) + + if isinstance(data_rule, str): + iso_function = rule_isomorphism + apply_match_args = False + elif isinstance(data_rule, nx.Graph): + iso_function = graph_isomorphism + apply_match_args = True + + if apply_match_args: + if iso_function( + template_data, + data_rule, + nodeMatch or self.nodeMatch, + edgeMatch or self.edgeMatch, + ): + data["class"] = template["class"] + break + else: + if iso_function(template_data, data_rule): + data["class"] = template["class"] + break + else: + new_class = max((temp["class"] for temp in templates), default=-1) + 1 + data["class"] = new_class + templates.append(data.copy()) # Append a copy to avoid reference issues + + return data, templates + + @staticmethod + def batch_dicts(input_list, batch_size): + """ + Splits a list of dictionaries into batches of a specified size. + + Args: + input_list (list of dict): The list of dictionaries to be batched. + batch_size (int): The size of each batch. + + Returns: + list of list of dict: A list where each element is a batch (sublist) of dictionaries. + + Raises: + ValueError: If batch_size is less than 1. + """ + + # Validate batch_size to ensure it's a positive integer + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + + # Initialize an empty list to hold the batches + batches = [] + + # Iterate over the input list in steps of batch_size + for i in range(0, len(input_list), batch_size): + # Append a batch slice to the batches list + # fmt: off + batches.append(input_list[i: i + batch_size]) + # fmt: on + + return batches + + def cluster( + self, + data: List[Dict], + templates: List[Dict], + rule_key: str = "gml", + attribute_key: str = "signature", + ) -> Tuple[List[Dict], List[Dict]]: + """ + Processes a list of graph data entries, classifying each based on existing templates. + + Parameters: + - data (List[Dict]): A list of dictionaries, each representing a graph or rule + to be classified. + - templates (List[Dict]): Dynamic templates used for categorization. + + Returns: + - Tuple[List[Dict], List[Dict]]: A tuple containing the list of classified data + and the updated templates. + """ + for entry in data: + _, templates = self.lib_check(entry, templates, rule_key, attribute_key) + return data, templates + + def fit( + self, + data: List[Dict], + templates: List[Dict], + rule_key: str = "gml", + attribute_key: str = "signature", + batch_size: Optional[int] = None, + ) -> Tuple[List[Dict], List[Dict]]: + """ + Processes and classifies data in batches. Uses GraphCluster for initial processing + and a stratified sampling technique to update templates if there is only one batch + and no initial templates are provided. + + Parameters: + - data (List[Dict]): Data to process. + - templates (List[Dict]): Templates for categorization. + - rule_key (str): Key to access rule or graph data. + - attribute_key (str): Key to access attributes used for filtering. + - batch_size (Optional[int]): Size of batches for processing, if not provided, processes all data at once. + + Returns: + - Tuple[List[Dict], List[Dict]]: The processed data and the potentially updated templates. + """ + if batch_size is not None: + batches = self.batch_dicts(data, batch_size) + else: + batches = [data] # Process all at once if no batch size provided + + output_data, output_templates = [], templates if templates is not None else [] + graph_cluster = GraphCluster() + + if len(batches) == 1: + batch = batches[0] + if not templates: + output_data = graph_cluster.fit(batch, rule_key, attribute_key) + output_templates = stratified_random_sample( + output_data, property_key="class", samples_per_class=1, seed=1 + ) + else: + output_data, output_templates = self.cluster( + batch, output_templates, rule_key, attribute_key + ) + else: + for batch in batches: + processed_data, new_templates = self.cluster( + batch, output_templates, rule_key, attribute_key + ) + output_data.extend(processed_data) + output_templates = new_templates + + return output_data, output_templates diff --git a/synutility/SynGraph/Cluster/graph_cluster.py b/synutility/SynGraph/Cluster/graph_cluster.py new file mode 100644 index 0000000..0ec5490 --- /dev/null +++ b/synutility/SynGraph/Cluster/graph_cluster.py @@ -0,0 +1,163 @@ +import networkx as nx +from operator import eq +from collections import OrderedDict +from typing import List, Set, Dict, Any, Tuple, Optional, Callable +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match + +from synutility.SynGraph.GML.parse_rule import strip_context +from synutility.SynGraph.GML.morphism import graph_isomorphism, rule_isomorphism + + +class GraphCluster: + def __init__( + self, + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "order", + ): + """ + Initializes the GraphCluster with customization options for node and edge + matching functions. This class is designed to facilitate clustering of graph nodes + and edges based on specified attributes and their matching criteria. + + Parameters: + - node_label_names (List[str]): A list of node attribute names to be considered + for matching. Each attribute name corresponds to a property of the nodes in the + graph. Default values provided. + - node_label_default (List[Any]): Default values for each of the node attributes + specified in `node_label_names`. These are used where node attributes are missing. + The length and order of this list should match `node_label_names`. + - edge_attribute (str): The name of the edge attribute to consider for matching + edges. This attribute is used to assess edge similarity. + + Raises: + - ValueError: If the lengths of `node_label_names` and `node_label_default` do not + match. + """ + if len(node_label_names) != len(node_label_default): + raise ValueError( + "The lengths of `node_label_names` and `node_label_default` must match." + ) + + self.nodeLabelNames = node_label_names + self.nodeLabelDefault = node_label_default + self.edgeAttribute = edge_attribute + self.nodeMatch = generic_node_match( + self.nodeLabelNames, self.nodeLabelDefault, [eq for _ in node_label_names] + ) + self.edgeMatch = generic_edge_match(self.edgeAttribute, 1, eq) + + def iterative_cluster( + self, + rules: List[str], + attributes: Optional[List[Any]] = None, + nodeMatch: Optional[Callable] = None, + edgeMatch: Optional[Callable] = None, + ) -> Tuple[List[Set[int]], Dict[int, int]]: + """ + Clusters rules based on their similarities, which could include structural or + attribute-based similarities depending on the given attributes. + + Parameters: + - rules (List[str]): List of rules, potentially serialized strings of rule + representations. + - attributes (Optional[List[Any]]): Attributes associated with each rule for + preliminary comparison, e.g., labels or properties. + + Returns: + - Tuple[List[Set[int]], Dict[int, int]]: A tuple containing a list of sets + (clusters), where each set contains indices of rules in the same cluster, + and a dictionary mapping each rule index to its cluster index. + """ + # Determine the appropriate isomorphism function based on rule type + if isinstance(rules[0], str): + iso_function = rule_isomorphism + apply_match_args = ( + False # rule_isomorphism does not use nodeMatch or edgeMatch + ) + elif isinstance(rules[0], nx.Graph): + iso_function = graph_isomorphism + apply_match_args = True # graph_isomorphism uses nodeMatch and edgeMatch + + if attributes is None: + attributes_sorted = [1] * len(rules) + else: + if isinstance(attributes[0], str): + attributes_sorted = attributes + elif isinstance(attributes, List): + attributes_sorted = [sorted(value) for value in attributes] + elif isinstance(attributes, OrderedDict): + attributes_sorted = [ + OrderedDict(sorted(value.items())) for value in attributes + ] + + visited = set() + clusters = [] + rule_to_cluster = {} + + for i, rule_i in enumerate(rules): + if i in visited: + continue + cluster = {i} + visited.add(i) + rule_to_cluster[i] = len(clusters) + # fmt: off + for j, rule_j in enumerate(rules[i + 1:], start=i + 1): + # fmt: on + if attributes_sorted[i] == attributes_sorted[j] and j not in visited: + # Conditionally use matching functions + if apply_match_args: + is_isomorphic = iso_function( + rule_i, rule_j, nodeMatch, edgeMatch + ) + else: + is_isomorphic = iso_function(rule_i, rule_j) + + if is_isomorphic: + cluster.add(j) + visited.add(j) + rule_to_cluster[j] = len(clusters) + + clusters.append(cluster) + + return clusters, rule_to_cluster + + def fit( + self, + data: List[Dict], + rule_key: str = "gml", + attribute_key: str = "signature", + ) -> List[Dict]: + """ + Automatically clusters the rules and assigns them cluster indices based on the + similarity, potentially using provided templates for clustering, or generating + new templates. + + Parameters: + - data (List[Dict]): A list containing dictionaries, each representing a + rule along with metadata. + - rule_key (str): The key in the dictionaries under `data` where the rule data + is stored. + - attribute_key (str): The key in the dictionaries under `data` where rule + attributes are stored. + + Returns: + - List[Dict]: Updated list of dictionaries with an added 'class' key for cluster + identification. + """ + if isinstance(data[0][rule_key], str): + rules = [strip_context(entry[rule_key]) for entry in data] + else: + rules = [entry[rule_key] for entry in data] + + attributes = ( + [entry.get(attribute_key) for entry in data] if attribute_key else None + ) + _, rule_to_cluster_dict = self.iterative_cluster( + rules, attributes, self.nodeMatch, self.edgeMatch + ) + + for index, entry in enumerate(data): + entry["class"] = rule_to_cluster_dict.get(index, None) + + return data diff --git a/synutility/SynGraph/GML/morphism.py b/synutility/SynGraph/GML/morphism.py index fa8f341..421bf60 100644 --- a/synutility/SynGraph/GML/morphism.py +++ b/synutility/SynGraph/GML/morphism.py @@ -1,6 +1,55 @@ +import torch +from operator import eq +import networkx as nx +from typing import Callable, Optional +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match + + from mod import ruleGMLString +def graph_isomorphism( + graph_1: nx.Graph, + graph_2: nx.Graph, + node_match: Optional[Callable] = None, + edge_match: Optional[Callable] = None, + use_defaults: bool = False, +) -> bool: + """ + Determines if two graphs are isomorphic, considering provided node and edge matching + functions. Uses default matching settings if none are provided. + + Parameters: + - graph_1 (nx.Graph): The first graph to compare. + - graph_2 (nx.Graph): The second graph to compare. + - node_match (Optional[Callable]): The function used to match nodes. + Uses default if None. + - edge_match (Optional[Callable]): The function used to match edges. + Uses default if None. + + Returns: + - bool: True if the graphs are isomorphic, False otherwise. + """ + # Define default node and edge attributes and match settings + if use_defaults: + node_label_names = ["element", "charge"] + node_label_default = ["*", 0] + edge_attribute = "order" + + # Default node and edge match functions if not provided + if node_match is None: + node_match = generic_node_match( + node_label_names, node_label_default, [eq] * len(node_label_names) + ) + if edge_match is None: + edge_match = generic_edge_match(edge_attribute, 1, eq) + + # Perform the isomorphism check using NetworkX + return nx.is_isomorphic( + graph_1, graph_2, node_match=node_match, edge_match=edge_match + ) + + def rule_isomorphism( rule_1: str, rule_2: str, morphism_type: str = "isomorphic" ) -> bool: diff --git a/synutility/SynGraph/GML/parse_rule.py b/synutility/SynGraph/GML/parse_rule.py index 43a0271..08b107f 100644 --- a/synutility/SynGraph/GML/parse_rule.py +++ b/synutility/SynGraph/GML/parse_rule.py @@ -104,14 +104,75 @@ def filter_context(context_lines, relevant_nodes): return filtered_context -def strip_context(gml_text: str) -> str: +# def strip_context(gml_text: str) -> str: +# """ +# Filters the 'context' section of GML-like content to remove hydrogen nodes +# that do not appear in both 'left' and 'right' sections, along with their edges. +# Preserves the original structure and formatting of the GML. + +# Parameters: +# - gml_text (str): GML-like content describing a chemical reaction rule. + +# Returns: +# - str: The modified GML content with the filtered 'context' section. +# """ +# lines = gml_text.split("\n") + +# # Locate main sections: rule, left, context, right +# rule_start, rule_end = find_block(lines, "rule [") +# left_start, left_end = find_block(lines, "left [") +# context_start, context_end = find_block(lines, "context [") +# right_start, right_end = find_block(lines, "right [") + +# # If we cannot find proper structure, return original text +# if any( +# x is None +# for x in [ +# rule_start, +# rule_end, +# left_start, +# left_end, +# context_start, +# context_end, +# right_start, +# right_end, +# ] +# ): +# return gml_text + +# # fmt: off +# left_lines = lines[left_start: left_end + 1] +# context_lines = lines[context_start: context_end + 1] +# right_lines = lines[right_start: right_end + 1] +# # fmt: on + +# # Determine relevant nodes by intersection of nodes in left and right edges +# left_nodes = get_nodes_from_edges(left_lines) +# right_nodes = get_nodes_from_edges(right_lines) +# relevant_nodes = left_nodes.intersection(right_nodes) + +# # Filter the context section +# filtered_context = filter_context(context_lines, relevant_nodes) + +# # Rebuild the full GML text +# # Replace the original context lines with the filtered context lines +# # fmt: off +# new_lines = lines[:context_start] + filtered_context + lines[context_end + 1:] +# # fmt: on + +# return "\n".join(new_lines) + + +def strip_context(gml_text: str, remove_all: bool = True) -> str: """ - Filters the 'context' section of GML-like content to remove hydrogen nodes - that do not appear in both 'left' and 'right' sections, along with their edges. - Preserves the original structure and formatting of the GML. + Filters or clears the 'context' section of GML-like content based on the remove_all flag. + If remove_all is True, all edges in the 'context' section are removed. + If False, it removes hydrogen nodes that do not appear in both 'left' and 'right' sections, + along with their edges, while preserving the original structure and formatting of the GML. Parameters: - gml_text (str): GML-like content describing a chemical reaction rule. + - remove_all (bool): Flag to determine if all edges should be removed from the 'context'. Returns: - str: The modified GML content with the filtered 'context' section. @@ -141,21 +202,28 @@ def strip_context(gml_text: str) -> str: return gml_text # fmt: off - left_lines = lines[left_start: left_end + 1] context_lines = lines[context_start: context_end + 1] - right_lines = lines[right_start: right_end + 1] - # fmt: on # Determine relevant nodes by intersection of nodes in left and right edges - left_nodes = get_nodes_from_edges(left_lines) - right_nodes = get_nodes_from_edges(right_lines) + left_nodes = get_nodes_from_edges(lines[left_start: left_end + 1]) + right_nodes = get_nodes_from_edges(lines[right_start: right_end + 1]) + # fmt: on relevant_nodes = left_nodes.intersection(right_nodes) - # Filter the context section + # Filter the context section based on relevant nodes filtered_context = filter_context(context_lines, relevant_nodes) + if remove_all: + # Remove all edges from the context + # Retain only node lines and other structural lines + final_context = [] + for line in filtered_context: + if not EDGE_REGEX.search(line.strip()): + final_context.append(line) + filtered_context = final_context + # Rebuild the full GML text - # Replace the original context lines with the filtered context lines + # Replace the original context lines with the filtered or cleared context lines # fmt: off new_lines = lines[:context_start] + filtered_context + lines[context_end + 1:] # fmt: on