diff --git a/examples/customized_models/HGT/hgt_nc.py b/examples/customized_models/HGT/hgt_nc.py index 1400be8856..9418102f5d 100644 --- a/examples/customized_models/HGT/hgt_nc.py +++ b/examples/customized_models/HGT/hgt_nc.py @@ -360,8 +360,7 @@ def main(args): train_task=False) # Run inference on the inference dataset and save the GNN embeddings in the specified path. - infer.infer(dataloader, - save_embed_path=config.save_embed_path, + infer.infer(dataloader, save_embed_path=config.save_embed_path, save_prediction_path=config.save_prediction_path, use_mini_batch_infer=True) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 07d6913d49..7e06563b80 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -497,7 +497,7 @@ def eval_fanout(self): return self._check_fanout(fanout, "Evaluation") else: # By default use -1 as full neighbor - return [-1] * len(self.fanout) + return [-1] * self.num_layers @property def hidden_size(self): diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index c27803c467..20298eb228 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -59,6 +59,7 @@ def __init__(self, dataset, target_idx, fanout, batch_size, device='cpu', exclude_training_targets=False): self._data = dataset self._device = device + self._fanout = fanout self._target_eidx = target_idx if remove_target_edge_type: assert reverse_edge_types_map is not None, \ @@ -121,6 +122,12 @@ def target_eidx(self): """ return self._target_eidx + @property + def fanout(self): + """ The fan out of each GNN layers + """ + return self._fanout + ################ Minibatch DataLoader (Link Prediction) ####################### BUILTIN_LP_UNIFORM_NEG_SAMPLER = 'uniform' @@ -163,6 +170,7 @@ def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, train_task=True, reverse_edge_types_map=None, exclude_training_targets=False, edge_mask_for_gnn_embeddings='train_mask'): self._data = dataset + self._fanout = fanout for etype in target_idx: assert etype in dataset.g.canonical_etypes, \ "edge type {} does not exist in the graph".format(etype) @@ -230,6 +238,12 @@ def data(self): """ return self._data + @property + def fanout(self): + """ The fan out of each GNN layers + """ + return self._fanout + class GSgnnLPJointNegDataLoader(GSgnnLinkPredictionDataLoader): """ Link prediction dataloader with joint negative sampler @@ -482,9 +496,12 @@ class GSgnnLinkPredictionTestDataLoader(): Batch size num_negative_edges: int The number of negative edges per positive edge + fanout: int + Evaluation fanout for computing node embedding """ - def __init__(self, dataset, target_idx, batch_size, num_negative_edges): + def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None): self._data = dataset + self._fanout = fanout for etype in target_idx: assert etype in dataset.g.canonical_etypes, \ "edge type {} does not exist in the graph".format(etype) @@ -534,6 +551,12 @@ def __next__(self): # return pos, neg pairs return cur_iter, self._neg_sample_type + @property + def fanout(self): + """ Get eval fanout + """ + return self._fanout + class GSgnnLinkPredictionJointTestDataLoader(GSgnnLinkPredictionTestDataLoader): """ Link prediction minibatch dataloader for validation and test with joint negative sampler @@ -567,6 +590,7 @@ class GSgnnNodeDataLoader(): """ def __init__(self, dataset, target_idx, fanout, batch_size, device, train_task=True): self._data = dataset + self._fanout = fanout self._target_nidx = target_idx assert isinstance(target_idx, dict) for ntype in target_idx: @@ -605,3 +629,9 @@ def target_nidx(self): """ The target node ids for prediction. """ return self._target_nidx + + @property + def fanout(self): + """ The fan out of each GNN layers + """ + return self._fanout diff --git a/python/graphstorm/inference/ep_infer.py b/python/graphstorm/inference/ep_infer.py index 595ff19a39..df5a106767 100644 --- a/python/graphstorm/inference/ep_infer.py +++ b/python/graphstorm/inference/ep_infer.py @@ -69,7 +69,7 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, do_eval = self.evaluator is not None sys_tracker.check('start inferencing') self._model.eval() - embs = do_full_graph_inference(self._model, loader.data, + embs = do_full_graph_inference(self._model, loader.data, fanout=loader.fanout, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') res = edge_mini_batch_predict(self._model, embs, loader, return_label=do_eval) diff --git a/python/graphstorm/inference/lp_infer.py b/python/graphstorm/inference/lp_infer.py index fb6bd099f7..17e4e44d19 100644 --- a/python/graphstorm/inference/lp_infer.py +++ b/python/graphstorm/inference/lp_infer.py @@ -69,7 +69,7 @@ def infer(self, data, loader, save_embed_path, """ sys_tracker.check('start inferencing') self._model.eval() - embs = do_full_graph_inference(self._model, data, + embs = do_full_graph_inference(self._model, data, fanout=loader.fanout, edge_mask=edge_mask_for_gnn_embeddings, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') diff --git a/python/graphstorm/inference/np_infer.py b/python/graphstorm/inference/np_infer.py index 606bc12635..fbb4d469e3 100644 --- a/python/graphstorm/inference/np_infer.py +++ b/python/graphstorm/inference/np_infer.py @@ -83,7 +83,7 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, embs = {ntype: embs} else: - embs = do_full_graph_inference(self._model, loader.data, + embs = do_full_graph_inference(self._model, loader.data, fanout=loader.fanout, task_tracker=self.task_tracker) res = node_mini_batch_predict(self._model, embs, loader, return_label=do_eval) pred = res[0] diff --git a/python/graphstorm/model/gnn.py b/python/graphstorm/model/gnn.py index 513b97e665..f7c8cf4fd2 100644 --- a/python/graphstorm/model/gnn.py +++ b/python/graphstorm/model/gnn.py @@ -602,7 +602,8 @@ def loss_func(self): """ return self._loss_fn -def do_full_graph_inference(model, data, batch_size=1024, edge_mask=None, task_tracker=None): +def do_full_graph_inference(model, data, batch_size=1024, fanout=None, edge_mask=None, + task_tracker=None): """ Do fullgraph inference It may use some of the edges indicated by `edge_mask` to compute GNN embeddings. @@ -615,6 +616,8 @@ def do_full_graph_inference(model, data, batch_size=1024, edge_mask=None, task_t The GraphStorm dataset batch_size : int The batch size for inferencing a GNN layer + fanout: list of int + The fanout for computing the GNN embeddings in a GNN layer. edge_mask : str The edge mask that indicates what edges are used to compute GNN embeddings. task_tracker: GSTaskTrackerAbc @@ -652,7 +655,7 @@ def get_input_embeds(input_nodes): return {ntype: input_embeds[ntype][ids].to(device) \ for ntype, ids in input_nodes.items()} embeddings = dist_inference(data.g, model.gnn_encoder, get_input_embeds, - batch_size, -1, edge_mask=edge_mask, + batch_size, fanout, edge_mask=edge_mask, task_tracker=task_tracker) model.train() else: @@ -666,7 +669,7 @@ def get_input_embeds(input_nodes): feat_field=data.node_feat_field) return model.node_input_encoder(feats, input_nodes) embeddings = dist_inference(data.g, model.gnn_encoder, get_input_embeds, - batch_size, -1, edge_mask=edge_mask, + batch_size, fanout, edge_mask=edge_mask, task_tracker=task_tracker) model.train() if get_rank() == 0: diff --git a/python/graphstorm/model/gnn_encoder_base.py b/python/graphstorm/model/gnn_encoder_base.py index 07aed32d78..f0ede32f59 100644 --- a/python/graphstorm/model/gnn_encoder_base.py +++ b/python/graphstorm/model/gnn_encoder_base.py @@ -86,7 +86,7 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout, The node features of the graph. batch_size : int The batch size for the GNN inference. - fanout : int + fanout : list of int The fanout for computing the GNN embeddings in a GNN layer. edge_mask : str The edge mask indicates which edges are used to compute GNN embeddings. @@ -118,7 +118,8 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout, partition_book=g.get_partition_book(), ntype=ntype, force_even=False) # need to provide the fanout as a list, the number of layers is one obviously here - sampler = dgl.dataloading.MultiLayerNeighborSampler([fanout], mask=edge_mask) + fanout_i = [-1] if fanout is None or len(fanout) == 0 else [fanout[i]] + sampler = dgl.dataloading.MultiLayerNeighborSampler(fanout_i, mask=edge_mask) dataloader = dgl.dataloading.DistNodeDataLoader(g, infer_nodes, sampler, batch_size=batch_size, shuffle=True, @@ -128,7 +129,6 @@ def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout, if task_tracker is not None: task_tracker.keep_alive(report_step=iter_l) block = blocks[0].to(device) - if not isinstance(input_nodes, dict): # This happens on a homogeneous graph. assert len(g.ntypes) == 1 diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index 81c8bd6c59..6f63a57280 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -125,7 +125,8 @@ def main(args): # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings model.prepare_input_encoder(train_data) - embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker) + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, + task_tracker=tracker) save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), th.distributed.get_world_size(), device=device, diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py index 06710e9f48..7edc5c702c 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py @@ -120,7 +120,8 @@ def main(args): # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings model.prepare_input_encoder(train_data) - embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker) + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, + task_tracker=tracker) save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), th.distributed.get_world_size(), device=device, diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index ca19c9cc72..a7302e403c 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -132,7 +132,7 @@ def main(args): # For example pre-compute all BERT embeddings model.prepare_input_encoder(train_data) # TODO(zhengda) we may not want to only use training edges to generate GNN embeddings. - embeddings = do_full_graph_inference(model, train_data, + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, edge_mask="train_mask", task_tracker=tracker) save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), th.distributed.get_world_size(), diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index ee5f1ace83..475f824a37 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -102,9 +102,9 @@ def main(args): 'Supported test negative samplers include ' f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout) test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. @@ -135,7 +135,7 @@ def main(args): # For example pre-compute all BERT embeddings model.prepare_input_encoder(train_data) # TODO(zhengda) we may not want to only use training edges to generate GNN embeddings. - embeddings = do_full_graph_inference(model, train_data, + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, edge_mask="train_mask", task_tracker=tracker) save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), th.distributed.get_world_size(), diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 0bc0938e75..70b40116fd 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -62,13 +62,13 @@ def main(args): dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, batch_size=config.eval_batch_size, - num_negative_edges=config.num_negative_edges_eval) + num_negative_edges=config.num_negative_edges_eval, + fanout=config.eval_fanout) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings model.prepare_input_encoder(infer_data) - infer.infer(infer_data, dataloader, - save_embed_path=config.save_embed_path, + infer.infer(infer_data, dataloader, save_embed_path=config.save_embed_path, node_id_mapping_file=config.node_id_mapping_file) def generate_parser(): diff --git a/python/graphstorm/run/gsgnn_np/gsgnn_np.py b/python/graphstorm/run/gsgnn_np/gsgnn_np.py index 5b1ef3c00e..dc8dd8e98c 100644 --- a/python/graphstorm/run/gsgnn_np/gsgnn_np.py +++ b/python/graphstorm/run/gsgnn_np/gsgnn_np.py @@ -116,7 +116,8 @@ def main(args): # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings model.prepare_input_encoder(train_data) - embeddings = do_full_graph_inference(model, train_data, task_tracker=tracker) + embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, + task_tracker=tracker) save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), th.distributed.get_world_size(), device=device, diff --git a/python/graphstorm/trainer/ep_trainer.py b/python/graphstorm/trainer/ep_trainer.py index 57f7208c04..aa4c5ee07d 100644 --- a/python/graphstorm/trainer/ep_trainer.py +++ b/python/graphstorm/trainer/ep_trainer.py @@ -162,7 +162,7 @@ def fit(self, train_loader, num_epochs, if self.evaluator is not None and \ self.evaluator.do_eval(total_steps, epoch_end=False): val_score = self.eval(model.module, val_loader, test_loader, - use_mini_batch_infer, total_steps) + use_mini_batch_infer, total_steps) if self.evaluator.do_early_stop(val_score): early_stop = True @@ -196,7 +196,7 @@ def fit(self, train_loader, num_epochs, val_score = None if self.evaluator is not None and self.evaluator.do_eval(total_steps, epoch_end=True): val_score = self.eval(model.module, val_loader, test_loader, use_mini_batch_infer, - total_steps) + total_steps) if self.evaluator.do_early_stop(val_score): early_stop = True @@ -255,7 +255,8 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps test_pred, test_label = edge_mini_batch_gnn_predict(model, test_loader, return_label=True) else: - emb = do_full_graph_inference(model, val_loader.data, task_tracker=self.task_tracker) + emb = do_full_graph_inference(model, val_loader.data, fanout=val_loader.fanout, + task_tracker=self.task_tracker) val_pred, val_label = edge_mini_batch_predict(model, emb, val_loader, return_label=True) test_pred, test_label = edge_mini_batch_predict(model, emb, test_loader, diff --git a/python/graphstorm/trainer/lp_trainer.py b/python/graphstorm/trainer/lp_trainer.py index 1f1b2999ac..edadb154b5 100644 --- a/python/graphstorm/trainer/lp_trainer.py +++ b/python/graphstorm/trainer/lp_trainer.py @@ -148,7 +148,6 @@ def fit(self, train_loader, num_epochs, back_time += (time.time() - t3) self.log_metric("Train loss", loss.item(), total_steps) - if i % 20 == 0 and self.rank == 0: print("Epoch {:05d} | Batch {:03d} | Train Loss: {:.4f} | Time: {:.4f}". format(epoch, i, loss.item(), time.time() - batch_tic)) @@ -223,8 +222,8 @@ def fit(self, train_loader, num_epochs, self.save_model_results_to_file(self.evaluator.best_test_score, save_perf_results_path) - def eval(self, model, data, val_loader, test_loader, - total_steps, edge_mask_for_gnn_embeddings): + def eval(self, model, data, val_loader, test_loader, total_steps, + edge_mask_for_gnn_embeddings): """ do the model evaluation using validiation and test sets Parameters @@ -249,7 +248,7 @@ def eval(self, model, data, val_loader, test_loader, test_start = time.time() sys_tracker.check('before prediction') model.eval() - emb = do_full_graph_inference(model, data, + emb = do_full_graph_inference(model, data, fanout=val_loader.fanout, edge_mask=edge_mask_for_gnn_embeddings, task_tracker=self.task_tracker) sys_tracker.check('compute embeddings') diff --git a/python/graphstorm/trainer/np_trainer.py b/python/graphstorm/trainer/np_trainer.py index 83b98cfd60..5565d71e69 100644 --- a/python/graphstorm/trainer/np_trainer.py +++ b/python/graphstorm/trainer/np_trainer.py @@ -186,8 +186,8 @@ def fit(self, train_loader, num_epochs, val_score = None if self.evaluator is not None and self.evaluator.do_eval(total_steps, epoch_end=True): - val_score = self.eval(model.module, val_loader, test_loader, - use_mini_batch_infer, total_steps) + val_score = self.eval(model.module, val_loader, test_loader, use_mini_batch_infer, + total_steps) if self.evaluator.do_early_stop(val_score): early_stop = True @@ -238,7 +238,8 @@ def eval(self, model, val_loader, test_loader, use_mini_batch_infer, total_steps test_pred, _, test_label = node_mini_batch_gnn_predict(model, test_loader, return_label=True) else: - emb = do_full_graph_inference(model, val_loader.data, task_tracker=self.task_tracker) + emb = do_full_graph_inference(model, val_loader.data, fanout=val_loader.fanout, + task_tracker=self.task_tracker) val_pred, val_label = node_mini_batch_predict(model, emb, val_loader, return_label=True) test_pred, test_label = node_mini_batch_predict(model, emb, test_loader, diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 79d59459c7..46c093e4c6 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -1351,7 +1351,6 @@ def test_gnn_info(): check_failure(config, "hidden_size") # lm model may not need hidden size assert config.use_mini_batch_infer == True check_failure(config, "fanout") # fanout must be provided if used - check_failure(config, "eval_fanout") args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'gnn_test_error1.yaml'), local_rank=0) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 71f86f64d2..0ee175346f 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -116,6 +116,14 @@ def require_cache_embed(self): assert_almost_equal(embs[ntype][0:len(embs[ntype])].numpy(), embs2[ntype][0:len(embs2[ntype])].numpy()) + embs3 = do_full_graph_inference(model, data, fanout=None) + embs4 = do_full_graph_inference(model, data, fanout=[-1, -1]) + assert len(embs3) == len(embs4) + for ntype in embs3: + assert ntype in embs4 + assert_almost_equal(embs3[ntype][0:len(embs3[ntype])].numpy(), + embs4[ntype][0:len(embs4[ntype])].numpy()) + target_nidx = {"n1": th.arange(g.number_of_nodes("n0"))} dataloader1 = GSgnnNodeDataLoader(data, target_nidx, fanout=[], batch_size=10, device="cuda:0", train_task=False)