diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index ddb933fc13..9efcc1082e 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -1752,15 +1752,15 @@ def log_report_frequency(self): ###################### Model training related ###################### @property def decoder_bias(self): - """ Node decoder bias. decoder_bias must be a boolean. Default is False. + """ Decoder bias. decoder_bias must be a boolean. Default is True. """ # pylint: disable=no-member if hasattr(self, "_decoder_bias"): assert self._decoder_bias in [True, False], \ "decoder_bias should be in [True, False]" return self._decoder_bias - # By default, node decoder bias is False - return False + # By default, decoder bias is True + return True @property def dropout(self): diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index fa1ed4189c..9202025e8a 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -439,7 +439,8 @@ def create_builtin_reconstruct_efeat_decoder(g, decoder_input_dim, config, train decoder = EdgeRegression(decoder_input_dim, target_etype=target_etype, out_dim=feat_dim, - dropout=dropout) + dropout=dropout, + use_bias=config.decoder_bias) loss_func = RegressionLossFunc() return decoder, loss_func @@ -617,14 +618,16 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): dropout_rate=dropout, regression=False, target_etype=target_etype, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) elif decoder_type == "MLPDecoder": decoder = MLPEdgeDecoder(decoder_input_dim, num_classes, multilabel=config.multilabel, target_etype=target_etype, num_ffn_layers=config.num_ffn_layers_in_decoder, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) elif decoder_type == "MLPEFeatEdgeDecoder": decoder_edge_feat = config.decoder_edge_feat assert decoder_edge_feat is not None, \ @@ -648,7 +651,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): target_etype=target_etype, dropout=config.dropout, num_ffn_layers=config.num_ffn_layers_in_decoder, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) else: assert False, f"decoder {decoder_type} is not supported." @@ -680,7 +684,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): target_etype=target_etype, dropout_rate=dropout, regression=True, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) elif decoder_type == "MLPDecoder": decoder = MLPEdgeDecoder(decoder_input_dim, 1, @@ -688,7 +693,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): target_etype=target_etype, regression=True, num_ffn_layers=config.num_ffn_layers_in_decoder, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) elif decoder_type == "MLPEFeatEdgeDecoder": decoder_edge_feat = config.decoder_edge_feat assert decoder_edge_feat is not None, \ @@ -713,7 +719,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): dropout=config.dropout, regression=True, num_ffn_layers=config.num_ffn_layers_in_decoder, - norm=config.decoder_norm) + norm=config.decoder_norm, + use_bias=config.decoder_bias) else: assert False, "decoder not supported" loss_func = RegressionLossFunc() diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index 6630a58b8c..1702114209 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -61,7 +61,8 @@ LinkPredictWeightedRotatEDecoder, LinkPredictTransEDecoder, LinkPredictContrastiveTransEDecoder, - LinkPredictWeightedTransEDecoder) + LinkPredictWeightedTransEDecoder, + EdgeRegression) from .gnn_encoder_base import GraphConvEncoder diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 5065283991..bbc30a7f3c 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -132,6 +132,11 @@ class DenseBiDecoder(GSEdgeDecoder): norm: str Normalization methods. Not used, but reserved for complex DenseBiDecoder child class implementation. Default: None. + use_bias: bool + Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, in_units, @@ -141,7 +146,8 @@ def __init__(self, num_basis=2, dropout_rate=0.0, regression=False, - norm=None): + norm=None, + use_bias=True): super().__init__() self.in_units = in_units @@ -157,6 +163,7 @@ def __init__(self, assert isinstance(target_etype, tuple) and len(target_etype) == 3, \ "Target etype must be a tuple of a canonical etype." self.target_etype = target_etype + self.use_bias = use_bias self._init_model() @@ -171,7 +178,7 @@ def _init_model(self): self.dropout = nn.Dropout(self.dropout) self.basis_para = nn.Parameter( th.randn(self.num_basis, self.in_units, self.in_units)) - self.combine_basis = nn.Linear(self.num_basis, basis_out, bias=False) + self.combine_basis = nn.Linear(self.num_basis, basis_out, bias=self.use_bias) self.reset_parameters() if self.regression: @@ -331,13 +338,16 @@ class EdgeRegression(GSEdgeDecoder): norm: str, optional Normalization methods. Not used, but reserved for complex edge regression. implementation. Default: None. + use_bias: bool + Whether the edge decoder uses a bias parameter. Default: True. """ def __init__(self, h_dim, target_etype, out_dim=1, dropout=0, - norm=None): + norm=None, + use_bias=True): super(EdgeRegression, self).__init__() self._h_dim = h_dim self._out_dim = out_dim @@ -349,6 +359,7 @@ def __init__(self, "Target etype must be a tuple of a canonical etype," \ f"e.g., (src_ntype, etype, dst_ntype), but got {target_etype}." self._target_etype = target_etype + self._use_bias = use_bias self._init_model() @@ -360,7 +371,7 @@ def _init_model(self): if self._norm is not None: logging.warning("Embedding normalization (batch norm or layer norm) " "is not supported in EdgeRegression") - self.linear = nn.Linear(h_dim * 2, h_dim, bias=True) + self.linear = nn.Linear(h_dim * 2, h_dim, bias=self._use_bias) self.relu = nn.ReLU() self.dropout = nn.Dropout(self._dropout) self.regression_head = nn.Linear(h_dim, out_dim, bias=True) @@ -500,6 +511,11 @@ class MLPEdgeDecoder(GSEdgeDecoder): norm: str Normalization methods. Not used, but reserved for complex MLPEdgeDecoder child class implementation. Default: None. + use_bias: bool + Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, @@ -510,7 +526,8 @@ def __init__(self, dropout=0, regression=False, num_ffn_layers=0, - norm=None): + norm=None, + use_bias=True): super(MLPEdgeDecoder, self).__init__() self.h_dim = h_dim self.multilabel = multilabel @@ -526,6 +543,7 @@ def __init__(self, assert isinstance(target_etype, tuple) and len(target_etype) == 3, \ "Target etype must be a tuple of a canonical etype." self.target_etype = target_etype + self.use_bias = use_bias self._init_model() @@ -543,6 +561,8 @@ def _init_model(self): # Here we assume the source and destination nodes have the same dimension. self.decoder = nn.Parameter(th.randn(self.h_dim * 2, self.out_dim)) + if self.use_bias: + self.bias = nn.Parameter(th.zeros(self.out_dim)) assert self.num_hidden_layers == 1, "More than one layers not supported" nn.init.xavier_uniform_(self.decoder, gain=nn.init.calculate_gain('relu')) @@ -561,8 +581,8 @@ def _compute_logits(self, g, h): The dictionary containing the embeddings Returns ------- - th.Tensor - Output of forward + out: th.Tensor + Output of forward. """ with g.local_scope(): u, v = g.edges(etype=self.target_etype) @@ -574,6 +594,8 @@ def _compute_logits(self, g, h): if self.num_ffn_layers > 0: h = self.ngnn_mlp(h) out = th.matmul(h, self.decoder) + if self.use_bias: + out = out + self.bias return out # pylint: disable=unused-argument @@ -706,6 +728,11 @@ class MLPEFeatEdgeDecoder(MLPEdgeDecoder): norm: str Normalization methods. Not used, but reserved for complex MLPEFeatEdgeDecoder child class implementation. Default: None. + use_bias: bool + Whether the edge decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, @@ -716,7 +743,8 @@ def __init__(self, dropout=0, regression=False, num_ffn_layers=0, - norm=None): + norm=None, + use_bias=True): self.feat_dim = feat_dim super(MLPEFeatEdgeDecoder, self).__init__(h_dim=h_dim, out_dim=out_dim, @@ -725,7 +753,8 @@ def __init__(self, dropout=dropout, regression=regression, num_ffn_layers=num_ffn_layers, - norm=norm) + norm=norm, + use_bias=use_bias) def _init_model(self): """ Init decoder model @@ -747,6 +776,8 @@ def _init_model(self): # combine output of nn_decoder and feat_decoder self.combine_decoder = nn.Parameter(th.randn(self.h_dim * 2, self.h_dim)) self.decoder = nn.Parameter(th.randn(self.h_dim, self.out_dim)) + if self.use_bias: + self.bias = nn.Parameter(th.zeros(self.out_dim)) self.dropout = nn.Dropout(self.dropout) self.nn_decoder_norm = None @@ -784,10 +815,12 @@ def _compute_logits(self, g, h, e_h): The minibatch graph h: dict of Tensors The dictionary containing the embeddings + e_h: dict of Tensor + The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. Returns ------- - th.Tensor - Output of forward + out: Tensor + Output of forward. """ assert e_h is not None, "edge feature is required" with g.local_scope(): @@ -819,6 +852,8 @@ def _compute_logits(self, g, h, e_h): combine_h = self.combine_norm(combine_h) combine_h = self.relu(combine_h) out = th.matmul(combine_h, self.decoder) + if self.use_bias: + out = out + self.bias return out diff --git a/python/graphstorm/model/node_decoder.py b/python/graphstorm/model/node_decoder.py index 48bd95d6fa..1e735f664b 100644 --- a/python/graphstorm/model/node_decoder.py +++ b/python/graphstorm/model/node_decoder.py @@ -39,7 +39,10 @@ class EntityClassifier(GSLayer): Normalization methods. Not used, but reserved for complex node classifier implementation. Default: None. use_bias: bool - Whether the node decoder uses a bias parameter. Default: False. + Whether the node decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, in_dim, @@ -47,7 +50,7 @@ def __init__(self, multilabel, dropout=0, norm=None, - use_bias=False): + use_bias=True): super(EntityClassifier, self).__init__() self._in_dim = in_dim self._num_classes = num_classes @@ -170,14 +173,17 @@ class EntityRegression(GSLayer): Normalization methods. Not used, but reserved for complex node regression implementation. Default: None. use_bias: bool - Whether the node decoder uses a bias parameter. Default: False. + Whether the node decoder uses a bias parameter. Default: True. + + .. versionchanged:: 0.4.0 + Add a new argument "use_bias" so users can control whether decoders have bias. """ def __init__(self, h_dim, dropout=0, out_dim=1, norm=None, - use_bias=False): + use_bias=True): super(EntityRegression, self).__init__() self._h_dim = h_dim self._out_dim = out_dim diff --git a/tests/unit-tests/data_utils.py b/tests/unit-tests/data_utils.py index 00b2b6be27..8648102975 100644 --- a/tests/unit-tests/data_utils.py +++ b/tests/unit-tests/data_utils.py @@ -46,6 +46,32 @@ def generate_mask(idx, length): th_mask = th.tensor(mask, dtype=th.bool) return th_mask +def generate_dummy_constant_graph(in_units): + """ + Generate a dummy heterogeneous graph to test edge decoder. + + Return + ------- + g: a heterogeneous graph. + + h: node embeddings. + + edge_type: graph schema ("n0", "r0", "n1") + """ + u = th.tensor([0, 0]) + v = th.tensor([1, 2]) + edge_type = ("n0", "r0", "n1") + g = dgl.heterograph({ + edge_type: (u, v) + }) + + h = { + "n0": th.ones(g.num_nodes("n0"), in_units), + "n1": th.ones(g.num_nodes("n1"), in_units) + } + + return g, h, edge_type + def generate_dummy_hetero_graph_for_efeat_gnn(is_random=True): """ generate a dummy heterogeneous graph to test the get_edge_feat_size() method. diff --git a/tests/unit-tests/test_decoder.py b/tests/unit-tests/test_decoder.py index 0a033fc2dc..ea3895f250 100644 --- a/tests/unit-tests/test_decoder.py +++ b/tests/unit-tests/test_decoder.py @@ -30,7 +30,10 @@ LinkPredictRotatEDecoder, LinkPredictContrastiveRotatEDecoder, LinkPredictTransEDecoder, - LinkPredictContrastiveTransEDecoder) + LinkPredictContrastiveTransEDecoder, + DenseBiDecoder, + EdgeRegression, + MLPEdgeDecoder) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER) from graphstorm.eval.utils import (calc_distmult_pos_score, @@ -41,7 +44,7 @@ from numpy.testing import assert_equal -from data_utils import generate_dummy_hetero_graph +from data_utils import generate_dummy_hetero_graph, generate_dummy_constant_graph def _check_scores(score, pos_score, neg_scores, etype, num_neg, batch_size): # pos scores @@ -983,6 +986,378 @@ def test_MLPEFeatEdgeDecoder(h_dim, feat_dim, out_dim, num_ffn_layers): pred = out.argmax(dim=1) assert_almost_equal(prediction.cpu().numpy(), pred.cpu().numpy()) +@pytest.mark.parametrize("in_units", [16, 64]) +@pytest.mark.parametrize("num_classes", [4, 8]) +def test_DenseBiDecoder(in_units, num_classes): + + g, h, edge_type = generate_dummy_constant_graph(in_units) + + with th.no_grad(): + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the combine_basis weight matrix. + # Because basis matrices are set to identity, the output of the bases are all + # equal. So, by modifying the weight of combine_basis, we trick the decoder + # to force it to predict TARGET_CLASS. + INCREMENT_VALUE = 100 + TARGET_CLASS = 2 + decoder = DenseBiDecoder( + in_units=in_units, + num_classes=num_classes, + multilabel=False, + target_etype=edge_type, + use_bias=False + ) + decoder.eval() + + # Test bias doesn't exist on combine_basis nn.Linear + assert decoder.combine_basis.bias is None + + for i in range(decoder.num_basis): + decoder.basis_para[i, :, :] = th.eye(in_units) + th.nn.init.eye_(decoder.combine_basis.weight) + decoder.combine_basis.weight[TARGET_CLASS][0] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + + + # Test classification with nonzero bias + # Same approach as above, but this time we modify the bias instead of combine_basis + # to force the decoder to predict TARGET_CLASS. + TARGET_CLASS = 3 + decoder = DenseBiDecoder( + in_units=in_units, + num_classes=num_classes, + multilabel=False, + target_etype=edge_type, + use_bias=True + ) + decoder.eval() + + assert decoder.combine_basis.bias is not None + + assert decoder.in_dims == in_units + assert decoder.out_dims == num_classes + assert not hasattr(decoder, "regression_head") + + for i in range(decoder.num_basis): + decoder.basis_para[i, :, :] = th.eye(in_units) + th.nn.init.eye_(decoder.combine_basis.weight) + th.nn.init.zeros_(decoder.combine_basis.bias) + decoder.combine_basis.bias[TARGET_CLASS] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + +@pytest.mark.parametrize("in_dim", [16, 64]) +@pytest.mark.parametrize("out_dim", [1, 8]) +def test_EdgeRegression(in_dim, out_dim): + + g, h, edge_type = generate_dummy_constant_graph(in_dim) + + # Test bias doesn't exist on linear layer + decoder = EdgeRegression( + h_dim=in_dim, + out_dim=out_dim, + target_etype=edge_type, + use_bias=False + ) + + assert decoder.linear.bias is None + + # Test cases when bias exists (zero and nonzero) + decoder = EdgeRegression( + h_dim=in_dim, + out_dim=out_dim, + target_etype=edge_type, + use_bias=True + ) + + assert decoder.in_dims == in_dim + assert decoder.out_dims == out_dim + assert decoder.linear.bias is not None + + decoder.eval() + with th.no_grad(): + # By setting weights to identity matrices and bias to 0, we should get all 1s + # in the output because our inputs are all 1s. + th.nn.init.eye_(decoder.linear.weight) + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.linear.bias) + th.nn.init.zeros_(decoder.regression_head.bias) + + # Test regression output, should be all 1s because of identity matrix weights and 0 bias. + prediction = decoder.predict(g, h) + assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), out_dim) + + # Test non-zero bias, the output should be all equal to TEST_BIAS_VALUE+1. + TEST_BIAS_VALUE = 7 + th.nn.init.constant_(decoder.linear.bias, TEST_BIAS_VALUE) + th.nn.init.eye_(decoder.linear.weight) + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.regression_head.bias) + + prediction = decoder.predict(g, h) + assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), out_dim) + +@pytest.mark.parametrize("in_dim", [16, 64]) +@pytest.mark.parametrize("out_dim", [4, 8]) +@pytest.mark.parametrize("num_ffn_layers", [0, 2]) +def test_MLPEdgeDecoder(in_dim, out_dim, num_ffn_layers): + + g, h, edge_type = generate_dummy_constant_graph(in_dim) + + # Test classification + # Test bias doesn't exist on decoder + decoder = MLPEdgeDecoder( + h_dim=in_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=False + ) + assert not hasattr(decoder, "bias") + assert not hasattr(decoder, "regression_head") + + decoder = MLPEdgeDecoder( + h_dim=in_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=True + ) + + assert decoder.in_dims == in_dim + assert decoder.out_dims == out_dim + assert hasattr(decoder, "bias") + assert not hasattr(decoder, "regression_head") + assert decoder.use_bias + + decoder.eval() + with th.no_grad(): + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the decoder weight matrix. + # Because all layers of the MLP are set to identity, the outputs will + # be the same as the inputs, all 1s. So, by modifying the weight of the + # decoder, we force it to predict TARGET_CLASS. + INCREMENT_VALUE = 10 + + # Test classification when bias = 0 + TARGET_CLASS = 2 + # Set up MLP for testing + for layer in decoder.ngnn_mlp.ngnn_gnn: + th.nn.init.eye_(layer) + th.nn.init.eye_(decoder.decoder) + decoder.decoder[0][TARGET_CLASS] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + + # Test classification with nonzero bias + TARGET_CLASS = 3 + # Set up MLP for testing + for layer in decoder.ngnn_mlp.ngnn_gnn: + th.nn.init.eye_(layer) + th.nn.init.eye_(decoder.decoder) + decoder.bias[TARGET_CLASS] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + + + # Test regression + # Test bias doesn't exist on decoder + decoder = MLPEdgeDecoder( + h_dim=in_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=False, + regression=True + ) + assert not hasattr(decoder, "bias") + assert hasattr(decoder, "regression_head") + + decoder = MLPEdgeDecoder( + h_dim=in_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=True, + regression=True + ) + + assert decoder.in_dims == in_dim + assert decoder.out_dims == 1 + assert hasattr(decoder, "bias") + assert hasattr(decoder, "regression_head") + assert decoder.use_bias + + decoder.eval() + with th.no_grad(): + # Test regression output, should be all 1s because of identity matrix weights and 1s tensor input. + # Set up MLP for testing + for layer in decoder.ngnn_mlp.ngnn_gnn: + th.nn.init.eye_(layer) + th.nn.init.eye_(decoder.decoder) + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.regression_head.bias) + prediction = decoder.predict(g, h) + assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), 1) + + # Test non-zero bias, should be all equal to TEST_BIAS_VALUE+1 because input is 1s. + # Set up MLP for testing + for layer in decoder.ngnn_mlp.ngnn_gnn: + th.nn.init.eye_(layer) + th.nn.init.eye_(decoder.decoder) + TEST_BIAS_VALUE = 6 + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.regression_head.bias) + th.nn.init.constant_(decoder.bias, TEST_BIAS_VALUE) + prediction = decoder.predict(g, h) + assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), 1) + +@pytest.mark.parametrize("in_dim", [16, 64]) +@pytest.mark.parametrize("out_dim", [4, 8]) +@pytest.mark.parametrize("feat_dim", [8, 32]) +@pytest.mark.parametrize("num_ffn_layers", [0, 2]) +def test_MLPEFeatDecoder_Constant_Inputs(in_dim, out_dim, feat_dim, num_ffn_layers): + + def prepareMLPEFeatEdgeDecoder(decoder: MLPEFeatEdgeDecoder): + th.nn.init.eye_(decoder.nn_decoder) + th.nn.init.eye_(decoder.feat_decoder) + for layer in decoder.ngnn_mlp.ngnn_gnn: + th.nn.init.eye_(layer) + th.nn.init.eye_(decoder.combine_decoder) + th.nn.init.eye_(decoder.decoder) + + g, h, edge_type = generate_dummy_constant_graph(in_dim) + + efeat = {edge_type: th.ones(g.num_edges(edge_type), feat_dim)} + + # Test classification + # Test bias doesn't exist on decoder + decoder = MLPEFeatEdgeDecoder( + h_dim=in_dim, + feat_dim=feat_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=False + ) + assert not hasattr(decoder, "bias") + assert not hasattr(decoder, "regression_head") + + decoder = MLPEFeatEdgeDecoder( + h_dim=in_dim, + feat_dim=feat_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=True + ) + + assert decoder.in_dims == in_dim + assert decoder.out_dims == out_dim + assert hasattr(decoder, "bias") + assert not hasattr(decoder, "regression_head") + assert decoder.use_bias + + decoder.eval() + with th.no_grad(): + # We trick the decoder to predict TARGET_CLASS by adding + # INCREMENT_VALUE at TARGET_CLASS' index in the decoder weight matrix. + # Because all layers of the MLP, edge feature decoder, and nn decoder + # are set to identity, the outputs will be the same as the inputs, all 1s. + # So, by modifying the weight of the decoder, we force it to predict TARGET_CLASS. + INCREMENT_VALUE = 10 # Trick the decoder to predict a specific class + + # Test classification when bias = 0 + TARGET_CLASS = 2 + prepareMLPEFeatEdgeDecoder(decoder) + decoder.decoder[0][TARGET_CLASS] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h, efeat) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + + # Test classification with nonzero bias + TARGET_CLASS = 3 + prepareMLPEFeatEdgeDecoder(decoder) + decoder.bias[TARGET_CLASS] += INCREMENT_VALUE # Trick decoder + + prediction = decoder.predict(g, h, efeat) + assert th.all(prediction == TARGET_CLASS) + assert prediction.shape[0] == g.num_edges() + + + # Test regression + # Test bias doesn't exist on decoder + decoder = MLPEFeatEdgeDecoder( + h_dim=in_dim, + feat_dim=feat_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=False, + regression=True + ) + assert not hasattr(decoder, "bias") + assert hasattr(decoder, "regression_head") + + decoder = MLPEFeatEdgeDecoder( + h_dim=in_dim, + feat_dim=feat_dim, + out_dim=out_dim, + multilabel=False, + target_etype=edge_type, + num_ffn_layers=num_ffn_layers, + use_bias=True, + regression=True + ) + + assert decoder.in_dims == in_dim + assert decoder.out_dims == 1 + assert hasattr(decoder, "bias") + assert hasattr(decoder, "regression_head") + assert decoder.use_bias + + decoder.eval() + with th.no_grad(): + # Test regression output, should be all 1s because of identity matrix weights and 1s tensor input. + prepareMLPEFeatEdgeDecoder(decoder) + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.regression_head.bias) + prediction = decoder.predict(g, h, efeat) + assert th.all(prediction == 1) + assert prediction.shape == (g.num_edges(), 1) + + # Test non-zero bias, should be all equal to TEST_BIAS_VALUE+1 because input is 1s. + # Set up MLP for testing + prepareMLPEFeatEdgeDecoder(decoder) + TEST_BIAS_VALUE = 6 + th.nn.init.eye_(decoder.regression_head.weight) + th.nn.init.zeros_(decoder.regression_head.bias) + th.nn.init.constant_(decoder.bias, TEST_BIAS_VALUE) + prediction = decoder.predict(g, h, efeat) + assert th.all(prediction == TEST_BIAS_VALUE+1) + assert prediction.shape == (g.num_edges(), 1) + @pytest.mark.parametrize("in_dim", [16, 64]) @pytest.mark.parametrize("out_dim", [1, 8]) def test_EntityRegression(in_dim, out_dim): @@ -1086,3 +1461,23 @@ def test_EntityClassifier(in_dim, num_classes): test_MLPEFeatEdgeDecoder(16,8,2,0) test_MLPEFeatEdgeDecoder(16,32,2,2) + + test_MLPEFeatDecoder_Constant_Inputs(16,4,8,0) + test_MLPEFeatDecoder_Constant_Inputs(16,8,32,0) + test_MLPEFeatDecoder_Constant_Inputs(64,8,8,2) + test_MLPEFeatDecoder_Constant_Inputs(64,4,32,2) + + test_DenseBiDecoder(16, 4) + test_DenseBiDecoder(16, 8) + test_DenseBiDecoder(64, 4) + test_DenseBiDecoder(64, 8) + + test_EdgeRegression(16, 1) + test_EdgeRegression(16, 8) + test_EdgeRegression(64, 1) + test_EdgeRegression(64, 8) + + test_MLPEdgeDecoder(16, 4, 0) + test_MLPEdgeDecoder(16, 8, 2) + test_MLPEdgeDecoder(64, 4, 2) + test_MLPEdgeDecoder(64, 8, 0) diff --git a/tests/unit-tests/test_gsf.py b/tests/unit-tests/test_gsf.py index e692777bc6..f48ae9e2f4 100644 --- a/tests/unit-tests/test_gsf.py +++ b/tests/unit-tests/test_gsf.py @@ -205,6 +205,7 @@ def test_create_builtin_edge_decoder(): "num_ffn_layers_in_decoder": 0, "multilabel_weights": None, "imbalance_class_weights": None, + "decoder_bias": True, } ) decoder, loss_func = create_builtin_edge_decoder(g, decoder_input_dim, config, train_task) @@ -225,6 +226,7 @@ def test_create_builtin_edge_decoder(): "num_ffn_layers_in_decoder": 0, "alpha": None, "gamma": None, + "decoder_bias": True, } ) decoder, loss_func = create_builtin_edge_decoder(g, decoder_input_dim, config, train_task) @@ -246,6 +248,7 @@ def test_create_builtin_edge_decoder(): "num_ffn_layers_in_decoder": 0, "alpha": 0.3, "gamma": 3., + "decoder_bias": True, } ) decoder, loss_func = create_builtin_edge_decoder(g, decoder_input_dim, config, train_task) @@ -263,6 +266,7 @@ def test_create_builtin_edge_decoder(): "decoder_type": "DenseBiDecoder", "num_decoder_basis": 2, "decoder_norm": None, + "decoder_bias": False, } ) decoder, loss_func = create_builtin_edge_decoder(g, decoder_input_dim, config, train_task) @@ -278,6 +282,7 @@ def test_create_builtin_edge_decoder(): "decoder_type": "MLPDecoder", "num_ffn_layers_in_decoder": 0, "decoder_norm": None, + "decoder_bias": False, } ) decoder, loss_func = create_builtin_edge_decoder(g, decoder_input_dim, config, train_task)