diff --git a/.gitignore b/.gitignore index 14385be..1706557 100644 --- a/.gitignore +++ b/.gitignore @@ -3,11 +3,5 @@ *.csv */catboost_info/* *.ipynb -test.py -rebalance_test.py -split_comparison.py -fp.py - *.json -split_benchmark_process.py -synutility/SynChem/Reaction/misc.py +test_mod.py diff --git a/Data/test.pkl.gz b/Data/test.pkl.gz new file mode 100644 index 0000000..e1caa71 Binary files /dev/null and b/Data/test.pkl.gz differ diff --git a/Test/SynChem/Molecule/__init__.py b/Test/SynChem/Molecule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynChem/Molecule/test_standardize.py b/Test/SynChem/Molecule/test_standardize.py new file mode 100644 index 0000000..46c5644 --- /dev/null +++ b/Test/SynChem/Molecule/test_standardize.py @@ -0,0 +1,90 @@ +import unittest +from rdkit import Chem +from synutility.SynChem.Molecule.standardize import ( + normalize_molecule, + canonicalize_tautomer, + salts_remover, + uncharge_molecule, + fragments_remover, + remove_explicit_hydrogens, + remove_radicals_and_add_hydrogens, + remove_isotopes, + clear_stereochemistry, +) + + +class TestMoleculeFunctions(unittest.TestCase): + + def test_normalize_molecule(self): + smi = "[Na]OC(=O)c1ccc(C[S+2]([O-])([O-]))cc1" + expect = "O=C(O[Na])c1ccc(C[S](=O)=O)cc1" + mol = Chem.MolFromSmiles(smi) + normalized_mol = normalize_molecule(mol) + self.assertIsInstance(normalized_mol, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(normalized_mol)) + + def test_canonicalize_tautomer(self): + smi = "N=c1[nH]cc[nH]1" + expect = "Nc1ncc[nH]1" + mol = Chem.MolFromSmiles(smi) + tautomer = canonicalize_tautomer(mol) + self.assertIsInstance(tautomer, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(tautomer)) + + def test_salts_remover(self): + smi = "CC(=O).[Na+]" + expect = "CC=O" + mol = Chem.MolFromSmiles(smi) + remover = salts_remover(mol) + self.assertIsInstance(remover, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(remover)) + + def test_uncharge_molecule(self): + smi = "CC(=O)[O-]" + expect = "CC(=O)O" + mol = Chem.MolFromSmiles(smi) + uncharged_mol = uncharge_molecule(mol) + self.assertIsInstance(uncharged_mol, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(uncharged_mol)) + + def test_fragments_remover(self): + smi = "CC(=O)[O-].[Na+]" + expect = "CC(=O)[O-]" + mol = Chem.MolFromSmiles(smi) + remover = fragments_remover(mol) + self.assertIsInstance(remover, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(remover)) + + def test_remove_explicit_hydrogens(self): + smi = "[CH4]" + expect = "C" + mol = Chem.MolFromSmiles(smi) + remover = remove_explicit_hydrogens(mol) + self.assertIsInstance(remover, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(remover)) + + def test_remove_radicals(self): + smi = "[CH3]" + expect = "C" + mol = Chem.MolFromSmiles(smi) + remover = remove_radicals_and_add_hydrogens(mol) + self.assertIsInstance(remover, Chem.Mol) + self.assertEqual(expect, Chem.MolToSmiles(remover)) + + def test_remove_isotopes(self): + # Molecule with isotopic labeling + smiles = "[13CH3]C([2H])([2H])[17O][18OH]" + expect = "[H]C([H])(C)OO" + mol = Chem.MolFromSmiles(smiles) + result_mol = remove_isotopes(mol) + for atom in result_mol.GetAtoms(): + self.assertEqual(atom.GetIsotope(), 0, "Isotopes not properly removed") + self.assertEqual(Chem.MolToSmiles(result_mol), expect) + + def test_clear_stereochemistry(self): + # Molecule with defined stereochemistry + smiles = "C[C@H](O)[C@@H](O)C" + mol = Chem.MolFromSmiles(smiles) + result_mol = clear_stereochemistry(mol) + has_stereo = any(atom.HasProp("_CIPCode") for atom in result_mol.GetAtoms()) + self.assertFalse(has_stereo, "Stereochemistry not properly cleared") diff --git a/Test/SynChem/Reaction/test_cleanning.py b/Test/SynChem/Reaction/test_cleanning.py new file mode 100644 index 0000000..18c8c00 --- /dev/null +++ b/Test/SynChem/Reaction/test_cleanning.py @@ -0,0 +1,26 @@ +import unittest +from synutility.SynChem.Reaction.cleanning import Cleanning + + +class TestCleaning(unittest.TestCase): + + def setUp(self): + self.cleaner = Cleanning() + + def test_remove_duplicates(self): + input_smiles = ["CC>>CC", "CC>>CC"] + expected_output = ["CC>>CC"] + result = self.cleaner.remove_duplicates(input_smiles) + self.assertEqual( + result, expected_output, "Failed to remove duplicates correctly" + ) + + def test_clean_smiles(self): + input_smiles = ["CC>>CC", "CC>>CC", "CC>>CCC"] + expected_output = ["CC>>CC"] # Assuming 'CC>>CCC' is not balanced + result = self.cleaner.clean_smiles(input_smiles) + self.assertEqual(result, expected_output, "Failed to clean SMILES correctly") + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Descriptor/__init__.py b/Test/SynGraph/Descriptor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynGraph/test_graph_descriptors.py b/Test/SynGraph/Descriptor/test_graph_descriptors.py similarity index 98% rename from Test/SynGraph/test_graph_descriptors.py rename to Test/SynGraph/Descriptor/test_graph_descriptors.py index 0b350e3..0acdc42 100644 --- a/Test/SynGraph/test_graph_descriptors.py +++ b/Test/SynGraph/Descriptor/test_graph_descriptors.py @@ -1,6 +1,6 @@ import unittest import networkx as nx -from synutility.SynGraph.graph_descriptors import GraphDescriptor +from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor class TestGraphDescriptor(unittest.TestCase): diff --git a/Test/SynGraph/Descriptor/test_graph_signature.py b/Test/SynGraph/Descriptor/test_graph_signature.py new file mode 100644 index 0000000..dc6fb59 --- /dev/null +++ b/Test/SynGraph/Descriptor/test_graph_signature.py @@ -0,0 +1,44 @@ +import unittest +from synutility.SynIO.data_type import load_from_pickle +from synutility.SynGraph.Descriptor.graph_signature import GraphSignature + + +class TestGraphSignature(unittest.TestCase): + + def setUp(self): + # Create a sample graph for testing + data = load_from_pickle("Data/test.pkl.gz") + self.rc = data[0]["GraphRules"][2] + self.its = data[0]["ITSGraph"][2] + + def test_create_topology_signature(self): + signature = GraphSignature(self.rc) + self.assertEqual(signature.create_topology_signature(), "114") + + def test_create_node_signature(self): + signature = GraphSignature(self.rc) + self.assertEqual(signature.create_node_signature(), "BrCHN") + + def test_create_node_signature_condensed(self): + signature = GraphSignature(self.its) + self.assertEqual(signature.create_node_signature(), "BrC{23}ClHN{3}O{5}S") + + def test_create_edge_signature(self): + signature = GraphSignature(self.rc) + self.assertEqual( + signature.create_edge_signature(), "Br[-1]H/Br[1]C/C[-1]N/H[1]N" + ) + + def test_create_graph_signature(self): + # Ensure the graph signature combines the results correctly + signature = GraphSignature(self.rc) + node_signature = "BrCHN" + edge_signature = "Br[-1]H/Br[1]C/C[-1]N/H[1]N" + topo_signature = "114" + expected = f"{topo_signature}.{node_signature}.{edge_signature}" + self.assertEqual(signature.create_graph_signature(), expected) + + +# Running the tests +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Fingerprint/__init__.py b/Test/SynGraph/Fingerprint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynGraph/Fingerprint/test_graph_fps.py b/Test/SynGraph/Fingerprint/test_graph_fps.py new file mode 100644 index 0000000..4a84908 --- /dev/null +++ b/Test/SynGraph/Fingerprint/test_graph_fps.py @@ -0,0 +1,59 @@ +import unittest +import networkx as nx +from synutility.SynGraph.Fingerprint.graph_fps import GraphFP + + +class TestGraphFP(unittest.TestCase): + + def setUp(self): + """Set up a test graph for use in all test cases.""" + self.graph = nx.gnp_random_graph(10, 0.5, seed=42) + self.nBits = 512 + self.hash_alg = "sha256" + self.fp_class = GraphFP( + graph=self.graph, nBits=self.nBits, hash_alg=self.hash_alg + ) + + def test_spectrum_fp(self): + """Test the spectrum-based fingerprint generation.""" + fingerprint = self.fp_class.fingerprint("spectrum") + self.assertEqual(len(fingerprint), self.nBits) + self.assertTrue(isinstance(fingerprint, str)) + + def test_adjacency_fp(self): + """Test the adjacency matrix-based fingerprint generation.""" + fingerprint = self.fp_class.fingerprint("adjacency") + self.assertEqual(len(fingerprint), self.nBits) + self.assertTrue(isinstance(fingerprint, str)) + + def test_degree_sequence_fp(self): + """Test the degree sequence-based fingerprint generation.""" + fingerprint = self.fp_class.fingerprint("degree") + self.assertEqual(len(fingerprint), self.nBits) + self.assertTrue(isinstance(fingerprint, str)) + + def test_motif_count_fp(self): + """Test the motif count-based fingerprint generation.""" + fingerprint = self.fp_class.fingerprint("motif") + self.assertEqual(len(fingerprint), self.nBits) + self.assertTrue(isinstance(fingerprint, str)) + + def test_iterative_deepening(self): + """Test the iterative deepening method.""" + short_fingerprint = "1010101010101010" + remaining_bits = self.nBits - len(short_fingerprint) + extended_fingerprint = self.fp_class.iterative_deepening(remaining_bits) + self.assertEqual(len(extended_fingerprint), remaining_bits) + self.assertTrue(isinstance(extended_fingerprint, str)) + + def test_fingerprint_length(self): + """Test that each method produces a fingerprint of exactly nBits.""" + methods = ["spectrum", "adjacency", "degree", "motif"] + for method in methods: + with self.subTest(method=method): + fingerprint = self.fp_class.fingerprint(method) + self.assertEqual(len(fingerprint), self.nBits) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Fingerprint/test_hash_fps.py b/Test/SynGraph/Fingerprint/test_hash_fps.py new file mode 100644 index 0000000..b9bb3a3 --- /dev/null +++ b/Test/SynGraph/Fingerprint/test_hash_fps.py @@ -0,0 +1,52 @@ +import unittest +import networkx as nx +from synutility.SynGraph.Fingerprint.hash_fps import HashFPs +from synutility.SynIO.data_type import load_from_pickle + + +class TestHashFPs(unittest.TestCase): + def setUp(self): + """Set up a simple graph for testing.""" + self.graph = nx.cycle_graph(4) # Simple cycle graph with 4 nodes + self.hasher = HashFPs(self.graph, numBits=128, hash_alg="sha256") + + def test_hash_fps_default(self): + """Test the default hash generation without specifying start or end nodes.""" + result = self.hasher.hash_fps() + self.assertEqual(len(result), 128) + self.assertIsInstance(result, str) + self.assertTrue(all(c in "01" for c in result), "Hash must be binary") + + def test_hash_fps_path_specified(self): + """Test hash generation with specified start and end nodes.""" + result = self.hasher.hash_fps(start_node=0, end_node=1) + self.assertEqual(len(result), 128) + self.assertTrue(all(c in "01" for c in result), "Hash must be binary") + + def test_hash_fps_invalid_hash_algorithm(self): + """Test initialization with an invalid hash algorithm.""" + with self.assertRaises(ValueError): + HashFPs(self.graph, numBits=128, hash_alg="invalid256") + + def test_hash_fps_negative_numBits(self): + """Test initialization with negative numBits.""" + with self.assertRaises(ValueError): + HashFPs(self.graph, numBits=-1, hash_alg="sha256") + + def test_hash_fps_large_numBits(self): + """Test hash generation with a large numBits.""" + large_hasher = HashFPs(self.graph, numBits=1024, hash_alg="sha512") + result = large_hasher.hash_fps() + self.assertEqual(len(result), 1024) + self.assertTrue(all(c in "01" for c in result), "Hash must be binary") + + def test_fps_rc(self): + data = load_from_pickle("Data/test.pkl.gz") + graph = data[0]["GraphRules"][2] + hasher = HashFPs(graph, numBits=1024, hash_alg="sha256") + result = hasher.hash_fps() + self.assertEqual(len(result), 1024) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Fingerprint/test_morgan_fps.py b/Test/SynGraph/Fingerprint/test_morgan_fps.py new file mode 100644 index 0000000..01304ee --- /dev/null +++ b/Test/SynGraph/Fingerprint/test_morgan_fps.py @@ -0,0 +1,39 @@ +import unittest +import networkx as nx +from synutility.SynGraph.Fingerprint.morgan_fps import MorganFPs +from synutility.SynIO.data_type import load_from_pickle + + +class TestMorganFPs(unittest.TestCase): + def setUp(self): + self.graph = nx.cycle_graph(5) # Creates a cycle graph for testing + self.morgan_fps = MorganFPs(self.graph, radius=2, nBits=128, hash_alg="sha256") + + def test_fingerprint_length(self): + """Test that the fingerprint is exactly the specified bit length.""" + fingerprint = self.morgan_fps.generate_fingerprint() + self.assertEqual(len(fingerprint), 128) + + def test_fingerprint_consistency(self): + """Test that the same graph with the same parameters produces the same fingerprint.""" + fingerprint1 = self.morgan_fps.generate_fingerprint() + fingerprint2 = self.morgan_fps.generate_fingerprint() + self.assertEqual(fingerprint1, fingerprint2) + + def test_fingerprint_variation_with_radius(self): + """Test that changing the radius changes the fingerprint.""" + new_morgan_fps = MorganFPs(self.graph, radius=1, nBits=128, hash_alg="sha256") + fingerprint1 = self.morgan_fps.generate_fingerprint() + fingerprint2 = new_morgan_fps.generate_fingerprint() + self.assertNotEqual(fingerprint1, fingerprint2) + + def test_fps_rc(self): + data = load_from_pickle("Data/test.pkl.gz") + graph = data[0]["GraphRules"][2] + hasher = MorganFPs(graph, radius=3, nBits=1024, hash_alg="sha256") + result = hasher.generate_fingerprint() + self.assertEqual(len(result), 1024) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynGraph/Fingerprint/test_path_fps.py b/Test/SynGraph/Fingerprint/test_path_fps.py new file mode 100644 index 0000000..0fc9d8f --- /dev/null +++ b/Test/SynGraph/Fingerprint/test_path_fps.py @@ -0,0 +1,40 @@ +import unittest +import networkx as nx +from synutility.SynGraph.Fingerprint.path_fps import PathFPs +from synutility.SynIO.data_type import load_from_pickle + + +class TestPathFPs(unittest.TestCase): + def setUp(self): + self.graph = nx.path_graph(5) # Creates a simple path graph + self.path_fps = PathFPs(self.graph, max_length=3, nBits=64, hash_alg="sha256") + + def test_fingerprint_length(self): + """Test that the fingerprint has the exact length specified by nBits.""" + fingerprint = self.path_fps.generate_fingerprint() + self.assertEqual(len(fingerprint), 64) + + def test_fingerprint_consistency(self): + """Test that the same graph with the same parameters produces the same + fingerprint.""" + fingerprint1 = self.path_fps.generate_fingerprint() + fingerprint2 = self.path_fps.generate_fingerprint() + self.assertEqual(fingerprint1, fingerprint2) + + def test_fingerprint_variation(self): + """Test that changing the parameters changes the fingerprint.""" + new_path_fps = PathFPs(self.graph, max_length=4, nBits=128, hash_alg="sha256") + fingerprint1 = self.path_fps.generate_fingerprint() + fingerprint2 = new_path_fps.generate_fingerprint() + self.assertNotEqual(fingerprint1, fingerprint2) + + def test_fps_rc(self): + data = load_from_pickle("Data/test.pkl.gz") + graph = data[0]["GraphRules"][2] + hasher = PathFPs(graph, max_length=5, nBits=1024, hash_alg="sha256") + result = hasher.generate_fingerprint() + self.assertEqual(len(result), 1024) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/__init__.py b/Test/SynIO/Format/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Test/SynIO/Format/test_gml_to_nx.py b/Test/SynIO/Format/test_gml_to_nx.py new file mode 100644 index 0000000..b6c3f60 --- /dev/null +++ b/Test/SynIO/Format/test_gml_to_nx.py @@ -0,0 +1,130 @@ +import unittest +from synutility.SynIO.data_type import load_from_pickle +from synutility.SynIO.Format.gml_to_nx import GMLToNX +from synutility.SynIO.Format.nx_to_gml import NXToGML +from synutility.SynIO.Format.isomorphism import isomorphism_check + + +class TestGMLToNX(unittest.TestCase): + + def setUp(self) -> None: + data = load_from_pickle("Data/test.pkl.gz")[0] + self.ground_truth_its = data["ITSGraph"] + self.ground_truth_rc = data["GraphRules"] + self.rule_its = NXToGML.transform(self.ground_truth_its) + self.rule_rc = NXToGML.transform(self.ground_truth_rc) + self.parser = GMLToNX(gml_text="") + gml_formatted_str = ( + "rule [\n" + ' ruleID "Test"\n' + " left [\n" + ' edge [ source 11 target 35 label "-" ]\n' + ' edge [ source 28 target 29 label "-" ]\n' + " ]\n" + " context [\n" + ' node [ id 11 label "N" ]\n' + ' node [ id 35 label "H" ]\n' + ' node [ id 28 label "C" ]\n' + ' node [ id 29 label "Br" ]\n' + " ]\n" + " right [\n" + ' edge [ source 11 target 28 label "-" ]\n' + ' edge [ source 35 target 29 label "-" ]\n' + " ]\n" + "]" + ) + self.parser_gml = GMLToNX(gml_formatted_str) + + def test_parse_element(self): + """ + Test the parsing of nodes and edges from the provided GML string. + """ + # Manually parse elements for testing + self.parser_gml._parse_element('node [ id 11 label "N" ]', "context") + self.parser_gml._parse_element('edge [ source 11 target 35 label "-" ]', "left") + + expected_node = ("11", {"element": "N", "charge": 0, "atom_map": 11}) + expected_edge = (11, 35, {"order": 1.0}) + + actual_node = self.parser_gml.graphs["context"].nodes(data=True)[11] + actual_edge = self.parser_gml.graphs["left"][11][35] + + self.assertEqual( + expected_node[1], + actual_node, + "Node attributes do not match expected values.", + ) + self.assertEqual( + expected_edge[2], + actual_edge, + "Edge attributes do not match expected values.", + ) + + def test_synchronize_nodes(self): + """ + Test the synchronization of nodes across different graph sections after parsing. + """ + # Simulate parsing nodes into the context graph + self.parser_gml.graphs["context"].add_node( + 11, element="N", charge=0, atom_map=11 + ) + self.parser_gml.graphs["context"].add_node( + 35, element="H", charge=0, atom_map=35 + ) + # Running synchronization + self.parser_gml._synchronize_nodes() + # Checking if nodes are present in left and right graphs + self.assertIn(11, self.parser_gml.graphs["left"]) + self.assertIn(35, self.parser_gml.graphs["right"]) + + def test_extract_simple_element(self): + """ + Test the extraction of an element without a charge. + """ + element, charge = self.parser._extract_element_and_charge("C") + self.assertEqual(element, "C") + self.assertEqual(charge, 0) + + def test_extract_element_with_positive_charge(self): + """ + Test the extraction of an element with a positive charge. + """ + element, charge = self.parser._extract_element_and_charge("Na+") + self.assertEqual(element, "Na") + self.assertEqual(charge, 1) + + def test_extract_element_with_negative_charge(self): + """ + Test the extraction of an element with a negative charge. + """ + element, charge = self.parser._extract_element_and_charge("Cl-") + self.assertEqual(element, "Cl") + self.assertEqual(charge, -1) + + def test_extract_element_with_multi_digit_charge(self): + """ + Test the extraction of an element with a multiple digit charge. + """ + element, charge = self.parser._extract_element_and_charge("Mg2+") + self.assertEqual(element, "Mg") + self.assertEqual(charge, 2) + + def test_extract_element_with_no_charge_number(self): + """ + Test the extraction where the charge number is implied as 1. + """ + element, charge = self.parser._extract_element_and_charge("K+") + self.assertEqual(element, "K") + self.assertEqual(charge, 1) + + def test_transform(self): + self.graphs_its = GMLToNX(self.rule_its).transform() + self.graphs_rc = GMLToNX(self.rule_rc).transform() + for key, _ in enumerate(self.graphs_its): + self.assertTrue( + isomorphism_check(self.graphs_its[key], self.ground_truth_its[key]) + ) + for key, _ in enumerate(self.graphs_rc): + self.assertTrue( + isomorphism_check(self.graphs_rc[key], self.ground_truth_rc[key]) + ) diff --git a/Test/SynIO/Format/test_graph_to_mol.py b/Test/SynIO/Format/test_graph_to_mol.py new file mode 100644 index 0000000..d7f7479 --- /dev/null +++ b/Test/SynIO/Format/test_graph_to_mol.py @@ -0,0 +1,59 @@ +import unittest +from rdkit import Chem +import networkx as nx +from synutility.SynIO.Format.graph_to_mol import GraphToMol + + +class TestGraphToMol(unittest.TestCase): + def setUp(self): + # Define node and edge attributes mappings + self.node_attributes = {"element": "element", "charge": "charge"} + self.edge_attributes = {"order": "order"} + self.converter = GraphToMol(self.node_attributes, self.edge_attributes) + + def test_simple_molecule_conversion(self): + # Create a simple water molecule graph + graph = nx.Graph() + graph.add_node(0, element="O", charge=0) + graph.add_node(1, element="H", charge=0) + graph.add_node(2, element="H", charge=0) + graph.add_edges_from([(0, 1), (0, 2)], order=1) + + mol = self.converter.graph_to_mol(graph) + smiles = Chem.CanonSmiles(Chem.MolToSmiles(mol)) + self.assertEqual(smiles, "O") + + def test_bond_order_handling(self): + # Create a graph representing ethene (C=C) + graph = nx.Graph() + graph.add_node(0, element="C", charge=0) + graph.add_node(1, element="C", charge=0) + graph.add_edge(0, 1, order=2) + + mol = self.converter.graph_to_mol(graph) + self.assertEqual(Chem.MolToSmiles(mol), "C=C") + + def test_ignore_bond_order(self): + # Create a graph representing ethene (C=C) but ignore bond order + graph = nx.Graph() + graph.add_node(0, element="C", charge=0) + graph.add_node(1, element="C", charge=0) + graph.add_edge(0, 1, order=2) + + mol = self.converter.graph_to_mol(graph, ignore_bond_order=True) + self.assertEqual(Chem.MolToSmiles(mol), "CC") + + def test_molecule_with_charges(self): + # Create a graph representing a charged molecule [NH4+] + graph = nx.Graph() + graph.add_node(0, element="N", charge=1) + for i in range(1, 5): + graph.add_node(i, element="H", charge=0) + graph.add_edge(0, i, order=1) + + mol = self.converter.graph_to_mol(graph) + self.assertEqual(Chem.CanonSmiles(Chem.MolToSmiles(mol)), "[NH4+]") + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_its_construction.py b/Test/SynIO/Format/test_its_construction.py new file mode 100644 index 0000000..0b34ebb --- /dev/null +++ b/Test/SynIO/Format/test_its_construction.py @@ -0,0 +1,51 @@ +import unittest +import networkx as nx +from synutility.SynIO.Format.its_construction import ITSConstruction + + +class TestITSConstruction(unittest.TestCase): + + def setUp(self): + # Create test graphs G and H with predefined attributes and edges + + # Ethylen C=C + self.G = nx.Graph() + self.G.add_node(1, element="C", aromatic=False, hcount=2, charge=0) + self.G.add_node(2, element="C", aromatic=False, hcount=2, charge=0) + self.G.add_edge(1, 2, order=2) + + # Ethan C-C + self.H = nx.Graph() + self.H.add_node(1, element="C", aromatic=False, hcount=3, charge=0) + self.H.add_node(2, element="C", aromatic=False, hcount=3, charge=0) + self.H.add_edge(1, 2, order=1) # Different order + + def test_ITSGraph(self): + ITS = ITSConstruction.ITSGraph(self.G, self.H) + self.assertTrue(isinstance(ITS, nx.Graph)) + self.assertEqual(len(ITS.nodes()), 2) + self.assertEqual(len(ITS.edges()), 1) + self.assertEqual(ITS[1][2]["order"], (2, 1)) + + def test_get_node_attributes_with_defaults(self): + attributes = ITSConstruction.get_node_attributes_with_defaults(self.G, 1) + self.assertEqual(attributes, ("C", False, 2, 0, ["", ""])) + + def test_add_edges_to_ITS(self): + ITS = nx.Graph() + ITS.add_node(1, element="C", aromatic=False, hcount=3, charge=0) + ITS.add_node(2, element="C", aromatic=False, hcount=3, charge=0) + new_ITS = ITSConstruction.add_edges_to_ITS(ITS, self.G, self.H) + self.assertTrue(isinstance(new_ITS, nx.Graph)) + self.assertEqual(len(new_ITS.edges()), 1) + self.assertEqual(new_ITS[1][2]["order"], (2, 1)) + + def test_add_standard_order_attribute(self): + graph = nx.Graph() + graph.add_edge(1, 2, order=(1, 2)) + updated_graph = ITSConstruction.add_standard_order_attribute(graph) + self.assertEqual(updated_graph[1][2]["standard_order"], -1) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_mol_to_graph.py b/Test/SynIO/Format/test_mol_to_graph.py new file mode 100644 index 0000000..8dd765a --- /dev/null +++ b/Test/SynIO/Format/test_mol_to_graph.py @@ -0,0 +1,49 @@ +import unittest +from rdkit import Chem +import networkx as nx +from synutility.SynIO.Format.mol_to_graph import MolToGraph + + +class TestMolToGraph(unittest.TestCase): + def setUp(self): + self.converter = MolToGraph() + # Example molecule: Ethanol + self.ethanol_smiles = "CCO" + self.mol = Chem.MolFromSmiles(self.ethanol_smiles) + + def test_add_partial_charges(self): + MolToGraph.add_partial_charges(self.mol) + for atom in self.mol.GetAtoms(): + self.assertTrue(atom.HasProp("_GasteigerCharge")) + + def test_get_stereochemistry(self): + # Test with chiral molecule + chiral_smiles = "CC[C@@H](C)O" + chiral_mol = Chem.MolFromSmiles(chiral_smiles) + chiral_atom = chiral_mol.GetAtomWithIdx(2) # The chiral carbon + stereo = MolToGraph.get_stereochemistry(chiral_atom) + self.assertIn(stereo, ["R", "S"]) + + def test_get_bond_stereochemistry(self): + # Test with E-stilbene + e_stilbene_smiles = "C/C=C/C" + e_stilbene_mol = Chem.MolFromSmiles(e_stilbene_smiles) + double_bond = e_stilbene_mol.GetBondWithIdx(1) # The double bond + stereo = MolToGraph.get_bond_stereochemistry(double_bond) + self.assertIn(stereo, ["E", "Z", "N"]) + + def test_mol_to_graph(self): + graph = self.converter.mol_to_graph(self.mol) + self.assertIsInstance(graph, nx.Graph) + # Check for expected number of nodes and edges + self.assertEqual(len(graph.nodes), self.mol.GetNumAtoms()) + self.assertEqual(len(graph.edges), self.mol.GetNumBonds()) + # Check attributes of an arbitrary atom and bond + some_atom = list(graph.nodes(data=True))[0] + some_edge = list(graph.edges(data=True))[0] + self.assertIn("element", some_atom[1]) + self.assertIn("order", some_edge[2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_nx_to_gml.py b/Test/SynIO/Format/test_nx_to_gml.py new file mode 100644 index 0000000..9095836 --- /dev/null +++ b/Test/SynIO/Format/test_nx_to_gml.py @@ -0,0 +1,99 @@ +import unittest +import networkx as nx +from synutility.SynIO.data_type import load_from_pickle +from synutility.SynIO.Format.nx_to_gml import NXToGML + + +class TestRuleWriting(unittest.TestCase): + + def setUp(self) -> None: + self.data = load_from_pickle("Data/test.pkl.gz")[0] + + def test_charge_to_string(self): + self.assertEqual(NXToGML._charge_to_string(3), "3+") + self.assertEqual(NXToGML._charge_to_string(-2), "2-") + self.assertEqual(NXToGML._charge_to_string(0), "") + + def test_find_changed_nodes(self): + G1 = nx.Graph() + G1.add_node(1, element="C", charge=0) + G2 = nx.Graph() + G2.add_node(1, element="C", charge=1) + changed_nodes = NXToGML._find_changed_nodes(G1, G2, ["charge"]) + self.assertEqual(changed_nodes, [1]) + + def test_convert_graph_to_gml_context(self): + G = nx.Graph() + G.add_node(1, element="C") + G.add_node(2, element="H") + changed_node_ids = [2] + gml_str = NXToGML._convert_graph_to_gml(G, "context", changed_node_ids) + expected_str = ' context [\n node [ id 1 label "C" ]\n ]\n' + self.assertEqual(gml_str, expected_str) + + def test_convert_graph_to_gml_left_right(self): + G = nx.Graph() + G.add_node(1, element="C", charge=1) + G.add_node(2, element="H", charge=0) + G.add_edge(1, 2, order=2) + changed_node_ids = [1] + gml_str = NXToGML._convert_graph_to_gml(G, "left", changed_node_ids) + expected_str = ( + ' left [\n edge [ source 1 target 2 label "=" ]' + + '\n node [ id 1 label "C+" ]\n ]\n' + ) + self.assertEqual(gml_str, expected_str) + + def test_rules_grammar(self): + L, R, K = self.data["GraphRules"] + changed_node_ids = NXToGML._find_changed_nodes(L, R, ["charge"]) + rule_name = "test_rule" + gml_str = NXToGML._rule_grammar(L, R, K, rule_name, changed_node_ids) + expected_str = ( + "rule [\n" + ' ruleID "test_rule"\n' + " left [\n" + ' edge [ source 11 target 35 label "-" ]\n' + ' edge [ source 28 target 29 label "-" ]\n' + " ]\n" + " context [\n" + ' node [ id 11 label "N" ]\n' + ' node [ id 35 label "H" ]\n' + ' node [ id 28 label "C" ]\n' + ' node [ id 29 label "Br" ]\n' + " ]\n" + " right [\n" + ' edge [ source 11 target 28 label "-" ]\n' + ' edge [ source 35 target 29 label "-" ]\n' + " ]\n" + "]" + ) + self.assertEqual(gml_str, expected_str) + + def test_transform(self): + graph_rules = self.data["GraphRules"] + gml_str = NXToGML.transform(graph_rules, rule_name="test_rule", reindex=True) + expected_str = ( + "rule [\n" + ' ruleID "test_rule"\n' + " left [\n" + ' edge [ source 1 target 2 label "-" ]\n' + ' edge [ source 3 target 4 label "-" ]\n' + " ]\n" + " context [\n" + ' node [ id 1 label "N" ]\n' + ' node [ id 2 label "H" ]\n' + ' node [ id 3 label "C" ]\n' + ' node [ id 4 label "Br" ]\n' + " ]\n" + " right [\n" + ' edge [ source 1 target 3 label "-" ]\n' + ' edge [ source 2 target 4 label "-" ]\n' + " ]\n" + "]" + ) + self.assertEqual(gml_str, expected_str) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/__init__.py b/Test/SynIO/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 79853e9..2557ff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synutility" -version = "0.0.8" +version = "0.0.9" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] @@ -20,12 +20,12 @@ dependencies = [ "scikit-learn>=1.4.0", "pandas>=1.5.3", "rdkit>=2024.3.3", - "networkx==3.3", - "seaborn==0.13.2", + "networkx>=3.3", + "seaborn>=0.13.2", ] [project.optional-dependencies] -all = ["drfp==0.3.6", "xgboost==2.1.1", "fgutils>=0.1.3", "rxn-chem-utils==1.5.0", "rxn-utils==2.0.0", "rxnmapper==0.3.0"] +all = ["drfp==0.3.6", "xgboost>=2.1.1", "fgutils>=0.1.3", "rxn-chem-utils==1.5.0", "rxn-utils==2.0.0", "rxnmapper==0.3.0"] [project.urls] homepage = "https://github.com/TieuLongPhan/SynUtils" diff --git a/requirements.txt b/requirements.txt index e9ad7ad..f4e805f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ drfp==0.3.6 fgutils>=0.1.3 rxn-chem-utils==1.5.0 rxn-utils==2.0.0 -rxnmapper==0.3.0 \ No newline at end of file +rxnmapper==0.3.0 +rdkit >= 2024.3.3 \ No newline at end of file diff --git a/synutility/SynChem/Molecule/__init__.py b/synutility/SynChem/Molecule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynChem/Molecule/standardize.py b/synutility/SynChem/Molecule/standardize.py new file mode 100644 index 0000000..16569d7 --- /dev/null +++ b/synutility/SynChem/Molecule/standardize.py @@ -0,0 +1,137 @@ +from rdkit import Chem +from rdkit.Chem import rdmolops +from rdkit.Chem.MolStandardize import rdMolStandardize +from rdkit.Chem.SaltRemover import SaltRemover + + +def normalize_molecule(mol: Chem.Mol) -> Chem.Mol: + """ + Normalize a molecule using RDKit's Normalizer. + + Parameters: + - mol (Chem.Mol): RDKit Mol object to be normalized. + + Returns: + - Chem.Mol: Normalized RDKit Mol object. + """ + normalizer = rdMolStandardize.Normalizer() + return normalizer.normalize(mol) + + +def canonicalize_tautomer(mol: Chem.Mol) -> Chem.Mol: + """ + Canonicalize the tautomer of a molecule using RDKit's TautomerCanonicalizer. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with canonicalized tautomer. + """ + tautomer_canonicalizer = rdMolStandardize.TautomerEnumerator() + return tautomer_canonicalizer.Canonicalize(mol) + + +def salts_remover(mol: Chem.Mol) -> Chem.Mol: + """ + Remove salt fragments from a molecule using RDKit's SaltRemover. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with salts removed. + """ + remover = SaltRemover() + return remover.StripMol(mol) + + +def uncharge_molecule(mol: Chem.Mol) -> Chem.Mol: + """ + Neutralize a molecule by removing counter-ions using RDKit's Uncharger. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Neutralized Mol object. + """ + uncharger = rdMolStandardize.Uncharger() + return uncharger.uncharge(mol) + + +def fragments_remover(mol: Chem.Mol) -> Chem.Mol: + """ + Remove small fragments from a molecule, keeping only the largest one. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with small fragments removed. + """ + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=True) + return max(frags, default=None, key=lambda m: m.GetNumAtoms()) + + +def remove_explicit_hydrogens(mol: Chem.Mol) -> Chem.Mol: + """ + Remove explicit hydrogens from a molecule to leave only the heavy atoms. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with explicit hydrogens removed. + """ + return Chem.RemoveHs(mol) + + +def remove_radicals_and_add_hydrogens(mol: Chem.Mol) -> Chem.Mol: + """ + Remove radicals from a molecule by setting radical electrons to zero and adding hydrogens where needed. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with radicals removed and necessary hydrogens added. + """ + mol = Chem.RemoveHs(mol) # Remove explicit hydrogens first + for atom in mol.GetAtoms(): + if atom.GetNumRadicalElectrons() > 0: + atom.SetNumExplicitHs( + atom.GetNumExplicitHs() + atom.GetNumRadicalElectrons() + ) + atom.SetNumRadicalElectrons(0) + mol = rdmolops.AddHs(mol) # Add hydrogens back + return remove_explicit_hydrogens(mol) + + +def remove_isotopes(mol: Chem.Mol) -> Chem.Mol: + """ + Remove isotopic information from a molecule. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with isotopes removed. + """ + for atom in mol.GetAtoms(): + atom.SetIsotope(0) + return mol + + +def clear_stereochemistry(mol: Chem.Mol) -> Chem.Mol: + """ + Clear all stereochemical information from a molecule. + + Parameters: + - mol (Chem.Mol): RDKit Mol object. + + Returns: + - Chem.Mol: Mol object with stereochemistry cleared. + """ + Chem.RemoveStereochemistry(mol) + return mol diff --git a/synutility/SynChem/Reaction/cleanning.py b/synutility/SynChem/Reaction/cleanning.py new file mode 100644 index 0000000..5e685a8 --- /dev/null +++ b/synutility/SynChem/Reaction/cleanning.py @@ -0,0 +1,59 @@ +from typing import List +from synutility.SynChem.Reaction.balance_check import BalanceReactionCheck +from synutility.SynChem.Reaction.standardize import Standardize + + +class Cleanning: + def __init__(self) -> None: + pass + + @staticmethod + def remove_duplicates(smiles_list: List[str]) -> List[str]: + """ + Removes duplicate SMILES strings from a list, maintaining the order of + first occurrences. Uses a set to track seen SMILES for efficiency. + + Parameters: + - smiles_list (List[str]): A list of SMILES strings representing + chemical reactions. + + Returns: + - List[str]: A list with unique SMILES strings, preserving the original order. + """ + seen = set() + unique_smiles = [ + smiles for smiles in smiles_list if not (smiles in seen or seen.add(smiles)) + ] + return unique_smiles + + @staticmethod + def clean_smiles(smiles_list: List[str]) -> List[str]: + """ + Cleans a list of SMILES strings by standardizing them, checking their chemical + balance, and removing duplicates. Each SMILES is first checked for validity and + then standardized. Only balanced reactions are kept. + + Parameters: + - smiles_list (List[str]): A list of SMILES strings representing chemical reactions. + + Returns: + - List[str]: A list of cleaned and standardized SMILES strings. + """ + # Standardize and check balance in separate list comprehensions + standardizer = Standardize() + balance_checker = BalanceReactionCheck() + + standardized_smiles = [ + standardizer.standardize_rsmi(smiles, True) + for smiles in smiles_list + if smiles + ] + balanced_smiles = [ + smiles + for smiles in standardized_smiles + if balance_checker.rsmi_balance_check(smiles) + ] + + # Remove duplicates from the balanced SMILES list + clean_smiles = Cleanning.remove_duplicates(balanced_smiles) + return clean_smiles diff --git a/synutility/SynChem/Reaction/standardize.py b/synutility/SynChem/Reaction/standardize.py index e9ec452..2a48090 100644 --- a/synutility/SynChem/Reaction/standardize.py +++ b/synutility/SynChem/Reaction/standardize.py @@ -124,4 +124,5 @@ def fit( rsmi = self.remove_atom_mapping(rsmi) rsmi = self.standardize_rsmi(rsmi, not ignore_stereo) + rsmi = rsmi.replace("[HH]", "[H][H]") return rsmi diff --git a/synutility/SynGraph/Descriptor/__init__.py b/synutility/SynGraph/Descriptor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/graph_descriptors.py b/synutility/SynGraph/Descriptor/graph_descriptors.py similarity index 99% rename from synutility/SynGraph/graph_descriptors.py rename to synutility/SynGraph/Descriptor/graph_descriptors.py index 5e0a871..f34cab0 100644 --- a/synutility/SynGraph/graph_descriptors.py +++ b/synutility/SynGraph/Descriptor/graph_descriptors.py @@ -7,8 +7,8 @@ class GraphDescriptor: - def __init__(self, graph: nx.Graph): - self.graph = graph + def __init__(self) -> None: + pass @staticmethod def is_graph_empty(graph: Union[nx.Graph, dict, list, Any]) -> bool: diff --git a/synutility/SynGraph/Descriptor/graph_signature.py b/synutility/SynGraph/Descriptor/graph_signature.py new file mode 100644 index 0000000..68a0173 --- /dev/null +++ b/synutility/SynGraph/Descriptor/graph_signature.py @@ -0,0 +1,128 @@ +import networkx as nx +from collections import Counter +from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor + + +class GraphSignature: + """ + Provides methods to generate canonical signatures for graph nodes, edges, and complete graphs, + useful for comparisons or identification in graph-based data structures. + + Attributes: + graph (nx.Graph): The graph for which signatures will be generated. + """ + + def __init__(self, graph: nx.Graph): + """ + Initializes the GraphSignature class with a specified graph. + + Parameters: + - graph (nx.Graph): A NetworkX graph instance. + """ + self.graph = graph + + def create_node_signature(self, condensed: bool = True) -> str: + """ + Generates a canonical node signature. If `condensed` is True, it condenses + consecutive occurrences of elements, formatting like 'Br{1}C{10}'. + + Parameters: + - condensed (bool): If True, condenses elements with counts. If False, keeps the original format. + + Returns: + - str: A concatenated string of sorted node elements, optionally with counts. + """ + # Sort elements + elements = sorted(data["element"] for node, data in self.graph.nodes(data=True)) + + if condensed: + # Count occurrences and format with counts + element_counts = Counter(elements) + signature_parts = [] + for element, count in element_counts.items(): + if count > 1: + signature_parts.append(f"{element}{{{count}}}") + else: + signature_parts.append(element) + return "".join(signature_parts) + else: + # Return the original, uncompressed format + return "".join(elements) + + def create_edge_signature(self) -> str: + """ + Generates a canonical edge signature by formatting each edge with sorted node elements and a bond order, + separated by '/', with each edge represented as 'node1[standard_order]node2'. + + Returns: + - str: A concatenated and sorted string of edge representations. + """ + edge_signature_parts = [] + for u, v, data in self.graph.edges(data=True): + standard_order = int( + data.get("standard_order", 1) + ) # Default to 1 if missing + node1, node2 = sorted( + [self.graph.nodes[u]["element"], self.graph.nodes[v]["element"]] + ) + part = f"{node1}[{standard_order}]{node2}" + edge_signature_parts.append(part) + return "/".join(sorted(edge_signature_parts)) + + def create_topology_signature(self) -> str: + """ + Generates a topology signature for the graph based on its cyclic properties and structure. + The topology is classified and quantified by identifying cycles and other structural features. + + Returns: + - str: A string representing the numerical and qualitative topology signature of the graph. + """ + des = GraphDescriptor() + topo = des.check_graph_type(self.graph) + cycle = des.get_cycle_member_rings(self.graph) + + topo_mapping = { + "Acyclic": 0, + "Single Cyclic": 1, + "Combinatorial Cyclic": 2, + "Complex Cyclic": 3, + } + + topo_code = topo_mapping.get(topo, 4) + + if topo_code == 0: + cycle = [0] # Represent acyclic graph with no cycles + elif topo_code == 3: + cycle = [0] + cycle # Add complexity prefix for complex cyclic graphs + + rstep = len(cycle) + cycle_str = "".join(map(str, cycle)) + return f"{rstep}{topo_code}{cycle_str}" + + def create_graph_signature( + self, + condensed: bool = True, + topology: bool = True, + nodes: bool = True, + edges: bool = True, + ) -> str: + """ + Combines node, edge, and topology signatures into a single comprehensive graph signature. + + Returns: + - str: A concatenated string representing the complete graph signature formatted as + 'topology_signature.node_signature.edge_signature'. + """ + if topology: + topo_signature = self.create_topology_signature() + else: + topo_signature = "" + if nodes: + node_signature = self.create_node_signature(condensed) + else: + node_signature = "" + if edges: + edge_signature = self.create_edge_signature() + else: + edge_signature = "" + return f"{topo_signature}.{node_signature}.{edge_signature}" diff --git a/synutility/SynGraph/Fingerprint/__init__.py b/synutility/SynGraph/Fingerprint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynGraph/Fingerprint/graph_fps.py b/synutility/SynGraph/Fingerprint/graph_fps.py new file mode 100644 index 0000000..58b1214 --- /dev/null +++ b/synutility/SynGraph/Fingerprint/graph_fps.py @@ -0,0 +1,97 @@ +import networkx as nx +import hashlib +import numpy as np + + +class GraphFP: + def __init__( + self, graph: nx.Graph, nBits: int = 1024, hash_alg: str = "sha256" + ) -> None: + """ + Initialize the GraphFP class to create binary fingerprints based on various graph + characteristics. + + Parameters: + - graph (nx.Graph): Graph on which to perform analysis. + - nBits (int): Size of the binary fingerprint in bits. + - hash_alg (str): Cryptographic hash function used for hashing. + """ + self.graph = graph + self.nBits = nBits + self.hash_alg = hash_alg + self.hash_function = getattr(hashlib, self.hash_alg) + + def fingerprint(self, method: str) -> str: + """ + Generate a binary string fingerprint of the graph using the specified method. + + Parameters: + - method (str): The method to use for fingerprinting + ('spectrum', 'adjacency', 'degree', 'motif') + + Returns: + - str: A binary string of length `nBits` that represents the fingerprint of + the graph. + """ + if method == "spectrum": + fp = self._spectrum_fp() + elif method == "adjacency": + fp = self._adjacency_fp() + elif method == "degree": + fp = self._degree_sequence_fp() + elif method == "motif": + fp = self._motif_count_fp() + else: + raise ValueError("Unsupported fingerprinting method.") + + # If the fingerprint is shorter than nBits, use iterative deepening + if len(fp) < self.nBits: + fp += self.iterative_deepening(self.nBits - len(fp)) + + return fp[: self.nBits] + + def _spectrum_fp(self) -> str: + # Graph spectrum (eigenvalues of the adjacency matrix) + eigenvalues = np.linalg.eigvals(nx.adjacency_matrix(self.graph).todense()) + sorted_eigenvalues = np.sort(eigenvalues)[: self.nBits] + eigen_str = "".join( + bin(int(abs(eig)))[2:].zfill(8) for eig in sorted_eigenvalues + ) + return eigen_str[: self.nBits] + + def _adjacency_fp(self) -> str: + # Adjacency matrix flattened + adj_matrix = nx.adjacency_matrix(self.graph).todense().flatten() + adj_str = "".join(str(int(x)) for x in adj_matrix) + return adj_str[: self.nBits] + + def _degree_sequence_fp(self) -> str: + # Degree sequence + degrees = sorted([d for n, d in self.graph.degree()], reverse=True) + degree_str = "".join(bin(d)[2:].zfill(8) for d in degrees) + return degree_str[: self.nBits] + + def _motif_count_fp(self) -> str: + # Motif counts (e.g., number of triangles) + triangles = sum(nx.triangles(self.graph).values()) // 3 + triangle_str = bin(triangles)[2:].zfill(self.nBits) + return triangle_str[: self.nBits] + + def iterative_deepening(self, remaining_bits: int) -> str: + """ + Extend the hash length using iterative hashing until the desired bit length is + achieved. + + Parameters: + - remaining_bits (int): Number of bits needed to complete the fingerprint + to `nBits`. + + Returns: + - str: Additional binary data to achieve the desired hash length. + """ + additional_data = "" + hash_obj = self.hash_function() + while len(additional_data) * 4 < remaining_bits: + hash_obj.update(additional_data.encode()) + additional_data += hash_obj.hexdigest() + return bin(int(additional_data, 16))[2:][:remaining_bits] diff --git a/synutility/SynGraph/Fingerprint/hash_fps.py b/synutility/SynGraph/Fingerprint/hash_fps.py new file mode 100644 index 0000000..3febabb --- /dev/null +++ b/synutility/SynGraph/Fingerprint/hash_fps.py @@ -0,0 +1,130 @@ +import networkx as nx +import hashlib +from typing import Optional, Any + + +class HashFPs: + def __init__( + self, graph: nx.Graph, numBits: int = 256, hash_alg: str = "sha256" + ) -> None: + """ + Initialize the HashFPs class with a graph and configuration settings. + + Parameters: + - graph (nx.Graph): The graph to be fingerprinted. + - numBits (int): Number of bits in the output binary hash. Default is 256 bits. + - hash_alg (str): The hash algorithm to use, such as 'sha256' or 'sha512'. + + Raises: + - ValueError: If `numBits` is non-positive or if `hash_alg` is not supported + by hashlib. + """ + self.graph = graph + self.numBits = numBits + self.hash_alg = hash_alg + self.validate_parameters() + + def validate_parameters(self) -> None: + """Validate the initial parameters for errors.""" + if self.numBits <= 0: + raise ValueError("Number of bits must be positive") + if not hasattr(hashlib, self.hash_alg): + raise ValueError(f"Unsupported hash algorithm: {self.hash_alg}") + + def hash_fps( + self, + start_node: Optional[int] = None, + end_node: Optional[int] = None, + max_path_length: Optional[int] = None, + ) -> str: + """ + Generate a binary hash fingerprint of the graph based on its paths and cycles. + + Parameters: + - start_node (Optional[int]): The starting node index for path detection. + - end_node (Optional[int]): The ending node index for path detection. + - max_path_length (Optional[int]): The maximum length for paths to be considered. + + Returns: + - str: A binary string representing the truncated hash of the graph's structural + features. + """ + hash_object = self.initialize_hash() + features = self.extract_features(start_node, end_node, max_path_length) + full_hash_binary = self.finalize_hash(hash_object, features) + return full_hash_binary + + def initialize_hash(self) -> Any: + """Initialize and return the hash object based on the specified algorithm.""" + return getattr(hashlib, self.hash_alg)() + + def extract_features( + self, + start_node: Optional[int], + end_node: Optional[int], + max_path_length: Optional[int], + ) -> str: + """ + Extract features from the graph based on paths and cycles. + + Parameters: + - start_node (Optional[int]): The starting node for path detection. + - end_node (Optional[int]): The ending node for path detection. + - max_path_length (Optional[int]): Cutoff for path length during detection. + + Returns: + - str: A string of concatenated feature values. + """ + cycles = list(nx.simple_cycles(self.graph)) + paths = [] + if start_node is not None and end_node is not None: + paths = list( + nx.all_simple_paths( + self.graph, + source=start_node, + target=end_node, + cutoff=max_path_length, + ) + ) + features = [len(c) for c in cycles] + [len(p) for p in paths] + return "".join(map(str, features)) + + def finalize_hash(self, hash_object: Any, features: str) -> str: + """ + Finalize the hash using the features extracted and return the hash as a binary + string. + + Parameters: + - hash_object (Any): The hash object. + - features (str): Concatenated string of graph features. + + Returns: + - str: The final binary string of the hash, truncated or extended to `numBits`. + """ + hash_object.update(features.encode()) + full_hash_binary = bin(int(hash_object.hexdigest(), 16))[2:] + if len(full_hash_binary) < self.numBits: + full_hash_binary += self.iterative_deepening( + hash_object, self.numBits - len(full_hash_binary) + ) + return full_hash_binary[: self.numBits] + + def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: + """ + Extend hash length using iterative hashing until the desired bit length is + achieved. + + Parameters: + - hash_object (hashlib._Hash): The hash object for iterative deepening. + - remaining_bits (int): Number of bits needed to reach `numBits`. + + Returns: + - str: Additional binary data to achieve the desired hash length. + """ + additional_data = "" + while ( + len(additional_data) * 4 < remaining_bits + ): # Each hex digit represents 4 bits + hash_object.update(additional_data.encode()) + additional_data += hash_object.hexdigest() + return bin(int(additional_data, 16))[2:][:remaining_bits] diff --git a/synutility/SynGraph/Fingerprint/morgan_fps.py b/synutility/SynGraph/Fingerprint/morgan_fps.py new file mode 100644 index 0000000..4bef1fa --- /dev/null +++ b/synutility/SynGraph/Fingerprint/morgan_fps.py @@ -0,0 +1,87 @@ +import networkx as nx +import hashlib +from typing import Any + + +class MorganFPs: + def __init__( + self, + graph: nx.Graph, + radius: int = 3, + nBits: int = 1024, + hash_alg: str = "sha256", + ): + """ + Initialize the MorganFPs class to generate fingerprints based on the Morgan + algorithm, approximating Extended Connectivity Fingerprints (ECFPs). + + Parameters: + - graph (nx.Graph): The graph to analyze. + - radius (int): The radius to consider for node neighborhood analysis. + - nBits (int): Total number of bits in the final fingerprint output. + - hash_alg (str): Hash algorithm to use for generating hashes of node + neighborhoods. + """ + self.graph = graph + self.radius = radius + self.nBits = nBits + self.hash_alg = hash_alg + self.hash_function = getattr(hashlib, self.hash_alg) + + def generate_fingerprint(self) -> str: + """ + Generate a binary string fingerprint of the graph based on the local environments + of nodes. Ensures the output is exactly `nBits` in length using iterative + deepening if necessary. + + Returns: + - str: A binary string of length `nBits` representing the fingerprint of the + graph. + """ + fingerprint = "" + for node in self.graph.nodes(): + neighborhood = nx.single_source_shortest_path_length( + self.graph, node, cutoff=self.radius + ) + neighborhood_str = "-".join( + [ + f"{nbr}-{dist}" + for nbr, dist in sorted(neighborhood.items()) + if nbr != node + ] + ) + hash_obj = self.hash_function(neighborhood_str.encode()) + node_hash = bin(int(hash_obj.hexdigest(), 16))[2:].zfill( + hash_obj.digest_size * 8 + ) + if len(fingerprint) + len(node_hash) > self.nBits: + needed_bits = self.nBits - len(fingerprint) + node_hash = node_hash[:needed_bits] + fingerprint += node_hash + if len(fingerprint) == self.nBits: + return fingerprint + + if len(fingerprint) < self.nBits: + fingerprint += self.iterative_deepening( + hash_obj, self.nBits - len(fingerprint) + ) + return fingerprint + + def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: + """ + Extend the hash length using iterative hashing until the desired bit length is + achieved. + + Parameters: + - hash_object (hashlib._Hash): The hash object used for iterative deepening. + - remaining_bits (int): Number of bits needed to complete the fingerprint to + `nBits`. + + Returns: + - str: Additional binary data to achieve the desired hash length. + """ + additional_data = "" + while len(additional_data) * 4 < remaining_bits: + hash_object.update(additional_data.encode()) + additional_data += hash_object.hexdigest() + return bin(int(additional_data, 16))[2:][:remaining_bits] diff --git a/synutility/SynGraph/Fingerprint/path_fps.py b/synutility/SynGraph/Fingerprint/path_fps.py new file mode 100644 index 0000000..b5e4e92 --- /dev/null +++ b/synutility/SynGraph/Fingerprint/path_fps.py @@ -0,0 +1,82 @@ +import networkx as nx +import hashlib +from typing import Any + + +class PathFPs: + def __init__( + self, + graph: nx.Graph, + max_length: int = 10, + nBits: int = 1024, + hash_alg: str = "sha256", + ) -> None: + """ + Initialize the PathFPs class to create a binary fingerprint based on paths in a + graph. + + Parameters: + - graph (nx.Graph): Graph on which to perform analysis. + - max_length (int): Limit on path lengths considered in the fingerprint. + - nBits (int): Size of the binary fingerprint in bits. + - hash_alg (str): Cryptographic hash function used for path hashing. + - hash_function (Callable): Hash function initialized from hashlib. + """ + self.graph = graph + self.max_length = max_length + self.nBits = nBits + self.hash_alg = hash_alg + self.hash_function = getattr(hashlib, self.hash_alg) + + def generate_fingerprint(self) -> str: + """ + Generate a binary string fingerprint of the graph by hashing paths up to a certain + length and combining them. + + Returns: + - str: A binary string of length `nBits` that represents the fingerprint of the + graph. + """ + fingerprint = "" + for node in self.graph.nodes(): + for target in self.graph.nodes(): + if node != target: + for path in nx.all_simple_paths( + self.graph, source=node, target=target, cutoff=self.max_length + ): + path_str = "-".join(map(str, path)) + hash_obj = self.hash_function(path_str.encode()) + path_hash = bin(int(hash_obj.hexdigest(), 16))[2:].zfill( + hash_obj.digest_size * 8 + ) + if len(fingerprint) + len(path_hash) > self.nBits: + needed_bits = self.nBits - len(fingerprint) + path_hash = path_hash[:needed_bits] + fingerprint += path_hash + if len(fingerprint) == self.nBits: + return fingerprint + + if len(fingerprint) < self.nBits: + fingerprint += self.iterative_deepening( + hash_obj, self.nBits - len(fingerprint) + ) + return fingerprint + + def iterative_deepening(self, hash_object: Any, remaining_bits: int) -> str: + """ + Extend the hash length using iterative hashing until the desired bit length is + achieved. + + Parameters: + - hash_object (hashlib._Hash): The hash object used for iterative deepening. + - remaining_bits (int): Number of bits needed to complete the fingerprint + to `nBits`. + + Returns: + - str: Additional binary data to achieve the desired hash length. + """ + additional_data = "" + while len(additional_data) * 4 < remaining_bits: + hash_object.update(additional_data.encode()) + additional_data += hash_object.hexdigest() + return bin(int(additional_data, 16))[2:][:remaining_bits] diff --git a/synutility/SynIO/Format/__init__.py b/synutility/SynIO/Format/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynIO/Format/dg_to_gml.py b/synutility/SynIO/Format/dg_to_gml.py new file mode 100644 index 0000000..b2d079d --- /dev/null +++ b/synutility/SynIO/Format/dg_to_gml.py @@ -0,0 +1,122 @@ +import regex +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynIO.debug import setup_logging +from mod import DGVertexMapper, smiles, Rule + +logger = setup_logging() + + +class DGToGML: + def __init__(self) -> None: + self.standardizer = Standardize() + pass + + @staticmethod + def getReactionSmiles(dg): + origSmiles = {} + for v in dg.vertices: + s = v.graph.smilesWithIds + s = regex.sub(":([0-9]+)]", ":o\\1]", s) + origSmiles[v.graph] = s + + res = {} + for e in dg.edges: + vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) + eductSmiles = [origSmiles[g] for g in vms.left] + + for ev in vms.left.vertices: + s = eductSmiles[ev.graphIndex] + s = s.replace(f":o{ev.vertex.id}]", f":{ev.id}]") + eductSmiles[ev.graphIndex] = s + + strs = set() + for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): + productSmiles = [origSmiles[g] for g in vms.right] + for ev in vms.left.vertices: + pv = vm.map[ev] + if not pv: + continue + s = productSmiles[pv.graphIndex] + s = s.replace(f":o{pv.vertex.id}]", f":{ev.id}]") + productSmiles[pv.graphIndex] = s + count = vms.left.numVertices + for pv in vms.right.vertices: + ev = vm.map.inverse(pv) + if ev: + continue + s = productSmiles[pv.graphIndex] + s = s.replace(f":o{pv.vertex.id}]", f":{count}]") + count += 1 + productSmiles[pv.graphIndex] = s + left = ".".join(eductSmiles) + right = ".".join(productSmiles) + s = f"{left}>>{right}" + assert ":o" not in s + strs.add(s) + res[e] = list(sorted(strs)) + return res + + @staticmethod + def parseReactionSmiles(line: str) -> Rule: + sLeft, sRight = line.split(">>") + ssLeft = sLeft.split(".") + ssRight = sRight.split(".") + mLeft = [smiles(s, add=False) for s in ssLeft] + mRight = [smiles(s, add=False) for s in ssRight] + + def printGraph(g): + extFromInt = {} + for iExt in range(g.minExternalId, g.maxExternalId + 1): + v = g.getVertexFromExternalId(iExt) + if not v.isNull(): + extFromInt[v] = iExt + s = "" + for v in g.vertices: + assert v in extFromInt + s += '\t\tnode [ id %d label "%s" ]\n' % (extFromInt[v], v.stringLabel) + for e in g.edges: + s += '\t\tedge [ source %d target %d label "%s" ]\n' % ( + extFromInt[e.source], + extFromInt[e.target], + e.stringLabel, + ) + return s + + s = "rule [\n\tleft [\n" + for m in mLeft: + s += printGraph(m) + s += "\t]\n\tright [\n" + for m in mRight: + s += printGraph(m) + s += "\t]\n]\n" + return s, Rule.fromGMLString(s, add=False) + + def fit(self, dg, origSmiles): + """ + Matches the original SMILES to a list of generated reaction SMILES and + returns the parsed reaction. + + Parameters: + - dg (DataGenerator): The data generator instance containing the reactions. + - origSmiles (str): The original SMILES string to match. + + Returns: + - Parsed reaction if a match is found; otherwise, None. + """ + try: + res = DGToGML.getReactionSmiles(dg) + smiles_list = [value for values in res.values() for value in values] + + smiles_standard = [ + self.standardizer.fit(rsmi, True, True) for rsmi in smiles_list + ] + origSmiles_standard = self.standardizer.fit(origSmiles, True, True) + + for index, value in enumerate(smiles_standard): + if value == origSmiles_standard: + return self.parseReactionSmiles(smiles_list[index]) + + return None + except Exception as e: + logger.error(f"An error occurred: {e}") + return None diff --git a/synutility/SynIO/Format/gml_to_nx.py b/synutility/SynIO/Format/gml_to_nx.py new file mode 100644 index 0000000..7e6d4f9 --- /dev/null +++ b/synutility/SynIO/Format/gml_to_nx.py @@ -0,0 +1,144 @@ +import networkx as nx +import re +from typing import Tuple +from synutility.SynIO.Format.its_construction import ITSConstruction + + +class GMLToNX: + def __init__(self, gml_text: str): + """ + Initializes a GMLToNX object that can parse GML-like text into separate + NetworkX graphs representing different stages or components of + a chemical reaction. + + Parameters: + - gml_text (str): The GML-like text content that will be parsed into graphs. + """ + self.gml_text = gml_text + self.graphs = {"left": nx.Graph(), "context": nx.Graph(), "right": nx.Graph()} + + def _parse_element(self, line: str, current_section: str): + """ + Private method to parse a line from the GML text, extracting nodes and edges + to populate the specified graph section. + + Parameters: + - line (str): The line of text from the GML data. + - current_section (str): The key of the graph section + ('left', 'context', 'right') being populated. + """ + label_to_order = {"-": 1, ":": 1.5, "=": 2, "#": 3} + tokens = line.split() + if "node" in line: + node_id = int(tokens[tokens.index("id") + 1]) + label = tokens[tokens.index("label") + 1].strip('"') + element, charge = self._extract_element_and_charge(label) + node_attributes = { + "element": element, + "charge": charge, + "atom_map": node_id, + } + self.graphs[current_section].add_node(node_id, **node_attributes) + elif "edge" in line: + source = int(tokens[tokens.index("source") + 1]) + target = int(tokens[tokens.index("target") + 1]) + label = tokens[tokens.index("label") + 1].strip('"') + order = label_to_order.get(label, 0) + self.graphs[current_section].add_edge(source, target, order=float(order)) + + def _synchronize_nodes(self): + """ + Private method to ensure that all nodes present in the 'context' graph are also + present in the 'left' and 'right' graphs. + """ + context_nodes = self.graphs["context"].nodes(data=True) + for graph_key in ["left", "right"]: + for node, data in context_nodes: + self.graphs[graph_key].add_node(node, **data) + + def _extract_element_and_charge(self, label: str) -> Tuple[str, int]: + """ + Extracts the chemical element and its charge from a node label. This function is + designed to handle labels formatted in several ways, including just an element + symbol ("C"), an element with a charge and sign ("Na+"), or an element with a + multi-digit charge and a sign ("Mg2+"). The function uses regular expressions + to parse the label accurately and extract the needed information. + + Parameters: + - label (str): The label from which to extract information. Expected to be in + one of the following formats: + - "Element" (e.g., "C") + - "Element[charge][sign]" (e.g., "Na+", "Mg2+") + - "Element[charge][sign]" where charge is optional and if present, must be + followed by a sign (e.g., "Al3+", "K+"). The charge is not assumed if no sign + is present. + + Returns: + - Tuple[str, int]: A tuple where the first element is the chemical element + symbol (str) extracted from the label, and the second element is an integer + representing the charge. The charge defaults to 0 if no charge information + is present in the label. + + Raises: + - ValueError: If the label does not conform to the expected formats, + which should not happen if labels are pre-validated. + + Note: + - The function assumes that the input label is well-formed according to the + described patterns. Labels without any recognizable pattern will default to + returning "X" as the element with a charge of 0, though this behavior + is conservative and primarily for error handling. + """ + # Regex to separate the element symbols from the optional charge and sign + match = re.match(r"([A-Za-z]+)(\d+)?([+-])?$", label) + if not match: + return ( + "X", + 0, + ) # Default case if regex fails to match, unlikely but safe to handle + + element = match.group(1) + charge_number = match.group(2) + charge_sign = match.group(3) + + if charge_number and charge_sign: + # If there's a number and a sign, combine them to form the charge + charge = int(charge_number) * (1 if charge_sign == "+" else -1) + elif charge_sign: + # If there is no number but there's a sign, it means the charge is 1 or -1 + charge = 1 if charge_sign == "+" else -1 + else: + # If no charge information is provided, assume a charge of 0 + charge = 0 + + return element, charge + + def transform(self) -> Tuple[nx.Graph, nx.Graph, nx.Graph]: + """ + Transforms the GML-like text into three distinct NetworkX graphs, each + representing different aspects of the reaction: 'left' for reactants, + 'context' for ITS, and 'right' for products. + + Returns: + - Tuple[nx.Graph, nx.Graph, nx.Graph]: A tuple containing the graphs for + reactants, products, and ITS. + """ + current_section = None + lines = self.gml_text.split("\n") + + for line in lines: + line = line.strip() + if line.startswith("rule") or line == "]": + continue + if any(section in line for section in ["left", "context", "right"]): + current_section = line.split("[")[0].strip() + continue + if line.startswith("node") or line.startswith("edge"): + self._parse_element(line, current_section) + + self._synchronize_nodes() + self.graphs["context"] = ITSConstruction.ITSGraph( + self.graphs["left"], self.graphs["right"] + ) + + return (self.graphs["left"], self.graphs["right"], self.graphs["context"]) diff --git a/synutility/SynIO/Format/graph_to_mol.py b/synutility/SynIO/Format/graph_to_mol.py new file mode 100644 index 0000000..0cbbddb --- /dev/null +++ b/synutility/SynIO/Format/graph_to_mol.py @@ -0,0 +1,87 @@ +from rdkit import Chem +import networkx as nx +from typing import Dict + + +class GraphToMol: + """ + Converts a NetworkX graph representation of a molecule back into an RDKit molecule + object, taking into account specific node and edge attributes for the construction of + the molecule. An option to ignore bond orders can be specified, + treating all bonds as single bonds. + """ + + def __init__( + self, node_attributes: Dict[str, str], edge_attributes: Dict[str, str] + ): + """ + Initializes the GraphToMol converter with mappings for node and edge attributes + to their corresponding RDKit atom and bond properties. + """ + self.node_attributes = node_attributes + self.edge_attributes = edge_attributes + + def graph_to_mol( + self, graph: nx.Graph, ignore_bond_order: bool = False + ) -> Chem.Mol: + """ + Converts a NetworkX graph into an RDKit molecule object by interpreting node and + edge attributes according to the provided mappings. Optionally ignores bond + orders, treating all bonds as single bonds. + + Parameters: + - graph (nx.Graph): The NetworkX graph representation of the molecule to be + converted. + - ignore_bond_order (bool): If True, all bonds are treated as single bonds + regardless of their specified order in the graph. + + Returns: + - Chem.Mol: An RDKit molecule object constructed based on the graph + representation. + """ + mol = Chem.RWMol() + + # Map for tracking RDKit atom indices corresponding to NetworkX nodes + node_to_idx: Dict[int, int] = {} + + # Add atoms to the molecule based on node attributes + for node, data in graph.nodes(data=True): + element = data.get( + self.node_attributes["element"], "C" + ) # Defaults to Carbon + charge = data.get(self.node_attributes["charge"], 0) + atom = Chem.Atom(element) + atom.SetFormalCharge(charge) + if "atom_class" in data: # Set atom map number if available + atom.SetAtomMapNum(data["atom_class"]) + idx = mol.AddAtom(atom) + node_to_idx[node] = idx + + # Add bonds between atoms based on edge attributes + for u, v, data in graph.edges(data=True): + if ignore_bond_order: + bond_order = 1 # Treat all bonds as single bonds + else: + bond_order = abs( + data.get(self.edge_attributes["order"], 1) + ) # Use absolute value of bond order + bond_type = self.get_bond_type_from_order(bond_order) + mol.AddBond(node_to_idx[u], node_to_idx[v], bond_type) + + # Attempt to sanitize the molecule to ensure its chemical validity + Chem.SanitizeMol(mol) + + return mol + + @staticmethod + def get_bond_type_from_order(order: float) -> Chem.BondType: + """ + Converts a numerical bond order into the corresponding RDKit BondType enum. + """ + if order == 1: + return Chem.BondType.SINGLE + elif order == 2: + return Chem.BondType.DOUBLE + elif order == 3: + return Chem.BondType.TRIPLE + return Chem.BondType.AROMATIC diff --git a/synutility/SynIO/Format/isomorphism.py b/synutility/SynIO/Format/isomorphism.py new file mode 100644 index 0000000..fa3ab10 --- /dev/null +++ b/synutility/SynIO/Format/isomorphism.py @@ -0,0 +1,48 @@ +import networkx as nx +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match +from operator import eq +from typing import List, Any + + +def isomorphism_check( + graph_i: nx.Graph, + graph_j: nx.Graph, + node_label_names: List[str] = ["element", "charge"], + node_label_default: List[Any] = ["*", 0], + edge_attribute: str = "order", + edge_default: Any = 1, +) -> bool: + """ + Checks if two graphs are isomorphic based on specified node and edge attributes. + + Parameters: + - graph_i (nx.Graph): First graph to be compared. + - graph_j (nx.Graph): Second graph to be compared. + - node_label_names (List[str]): List of node attribute names to consider for isomorphism. + - node_label_default (List[Any]): Defaults for node attributes if not present. + - edge_attribute (str): Edge attribute name to consider for isomorphism. + - edge_default (Any): Default value for the edge attribute if not present. + + Returns: + - bool: True if the graphs are isomorphic considering specified attributes, else False. + """ + + # Prepare the operator list for node matching + node_label_operators = [eq for _ in node_label_names] + + # Create node and edge match functions using generic matchers + node_match = generic_node_match( + node_label_names, # attribute names + node_label_default, # default values for each attribute + node_label_operators, # operators for each attribute comparison + ) + edge_match = generic_edge_match( + edge_attribute, # The attribute name to compare for edges + edge_default, # Default value if the attribute is missing + eq, # Operator for comparing edge attributes + ) + + # Use the isomorphic check with node and edge match functions + return nx.is_isomorphic( + graph_i, graph_j, node_match=node_match, edge_match=edge_match + ) diff --git a/synutility/SynIO/Format/its_construction.py b/synutility/SynIO/Format/its_construction.py new file mode 100644 index 0000000..8176ca2 --- /dev/null +++ b/synutility/SynIO/Format/its_construction.py @@ -0,0 +1,213 @@ +import networkx as nx +from typing import Tuple, Dict, Any +from copy import deepcopy + + +class ITSConstruction: + @staticmethod + def ITSGraph( + G: nx.Graph, + H: nx.Graph, + ignore_aromaticity: bool = False, + attributes_defaults: Dict[str, Any] = None, + balance_its: bool = True, + ) -> nx.Graph: + """ + Creates a Combined Graph Representation (CGR) from two input graphs G and H. + + This function merges the nodes of G and H, preserving their attributes. Edges are + added based on their presence in G and/or H, with special labeling for edges + unique to one graph. + + Parameters: + - G (nx.Graph): The first input graph. + - H (nx.Graph): The second input graph. + - ignore_aromaticity (bool): Whether to ignore aromaticity in the graphs. + Defaults to False. + - attributes_defaults (Dict[str, Any]): A dictionary of default attributes + to use for nodes that are not present in either G or H. + + Returns: + - nx.Graph: The Combined Graph Representation as a new graph instance. + """ + # Create a null graph from a copy of G to preserve attributes + if (balance_its and len(G.nodes()) <= len(H.nodes())) or ( + not balance_its and len(G.nodes()) >= len(H.nodes()) + ): + ITS = deepcopy(G) + else: + ITS = deepcopy(H) + + ITS.remove_edges_from(list(ITS.edges())) + + # Initialize a dictionary to hold node types + typesDict = dict() + + # Add typeG and typeH attributes, or default attributes for "*" unknown elements + for v in list(ITS.nodes()): + # Check if v is in both G and H + if v not in G.nodes() or v not in H.nodes(): + continue + else: + typesG = ITSConstruction.get_node_attributes_with_defaults( + G, v, attributes_defaults + ) # node attribute in reactant graph + typesH = ITSConstruction.get_node_attributes_with_defaults( + H, v, attributes_defaults + ) # node attribute in product graph + typesDict[v] = (typesG, typesH) + + nx.set_node_attributes(ITS, typesDict, "typesGH") + + # Add edges from G and H + ITS = ITSConstruction.add_edges_to_ITS(ITS, G, H, ignore_aromaticity) + + return ITS + + @staticmethod + def get_node_attribute(graph: nx.Graph, node: int, attribute: str, default): + """ + Retrieves a specific attribute for a node in a graph, returning a default value if + the attribute is missing. + + Parameters: + - graph (nx.Graph): The graph from which to retrieve the node attribute. + - node (int): The node identifier. + - attribute (str): The attribute to retrieve. + - default: The default value to return if the attribute is missing. + + Returns: + - The value of the node attribute, or the default value if the attribute is + missing. + """ + try: + return graph.nodes[node][attribute] + except KeyError: + return default + + @staticmethod + def get_node_attributes_with_defaults( + graph: nx.Graph, node: int, attributes_defaults: Dict[str, Any] = None + ) -> Tuple: + """ + Retrieves node attributes from a graph, assigning default values if they are + missing. Allows for an optional dictionary of attribute-default value pairs to + specify custom attributes and defaults. + + Parameters: + - graph (nx.Graph): The graph from which to retrieve node attributes. + - node (int): The node identifier. + - attributes_defaults (Dict[str, Any], optional): A dictionary specifying + attributes and their default values. + + Returns: + - Tuple: A tuple containing the node attributes in the order specified by + attributes_defaults. + """ + if attributes_defaults is None: + attributes_defaults = { + "element": "*", + "aromatic": False, + "hcount": 0, + "charge": 0, + "neighbors": ["", ""], + } + + return tuple( + ITSConstruction.get_node_attribute(graph, node, attr, default) + for attr, default in attributes_defaults.items() + ) + + @staticmethod + def add_edges_to_ITS( + ITS: nx.Graph, G: nx.Graph, H: nx.Graph, ignore_aromaticity: bool = False + ) -> nx.Graph: + """ + Adds edges to the Combined Graph Representation (ITS) based on the edges of G and + H, and returns a new graph without modifying the original ITS. + + Parameters: + - ITS (nx.Graph): The initial combined graph representation. + - G (nx.Graph): The first input graph. + - H (nx.Graph): The second input graph. + - ignore_aromaticity (bool): Whether to ignore aromaticity in the graphs. Defaults + to False. + + Returns: + - nx.Graph: The updated graph with added edges. + """ + new_ITS = deepcopy(ITS) + + # Add edges from G and H + for graph_from, graph_to, reverse in [(G, H, False), (H, G, True)]: + for u, v in graph_from.edges(): + if not new_ITS.has_edge(u, v): + if graph_to.has_edge(u, v) or graph_to.has_edge(v, u): + edge_label = ( + (graph_from[u][v]["order"], graph_to[u][v]["order"]) + if graph_to.has_edge(u, v) + else ( + (graph_from[v][u]["order"], graph_to[v][u]["order"]) + if reverse + else ( + graph_from[u][v]["order"], + graph_to[v][u]["order"], + ) + ) + ) + new_ITS.add_edge(u, v, order=edge_label) + else: + edge_label = ( + (graph_from[u][v]["order"], 0) + if not reverse + else (0, graph_from[u][v]["order"]) + ) + new_ITS.add_edge(u, v, order=edge_label) + nodes_to_remove = [node for node in new_ITS.nodes() if not new_ITS.nodes[node]] + new_ITS.remove_nodes_from(nodes_to_remove) + new_ITS = ITSConstruction.add_standard_order_attribute( + new_ITS, ignore_aromaticity + ) + return new_ITS + + @staticmethod + def add_standard_order_attribute( + graph: nx.Graph, ignore_aromaticity: bool = False + ) -> nx.Graph: + """ + Adds a 'standard_order' attribute to each edge in the provided NetworkX graph. + This attribute is calculated based on the existing 'order' attribute, which should + be a tuple associated with each edge. The 'standard_order' is computed by + subtracting the second element of the 'order' tuple from the first element. + If any element of the 'order' tuple is not an integer (e.g., '*'), it is treated + as 0 for the purpose of this computation. + + Parameters: + - graph (NetworkX.Graph): A NetworkX graph where each edge has an 'order' + attribute formatted as a tuple. + + Returns: + - NetworkX.Graph: The same graph passed as input, now with a 'standard_order' + attribute added to each edge, reflecting the computed standard order derived from + the 'order' attribute. + """ + + new_graph = graph.copy() + + for u, v, data in new_graph.edges(data=True): + if "order" in data and isinstance(data["order"], tuple): + # Extract order values, replacing non-ints with 0 + first_order = data["order"][0] + second_order = data["order"][1] + # Compute standard order + standard_order = first_order - second_order + if ignore_aromaticity: + if abs(standard_order) < 1: # to ignore aromaticity + standard_order = 0 + # Update the edge data with a new attribute 'standard_order' + new_graph[u][v]["standard_order"] = standard_order + else: + # If 'order' attribute is missing or not a tuple, 'standard_order' to 0 + new_graph[u][v]["standard_order"] = 0 + + return new_graph diff --git a/synutility/SynIO/Format/mol_to_graph.py b/synutility/SynIO/Format/mol_to_graph.py new file mode 100644 index 0000000..e76acf9 --- /dev/null +++ b/synutility/SynIO/Format/mol_to_graph.py @@ -0,0 +1,167 @@ +from rdkit import Chem +from rdkit.Chem import AllChem +import networkx as nx +from typing import Dict +import random + + +class MolToGraph: + """ + A class for converting molecules from SMILES strings to graph representations using + RDKit and NetworkX. + """ + + def __init__(self): + """Initialize the MolToGraphConverter class.""" + pass + + @staticmethod + def add_partial_charges(mol: Chem.Mol) -> None: + """ + Computes and assigns Gasteiger partial charges to each atom in the given molecule. + + Parameters: + - mol (Chem.Mol): An RDKit molecule object. + """ + AllChem.ComputeGasteigerCharges(mol) + + @staticmethod + def get_stereochemistry(atom: Chem.Atom) -> str: + """ + Determines the stereochemistry (R/S configuration) of a given atom. + + Parameters: + - atom (Chem.Atom): An RDKit atom object. + + Returns: + - str: The stereochemistry ('R', 'S', or 'N' for non-chiral). + """ + chiral_tag = atom.GetChiralTag() + if chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CCW: + return "S" + elif chiral_tag == Chem.ChiralType.CHI_TETRAHEDRAL_CW: + return "R" + return "N" + + @staticmethod + def get_bond_stereochemistry(bond: Chem.Bond) -> str: + """ + Determines the stereochemistry (E/Z configuration) of a given bond. + + Parameters: + - bond (Chem.Bond): An RDKit bond object. + + Returns: + - str: The stereochemistry ('E', 'Z', or 'N' for non-stereospecific + or non-double bonds). + """ + if bond.GetBondType() != Chem.BondType.DOUBLE: + return "N" + stereo = bond.GetStereo() + if stereo == Chem.BondStereo.STEREOE: + return "E" + elif stereo == Chem.BondStereo.STEREOZ: + return "Z" + return "N" + + @staticmethod + def has_atom_mapping(mol: Chem.Mol) -> bool: + """ + Check if the given molecule has any atom mapping numbers. + + Atom mapping numbers are used in chemical reactions to track the correspondence + between atoms in reactants and products. + + Parameters: + - mol (Chem.Mol): An RDKit molecule object. + + Returns: + - bool: True if any atom in the molecule has a mapping number, False otherwise. + """ + for atom in mol.GetAtoms(): + if atom.HasProp("molAtomMapNumber"): + return True + return False + + @staticmethod + def random_atom_mapping(mol: Chem.Mol) -> Chem.Mol: + """ + Assigns a random atom mapping number to each atom in the given molecule. + + This method iterates over all atoms in the molecule and assigns a random + mapping number between 1 and the total number of atoms to each atom. + + Parameters: + - mol (Chem.Mol): An RDKit molecule object. + + Returns: + - Chem.Mol: The RDKit molecule object with random atom mapping numbers assigned. + """ + atom_indices = list(range(1, mol.GetNumAtoms() + 1)) + random.shuffle(atom_indices) + for atom, idx in zip(mol.GetAtoms(), atom_indices): + atom.SetProp("molAtomMapNumber", str(idx)) + return mol + + @classmethod + def mol_to_graph(cls, mol: Chem.Mol, drop_non_aam: bool = False) -> nx.Graph: + """ + Converts an RDKit molecule object to a NetworkX graph with specified atom and bond + attributes. + Optionally excludes atoms without atom mapping numbers if drop_non_aam is True. + + Parameters: + - mol (Chem.Mol): An RDKit molecule object. + - drop_non_aam (bool, optional): If True, nodes without atom mapping numbers will + be dropped. + + Returns: + - nx.Graph: A NetworkX graph representing the molecule. + """ + cls.add_partial_charges(mol) + graph = nx.Graph() + index_to_class: Dict[int, int] = {} + if cls.has_atom_mapping(mol) is False: + mol = cls.random_atom_mapping(mol) + + for atom in mol.GetAtoms(): + atom_map = atom.GetAtomMapNum() + + if drop_non_aam and atom_map == 0: + continue + gasteiger_charge = round(float(atom.GetProp("_GasteigerCharge")), 3) + index_to_class[atom.GetIdx()] = atom_map + graph.add_node( + atom_map, + charge=atom.GetFormalCharge(), + hcount=atom.GetTotalNumHs(), + aromatic=atom.GetIsAromatic(), + element=atom.GetSymbol(), + atom_map=atom_map, + isomer=cls.get_stereochemistry(atom), + partial_charge=gasteiger_charge, + hybridization=str(atom.GetHybridization()), + in_ring=atom.IsInRing(), + explicit_valence=atom.GetExplicitValence(), + implicit_hcount=atom.GetNumImplicitHs(), + neighbors=sorted( + [neighbor.GetSymbol() for neighbor in atom.GetNeighbors()] + ), + ) + + for bond in mol.GetBonds(): + begin_atom_class = index_to_class.get(bond.GetBeginAtomIdx()) + end_atom_class = index_to_class.get(bond.GetEndAtomIdx()) + if begin_atom_class is None or end_atom_class is None: + continue + graph.add_edge( + begin_atom_class, + end_atom_class, + order=bond.GetBondTypeAsDouble(), + ez_isomer=cls.get_bond_stereochemistry(bond), + bond_type=str(bond.GetBondType()), + conjugated=bond.GetIsConjugated(), + in_ring=bond.IsInRing(), + ) + + return graph diff --git a/synutility/SynIO/Format/nx_to_gml.py b/synutility/SynIO/Format/nx_to_gml.py new file mode 100644 index 0000000..c257da1 --- /dev/null +++ b/synutility/SynIO/Format/nx_to_gml.py @@ -0,0 +1,168 @@ +import networkx as nx +from typing import Tuple, Dict, List + + +class NXToGML: + + def __init__(self) -> None: + pass + + @staticmethod + def _charge_to_string(charge): + """ + Converts an integer charge into a string representation. + + Parameters: + - charge (int): The charge value, which can be positive, negative, or zero. + + Returns: + - str: The string representation of the charge. + """ + if charge > 0: + return ( + "+" if charge == 1 else f"{charge}+" + ) # '+' for +1, '2+', '3+', etc., for higher values + elif charge < 0: + return ( + "-" if charge == -1 else f"{-charge}-" + ) # '-' for -1, '2-', '3-', etc., for lower values + else: + return "" # No charge symbol for neutral atoms + + @staticmethod + def _find_changed_nodes( + graph1: nx.Graph, graph2: nx.Graph, attributes: list = ["charge"] + ) -> list: + """ + Identifies nodes with changes in specified attributes between two NetworkX graphs. + + Parameters: + - graph1 (nx.Graph): The first NetworkX graph. + - graph2 (nx.Graph): The second NetworkX graph. + - attributes (list): A list of attribute names to check for changes. + + Returns: + - list: Node identifiers that have changes in the specified attributes. + """ + changed_nodes = [] + + # Iterate through nodes in the first graph + for node in graph1.nodes(): + # Ensure the node exists in both graphs + if node in graph2: + # Check each specified attribute for changes + for attr in attributes: + value1 = graph1.nodes[node].get(attr, None) + value2 = graph2.nodes[node].get(attr, None) + + if value1 != value2: + changed_nodes.append(node) + break + + return changed_nodes + + @staticmethod + def _convert_graph_to_gml( + graph: nx.Graph, section: str, changed_node_ids: List + ) -> str: + """ + Convert a NetworkX graph to a GML string representation, focusing on nodes for the + 'context' section and on nodes and edges for the 'left' or 'right' sections. + + Parameters: + - graph (nx.Graph): The NetworkX graph to be converted. + - section (str): The section name in the GML output, typically "left", "right", or + "context". + - changed_node_ids (List): list of nodes change attribute + + Returns: + str: The GML string representation of the graph for the specified section. + """ + order_to_label = {1: "-", 1.5: ":", 2: "=", 3: "#"} + gml_str = f" {section} [\n" + + if section == "context": + for node in graph.nodes(data=True): + if node[0] not in changed_node_ids: + element = node[1].get("element", "X") + charge = node[1].get("charge", 0) + charge_str = NXToGML._charge_to_string(charge) + gml_str += ( + f' node [ id {node[0]} label "{element}{charge_str}" ]\n' + ) + + if section != "context": + for edge in graph.edges(data=True): + label = order_to_label.get(edge[2].get("order", 1), "-") + gml_str += f' edge [ source {edge[0]} target {edge[1]} label "{label}" ]\n' + for node in graph.nodes(data=True): + if node[0] in changed_node_ids: + element = node[1].get("element", "X") + charge = node[1].get("charge", 0) + charge_str = NXToGML._charge_to_string(charge) + gml_str += ( + f' node [ id {node[0]} label "{element}{charge_str}" ]\n' + ) + + gml_str += " ]\n" + return gml_str + + @staticmethod + def _rule_grammar( + L: nx.Graph, R: nx.Graph, K: nx.Graph, rule_name: str, changed_node_ids: List + ) -> str: + """ + Generate a GML string representation for a chemical rule, including its left, + context, and right graphs. + + Parameters: + - L (nx.Graph): The left graph. + - R (nx.Graph): The right graph. + - K (nx.Graph): The context graph. + - rule_name (str): The name of the rule. + + Returns: + - str: The GML string representation of the rule. + """ + gml_str = "rule [\n" + gml_str += f' ruleID "{rule_name}"\n' + gml_str += NXToGML._convert_graph_to_gml(L, "left", changed_node_ids) + gml_str += NXToGML._convert_graph_to_gml(K, "context", changed_node_ids) + gml_str += NXToGML._convert_graph_to_gml(R, "right", changed_node_ids) + gml_str += "]" + return gml_str + + @staticmethod + def transform( + graph_rules: Tuple[nx.Graph, nx.Graph, nx.Graph], + rule_name: str = "Test", + reindex: bool = False, + attributes: List[str] = ["charge"], + ) -> Dict[str, str]: + """ + Process a dictionary of graph rules to generate GML strings for each rule, with an + option to reindex nodes and edges. + + Parameters: + - graph_rules (Dict[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): A dictionary + mapping rule names to tuples of (L, R, K) graphs. + - reindex (bool): If true, reindex node IDs based on the L graph sequence. + + Returns: + - Dict[str, str]: A dictionary mapping rule names to their GML string + representations. + """ + L, R, K = graph_rules + if reindex: + # Create an index mapping from L graph + index_mapping = { + old_id: new_id for new_id, old_id in enumerate(L.nodes(), 1) + } + + # Apply the mapping to L, R, and K graphs + L = nx.relabel_nodes(L, index_mapping) + R = nx.relabel_nodes(R, index_mapping) + K = nx.relabel_nodes(K, index_mapping) + changed_node_ids = NXToGML._find_changed_nodes(L, R, attributes) + rule_grammar = NXToGML._rule_grammar(L, R, K, rule_name, changed_node_ids) + return rule_grammar diff --git a/synutility/SynMOD/__init__.py b/synutility/SynMOD/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/synutility/SynMOD/rule_apply.py b/synutility/SynMOD/rule_apply.py new file mode 100644 index 0000000..233faaa --- /dev/null +++ b/synutility/SynMOD/rule_apply.py @@ -0,0 +1,58 @@ +import os +from synutility.SynIO.debug import setup_logging +from mod import smiles, ruleGMLString, DG, config + +logger = setup_logging() + + +def deduplicateGraphs(initial): + """ + Removes duplicate graphs from a list based on graph isomorphism. + + Parameters: + - initial (list): List of graph objects. + + Returns: + - List of unique graph objects. + """ + unique_graphs = [] + for candidate in initial: + # Check if candidate is isomorphic to any graph already in unique_graphs + if not any(candidate.isomorphism(existing) != 0 for existing in unique_graphs): + unique_graphs.append(candidate) + return unique_graphs + + +def rule_apply(smiles_list, rule, print_output=False): + """ + Applies a reaction rule to a list of SMILES and optionally prints the output. + + Parameters: + - smiles_list (list): List of SMILES strings. + - rule (str): Reaction rule in GML string format. + - print_output (bool): If True, output will be printed to a directory. + + Returns: + - dg (DG): The derivation graph after applying the rule. + """ + try: + initial_molecules = [smiles(smile, add=False) for smile in smiles_list] + initial_molecules = deduplicateGraphs(initial_molecules) + initial_molecules = sorted( + initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False + ) + + reaction_rule = ruleGMLString(rule) + + dg = DG(graphDatabase=initial_molecules) + config.dg.doRuleIsomorphismDuringBinding = False + dg.build().apply(initial_molecules, reaction_rule, verbosity=8) + + # Optionally print the output + if print_output: + os.makedirs("out", exist_ok=True) + dg.print() + + return dg + except Exception as e: + logger.error(f"An error occurred: {e}")