-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update * test mod compatible * add graph visualizer * fix lint * add copy right for FGUtils * prepare release * add partial map expansion * add testcase for partial expansion ver1
- Loading branch information
1 parent
8c17232
commit cfb522a
Showing
28 changed files
with
1,542 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
*.ipynb | ||
*.json | ||
test_mod.py | ||
test_format.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import unittest | ||
import networkx as nx | ||
from synutility.SynAAM.normalize_aam import NormalizeAAM | ||
|
||
|
||
class TestNormalizeAAM(unittest.TestCase): | ||
def setUp(self): | ||
"""Set up for testing.""" | ||
self.normalizer = NormalizeAAM() | ||
|
||
def test_fix_atom_mapping(self): | ||
"""Test that atom mappings are incremented correctly.""" | ||
input_smiles = "[C:0]([H:1])([H:2])[H:3]" | ||
expected_smiles = "[C:1]([H:2])([H:3])[H:4]" | ||
self.assertEqual( | ||
self.normalizer.fix_atom_mapping(input_smiles), expected_smiles | ||
) | ||
|
||
def test_fix_rsmi(self): | ||
"""Test that RSMI atom mappings are incremented correctly | ||
for both reactants and products.""" | ||
input_rsmi = "[C:0]>>[C:1]" | ||
expected_rsmi = "[C:1]>>[C:2]" | ||
self.assertEqual(self.normalizer.fix_rsmi(input_rsmi), expected_rsmi) | ||
|
||
def test_extract_subgraph(self): | ||
"""Test extraction of a subgraph based on specified indices.""" | ||
g = nx.complete_graph(5) | ||
indices = [0, 1, 2] | ||
subgraph = self.normalizer.extract_subgraph(g, indices) | ||
self.assertEqual(len(subgraph.nodes()), 3) | ||
self.assertTrue(all(node in subgraph for node in indices)) | ||
|
||
def test_reset_indices_and_atom_map(self): | ||
"""Test resetting of indices and atom map in a subgraph.""" | ||
g = nx.path_graph(5) | ||
for i in range(5): | ||
g.nodes[i]["atom_map"] = i + 1 | ||
reset_graph = self.normalizer.reset_indices_and_atom_map(g) | ||
expected_atom_maps = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5} | ||
for node in reset_graph: | ||
self.assertEqual( | ||
reset_graph.nodes[node]["atom_map"], expected_atom_maps[node] | ||
) | ||
|
||
def test_reaction_smiles_processing(self): | ||
"""Test that the reaction SMILES string is processed to meet expected output.""" | ||
input_rsmi = ( | ||
"[C:2]([C:3]([H:9])([H:10])[H:11])([H:8])=[C:1]([C:0]([H:6])([H:5])" | ||
+ "[H:4])[H:7].[H:12][H:13]>>[C:3]([C:2]([C:1]([C:0]([H:6])([H:5])" | ||
+ "[H:4])([H:12])[H:7])([H:8])[H:13])([H:9])([H:10])[H:11]" | ||
) | ||
expected_output = ( | ||
"[CH3:1][CH:2]=[CH:3][CH3:4].[H:5][H:6]>>[CH3:1][CH:2]([CH:3]" | ||
+ "([CH3:4])[H:6])[H:5]" | ||
) | ||
result = self.normalizer.fit(input_rsmi) | ||
self.assertEqual(result, expected_output) | ||
|
||
|
||
# Run the unittest | ||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
from synutility.SynAAM.partial_expand import PartialExpand | ||
|
||
|
||
class TestPartialExpand(unittest.TestCase): | ||
def test_expand(self): | ||
""" | ||
Test the expand function of the PartialExpand class with a given RSMI. | ||
""" | ||
# Input RSMI | ||
input_rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" | ||
# Expected output | ||
expected_rsmi = ( | ||
"[CH2:1]=[CH:2][CH3:3].[H:4][H:5]>>[CH2:1]([CH:2]([CH3:3])[H:5])[H:4]" | ||
) | ||
# Perform the expansion | ||
output_rsmi = PartialExpand.expand(input_rsmi) | ||
# Assert the result matches the expected output | ||
self.assertEqual(output_rsmi, expected_rsmi) | ||
|
||
def test_expand_2(self): | ||
input_rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" | ||
output_rsmi = PartialExpand.expand(input_rsmi) | ||
expected_rsmi = ( | ||
"[CH3:1][CH2:2][CH2:3][Cl:4].[NH2:5][H:6]" | ||
+ ">>[CH3:1][CH2:2][CH2:3][NH2:5].[Cl:4][H:6]" | ||
) | ||
self.assertEqual(output_rsmi, expected_rsmi) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os | ||
import unittest | ||
import tempfile | ||
from synutility.SynGraph.Transform.core_engine import CoreEngine | ||
|
||
|
||
class TestCoreEngine(unittest.TestCase): | ||
def setUp(self): | ||
# Create a temporary directory | ||
self.temp_dir = tempfile.TemporaryDirectory() | ||
|
||
# Path for the rule file | ||
self.rule_file_path = os.path.join(self.temp_dir.name, "test_rule.gml") | ||
|
||
# Define rule content | ||
self.rule_content = """ | ||
rule [ | ||
ruleID "1" | ||
left [ | ||
edge [ source 1 target 2 label "=" ] | ||
edge [ source 3 target 4 label "-" ] | ||
] | ||
context [ | ||
node [ id 1 label "C" ] | ||
node [ id 2 label "C" ] | ||
node [ id 3 label "H" ] | ||
node [ id 4 label "H" ] | ||
] | ||
right [ | ||
edge [ source 1 target 2 label "-" ] | ||
edge [ source 1 target 3 label "-" ] | ||
edge [ source 2 target 4 label "-" ] | ||
] | ||
] | ||
""" | ||
|
||
# Write rule content to the temporary file | ||
with open(self.rule_file_path, "w") as rule_file: | ||
rule_file.write(self.rule_content) | ||
|
||
# Initialize SMILES strings for testing | ||
self.initial_smiles_fw = ["CC=CC", "[HH]"] | ||
self.initial_smiles_bw = ["CCCC"] | ||
|
||
def tearDown(self): | ||
# Clean up temporary directory | ||
self.temp_dir.cleanup() | ||
|
||
def test_perform_reaction_forward(self): | ||
# Test the perform_reaction method with forward reaction type | ||
result = CoreEngine.perform_reaction( | ||
rule_file_path=self.rule_file_path, | ||
initial_smiles=self.initial_smiles_fw, | ||
prediction_type="forward", | ||
print_results=False, | ||
verbosity=0, | ||
) | ||
print(result) | ||
# Check if result is a list of strings and has content | ||
self.assertIsInstance( | ||
result, list, "Expected a list of reaction SMILES strings." | ||
) | ||
self.assertTrue( | ||
len(result) > 0, "Result should contain reaction SMILES strings." | ||
) | ||
|
||
self.assertEqual(result[0], "CC=CC.[HH]>>CCCC") | ||
|
||
# Check if the result SMILES format matches expected output format | ||
for reaction_smiles in result: | ||
self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") | ||
parts = reaction_smiles.split(">>") | ||
self.assertEqual( | ||
parts[0], | ||
".".join(self.initial_smiles_fw), | ||
"Base SMILES are not correctly formatted.", | ||
) | ||
self.assertTrue(len(parts[1]) > 0, "Product SMILES should be non-empty.") | ||
|
||
def test_perform_reaction_backward(self): | ||
# Test the perform_reaction method with backward reaction type | ||
result = CoreEngine.perform_reaction( | ||
rule_file_path=self.rule_file_path, | ||
initial_smiles=self.initial_smiles_bw, | ||
prediction_type="backward", | ||
print_results=False, | ||
verbosity=0, | ||
) | ||
# Check if result is a list of strings and has content | ||
self.assertIsInstance( | ||
result, list, "Expected a list of reaction SMILES strings." | ||
) | ||
self.assertTrue( | ||
len(result) > 0, "Result should contain reaction SMILES strings." | ||
) | ||
self.assertEqual(result[0], "C=CCC.[H][H]>>CCCC") | ||
self.assertEqual(result[1], "[H][H].C(C)=CC>>CCCC") | ||
|
||
# Check if the result SMILES format matches expected output format | ||
for reaction_smiles in result: | ||
self.assertIn(">>", reaction_smiles, "Reaction SMILES format is incorrect.") | ||
parts = reaction_smiles.split(">>") | ||
self.assertTrue(len(parts[0]) > 0, "Product SMILES should be non-empty.") | ||
self.assertEqual( | ||
parts[1], | ||
".".join(self.initial_smiles_bw), | ||
"Base SMILES are not correctly formatted.", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build" | |
|
||
[project] | ||
name = "synutility" | ||
version = "0.0.10" | ||
version = "0.0.11" | ||
authors = [ | ||
{name="Tieu Long Phan", email="[email protected]"} | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
pytest Test/SynChem Test/SynAAM Test/SynGraph Test/SynIO Test/SynSplit Test/SynSplit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
scikit-learn>=1.4.0 | ||
xgboost==2.1.1 | ||
pandas==1.5.3 | ||
seaborn==0.13.2 | ||
drfp==0.3.6 | ||
fgutils>=0.1.3 | ||
rxn-chem-utils==1.5.0 | ||
rxn-utils==2.0.0 | ||
rxnmapper==0.3.0 | ||
rdkit >= 2024.3.3 | ||
rdkit >= 2024.3.3 | ||
pandas>=2.2.0 |
Empty file.
Oops, something went wrong.