From 496265e3973ff48ba4dd80728967c40e6bcbbd1e Mon Sep 17 00:00:00 2001 From: Paul Rock Date: Thu, 14 Dec 2023 03:00:39 +0300 Subject: [PATCH] Tasts package for basic models implemented, some additional logic tested too --- tests/models/test_gat.py | 31 +++++++++++++++++ tests/models/test_gcn.py | 30 ++++++++++++++++ tests/models/test_graphsage.py | 30 ++++++++++++++++ tests/test_negative_samples.py | 31 +++++++++++++++++ tests/test_subgraphs.py | 63 ++++++++++++++++++++++++++++++++++ 5 files changed, 185 insertions(+) create mode 100644 tests/models/test_gat.py create mode 100644 tests/models/test_gcn.py create mode 100644 tests/models/test_graphsage.py create mode 100644 tests/test_negative_samples.py create mode 100644 tests/test_subgraphs.py diff --git a/tests/models/test_gat.py b/tests/models/test_gat.py new file mode 100644 index 0000000..feea0d8 --- /dev/null +++ b/tests/models/test_gat.py @@ -0,0 +1,31 @@ +import torch +import torch.nn.functional as F +from redkg.models.gat import GAT + + +def test_GAT(): + """ + Test that the GAT model can be instantiated and run. + """ + model = GAT( + in_channels=100, + hidden_channels=200, + out_channels=50, + num_layers=3, + activation=F.relu, + dropout_rate=0.5, + heads=2 + ) + + x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels] + + # edge_index: COO format graph adjacency matrix, shape [2, num_edges] + edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long) + + output = model(x, edge_index) + + # Check the output shape + assert output.shape == (16, 50) + + # Check the forward pass + assert not torch.isnan(output).any() diff --git a/tests/models/test_gcn.py b/tests/models/test_gcn.py new file mode 100644 index 0000000..699aa17 --- /dev/null +++ b/tests/models/test_gcn.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F +from redkg.models.gcn import GCN + + +def test_GCN(): + """ + Test that the GCN model can be instantiated and run. + """ + model = GCN( + in_channels=100, + hidden_channels=200, + out_channels=50, + num_layers=3, + activation=F.relu, + dropout_rate=0.5 + ) + + x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels] + + # edge_index: COO format graph adjacency matrix, shape [2, num_edges] + edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long) + + output = model(x, edge_index) + + # Check the output shape + assert output.shape == (16, 50) + + # Check the forward pass + assert not torch.isnan(output).any() diff --git a/tests/models/test_graphsage.py b/tests/models/test_graphsage.py new file mode 100644 index 0000000..69d7232 --- /dev/null +++ b/tests/models/test_graphsage.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F +from redkg.models.graphsage import GraphSAGE + + +def test_GraphSAGE(): + """ + Test that the GraphSAGE model can be instantiated and run. + """ + model = GraphSAGE( + in_channels=100, + hidden_channels=200, + out_channels=50, + num_layers=3, + activation=F.relu, + dropout_rate=0.5 + ) + + x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels] + + # edge_index: COO format graph adjacency matrix, shape [2, num_edges] + edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long) + + output = model(x, edge_index) + + # Check the output shape + assert output.shape == (16, 50) + + # Check the forward pass + assert not torch.isnan(output).any() diff --git a/tests/test_negative_samples.py b/tests/test_negative_samples.py new file mode 100644 index 0000000..aa869ab --- /dev/null +++ b/tests/test_negative_samples.py @@ -0,0 +1,31 @@ +from redkg.negative_samples import common_neighbors, generate_negative_samples +import torch + + +def test_common_neighbors(): + """ + Test that the common neighbors are correctly computed. + """ + + # node connectivity: 0-1, 1-2, 2-3, 3-0 + edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]]) + + neighbors = common_neighbors(edge_index, num_nodes=4) + + assert neighbors == {0: {1, 3}, 1: {0, 2}, 2: {1, 3}, 3: {0, 2}} + + +def test_generate_negative_samples(): + """ + Test that the negative samples are correctly generated. + """ + + # node connectivity: 0-1, 1-2, 2-3, 3-0 + edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]]) + + negative_samples = generate_negative_samples(edge_index, num_nodes=4, num_neg_samples=1) + + for ns in negative_samples: + assert len(ns) == 2 + assert ns[0] != ns[1] + assert ns not in edge_index.tolist() diff --git a/tests/test_subgraphs.py b/tests/test_subgraphs.py new file mode 100644 index 0000000..98e5b39 --- /dev/null +++ b/tests/test_subgraphs.py @@ -0,0 +1,63 @@ +import torch +import pytest +from redkg.generate_subgraphs import generate_subgraphs, generate_subgraphs_dataset + + +def test_generate_subgraphs(): + """ + Test that the generated subgraphs have the correct number of nodes and links. + """ + dataset = { + 'nodes': [{'id': i} for i in range(10)], + 'links': [{'source': i, 'target': i + 1} for i in range(9)] + } + num_subgraphs = 5 + min_nodes = 2 + max_nodes = 5 + + result = generate_subgraphs(dataset, num_subgraphs, min_nodes, max_nodes) + + assert len(result) == num_subgraphs + for subgraph in result: + assert min_nodes <= len(subgraph['nodes']) <= max_nodes + for link in subgraph['links']: + assert link['source'] in [node['id'] for node in subgraph['nodes']] + assert link['target'] in [node['id'] for node in subgraph['nodes']] + + +@pytest.fixture +def large_dataset(): + """ + Return a mock dataset with 20 nodes and 5 features. + """ + class MockDataset: + def __init__(self): + self.node_mapping = {i: i for i in range(20)} + self.x = torch.randn(20, 5) + self.y = torch.randn(20, 1) + + return MockDataset() + + +def test_generate_subgraphs_dataset(large_dataset): + """ + Test that the generated subgraphs have the correct number of nodes and links. + """ + subgraphs = [ + { + 'nodes': [{'id': i} for i in range(5)], + 'links': [{'source': i, 'target': i + 1} for i in range(4)] + }, + { + 'nodes': [{'id': i + 5} for i in range(5)], + 'links': [{'source': i + 5, 'target': i + 6} for i in range(4)] + } + ] + + result = generate_subgraphs_dataset(subgraphs, large_dataset) + + assert len(result) == len(subgraphs) + for data in result: + assert data.x.shape == large_dataset.x.shape + assert data.y.shape == large_dataset.y.shape + assert data.edge_index.shape[0] == 2