-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix strip context, add graph clustering with batch
- Loading branch information
1 parent
413d37d
commit 7fa09bd
Showing
10 changed files
with
761 additions
and
12 deletions.
There are no files selected for viewing
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import time | ||
import unittest | ||
from synutility.SynIO.data_type import load_from_pickle | ||
from synutility.SynGraph.Descriptor.graph_signature import GraphSignature | ||
from synutility.SynGraph.Cluster.batch_cluster import BatchCluster | ||
|
||
|
||
class TestBatchCluster(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz") | ||
cls.templates = None | ||
for value in cls.graphs: | ||
value["rc_sig"] = GraphSignature(value["RC"]).create_graph_signature() | ||
value["its_sig"] = GraphSignature(value["ITS"]).create_graph_signature() | ||
|
||
def test_initialization(self): | ||
"""Test initialization and verify if the attributes are set correctly.""" | ||
cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order") | ||
self.assertEqual(cluster.nodeLabelNames, ["element", "charge"]) | ||
self.assertEqual(cluster.nodeLabelDefault, ["*", 0]) | ||
self.assertEqual(cluster.edgeAttribute, "bond_order") | ||
|
||
def test_initialization_failure(self): | ||
"""Test initialization failure when lengths of node labels and defaults do not match.""" | ||
with self.assertRaises(ValueError): | ||
BatchCluster(["element"], ["*", 0, 1], "bond_order") | ||
|
||
def test_batch_dicts(self): | ||
"""Test the batching function to split data correctly.""" | ||
batch_cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order") | ||
input_list = [{"id": i} for i in range(10)] | ||
batches = batch_cluster.batch_dicts(input_list, 3) | ||
self.assertEqual(len(batches), 4) | ||
self.assertEqual(len(batches[0]), 3) | ||
self.assertEqual(len(batches[-1]), 1) | ||
|
||
def test_lib_check_functionality(self): | ||
"""Test the lib_check method using directly comparable results.""" | ||
cluster = BatchCluster() | ||
batch_1 = self.graphs[:50] | ||
batch_2 = self.graphs[50:] | ||
_, templates = cluster.fit(batch_1, None, "RC", "rc_sig") | ||
for entry in batch_2: | ||
_, templates = cluster.lib_check(entry, templates, "RC", "rc_sig") | ||
self.assertEqual(len(templates), 30) | ||
|
||
def test_cluster_integration(self): | ||
"""Test the cluster method to ensure it processes data entries correctly.""" | ||
cluster = BatchCluster() | ||
expected_template_count = 30 | ||
_, updated_templates = cluster.cluster(self.graphs, [], "RC", "rc_sig") | ||
|
||
self.assertEqual( | ||
len(updated_templates), | ||
expected_template_count, | ||
f"Failed: expected {expected_template_count} templates, got {len(updated_templates)}", | ||
) | ||
|
||
def test_fit(self): | ||
cluster = BatchCluster() | ||
batch_sizes = [None, 10] | ||
expected_template_count = 30 | ||
|
||
for batch_size in batch_sizes: | ||
start_time = time.time() | ||
_, updated_templates = cluster.fit( | ||
self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size | ||
) | ||
elapsed_time = time.time() - start_time | ||
|
||
self.assertEqual( | ||
len(updated_templates), | ||
expected_template_count, | ||
f"Failed for batch_size={batch_size}: expected " | ||
+ f"{expected_template_count} templates, got {len(updated_templates)}", | ||
) | ||
print( | ||
f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds." | ||
) | ||
|
||
def test_fit_gml(self): | ||
cluster = BatchCluster() | ||
batch_sizes = [None, 10] | ||
expected_template_count = ( | ||
30 # Assuming this is the expected number of templates after processing | ||
) | ||
|
||
for batch_size in batch_sizes: | ||
start_time = time.time() | ||
_, updated_templates = cluster.fit( | ||
self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size | ||
) | ||
elapsed_time = time.time() - start_time | ||
|
||
self.assertEqual( | ||
len(updated_templates), | ||
expected_template_count, | ||
f"Failed for batch_size={batch_size}: expected" | ||
+ f" {expected_template_count} templates, got {len(updated_templates)}", | ||
) | ||
print( | ||
f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds." | ||
) | ||
|
||
|
||
# To run the tests | ||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import time | ||
import unittest | ||
from synutility.SynIO.data_type import load_from_pickle | ||
from synutility.SynGraph.Cluster.graph_cluster import GraphCluster | ||
from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor | ||
|
||
|
||
class TestRCCluster(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# Load data once for all tests | ||
cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz") | ||
for value in cls.graphs: | ||
# value["RC"] = value["GraphRules"][2] | ||
# value["ITS"] = value["ITSGraph"][2] | ||
value = GraphDescriptor.get_descriptors(value) | ||
cls.clusterer = GraphCluster() | ||
|
||
def test_initialization(self): | ||
"""Test the initialization and configuration of the RCCluster.""" | ||
self.assertIsInstance(self.clusterer.nodeLabelNames, list) | ||
self.assertEqual(self.clusterer.edgeAttribute, "order") | ||
self.assertEqual( | ||
len(self.clusterer.nodeLabelNames), len(self.clusterer.nodeLabelDefault) | ||
) | ||
|
||
def test_auto_cluster(self): | ||
"""Test the auto_cluster method functionality.""" | ||
rc = [value["RC"] for value in self.graphs] | ||
cycles = [value["cycle"] for value in self.graphs] | ||
signature = [value["signature_rc"] for value in self.graphs] | ||
atom_count = [value["atom_count"] for value in self.graphs] | ||
for att in [None, cycles, signature, atom_count]: | ||
clusters, graph_to_cluster = self.clusterer.iterative_cluster( | ||
rc, | ||
att, | ||
nodeMatch=self.clusterer.nodeMatch, | ||
edgeMatch=self.clusterer.edgeMatch, | ||
) | ||
self.assertIsInstance(clusters, list) | ||
self.assertIsInstance(graph_to_cluster, dict) | ||
self.assertEqual(len(clusters), 30) | ||
|
||
def test_auto_cluster_wrong_isomorphism(self): | ||
rc = [value["RC"] for value in self.graphs] | ||
cycles = [value["cycle"] for value in self.graphs] | ||
signature = [value["signature_rc"] for value in self.graphs] | ||
atom_count = [value["atom_count"] for value in self.graphs] | ||
|
||
# cluster all | ||
clusters, _ = self.clusterer.iterative_cluster( | ||
rc, None, nodeMatch=None, edgeMatch=None | ||
) | ||
self.assertEqual(len(clusters), 8) # wrong value | ||
|
||
# cluster with cycle | ||
clusters, _ = self.clusterer.iterative_cluster( | ||
rc, cycles, nodeMatch=None, edgeMatch=None | ||
) | ||
self.assertEqual(len(clusters), 8) # wrong value | ||
|
||
# cluster with atom_count | ||
clusters, _ = self.clusterer.iterative_cluster( | ||
rc, atom_count, nodeMatch=None, edgeMatch=None | ||
) | ||
self.assertEqual(len(clusters), 27) # wrong value but almost correct | ||
|
||
# cluster with signature | ||
clusters, _ = self.clusterer.iterative_cluster( | ||
rc, signature, nodeMatch=None, edgeMatch=None | ||
) | ||
self.assertEqual(len(clusters), 30) # correct by some magic. No proof for this | ||
|
||
def test_fit(self): | ||
"""Test the fit method to ensure it correctly updates data entries with cluster indices.""" | ||
|
||
clustered_data = self.clusterer.fit( | ||
self.graphs, rule_key="RC", attribute_key="atom_count" | ||
) | ||
max_class = 0 | ||
for item in clustered_data: | ||
print(item["class"]) | ||
max_class = item["class"] if item["class"] >= max_class else max_class | ||
# print(max_class) | ||
self.assertIn("class", item) | ||
self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29 | ||
|
||
def test_fit_gml(self): | ||
"""Test the fit method to ensure it correctly updates data entries with cluster indices.""" | ||
|
||
clustered_data = self.clusterer.fit( | ||
self.graphs, rule_key="rc", attribute_key="atom_count" | ||
) | ||
max_class = 0 | ||
for item in clustered_data: | ||
print(item["class"]) | ||
max_class = item["class"] if item["class"] >= max_class else max_class | ||
# print(max_class) | ||
self.assertIn("class", item) | ||
self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29 | ||
|
||
def test_fit_time_compare(self): | ||
attributes = { | ||
"None": None, | ||
"Cycles": "cycle", | ||
"Signature": "signature_rc", | ||
"Atom_count": "atom_count", | ||
} | ||
|
||
results = {} | ||
for name, attr in attributes.items(): | ||
start_time = time.time() | ||
clustered_data = self.clusterer.fit( | ||
self.graphs, rule_key="RC", attribute_key=attr | ||
) | ||
elapsed_time = time.time() - start_time | ||
|
||
# Optionally print out class information or verify correctness | ||
max_class = max(item["class"] for item in clustered_data if "class" in item) | ||
|
||
results[name] = elapsed_time | ||
|
||
# Basic verification that 'class' is assigned and max class is as expected | ||
self.assertTrue(all("class" in item for item in clustered_data)) | ||
self.assertEqual( | ||
max_class, 29 | ||
) # Ensure the maximum class index is as expected | ||
|
||
# Compare results to check which attribute took the least/most time | ||
min_time_attr = min(results, key=results.get) | ||
max_time_attr = max(results, key=results.get) | ||
self.assertIn(min_time_attr, ["atom_count", "Signature"]) | ||
self.assertIn(max_time_attr, ["None", "Cycles"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
#!/bin/bash | ||
|
||
flake8 . --count --max-complexity=13 --max-line-length=120 \ | ||
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \ | ||
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401, morphism.py:F401" \ | ||
--exclude venv,core_engine.py,rule_apply.py \ | ||
--statistics |
Empty file.
Oops, something went wrong.