Skip to content

Commit

Permalink
refractor src
Browse files Browse the repository at this point in the history
  • Loading branch information
TieuLongPhan committed Dec 6, 2024
1 parent bdb77e4 commit 1bf0cd4
Show file tree
Hide file tree
Showing 18 changed files with 1,000 additions and 545 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ Data/Temp/Benchmark/Complete/*
Data/Temp/Benchmark/Hier/*
Data/Temp/Benchmark/Raw/*
*.ipynb
*backup
bug.py
84 changes: 84 additions & 0 deletions Test/SynITS/test_hydrogen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import unittest
import networkx as nx
from synutility.SynIO.data_type import load_from_pickle
from syntemp.SynITS.hydrogen_utils import (
check_explicit_hydrogen,
check_hcount_change,
get_cycle_member_rings,
get_priority,
)


class TestGraphFunctions(unittest.TestCase):

def setUp(self):
# Create a test graph for the tests
self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz")

def test_check_explicit_hydrogen(self):
# Test the check_explicit_hydrogen function
# Note, usually only appear in reactants (+H2 reactions)
count_r, hydrogen_nodes_r = check_explicit_hydrogen(
self.data[20]["ITSGraph"][0]
)
self.assertEqual(count_r, 2)
self.assertEqual(hydrogen_nodes_r, [45, 46])

def test_check_hcount_change(self):
# Test the check_hcount_change function
max_change = check_hcount_change(
self.data[20]["ITSGraph"][0], self.data[20]["ITSGraph"][0]
)
self.assertEqual(max_change, 2)

def test_get_cycle_member_rings_minimal(self):
# Test get_cycle_member_rings with 'minimal' cycles
member_rings = get_cycle_member_rings(self.data[1]["GraphRules"][2], "minimal")
self.assertEqual(member_rings, [4]) # Cycles of size 4 and 3

def test_get_priority(self):
# Create a test graph for the tests
self.graph = nx.Graph()
self.graph.add_nodes_from(
[
(1, {"element": "H", "hcount": 2}),
(2, {"element": "C", "hcount": 1}),
(3, {"element": "H", "hcount": 1}),
]
)
self.graph.add_edges_from([(1, 2), (2, 3)])

# Create another graph for `check_hcount_change` tests
self.prod_graph = nx.Graph()
self.prod_graph.add_nodes_from(
[
(1, {"element": "H", "hcount": 1}),
(2, {"element": "C", "hcount": 1}),
(3, {"element": "H", "hcount": 2}),
]
)
self.prod_graph.add_edges_from([(1, 2), (2, 3)])

# Create a more complex graph for cycle tests
self.complex_graph = nx.Graph()
self.complex_graph.add_edges_from(
[
(1, 2),
(2, 3),
(3, 4),
(4, 1), # A simple square cycle
(3, 5),
(5, 6),
(6, 3), # Another cycle
]
)
reaction_centers = [self.graph, self.prod_graph, self.complex_graph]

# Get priority indices
priority_indices = get_priority(reaction_centers)

self.assertEqual(priority_indices, [0, 1])


if __name__ == "__main__":
unittest.main()
30 changes: 8 additions & 22 deletions Test/SynITS/test_its_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from syntemp.SynITS.its_extraction import ITSExtraction
from syntemp.SynITS.its_construction import ITSConstruction

from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph


class TestITSExtraction(unittest.TestCase):

Expand Down Expand Up @@ -29,31 +31,15 @@ def setUp(self):
]
self.mapper_names = ["local_mapper", "rxn_mapper", "graphormer"]

def test_graph_from_smiles(self):
graph = ITSExtraction.graph_from_smiles(self.smiles1)
self.assertEqual(len(graph.nodes()), 4)
self.assertEqual(len(graph.edges()), 3)

def test_check_equivariant_graph(self):
react_local_mapper, prod_local_mapper = self.mapped_smiles_list[0][
"local_mapper"
].split(">>")
G_local = ITSExtraction.graph_from_smiles(react_local_mapper)
H_local = ITSExtraction.graph_from_smiles(prod_local_mapper)
G_local, H_local = rsmi_to_graph(self.mapped_smiles_list[0]["local_mapper"])
ITS_local = ITSConstruction.ITSGraph(G_local, H_local)

react_rxn_mapper, prod_rxn_mapper = self.mapped_smiles_list[0][
"rxn_mapper"
].split(">>")
G_rxn = ITSExtraction.graph_from_smiles(react_rxn_mapper)
H_rxn = ITSExtraction.graph_from_smiles(prod_rxn_mapper)
G_rxn, H_rxn = rsmi_to_graph(self.mapped_smiles_list[0]["rxn_mapper"])
ITS_rxn = ITSConstruction.ITSGraph(G_rxn, H_rxn)

react_graphormer, prod_graphormer = self.mapped_smiles_list[0][
"graphormer"
].split(">>")
G_graphormer = ITSExtraction.graph_from_smiles(react_graphormer)
H_graphormer = ITSExtraction.graph_from_smiles(prod_graphormer)
G_graphormer, H_graphormer = rsmi_to_graph(
self.mapped_smiles_list[0]["graphormer"]
)
ITS_graphormer = ITSConstruction.ITSGraph(G_graphormer, H_graphormer)

classified, equivariant = ITSExtraction.check_equivariant_graph(
Expand Down Expand Up @@ -82,7 +68,7 @@ def test_parallel_process_smiles(self):
self.assertIsNotNone(results[0]["GraphRules"])

# Inequivalent AAM
self.assertEqual(results_wrong[0]["equivariant"], 0)
self.assertEqual(results_wrong[0]["equivariant"], -1) # -1 mean exit early

def test_unsanitize_smiles(self):
test_2 = {
Expand Down
161 changes: 96 additions & 65 deletions Test/SynITS/test_its_hadjuster.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,96 @@
# import unittest
# import networkx as nx
# from SynTemp.SynITS.its_hadjuster import ITSHAdjuster


# class TestITSHAdjuster(unittest.TestCase):

# def create_mock_graph(self, hcounts: dict) -> nx.Graph:
# """Utility function to create a mock graph with specified
# hydrogen counts for nodes."""
# graph = nx.Graph()
# for node_id, hcount in hcounts.items():
# graph.add_node(node_id, hcount=hcount)
# return graph

# def test_check_hcount_change(self):
# # Mock reactant and product graphs with specified hydrogen counts
# react_graph = self.create_mock_graph({1: 1, 2: 2})
# prod_graph = self.create_mock_graph({1: 0, 2: 3})

# # Expected: one hydrogen formation (node 1) and one hydrogen break (node 2)
# max_hydrogen_change = ITSHAdjuster.check_hcount_change(react_graph, prod_graph)
# self.assertEqual(max_hydrogen_change, 1)

# def test_add_hydrogen_nodes(self):
# # Mock reactant and product graphs with specified hydrogen counts
# react_graph = self.create_mock_graph({1: 1})
# prod_graph = self.create_mock_graph({1: 0})

# # Add hydrogen nodes to reactant and product graphs
# updated_react_graph, _ = ITSHAdjuster.add_hydrogen_nodes(
# react_graph, prod_graph
# )

# # Verify that hydrogen nodes have been added correctly
# self.assertIn(
# max(updated_react_graph.nodes), updated_react_graph.nodes
# ) # Hydrogen node added to reactant graph
# self.assertEqual(
# updated_react_graph.nodes[max(updated_react_graph.nodes)]["element"], "H"
# ) # Check element of added node

# def test_add_hydrogen_nodes_multiple(self):
# # Mock reactant and product graphs with specified hydrogen counts
# react_graph = self.create_mock_graph({1: 2, 2: 1})
# prod_graph = self.create_mock_graph({1: 0, 2: 2})

# # Generate updated graph pairs with multiple hydrogen nodes added
# updated_graph_pairs = ITSHAdjuster.add_hydrogen_nodes_multiple(
# react_graph, prod_graph
# )

# # Verify that multiple updated graph pairs are generated
# self.assertTrue(len(updated_graph_pairs) > 1) # Multiple permutations generated
# for react_graph, prod_graph in updated_graph_pairs:
# self.assertIn(
# max(react_graph.nodes), react_graph.nodes
# ) # Hydrogen node added to reactant graph
# self.assertIn(
# max(prod_graph.nodes), prod_graph.nodes
# ) # Hydrogen node added to product graph


# if __name__ == "__main__":
# unittest.main()
import unittest
import networkx as nx
from copy import deepcopy
from synutility.SynIO.data_type import load_from_pickle
from syntemp.SynITS.its_hadjuster import ITSHAdjuster


class TestITSHAdjuster(unittest.TestCase):

def setUp(self):
"""Setup before each test."""
# Create sample graphs
self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz")

def test_process_single_graph_data_success(self):
"""Test the process_single_graph_data method."""
processed_data = ITSHAdjuster.process_single_graph_data(
self.data[0], "ITSGraph"
)
for value in processed_data["ITSGraph"]:
self.assertTrue(isinstance(value, nx.Graph))
for value in processed_data["GraphRules"]:
self.assertTrue(isinstance(value, nx.Graph))

def test_process_single_graph_data_fail(self):
"""Test the process_single_graph_data method."""
processed_data = ITSHAdjuster.process_single_graph_data(
self.data[16], "ITSGraph"
)
self.assertIsNone(processed_data["ITSGraph"])
self.assertIsNone(processed_data["GraphRules"])

def test_process_single_graph_data_empty_graph(self):
"""Test that an empty graph results in empty ITSGraph and GraphRules."""
empty_graph_data = {
"ITSGraph": [None, None, None],
"GraphRules": [None, None, None],
}

processed_data = ITSHAdjuster.process_single_graph_data(
empty_graph_data, "ITSGraph"
)

# Ensure the result is None or empty as expected for an empty graph
self.assertIsNone(processed_data["ITSGraph"])
self.assertIsNone(processed_data["GraphRules"])

def test_process_single_graph_data_safe(self):
"""Test the process_single_graph_data method."""
processed_data = ITSHAdjuster.process_single_graph_data_safe(
self.data[0], "ITSGraph", job_timeout=0.0001
)
self.assertIsNone(processed_data["ITSGraph"])
self.assertIsNone(processed_data["GraphRules"])

def test_process_graph_data_parallel(self):
"""Test the process_graph_data_parallel method."""
result = ITSHAdjuster().process_graph_data_parallel(
self.data, "ITSGraph", n_jobs=1, verbose=0, get_priority_graph=True
)
result = [value for value in result if value["ITSGraph"]]
# Check if the result matches the input data structure
self.assertEqual(len(result), 48)

def test_process_graph_data_parallel_safe(self):
"""Test the process_graph_data_parallel method."""
result = ITSHAdjuster().process_graph_data_parallel(
self.data,
"ITSGraph",
n_jobs=1,
verbose=0,
get_priority_graph=True,
safe=True,
job_timeout=0.0001, # lower timeout will fail all process
)
result = [value for value in result if value["ITSGraph"]]
# Check if the result matches the input data structure
self.assertEqual(len(result), 0)

def test_process_multiple_hydrogens(self):
"""Test the process_multiple_hydrogens method."""
graphs = deepcopy(self.data[0])
react_graph, prod_graph, _ = graphs["ITSGraph"]

result = ITSHAdjuster.process_multiple_hydrogens(
graphs, react_graph, prod_graph, ignore_aromaticity=False, balance_its=True
)

for value in result["ITSGraph"]:
self.assertTrue(isinstance(value, nx.Graph))
for value in result["GraphRules"]:
self.assertTrue(isinstance(value, nx.Graph))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion Test/SynRule/test_rc_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_auto_cluster_wrong_isomorphism(self):
clusters, _ = self.clusterer.auto_cluster(
rc, signature, nodeMatch=None, edgeMatch=None
)
self.assertEqual(len(clusters), 36) # wrong value but almost correct
self.assertEqual(len(clusters), 37) # correct by some magic. No proof for this

def test_fit(self):
"""Test the fit method to ensure it correctly updates data entries with cluster indices."""
Expand Down
4 changes: 2 additions & 2 deletions Test/test_auto_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def setUp(self) -> None:
def test_temp_extract(self):
(rules, _, _, _, _, _) = self.auto.temp_extract(self.data, lib_path=None)
self.assertIn("ruleID", rules[0][0])
self.assertEqual(len(rules[0]), 10)
self.assertEqual(len(rules[0]), 9)

def test_temp_extract_lib(self):
print(f"{root_dir}/Data/Testcase/Compose/SingleRule")
(rules, _, _, _, _, _) = self.auto.temp_extract(
self.data, lib_path=f"{root_dir}/Data/Testcase/Compose/SingleRule"
) # 1 rules exist
self.assertIn("ruleID", rules[0][0])
self.assertEqual(len(rules[0]), 8)
self.assertEqual(len(rules[0]), 7)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"seaborn>=0.13.2",
"joblib>=1.3.2",
"synrbl>=0.0.25",
"synutility>=0.0.12"
"synutility>=0.0.13"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ networkx>=3.3
seaborn>=0.13.2
joblib>=1.3.2
synrbl>=0.0.25
synutility>=0.0.12
synutility>=0.0.13
Loading

0 comments on commit 1bf0cd4

Please sign in to comment.