-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bdb77e4
commit 1bf0cd4
Showing
18 changed files
with
1,000 additions
and
545 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ Data/Temp/Benchmark/Complete/* | |
Data/Temp/Benchmark/Hier/* | ||
Data/Temp/Benchmark/Raw/* | ||
*.ipynb | ||
*backup | ||
bug.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import unittest | ||
import networkx as nx | ||
from synutility.SynIO.data_type import load_from_pickle | ||
from syntemp.SynITS.hydrogen_utils import ( | ||
check_explicit_hydrogen, | ||
check_hcount_change, | ||
get_cycle_member_rings, | ||
get_priority, | ||
) | ||
|
||
|
||
class TestGraphFunctions(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Create a test graph for the tests | ||
self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz") | ||
|
||
def test_check_explicit_hydrogen(self): | ||
# Test the check_explicit_hydrogen function | ||
# Note, usually only appear in reactants (+H2 reactions) | ||
count_r, hydrogen_nodes_r = check_explicit_hydrogen( | ||
self.data[20]["ITSGraph"][0] | ||
) | ||
self.assertEqual(count_r, 2) | ||
self.assertEqual(hydrogen_nodes_r, [45, 46]) | ||
|
||
def test_check_hcount_change(self): | ||
# Test the check_hcount_change function | ||
max_change = check_hcount_change( | ||
self.data[20]["ITSGraph"][0], self.data[20]["ITSGraph"][0] | ||
) | ||
self.assertEqual(max_change, 2) | ||
|
||
def test_get_cycle_member_rings_minimal(self): | ||
# Test get_cycle_member_rings with 'minimal' cycles | ||
member_rings = get_cycle_member_rings(self.data[1]["GraphRules"][2], "minimal") | ||
self.assertEqual(member_rings, [4]) # Cycles of size 4 and 3 | ||
|
||
def test_get_priority(self): | ||
# Create a test graph for the tests | ||
self.graph = nx.Graph() | ||
self.graph.add_nodes_from( | ||
[ | ||
(1, {"element": "H", "hcount": 2}), | ||
(2, {"element": "C", "hcount": 1}), | ||
(3, {"element": "H", "hcount": 1}), | ||
] | ||
) | ||
self.graph.add_edges_from([(1, 2), (2, 3)]) | ||
|
||
# Create another graph for `check_hcount_change` tests | ||
self.prod_graph = nx.Graph() | ||
self.prod_graph.add_nodes_from( | ||
[ | ||
(1, {"element": "H", "hcount": 1}), | ||
(2, {"element": "C", "hcount": 1}), | ||
(3, {"element": "H", "hcount": 2}), | ||
] | ||
) | ||
self.prod_graph.add_edges_from([(1, 2), (2, 3)]) | ||
|
||
# Create a more complex graph for cycle tests | ||
self.complex_graph = nx.Graph() | ||
self.complex_graph.add_edges_from( | ||
[ | ||
(1, 2), | ||
(2, 3), | ||
(3, 4), | ||
(4, 1), # A simple square cycle | ||
(3, 5), | ||
(5, 6), | ||
(6, 3), # Another cycle | ||
] | ||
) | ||
reaction_centers = [self.graph, self.prod_graph, self.complex_graph] | ||
|
||
# Get priority indices | ||
priority_indices = get_priority(reaction_centers) | ||
|
||
self.assertEqual(priority_indices, [0, 1]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,96 @@ | ||
# import unittest | ||
# import networkx as nx | ||
# from SynTemp.SynITS.its_hadjuster import ITSHAdjuster | ||
|
||
|
||
# class TestITSHAdjuster(unittest.TestCase): | ||
|
||
# def create_mock_graph(self, hcounts: dict) -> nx.Graph: | ||
# """Utility function to create a mock graph with specified | ||
# hydrogen counts for nodes.""" | ||
# graph = nx.Graph() | ||
# for node_id, hcount in hcounts.items(): | ||
# graph.add_node(node_id, hcount=hcount) | ||
# return graph | ||
|
||
# def test_check_hcount_change(self): | ||
# # Mock reactant and product graphs with specified hydrogen counts | ||
# react_graph = self.create_mock_graph({1: 1, 2: 2}) | ||
# prod_graph = self.create_mock_graph({1: 0, 2: 3}) | ||
|
||
# # Expected: one hydrogen formation (node 1) and one hydrogen break (node 2) | ||
# max_hydrogen_change = ITSHAdjuster.check_hcount_change(react_graph, prod_graph) | ||
# self.assertEqual(max_hydrogen_change, 1) | ||
|
||
# def test_add_hydrogen_nodes(self): | ||
# # Mock reactant and product graphs with specified hydrogen counts | ||
# react_graph = self.create_mock_graph({1: 1}) | ||
# prod_graph = self.create_mock_graph({1: 0}) | ||
|
||
# # Add hydrogen nodes to reactant and product graphs | ||
# updated_react_graph, _ = ITSHAdjuster.add_hydrogen_nodes( | ||
# react_graph, prod_graph | ||
# ) | ||
|
||
# # Verify that hydrogen nodes have been added correctly | ||
# self.assertIn( | ||
# max(updated_react_graph.nodes), updated_react_graph.nodes | ||
# ) # Hydrogen node added to reactant graph | ||
# self.assertEqual( | ||
# updated_react_graph.nodes[max(updated_react_graph.nodes)]["element"], "H" | ||
# ) # Check element of added node | ||
|
||
# def test_add_hydrogen_nodes_multiple(self): | ||
# # Mock reactant and product graphs with specified hydrogen counts | ||
# react_graph = self.create_mock_graph({1: 2, 2: 1}) | ||
# prod_graph = self.create_mock_graph({1: 0, 2: 2}) | ||
|
||
# # Generate updated graph pairs with multiple hydrogen nodes added | ||
# updated_graph_pairs = ITSHAdjuster.add_hydrogen_nodes_multiple( | ||
# react_graph, prod_graph | ||
# ) | ||
|
||
# # Verify that multiple updated graph pairs are generated | ||
# self.assertTrue(len(updated_graph_pairs) > 1) # Multiple permutations generated | ||
# for react_graph, prod_graph in updated_graph_pairs: | ||
# self.assertIn( | ||
# max(react_graph.nodes), react_graph.nodes | ||
# ) # Hydrogen node added to reactant graph | ||
# self.assertIn( | ||
# max(prod_graph.nodes), prod_graph.nodes | ||
# ) # Hydrogen node added to product graph | ||
|
||
|
||
# if __name__ == "__main__": | ||
# unittest.main() | ||
import unittest | ||
import networkx as nx | ||
from copy import deepcopy | ||
from synutility.SynIO.data_type import load_from_pickle | ||
from syntemp.SynITS.its_hadjuster import ITSHAdjuster | ||
|
||
|
||
class TestITSHAdjuster(unittest.TestCase): | ||
|
||
def setUp(self): | ||
"""Setup before each test.""" | ||
# Create sample graphs | ||
self.data = load_from_pickle("./Data/Testcase/hydrogen_test.pkl.gz") | ||
|
||
def test_process_single_graph_data_success(self): | ||
"""Test the process_single_graph_data method.""" | ||
processed_data = ITSHAdjuster.process_single_graph_data( | ||
self.data[0], "ITSGraph" | ||
) | ||
for value in processed_data["ITSGraph"]: | ||
self.assertTrue(isinstance(value, nx.Graph)) | ||
for value in processed_data["GraphRules"]: | ||
self.assertTrue(isinstance(value, nx.Graph)) | ||
|
||
def test_process_single_graph_data_fail(self): | ||
"""Test the process_single_graph_data method.""" | ||
processed_data = ITSHAdjuster.process_single_graph_data( | ||
self.data[16], "ITSGraph" | ||
) | ||
self.assertIsNone(processed_data["ITSGraph"]) | ||
self.assertIsNone(processed_data["GraphRules"]) | ||
|
||
def test_process_single_graph_data_empty_graph(self): | ||
"""Test that an empty graph results in empty ITSGraph and GraphRules.""" | ||
empty_graph_data = { | ||
"ITSGraph": [None, None, None], | ||
"GraphRules": [None, None, None], | ||
} | ||
|
||
processed_data = ITSHAdjuster.process_single_graph_data( | ||
empty_graph_data, "ITSGraph" | ||
) | ||
|
||
# Ensure the result is None or empty as expected for an empty graph | ||
self.assertIsNone(processed_data["ITSGraph"]) | ||
self.assertIsNone(processed_data["GraphRules"]) | ||
|
||
def test_process_single_graph_data_safe(self): | ||
"""Test the process_single_graph_data method.""" | ||
processed_data = ITSHAdjuster.process_single_graph_data_safe( | ||
self.data[0], "ITSGraph", job_timeout=0.0001 | ||
) | ||
self.assertIsNone(processed_data["ITSGraph"]) | ||
self.assertIsNone(processed_data["GraphRules"]) | ||
|
||
def test_process_graph_data_parallel(self): | ||
"""Test the process_graph_data_parallel method.""" | ||
result = ITSHAdjuster().process_graph_data_parallel( | ||
self.data, "ITSGraph", n_jobs=1, verbose=0, get_priority_graph=True | ||
) | ||
result = [value for value in result if value["ITSGraph"]] | ||
# Check if the result matches the input data structure | ||
self.assertEqual(len(result), 48) | ||
|
||
def test_process_graph_data_parallel_safe(self): | ||
"""Test the process_graph_data_parallel method.""" | ||
result = ITSHAdjuster().process_graph_data_parallel( | ||
self.data, | ||
"ITSGraph", | ||
n_jobs=1, | ||
verbose=0, | ||
get_priority_graph=True, | ||
safe=True, | ||
job_timeout=0.0001, # lower timeout will fail all process | ||
) | ||
result = [value for value in result if value["ITSGraph"]] | ||
# Check if the result matches the input data structure | ||
self.assertEqual(len(result), 0) | ||
|
||
def test_process_multiple_hydrogens(self): | ||
"""Test the process_multiple_hydrogens method.""" | ||
graphs = deepcopy(self.data[0]) | ||
react_graph, prod_graph, _ = graphs["ITSGraph"] | ||
|
||
result = ITSHAdjuster.process_multiple_hydrogens( | ||
graphs, react_graph, prod_graph, ignore_aromaticity=False, balance_its=True | ||
) | ||
|
||
for value in result["ITSGraph"]: | ||
self.assertTrue(isinstance(value, nx.Graph)) | ||
for value in result["GraphRules"]: | ||
self.assertTrue(isinstance(value, nx.Graph)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,4 @@ networkx>=3.3 | |
seaborn>=0.13.2 | ||
joblib>=1.3.2 | ||
synrbl>=0.0.25 | ||
synutility>=0.0.12 | ||
synutility>=0.0.13 |
Oops, something went wrong.