Skip to content

Commit

Permalink
Add fanout for full-graph inference (#182)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Israt Nisa <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
4 people authored May 20, 2023
1 parent e497b50 commit 8b742be
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 34 deletions.
3 changes: 1 addition & 2 deletions examples/customized_models/HGT/hgt_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ def main(args):
train_task=False)

# Run inference on the inference dataset and save the GNN embeddings in the specified path.
infer.infer(dataloader,
save_embed_path=config.save_embed_path,
infer.infer(dataloader, save_embed_path=config.save_embed_path,
save_prediction_path=config.save_prediction_path,
use_mini_batch_infer=True)

Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def eval_fanout(self):
return self._check_fanout(fanout, "Evaluation")
else:
# By default use -1 as full neighbor
return [-1] * len(self.fanout)
return [-1] * self.num_layers

@property
def hidden_size(self):
Expand Down
32 changes: 31 additions & 1 deletion python/graphstorm/dataloading/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, dataset, target_idx, fanout, batch_size, device='cpu',
exclude_training_targets=False):
self._data = dataset
self._device = device
self._fanout = fanout
self._target_eidx = target_idx
if remove_target_edge_type:
assert reverse_edge_types_map is not None, \
Expand Down Expand Up @@ -121,6 +122,12 @@ def target_eidx(self):
"""
return self._target_eidx

@property
def fanout(self):
""" The fan out of each GNN layers
"""
return self._fanout

################ Minibatch DataLoader (Link Prediction) #######################

BUILTIN_LP_UNIFORM_NEG_SAMPLER = 'uniform'
Expand Down Expand Up @@ -163,6 +170,7 @@ def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges,
train_task=True, reverse_edge_types_map=None, exclude_training_targets=False,
edge_mask_for_gnn_embeddings='train_mask'):
self._data = dataset
self._fanout = fanout
for etype in target_idx:
assert etype in dataset.g.canonical_etypes, \
"edge type {} does not exist in the graph".format(etype)
Expand Down Expand Up @@ -230,6 +238,12 @@ def data(self):
"""
return self._data

@property
def fanout(self):
""" The fan out of each GNN layers
"""
return self._fanout

class GSgnnLPJointNegDataLoader(GSgnnLinkPredictionDataLoader):
""" Link prediction dataloader with joint negative sampler
Expand Down Expand Up @@ -482,9 +496,12 @@ class GSgnnLinkPredictionTestDataLoader():
Batch size
num_negative_edges: int
The number of negative edges per positive edge
fanout: int
Evaluation fanout for computing node embedding
"""
def __init__(self, dataset, target_idx, batch_size, num_negative_edges):
def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None):
self._data = dataset
self._fanout = fanout
for etype in target_idx:
assert etype in dataset.g.canonical_etypes, \
"edge type {} does not exist in the graph".format(etype)
Expand Down Expand Up @@ -534,6 +551,12 @@ def __next__(self):
# return pos, neg pairs
return cur_iter, self._neg_sample_type

@property
def fanout(self):
""" Get eval fanout
"""
return self._fanout

class GSgnnLinkPredictionJointTestDataLoader(GSgnnLinkPredictionTestDataLoader):
""" Link prediction minibatch dataloader for validation and test
with joint negative sampler
Expand Down Expand Up @@ -567,6 +590,7 @@ class GSgnnNodeDataLoader():
"""
def __init__(self, dataset, target_idx, fanout, batch_size, device, train_task=True):
self._data = dataset
self._fanout = fanout
self._target_nidx = target_idx
assert isinstance(target_idx, dict)
for ntype in target_idx:
Expand Down Expand Up @@ -605,3 +629,9 @@ def target_nidx(self):
""" The target node ids for prediction.
"""
return self._target_nidx

