From 49c17af517f413b5874f1588f256ad4ddc485cb8 Mon Sep 17 00:00:00 2001 From: Ronald Xu Date: Fri, 10 Jan 2025 03:45:47 +0000 Subject: [PATCH] address comments --- python/graphstorm/model/edge_decoder.py | 16 ++- python/graphstorm/model/node_decoder.py | 6 + tests/unit-tests/data_utils.py | 26 ++++ tests/unit-tests/test_decoder.py | 167 ++++++++++++------------ 4 files changed, 126 insertions(+), 89 deletions(-) diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 18dfb612fb..b2ece2150d 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -134,6 +134,9 @@ class DenseBiDecoder(GSEdgeDecoder): implementation. Default: None. use_bias: bool Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, in_units, @@ -337,6 +340,9 @@ class EdgeRegression(GSEdgeDecoder): implementation. Default: None. use_bias: bool Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, @@ -510,6 +516,9 @@ class MLPEdgeDecoder(GSEdgeDecoder): implementation. Default: None. use_bias: bool Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, @@ -575,8 +584,8 @@ def _compute_logits(self, g, h): The dictionary containing the embeddings Returns ------- - th.Tensor - Output of forward + out + Output of forward. """ with g.local_scope(): u, v = g.edges(etype=self.target_etype) @@ -724,6 +733,9 @@ class MLPEFeatEdgeDecoder(MLPEdgeDecoder): class implementation. Default: None. use_bias: bool Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, diff --git a/python/graphstorm/model/node_decoder.py b/python/graphstorm/model/node_decoder.py index 4c2f9fd4c9..1e735f664b 100644 --- a/python/graphstorm/model/node_decoder.py +++ b/python/graphstorm/model/node_decoder.py @@ -40,6 +40,9 @@ class EntityClassifier(GSLayer): implementation. Default: None. use_bias: bool Whether the node decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, in_dim, @@ -171,6 +174,9 @@ class EntityRegression(GSLayer): implementation. Default: None. use_bias: bool Whether the node decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, diff --git a/tests/unit-tests/data_utils.py b/tests/unit-tests/data_utils.py index 00b2b6be27..8648102975 100644 --- a/tests/unit-tests/data_utils.py +++ b/tests/unit-tests/data_utils.py @@ -46,6 +46,32 @@ def generate_mask(idx, length): th_mask = th.tensor(mask, dtype=th.bool) return th_mask +def generate_dummy_constant_graph(in_units): + """ + Generate a dummy heterogeneous graph to test edge decoder. + + Return + ------- + g: a heterogeneous graph. + + h: node embeddings. + + edge_type: graph schema ("n0", "r0", "n1") + """ + u = th.tensor([0, 0]) + v = th.tensor([1, 2]) + edge_type = ("n0", "r0", "n1") + g = dgl.heterograph({ + edge_type: (u, v) + }) + + h = { + "n0": th.ones(g.num_nodes("n0"), in_units), + "n1": th.ones(g.num_nodes("n1"), in_units) + } + + return g, h, edge_type + def generate_dummy_hetero_graph_for_efeat_gnn(is_random=True): """ generate a dummy heterogeneous graph to test the get_edge_feat_size() method. diff --git a/tests/unit-tests/test_decoder.py b/tests/unit-tests/test_decoder.py index 3028d8330f..ea3895f250 100644 --- a/tests/unit-tests/test_decoder.py +++ b/tests/unit-tests/test_decoder.py @@ -44,7 +44,7 @@ from numpy.testing import assert_equal -from data_utils import generate_dummy_hetero_graph +from data_utils import generate_dummy_hetero_graph, generate_dummy_constant_graph def _check_scores(score, pos_score, neg_scores, etype, num_neg, batch_size): # pos scores @@ -990,80 +990,72 @@ def test_MLPEFeatEdgeDecoder(h_dim, feat_dim, out_dim, num_ffn_layers): @pytest.mark.parametrize("num_classes", [4, 8]) def test_DenseBiDecoder(in_units, num_classes): - u = th.tensor([0, 0]) - v = th.tensor([1, 2]) - edge_type = ("n0", "r0", "n1") - g = dgl.heterograph({ - edge_type: (u, v) - }) + g, h, edge_type = generate_dummy_constant_graph(in_units) - h = { - "n0": th.ones(g.num_nodes("n0"), in_units), - "n1": th.ones(g.num_nodes("n1"), in_units) - } - - # Test bias doesn't exist on combine_basis nn.Linear - decoder = DenseBiDecoder( - in_units=in_units, - num_classes=num_classes, - multilabel=False, - target_etype=edge_type, - use_bias=False - ) - assert not decoder.combine_basis.bias - - # Test classification by tricking decoder - decoder = DenseBiDecoder( - in_units=in_units, - num_classes=num_classes, - multilabel=False, - target_etype=edge_type, - use_bias=True - ) - - assert decoder.in_dims == in_units - assert decoder.out_dims == num_classes - assert not hasattr(decoder, "regression_head") - - decoder.eval() with th.no_grad(): - INCREMENT_VALUE = 10 # Trick the decoder to predict a specific class - - # Test classification when bias = 0 + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the combine_basis weight matrix. + # Because basis matrices are set to identity, the output of the bases are all + # equal. So, by modifying the weight of combine_basis, we trick the decoder + # to force it to predict TARGET_CLASS. + INCREMENT_VALUE = 100 TARGET_CLASS = 2 - th.nn.init.ones_(decoder.basis_para) - th.nn.init.ones_(decoder.combine_basis.weight) - th.nn.init.zeros_(decoder.combine_basis.bias) + decoder = DenseBiDecoder( + in_units=in_units, + num_classes=num_classes, + multilabel=False, + target_etype=edge_type, + use_bias=False + ) + decoder.eval() + + # Test bias doesn't exist on combine_basis nn.Linear + assert decoder.combine_basis.bias is None + + for i in range(decoder.num_basis): + decoder.basis_para[i, :, :] = th.eye(in_units) + th.nn.init.eye_(decoder.combine_basis.weight) decoder.combine_basis.weight[TARGET_CLASS][0] += INCREMENT_VALUE # Trick decoder prediction = decoder.predict(g, h) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + # Test classification with nonzero bias + # Same approach as above, but this time we modify the bias instead of combine_basis + # to force the decoder to predict TARGET_CLASS. TARGET_CLASS = 3 - th.nn.init.ones_(decoder.basis_para) - th.nn.init.ones_(decoder.combine_basis.weight) + decoder = DenseBiDecoder( + in_units=in_units, + num_classes=num_classes, + multilabel=False, + target_etype=edge_type, + use_bias=True + ) + decoder.eval() + + assert decoder.combine_basis.bias is not None + + assert decoder.in_dims == in_units + assert decoder.out_dims == num_classes + assert not hasattr(decoder, "regression_head") + + for i in range(decoder.num_basis): + decoder.basis_para[i, :, :] = th.eye(in_units) + th.nn.init.eye_(decoder.combine_basis.weight) th.nn.init.zeros_(decoder.combine_basis.bias) decoder.combine_basis.bias[TARGET_CLASS] += INCREMENT_VALUE # Trick decoder prediction = decoder.predict(g, h) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() @pytest.mark.parametrize("in_dim", [16, 64]) @pytest.mark.parametrize("out_dim", [1, 8]) def test_EdgeRegression(in_dim, out_dim): - u = th.tensor([0, 0]) - v = th.tensor([1, 2]) - edge_type = ("n0", "r0", "n1") - g = dgl.heterograph({ - edge_type: (u, v) - }) - - h = { - "n0": th.ones(g.num_nodes("n0"), in_dim), - "n1": th.ones(g.num_nodes("n1"), in_dim) - } + g, h, edge_type = generate_dummy_constant_graph(in_dim) # Test bias doesn't exist on linear layer decoder = EdgeRegression( @@ -1073,7 +1065,7 @@ def test_EdgeRegression(in_dim, out_dim): use_bias=False ) - assert not decoder.linear.bias + assert decoder.linear.bias is None # Test cases when bias exists (zero and nonzero) decoder = EdgeRegression( @@ -1085,9 +1077,12 @@ def test_EdgeRegression(in_dim, out_dim): assert decoder.in_dims == in_dim assert decoder.out_dims == out_dim + assert decoder.linear.bias is not None decoder.eval() with th.no_grad(): + # By setting weights to identity matrices and bias to 0, we should get all 1s + # in the output because our inputs are all 1s. th.nn.init.eye_(decoder.linear.weight) th.nn.init.eye_(decoder.regression_head.weight) th.nn.init.zeros_(decoder.linear.bias) @@ -1096,8 +1091,9 @@ def test_EdgeRegression(in_dim, out_dim): # Test regression output, should be all 1s because of identity matrix weights and 0 bias. prediction = decoder.predict(g, h) assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), out_dim) - # Test non-zero bias, should be all equal to TEST_BIAS_VALUE+1. + # Test non-zero bias, the output should be all equal to TEST_BIAS_VALUE+1. TEST_BIAS_VALUE = 7 th.nn.init.constant_(decoder.linear.bias, TEST_BIAS_VALUE) th.nn.init.eye_(decoder.linear.weight) @@ -1106,23 +1102,14 @@ def test_EdgeRegression(in_dim, out_dim): prediction = decoder.predict(g, h) assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), out_dim) @pytest.mark.parametrize("in_dim", [16, 64]) @pytest.mark.parametrize("out_dim", [4, 8]) @pytest.mark.parametrize("num_ffn_layers", [0, 2]) def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): - u = th.tensor([0, 0]) - v = th.tensor([1, 2]) - edge_type = ("n0", "r0", "n1") - g = dgl.heterograph({ - edge_type: (u, v) - }) - - h = { - "n0": th.ones(g.num_nodes("n0"), in_dim), - "n1": th.ones(g.num_nodes("n1"), in_dim) - } + g, h, edge_type = generate_dummy_constant_graph(in_dim) # Test classification # Test bias doesn't exist on decoder @@ -1137,7 +1124,6 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): assert not hasattr(decoder, "bias") assert not hasattr(decoder, "regression_head") - # Test classification by tricking decoder decoder = MLPEdgeDecoder( h_dim=in_dim, out_dim=out_dim, @@ -1155,7 +1141,12 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): decoder.eval() with th.no_grad(): - INCREMENT_VALUE = 10 # Trick the decoder to predict a specific class + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the decoder weight matrix. + # Because all layers of the MLP are set to identity, the outputs will + # be the same as the inputs, all 1s. So, by modifying the weight of the + # decoder, we force it to predict TARGET_CLASS. + INCREMENT_VALUE = 10 # Test classification when bias = 0 TARGET_CLASS = 2 @@ -1167,6 +1158,7 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): prediction = decoder.predict(g, h) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() # Test classification with nonzero bias TARGET_CLASS = 3 @@ -1178,6 +1170,7 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): prediction = decoder.predict(g, h) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() # Test regression @@ -1221,6 +1214,7 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): th.nn.init.zeros_(decoder.regression_head.bias) prediction = decoder.predict(g, h) assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), 1) # Test non-zero bias, should be all equal to TEST_BIAS_VALUE+1 because input is 1s. # Set up MLP for testing @@ -1233,12 +1227,13 @@ def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): th.nn.init.constant_(decoder.bias, TEST_BIAS_VALUE) prediction = decoder.predict(g, h) assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), 1) @pytest.mark.parametrize("in_dim", [16, 64]) @pytest.mark.parametrize("out_dim", [4, 8]) @pytest.mark.parametrize("feat_dim", [8, 32]) @pytest.mark.parametrize("num_ffn_layers", [0, 2]) -def test_MLPEFeatEdgeDecoder_hardcoded(in_dim, out_dim, feat_dim, num_ffn_layers): +def test_MLPEFeatDecoder_Constant_Inputs(in_dim, out_dim, feat_dim, num_ffn_layers): def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): th.nn.init.eye_(decoder.nn_decoder) @@ -1248,17 +1243,7 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): th.nn.init.eye_(decoder.combine_decoder) th.nn.init.eye_(decoder.decoder) - u = th.tensor([0, 0]) - v = th.tensor([1, 2]) - edge_type = ("n0", "r0", "n1") - g = dgl.heterograph({ - edge_type: (u, v) - }) - - h = { - "n0": th.ones(g.num_nodes("n0"), in_dim), - "n1": th.ones(g.num_nodes("n1"), in_dim) - } + g, h, edge_type = generate_dummy_constant_graph(in_dim) efeat = {edge_type: th.ones(g.num_edges(edge_type), feat_dim)} @@ -1276,7 +1261,6 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): assert not hasattr(decoder, "bias") assert not hasattr(decoder, "regression_head") - # Test classification by tricking decoder decoder = MLPEFeatEdgeDecoder( h_dim=in_dim, feat_dim=feat_dim, @@ -1295,6 +1279,11 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): decoder.eval() with th.no_grad(): + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the decoder weight matrix. + # Because all layers of the MLP, edge feature decoder, and nn decoder + # are set to identity, the outputs will be the same as the inputs, all 1s. + # So, by modifying the weight of the decoder, we force it to predict TARGET_CLASS. INCREMENT_VALUE = 10 # Trick the decoder to predict a specific class # Test classification when bias = 0 @@ -1304,6 +1293,7 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): prediction = decoder.predict(g, h, efeat) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() # Test classification with nonzero bias TARGET_CLASS = 3 @@ -1312,6 +1302,7 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): prediction = decoder.predict(g, h, efeat) assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() # Test regression @@ -1354,6 +1345,7 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): th.nn.init.zeros_(decoder.regression_head.bias) prediction = decoder.predict(g, h, efeat) assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), 1) # Test non-zero bias, should be all equal to TEST_BIAS_VALUE+1 because input is 1s. # Set up MLP for testing @@ -1364,6 +1356,7 @@ def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): th.nn.init.constant_(decoder.bias, TEST_BIAS_VALUE) prediction = decoder.predict(g, h, efeat) assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), 1) @pytest.mark.parametrize("in_dim", [16, 64]) @pytest.mark.parametrize("out_dim", [1, 8]) @@ -1469,10 +1462,10 @@ def test_EntityClassifier(in_dim, num_classes): test_MLPEFeatEdgeDecoder(16,8,2,0) test_MLPEFeatEdgeDecoder(16,32,2,2) - test_MLPEFeatEdgeDecoder_hardcoded(16,4,8,0) - test_MLPEFeatEdgeDecoder_hardcoded(16,8,32,0) - test_MLPEFeatEdgeDecoder_hardcoded(64,8,8,2) - test_MLPEFeatEdgeDecoder_hardcoded(64,4,32,2) + test_MLPEFeatDecoder_Constant_Inputs(16,4,8,0) + test_MLPEFeatDecoder_Constant_Inputs(16,8,32,0) + test_MLPEFeatDecoder_Constant_Inputs(64,8,8,2) + test_MLPEFeatDecoder_Constant_Inputs(64,4,32,2) test_DenseBiDecoder(16, 4) test_DenseBiDecoder(16, 8)