Skip to content

Commit

Permalink
fix strip context, add graph clustering with batch
Browse files Browse the repository at this point in the history
  • Loading branch information
TieuLongPhan committed Dec 20, 2024
1 parent 413d37d commit 7fa09bd
Show file tree
Hide file tree
Showing 10 changed files with 761 additions and 12 deletions.
Binary file added Data/Testcase/graph.pkl.gz
Binary file not shown.
Empty file.
109 changes: 109 additions & 0 deletions Test/SynGraph/Cluster/test_batch_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import time
import unittest
from synutility.SynIO.data_type import load_from_pickle
from synutility.SynGraph.Descriptor.graph_signature import GraphSignature
from synutility.SynGraph.Cluster.batch_cluster import BatchCluster


class TestBatchCluster(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz")
cls.templates = None
for value in cls.graphs:
value["rc_sig"] = GraphSignature(value["RC"]).create_graph_signature()
value["its_sig"] = GraphSignature(value["ITS"]).create_graph_signature()

def test_initialization(self):
"""Test initialization and verify if the attributes are set correctly."""
cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order")
self.assertEqual(cluster.nodeLabelNames, ["element", "charge"])
self.assertEqual(cluster.nodeLabelDefault, ["*", 0])
self.assertEqual(cluster.edgeAttribute, "bond_order")

def test_initialization_failure(self):
"""Test initialization failure when lengths of node labels and defaults do not match."""
with self.assertRaises(ValueError):
BatchCluster(["element"], ["*", 0, 1], "bond_order")

def test_batch_dicts(self):
"""Test the batching function to split data correctly."""
batch_cluster = BatchCluster(["element", "charge"], ["*", 0], "bond_order")
input_list = [{"id": i} for i in range(10)]
batches = batch_cluster.batch_dicts(input_list, 3)
self.assertEqual(len(batches), 4)
self.assertEqual(len(batches[0]), 3)
self.assertEqual(len(batches[-1]), 1)

def test_lib_check_functionality(self):
"""Test the lib_check method using directly comparable results."""
cluster = BatchCluster()
batch_1 = self.graphs[:50]
batch_2 = self.graphs[50:]
_, templates = cluster.fit(batch_1, None, "RC", "rc_sig")
for entry in batch_2:
_, templates = cluster.lib_check(entry, templates, "RC", "rc_sig")
self.assertEqual(len(templates), 30)

def test_cluster_integration(self):
"""Test the cluster method to ensure it processes data entries correctly."""
cluster = BatchCluster()
expected_template_count = 30
_, updated_templates = cluster.cluster(self.graphs, [], "RC", "rc_sig")

self.assertEqual(
len(updated_templates),
expected_template_count,
f"Failed: expected {expected_template_count} templates, got {len(updated_templates)}",
)

def test_fit(self):
cluster = BatchCluster()
batch_sizes = [None, 10]
expected_template_count = 30

for batch_size in batch_sizes:
start_time = time.time()
_, updated_templates = cluster.fit(
self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size
)
elapsed_time = time.time() - start_time

self.assertEqual(
len(updated_templates),
expected_template_count,
f"Failed for batch_size={batch_size}: expected "
+ f"{expected_template_count} templates, got {len(updated_templates)}",
)
print(
f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds."
)

def test_fit_gml(self):
cluster = BatchCluster()
batch_sizes = [None, 10]
expected_template_count = (
30 # Assuming this is the expected number of templates after processing
)

for batch_size in batch_sizes:
start_time = time.time()
_, updated_templates = cluster.fit(
self.graphs, self.templates, "RC", "rc_sig", batch_size=batch_size
)
elapsed_time = time.time() - start_time

self.assertEqual(
len(updated_templates),
expected_template_count,
f"Failed for batch_size={batch_size}: expected"
+ f" {expected_template_count} templates, got {len(updated_templates)}",
)
print(
f"Test for batch_size={batch_size} completed in {elapsed_time:.2f} seconds."
)


# To run the tests
if __name__ == "__main__":
unittest.main()
138 changes: 138 additions & 0 deletions Test/SynGraph/Cluster/test_graph_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import time
import unittest
from synutility.SynIO.data_type import load_from_pickle
from synutility.SynGraph.Cluster.graph_cluster import GraphCluster
from synutility.SynGraph.Descriptor.graph_descriptors import GraphDescriptor


class TestRCCluster(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Load data once for all tests
cls.graphs = load_from_pickle("Data/Testcase/graph.pkl.gz")
for value in cls.graphs:
# value["RC"] = value["GraphRules"][2]
# value["ITS"] = value["ITSGraph"][2]
value = GraphDescriptor.get_descriptors(value)
cls.clusterer = GraphCluster()

def test_initialization(self):
"""Test the initialization and configuration of the RCCluster."""
self.assertIsInstance(self.clusterer.nodeLabelNames, list)
self.assertEqual(self.clusterer.edgeAttribute, "order")
self.assertEqual(
len(self.clusterer.nodeLabelNames), len(self.clusterer.nodeLabelDefault)
)

def test_auto_cluster(self):
"""Test the auto_cluster method functionality."""
rc = [value["RC"] for value in self.graphs]
cycles = [value["cycle"] for value in self.graphs]
signature = [value["signature_rc"] for value in self.graphs]
atom_count = [value["atom_count"] for value in self.graphs]
for att in [None, cycles, signature, atom_count]:
clusters, graph_to_cluster = self.clusterer.iterative_cluster(
rc,
att,
nodeMatch=self.clusterer.nodeMatch,
edgeMatch=self.clusterer.edgeMatch,
)
self.assertIsInstance(clusters, list)
self.assertIsInstance(graph_to_cluster, dict)
self.assertEqual(len(clusters), 30)

def test_auto_cluster_wrong_isomorphism(self):
rc = [value["RC"] for value in self.graphs]
cycles = [value["cycle"] for value in self.graphs]
signature = [value["signature_rc"] for value in self.graphs]
atom_count = [value["atom_count"] for value in self.graphs]

# cluster all
clusters, _ = self.clusterer.iterative_cluster(
rc, None, nodeMatch=None, edgeMatch=None
)
self.assertEqual(len(clusters), 8) # wrong value

# cluster with cycle
clusters, _ = self.clusterer.iterative_cluster(
rc, cycles, nodeMatch=None, edgeMatch=None
)
self.assertEqual(len(clusters), 8) # wrong value

# cluster with atom_count
clusters, _ = self.clusterer.iterative_cluster(
rc, atom_count, nodeMatch=None, edgeMatch=None
)
self.assertEqual(len(clusters), 27) # wrong value but almost correct

# cluster with signature
clusters, _ = self.clusterer.iterative_cluster(
rc, signature, nodeMatch=None, edgeMatch=None
)
self.assertEqual(len(clusters), 30) # correct by some magic. No proof for this

def test_fit(self):
"""Test the fit method to ensure it correctly updates data entries with cluster indices."""

clustered_data = self.clusterer.fit(
self.graphs, rule_key="RC", attribute_key="atom_count"
)
max_class = 0
for item in clustered_data:
print(item["class"])
max_class = item["class"] if item["class"] >= max_class else max_class
# print(max_class)
self.assertIn("class", item)
self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29

def test_fit_gml(self):
"""Test the fit method to ensure it correctly updates data entries with cluster indices."""

clustered_data = self.clusterer.fit(
self.graphs, rule_key="rc", attribute_key="atom_count"
)
max_class = 0
for item in clustered_data:
print(item["class"])
max_class = item["class"] if item["class"] >= max_class else max_class
# print(max_class)
self.assertIn("class", item)
self.assertEqual(max_class, 29) # 30 classes start from 0 so max is 29

def test_fit_time_compare(self):
attributes = {
"None": None,
"Cycles": "cycle",
"Signature": "signature_rc",
"Atom_count": "atom_count",
}

results = {}
for name, attr in attributes.items():
start_time = time.time()
clustered_data = self.clusterer.fit(
self.graphs, rule_key="RC", attribute_key=attr
)
elapsed_time = time.time() - start_time

# Optionally print out class information or verify correctness
max_class = max(item["class"] for item in clustered_data if "class" in item)

results[name] = elapsed_time

# Basic verification that 'class' is assigned and max class is as expected
self.assertTrue(all("class" in item for item in clustered_data))
self.assertEqual(
max_class, 29
) # Ensure the maximum class index is as expected

# Compare results to check which attribute took the least/most time
min_time_attr = min(results, key=results.get)
max_time_attr = max(results, key=results.get)
self.assertIn(min_time_attr, ["atom_count", "Signature"])
self.assertIn(max_time_attr, ["None", "Cycles"])


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion lint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

flake8 . --count --max-complexity=13 --max-line-length=120 \
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401, morphism.py:F401" \
--exclude venv,core_engine.py,rule_apply.py \
--statistics
Empty file.
Loading

0 comments on commit 7fa09bd

Please sign in to comment.