Skip to content

Commit

Permalink
Merge branch 'main' into examples-sm-pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo authored Jan 13, 2025
2 parents 4752935 + faec9ab commit dfdd65b
Show file tree
Hide file tree
Showing 18 changed files with 2,661 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Full argument list of the ``gconstruct.construct_graph`` command
* **-\-add-reverse-edges**: boolean value to decide whether to add reverse edges for the given graph. Adding this argument sets it to true; otherwise, it defaults to false. It is **strongly** suggested to include this argument for graph construction, as some nodes in the original data may not have in-degrees, and thus cannot update their presentations by aggregating messages from their neighbors. Adding this arugment helps prevent this issue.
* **-\-output-format**: the format of constructed graph, options are ``DGL``, ``DistDGL``. Default is ``DistDGL``. It also accepts multiple graph formats at the same time separated by an space, for example ``--output-format "DGL DistDGL"``. The output format is explained in the :ref:`Output <gcon-output-format>` section above.
* **-\-num-parts**: an integer value that specifies the number of graph partitions to produce. This is only valid if the output format is ``DistDGL``.
* **-\-part-method**: the partition method to use during partitioning. We support 'metis' or 'random'.
* **-\-skip-nonexist-edges**: boolean value to decide whether skip edges whose endpoint nodes don't exist. Default is true.
* **-\-ext-mem-workspace**: the directory where the tool can store intermediate data during graph construction. Suggest to use high-speed SSD as the external memory workspace.
* **-\-ext-mem-feat-size**: the minimal number of feature dimensions that features can be stored in external memory. Default is 64.
Expand Down
2 changes: 1 addition & 1 deletion graphstorm-processing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ classifiers = [


[build-system]
requires = ["poetry-core>=1.0.8"]
requires = ["poetry-core<2.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.black]
Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,15 +1752,15 @@ def log_report_frequency(self):
###################### Model training related ######################
@property
def decoder_bias(self):
""" Node decoder bias. decoder_bias must be a boolean. Default is False.
""" Decoder bias. decoder_bias must be a boolean. Default is True.
"""
# pylint: disable=no-member
if hasattr(self, "_decoder_bias"):
assert self._decoder_bias in [True, False], \
"decoder_bias should be in [True, False]"
return self._decoder_bias
# By default, node decoder bias is False
return False
# By default, decoder bias is True
return True

@property
def dropout(self):
Expand Down
1 change: 1 addition & 0 deletions python/graphstorm/gconstruct/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def process_graph(args):
help="The number of graph partitions. " + \
"This is only valid if the output format is DistDGL.")
argparser.add_argument("--part-method", type=str, default='metis',
choices=['metis', 'random'],
help="The partition method. Currently, we support 'metis' and 'random'.")
argparser.add_argument("--skip-nonexist-edges", action='store_true',
help="Skip edges that whose endpoint nodes don't exist.")
Expand Down
21 changes: 14 additions & 7 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def create_builtin_reconstruct_efeat_decoder(g, decoder_input_dim, config, train
decoder = EdgeRegression(decoder_input_dim,
target_etype=target_etype,
out_dim=feat_dim,
dropout=dropout)
dropout=dropout,
use_bias=config.decoder_bias)

loss_func = RegressionLossFunc()
return decoder, loss_func
Expand Down Expand Up @@ -617,14 +618,16 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
dropout_rate=dropout,
regression=False,
target_etype=target_etype,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
elif decoder_type == "MLPDecoder":
decoder = MLPEdgeDecoder(decoder_input_dim,
num_classes,
multilabel=config.multilabel,
target_etype=target_etype,
num_ffn_layers=config.num_ffn_layers_in_decoder,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
elif decoder_type == "MLPEFeatEdgeDecoder":
decoder_edge_feat = config.decoder_edge_feat
assert decoder_edge_feat is not None, \
Expand All @@ -648,7 +651,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
target_etype=target_etype,
dropout=config.dropout,
num_ffn_layers=config.num_ffn_layers_in_decoder,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
else:
assert False, f"decoder {decoder_type} is not supported."

Expand Down Expand Up @@ -680,15 +684,17 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
target_etype=target_etype,
dropout_rate=dropout,
regression=True,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
elif decoder_type == "MLPDecoder":
decoder = MLPEdgeDecoder(decoder_input_dim,
1,
multilabel=False,
target_etype=target_etype,
regression=True,
num_ffn_layers=config.num_ffn_layers_in_decoder,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
elif decoder_type == "MLPEFeatEdgeDecoder":
decoder_edge_feat = config.decoder_edge_feat
assert decoder_edge_feat is not None, \
Expand All @@ -713,7 +719,8 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
dropout=config.dropout,
regression=True,
num_ffn_layers=config.num_ffn_layers_in_decoder,
norm=config.decoder_norm)
norm=config.decoder_norm,
use_bias=config.decoder_bias)
else:
assert False, "decoder not supported"
loss_func = RegressionLossFunc()
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
LinkPredictWeightedRotatEDecoder,
LinkPredictTransEDecoder,
LinkPredictContrastiveTransEDecoder,
LinkPredictWeightedTransEDecoder)
LinkPredictWeightedTransEDecoder,
EdgeRegression)

from .gnn_encoder_base import GraphConvEncoder

Expand Down
57 changes: 46 additions & 11 deletions python/graphstorm/model/edge_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class DenseBiDecoder(GSEdgeDecoder):
norm: str
Normalization methods. Not used, but reserved for complex DenseBiDecoder child class
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
in_units,
Expand All @@ -141,7 +146,8 @@ def __init__(self,
num_basis=2,
dropout_rate=0.0,
regression=False,
norm=None):
norm=None,
use_bias=True):
super().__init__()

self.in_units = in_units
Expand All @@ -157,6 +163,7 @@ def __init__(self,
assert isinstance(target_etype, tuple) and len(target_etype) == 3, \
"Target etype must be a tuple of a canonical etype."
self.target_etype = target_etype
self.use_bias = use_bias

self._init_model()

Expand All @@ -171,7 +178,7 @@ def _init_model(self):
self.dropout = nn.Dropout(self.dropout)
self.basis_para = nn.Parameter(
th.randn(self.num_basis, self.in_units, self.in_units))
self.combine_basis = nn.Linear(self.num_basis, basis_out, bias=False)
self.combine_basis = nn.Linear(self.num_basis, basis_out, bias=self.use_bias)
self.reset_parameters()

if self.regression:
Expand Down Expand Up @@ -331,13 +338,16 @@ class EdgeRegression(GSEdgeDecoder):
norm: str, optional
Normalization methods. Not used, but reserved for complex edge regression.
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
"""
def __init__(self,
h_dim,
target_etype,
out_dim=1,
dropout=0,
norm=None):
norm=None,
use_bias=True):
super(EdgeRegression, self).__init__()
self._h_dim = h_dim
self._out_dim = out_dim
Expand All @@ -349,6 +359,7 @@ def __init__(self,
"Target etype must be a tuple of a canonical etype," \
f"e.g., (src_ntype, etype, dst_ntype), but got {target_etype}."
self._target_etype = target_etype
self._use_bias = use_bias

self._init_model()

Expand All @@ -360,7 +371,7 @@ def _init_model(self):
if self._norm is not None:
logging.warning("Embedding normalization (batch norm or layer norm) "
"is not supported in EdgeRegression")
self.linear = nn.Linear(h_dim * 2, h_dim, bias=True)
self.linear = nn.Linear(h_dim * 2, h_dim, bias=self._use_bias)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(self._dropout)
self.regression_head = nn.Linear(h_dim, out_dim, bias=True)
Expand Down Expand Up @@ -500,6 +511,11 @@ class MLPEdgeDecoder(GSEdgeDecoder):
norm: str
Normalization methods. Not used, but reserved for complex MLPEdgeDecoder child class
implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand All @@ -510,7 +526,8 @@ def __init__(self,
dropout=0,
regression=False,
num_ffn_layers=0,
norm=None):
norm=None,
use_bias=True):
super(MLPEdgeDecoder, self).__init__()
self.h_dim = h_dim
self.multilabel = multilabel
Expand All @@ -526,6 +543,7 @@ def __init__(self,
assert isinstance(target_etype, tuple) and len(target_etype) == 3, \
"Target etype must be a tuple of a canonical etype."
self.target_etype = target_etype
self.use_bias = use_bias

self._init_model()

Expand All @@ -543,6 +561,8 @@ def _init_model(self):

# Here we assume the source and destination nodes have the same dimension.
self.decoder = nn.Parameter(th.randn(self.h_dim * 2, self.out_dim))
if self.use_bias:
self.bias = nn.Parameter(th.zeros(self.out_dim))
assert self.num_hidden_layers == 1, "More than one layers not supported"
nn.init.xavier_uniform_(self.decoder,
gain=nn.init.calculate_gain('relu'))
Expand All @@ -561,8 +581,8 @@ def _compute_logits(self, g, h):
The dictionary containing the embeddings
Returns
-------
th.Tensor
Output of forward
out: th.Tensor
Output of forward.
"""
with g.local_scope():
u, v = g.edges(etype=self.target_etype)
Expand All @@ -574,6 +594,8 @@ def _compute_logits(self, g, h):
if self.num_ffn_layers > 0:
h = self.ngnn_mlp(h)
out = th.matmul(h, self.decoder)
if self.use_bias:
out = out + self.bias
return out

# pylint: disable=unused-argument
Expand Down Expand Up @@ -706,6 +728,11 @@ class MLPEFeatEdgeDecoder(MLPEdgeDecoder):
norm: str
Normalization methods. Not used, but reserved for complex MLPEFeatEdgeDecoder child
class implementation. Default: None.
use_bias: bool
Whether the edge decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
Expand All @@ -716,7 +743,8 @@ def __init__(self,
dropout=0,
regression=False,
num_ffn_layers=0,
norm=None):
norm=None,
use_bias=True):
self.feat_dim = feat_dim
super(MLPEFeatEdgeDecoder, self).__init__(h_dim=h_dim,
out_dim=out_dim,
Expand All @@ -725,7 +753,8 @@ def __init__(self,
dropout=dropout,
regression=regression,
num_ffn_layers=num_ffn_layers,
norm=norm)
norm=norm,
use_bias=use_bias)

def _init_model(self):
""" Init decoder model
Expand All @@ -747,6 +776,8 @@ def _init_model(self):
# combine output of nn_decoder and feat_decoder
self.combine_decoder = nn.Parameter(th.randn(self.h_dim * 2, self.h_dim))
self.decoder = nn.Parameter(th.randn(self.h_dim, self.out_dim))
if self.use_bias:
self.bias = nn.Parameter(th.zeros(self.out_dim))
self.dropout = nn.Dropout(self.dropout)

self.nn_decoder_norm = None
Expand Down Expand Up @@ -784,10 +815,12 @@ def _compute_logits(self, g, h, e_h):
The minibatch graph
h: dict of Tensors
The dictionary containing the embeddings
e_h: dict of Tensor
The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}.
Returns
-------
th.Tensor
Output of forward
out: Tensor
Output of forward.
"""
assert e_h is not None, "edge feature is required"
with g.local_scope():
Expand Down Expand Up @@ -819,6 +852,8 @@ def _compute_logits(self, g, h, e_h):
combine_h = self.combine_norm(combine_h)
combine_h = self.relu(combine_h)
out = th.matmul(combine_h, self.decoder)
if self.use_bias:
out = out + self.bias

return out

Expand Down
14 changes: 10 additions & 4 deletions python/graphstorm/model/node_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ class EntityClassifier(GSLayer):
Normalization methods. Not used, but reserved for complex node classifier
implementation. Default: None.
use_bias: bool
Whether the node decoder uses a bias parameter. Default: False.
Whether the node decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
in_dim,
num_classes,
multilabel,
dropout=0,
norm=None,
use_bias=False):
use_bias=True):
super(EntityClassifier, self).__init__()
self._in_dim = in_dim
self._num_classes = num_classes
Expand Down Expand Up @@ -170,14 +173,17 @@ class EntityRegression(GSLayer):
Normalization methods. Not used, but reserved for complex node regression
implementation. Default: None.
use_bias: bool
Whether the node decoder uses a bias parameter. Default: False.
Whether the node decoder uses a bias parameter. Default: True.
.. versionchanged:: 0.4.0
Add a new argument "use_bias" so users can control whether decoders have bias.
"""
def __init__(self,
h_dim,
dropout=0,
out_dim=1,
norm=None,
use_bias=False):
use_bias=True):
super(EntityRegression, self).__init__()
self._h_dim = h_dim
self._out_dim = out_dim
Expand Down
Loading

0 comments on commit dfdd65b

Please sign in to comment.