@property
def fanout(self):
""" The fan out of each GNN layers
"""
return self._fanout
2 changes: 1 addition & 1 deletion python/graphstorm/inference/ep_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def infer(self, loader, save_embed_path, save_prediction_path=None,
do_eval = self.evaluator is not None
sys_tracker.check('start inferencing')
self._model.eval()
embs = do_full_graph_inference(self._model, loader.data,
embs = do_full_graph_inference(self._model, loader.data, fanout=loader.fanout,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
res = edge_mini_batch_predict(self._model, embs, loader, return_label=do_eval)
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/inference/lp_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def infer(self, data, loader, save_embed_path,
"""
sys_tracker.check('start inferencing')
self._model.eval()
embs = do_full_graph_inference(self._model, data,
embs = do_full_graph_inference(self._model, data, fanout=loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/inference/np_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def infer(self, loader, save_embed_path, save_prediction_path=None,

embs = {ntype: embs}
else:
embs = do_full_graph_inference(self._model, loader.data,
embs = do_full_graph_inference(self._model, loader.data, fanout=loader.fanout,
task_tracker=self.task_tracker)
res = node_mini_batch_predict(self._model, embs, loader, return_label=do_eval)
pred = res[0]
Expand Down
9 changes: 6 additions & 3 deletions python/graphstorm/model/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ def loss_func(self):
"""
return self._loss_fn

def do_full_graph_inference(model, data, batch_size=1024, edge_mask=None, task_tracker=None):
def do_full_graph_inference(model, data, batch_size=1024, fanout=None, edge_mask=None,
task_tracker=None):
""" Do fullgraph inference
It may use some of the edges indicated by `edge_mask` to compute GNN embeddings.
Expand All @@ -615,6 +616,8 @@ def do_full_graph_inference(model, data, batch_size=1024, edge_mask=None, task_t
The GraphStorm dataset
batch_size : int
The batch size for inferencing a GNN layer
fanout: list of int
The fanout for computing the GNN embeddings in a GNN layer.
edge_mask : str
The edge mask that indicates what edges are used to compute GNN embeddings.
task_tracker: GSTaskTrackerAbc
Expand Down Expand Up @@ -652,7 +655,7 @@ def get_input_embeds(input_nodes):
return {ntype: input_embeds[ntype][ids].to(device) \
for ntype, ids in input_nodes.items()}
embeddings = dist_inference(data.g, model.gnn_encoder, get_input_embeds,
batch_size, -1, edge_mask=edge_mask,
batch_size, fanout, edge_mask=edge_mask,
task_tracker=task_tracker)
model.train()
else:
Expand All @@ -666,7 +669,7 @@ def get_input_embeds(input_nodes):
feat_field=data.node_feat_field)
return model.node_input_encoder(feats, input_nodes)
embeddings = dist_inference(data.g, model.gnn_encoder, get_input_embeds,
batch_size, -1, edge_mask=edge_mask,
batch_size, fanout, edge_mask=edge_mask,
task_tracker=task_tracker)
model.train()
if get_rank() == 0:
Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/model/gnn_encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout,
The node features of the graph.
batch_size : int
The batch size for the GNN inference.
fanout : int
fanout : list of int
The fanout for computing the GNN embeddings in a GNN layer.
edge_mask : str
The edge mask indicates which edges are used to compute GNN embeddings.
Expand Down Expand Up @@ -118,7 +118,8 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout,
partition_book=g.get_partition_book(),
ntype=ntype, force_even=False)
# need to provide the fanout as a list, the number of layers is one obviously here
sampler = dgl.dataloading.MultiLayerNeighborSampler([fanout], mask=edge_mask)
fanout_i = [-1] if fanout is None or len(fanout) == 0 else [fanout[i]]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanout_i, mask=edge_mask)
dataloader = dgl.dataloading.DistNodeDataLoader(g, infer_nodes, sampler,
batch_size=batch_size,
shuffle=True,
Expand All @@ -128,7 +129,6 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout,
if task_tracker is not None:
task_tracker.keep_alive(report_step=iter_l)
block = blocks[0].to(device)

if not isinstance(input_nodes, dict):
# This happens on a homogeneous graph.
assert len(g.ntypes) == 1
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def main(args):
# The input layer can pre-compute node features in the preparing step if needed.
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker)
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
th.distributed.get_world_size(),
device=device,
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def main(args):
# The input layer can pre-compute node features in the preparing step if needed.
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker)
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
th.distributed.get_world_size(),
device=device,
Expand Down
2 changes: 1 addition & 1 deletion python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main(args):
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
# TODO(zhengda) we may not want to only use training edges to generate GNN embeddings.
embeddings = do_full_graph_inference(model, train_data,
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
edge_mask="train_mask", task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
th.distributed.get_world_size(),
Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def main(args):
'Supported test negative samplers include '
f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]')
val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs,
config.eval_batch_size, config.num_negative_edges_eval)
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout)
test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs,
config.eval_batch_size, config.num_negative_edges_eval)
config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout)

# Preparing input layer for training or inference.
# The input layer can pre-compute node features in the preparing step if needed.
Expand Down Expand Up @@ -135,7 +135,7 @@ def main(args):
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
# TODO(zhengda) we may not want to only use training edges to generate GNN embeddings.
embeddings = do_full_graph_inference(model, train_data,
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
edge_mask="train_mask", task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
th.distributed.get_world_size(),
Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def main(args):

dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs,
batch_size=config.eval_batch_size,
num_negative_edges=config.num_negative_edges_eval)
num_negative_edges=config.num_negative_edges_eval,
fanout=config.eval_fanout)
# Preparing input layer for training or inference.
# The input layer can pre-compute node features in the preparing step if needed.
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(infer_data)
infer.infer(infer_data, dataloader,
save_embed_path=config.save_embed_path,
infer.infer(infer_data, dataloader, save_embed_path=config.save_embed_path,
node_id_mapping_file=config.node_id_mapping_file)

def generate_parser():
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/run/gsgnn_np/gsgnn_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def main(args):
# The input layer can pre-compute node features in the preparing step if needed.
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker)
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
th.distributed.get_world_size(),
device=device,
Expand Down
7 changes: 4 additions & 3 deletions python/graphstorm/trainer/ep_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def fit(self, train_loader, num_epochs,
if self.evaluator is not None and \
self.evaluator.do_eval(total_steps, epoch_end=False):
val_score = self.eval(model.module, val_loader, test_loader,
use_mini_batch_infer, total_steps)
use_mini_batch_infer, total_steps)

if self.evaluator.do_early_stop(val_score):
early_stop = True
Expand Down Expand Up @@ -196,7 +196,7 @@ def fit(self, train_loader, num_epochs,
val_score = None
if self.evaluator is not None and self.evaluator.do_eval(total_steps, epoch_end=True):
val_score = self.eval(model.module, val_loader, test_loader, use_mini_batch_infer,
total_steps)
total_steps)

if self.evaluator.do_early_stop(val_score):
early_stop = True
Expand Down Expand Up @@ -255,7 +255,8 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps
test_pred, test_label = edge_mini_batch_gnn_predict(model, test_loader,
return_label=True)
else:
emb = do_full_graph_inference(model, val_loader.data, task_tracker=self.task_tracker)
emb = do_full_graph_inference(model, val_loader.data, fanout=val_loader.fanout,
task_tracker=self.task_tracker)
val_pred, val_label = edge_mini_batch_predict(model, emb, val_loader,
return_label=True)
test_pred, test_label = edge_mini_batch_predict(model, emb, test_loader,
Expand Down
7 changes: 3 additions & 4 deletions python/graphstorm/trainer/lp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def fit(self, train_loader, num_epochs,
back_time += (time.time() - t3)

self.log_metric("Train loss", loss.item(), total_steps)

if i % 20 == 0 and self.rank == 0:
print("Epoch {:05d} | Batch {:03d} | Train Loss: {:.4f} | Time: {:.4f}".
format(epoch, i, loss.item(), time.time() - batch_tic))
Expand Down Expand Up @@ -223,8 +222,8 @@ def fit(self, train_loader, num_epochs,
self.save_model_results_to_file(self.evaluator.best_test_score,
save_perf_results_path)

def eval(self, model, data, val_loader, test_loader,
total_steps, edge_mask_for_gnn_embeddings):
def eval(self, model, data, val_loader, test_loader, total_steps,
edge_mask_for_gnn_embeddings):
""" do the model evaluation using validiation and test sets
Parameters
Expand All @@ -249,7 +248,7 @@ def eval(self, model, data, val_loader, test_loader,
test_start = time.time()
sys_tracker.check('before prediction')
model.eval()
emb = do_full_graph_inference(model, data,
emb = do_full_graph_inference(model, data, fanout=val_loader.fanout,
edge_mask=edge_mask_for_gnn_embeddings,
task_tracker=self.task_tracker)
sys_tracker.check('compute embeddings')
Expand Down
7 changes: 4 additions & 3 deletions python/graphstorm/trainer/np_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def fit(self, train_loader, num_epochs,

val_score = None
if self.evaluator is not None and self.evaluator.do_eval(total_steps, epoch_end=True):
val_score = self.eval(model.module, val_loader, test_loader,
use_mini_batch_infer, total_steps)
val_score = self.eval(model.module, val_loader, test_loader, use_mini_batch_infer,
total_steps)
if self.evaluator.do_early_stop(val_score):
early_stop = True

Expand Down Expand Up @@ -238,7 +238,8 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps
test_pred, _, test_label = node_mini_batch_gnn_predict(model, test_loader,
return_label=True)
else:
emb = do_full_graph_inference(model, val_loader.data, task_tracker=self.task_tracker)
emb = do_full_graph_inference(model, val_loader.data, fanout=val_loader.fanout,
task_tracker=self.task_tracker)
val_pred, val_label = node_mini_batch_predict(model, emb, val_loader,
return_label=True)
test_pred, test_label = node_mini_batch_predict(model, emb, test_loader,
Expand Down
1 change: 0 additions & 1 deletion tests/unit-tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,6 @@ def test_gnn_info():
check_failure(config, "hidden_size") # lm model may not need hidden size
assert config.use_mini_batch_infer == True
check_failure(config, "fanout") # fanout must be provided if used
check_failure(config, "eval_fanout")

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'gnn_test_error1.yaml'),
local_rank=0)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit-tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ def require_cache_embed(self):
assert_almost_equal(embs[ntype][0:len(embs[ntype])].numpy(),
embs2[ntype][0:len(embs2[ntype])].numpy())

embs3 = do_full_graph_inference(model, data, fanout=None)
embs4 = do_full_graph_inference(model, data, fanout=[-1, -1])
assert len(embs3) == len(embs4)
for ntype in embs3:
assert ntype in embs4
assert_almost_equal(embs3[ntype][0:len(embs3[ntype])].numpy(),
embs4[ntype][0:len(embs4[ntype])].numpy())

target_nidx = {"n1": th.arange(g.number_of_nodes("n0"))}
dataloader1 = GSgnnNodeDataLoader(data, target_nidx, fanout=[],
batch_size=10, device="cuda:0", train_task=False)
Expand Down

0 comments on commit 8b742be

Please sign in to comment.