Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GCN, GAT and GraphSAGE tests and docs #31

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/models/test_gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import torch.nn.functional as F
from redkg.models.gat import GAT


def test_GAT():
"""
Test that the GAT model can be instantiated and run.
"""
model = GAT(
in_channels=100,
hidden_channels=200,
out_channels=50,
num_layers=3,
activation=F.relu,
dropout_rate=0.5,
heads=2
)

x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels]

# edge_index: COO format graph adjacency matrix, shape [2, num_edges]
edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long)

output = model(x, edge_index)

# Check the output shape
assert output.shape == (16, 50)

# Check the forward pass
assert not torch.isnan(output).any()
30 changes: 30 additions & 0 deletions tests/models/test_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn.functional as F
from redkg.models.gcn import GCN


def test_GCN():
"""
Test that the GCN model can be instantiated and run.
"""
model = GCN(
in_channels=100,
hidden_channels=200,
out_channels=50,
num_layers=3,
activation=F.relu,
dropout_rate=0.5
)

x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels]

# edge_index: COO format graph adjacency matrix, shape [2, num_edges]
edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long)

output = model(x, edge_index)

# Check the output shape
assert output.shape == (16, 50)

# Check the forward pass
assert not torch.isnan(output).any()
30 changes: 30 additions & 0 deletions tests/models/test_graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn.functional as F
from redkg.models.graphsage import GraphSAGE


def test_GraphSAGE():
"""
Test that the GraphSAGE model can be instantiated and run.
"""
model = GraphSAGE(
in_channels=100,
hidden_channels=200,
out_channels=50,
num_layers=3,
activation=F.relu,
dropout_rate=0.5
)

x = torch.randn(16, 100) # random node feature matrix of shape [num_nodes, in_channels]

# edge_index: COO format graph adjacency matrix, shape [2, num_edges]
edge_index = torch.randint(high=16, size=(2, 48), dtype=torch.long)

output = model(x, edge_index)

# Check the output shape
assert output.shape == (16, 50)

# Check the forward pass
assert not torch.isnan(output).any()
31 changes: 31 additions & 0 deletions tests/test_negative_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from redkg.negative_samples import common_neighbors, generate_negative_samples
import torch


def test_common_neighbors():
"""
Test that the common neighbors are correctly computed.
"""

# node connectivity: 0-1, 1-2, 2-3, 3-0
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]])

neighbors = common_neighbors(edge_index, num_nodes=4)

assert neighbors == {0: {1, 3}, 1: {0, 2}, 2: {1, 3}, 3: {0, 2}}


def test_generate_negative_samples():
"""
Test that the negative samples are correctly generated.
"""

# node connectivity: 0-1, 1-2, 2-3, 3-0
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]])

negative_samples = generate_negative_samples(edge_index, num_nodes=4, num_neg_samples=1)

for ns in negative_samples:
assert len(ns) == 2
assert ns[0] != ns[1]
assert ns not in edge_index.tolist()
63 changes: 63 additions & 0 deletions tests/test_subgraphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import pytest
from redkg.generate_subgraphs import generate_subgraphs, generate_subgraphs_dataset


def test_generate_subgraphs():
"""
Test that the generated subgraphs have the correct number of nodes and links.
"""
dataset = {
'nodes': [{'id': i} for i in range(10)],
'links': [{'source': i, 'target': i + 1} for i in range(9)]
}
num_subgraphs = 5
min_nodes = 2
max_nodes = 5

result = generate_subgraphs(dataset, num_subgraphs, min_nodes, max_nodes)

assert len(result) == num_subgraphs
for subgraph in result:
assert min_nodes <= len(subgraph['nodes']) <= max_nodes
for link in subgraph['links']:
assert link['source'] in [node['id'] for node in subgraph['nodes']]
assert link['target'] in [node['id'] for node in subgraph['nodes']]


@pytest.fixture
def large_dataset():
"""
Return a mock dataset with 20 nodes and 5 features.
"""
class MockDataset:
def __init__(self):
self.node_mapping = {i: i for i in range(20)}
self.x = torch.randn(20, 5)
self.y = torch.randn(20, 1)

return MockDataset()


def test_generate_subgraphs_dataset(large_dataset):
"""
Test that the generated subgraphs have the correct number of nodes and links.
"""
subgraphs = [
{
'nodes': [{'id': i} for i in range(5)],
'links': [{'source': i, 'target': i + 1} for i in range(4)]
},
{
'nodes': [{'id': i + 5} for i in range(5)],
'links': [{'source': i + 5, 'target': i + 6} for i in range(4)]
}
]

result = generate_subgraphs_dataset(subgraphs, large_dataset)

assert len(result) == len(subgraphs)
for data in result:
assert data.x.shape == large_dataset.x.shape
assert data.y.shape == large_dataset.y.shape
assert data.edge_index.shape[0] == 2
Loading