Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RonaldBXu committed Jan 10, 2025
1 parent addf5e0 commit 49c17af
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 89 deletions.
16 changes: 14 additions & 2 deletions python/graphstorm/model/edge_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class DenseBiDecoder(GSEdgeDecoder):
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
in_units,
Expand Down Expand Up @@ -337,6 +340,9 @@ class EdgeRegression(GSEdgeDecoder):
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand Down Expand Up @@ -510,6 +516,9 @@ class MLPEdgeDecoder(GSEdgeDecoder):
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand Down Expand Up @@ -575,8 +584,8 @@ def _compute_logits(self, g, h):
The dictionary containing the embeddings
Returns
-------
th.Tensor
Output of forward
out
Output of forward.
"""
with g.local_scope():
u, v = g.edges(etype=self.target_etype)
Expand Down Expand Up @@ -724,6 +733,9 @@ class MLPEFeatEdgeDecoder(MLPEdgeDecoder):
class implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand Down
6 changes: 6 additions & 0 deletions python/graphstorm/model/node_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class EntityClassifier(GSLayer):
implementation. Default: None.
use_bias: bool
Whether the node decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
in_dim,
Expand Down Expand Up @@ -171,6 +174,9 @@ class EntityRegression(GSLayer):
implementation. Default: None.
use_bias: bool
Whether the node decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit-tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,32 @@ def generate_mask(idx, length):
th_mask = th.tensor(mask, dtype=th.bool)
return th_mask

def generate_dummy_constant_graph(in_units):
"""
Generate a dummy heterogeneous graph to test edge decoder.
Return
-------
g: a heterogeneous graph.
h: node embeddings.
edge_type: graph schema ("n0", "r0", "n1")
"""
u = th.tensor([0, 0])
v = th.tensor([1, 2])
edge_type = ("n0", "r0", "n1")
g = dgl.heterograph({
edge_type: (u, v)
})

h = {
"n0": th.ones(g.num_nodes("n0"), in_units),
"n1": th.ones(g.num_nodes("n1"), in_units)
}

return g, h, edge_type

def generate_dummy_hetero_graph_for_efeat_gnn(is_random=True):
"""
generate a dummy heterogeneous graph to test the get_edge_feat_size() method.
Expand Down
Loading

0 comments on commit 49c17af

Please sign in to comment.