From 6c318461bd93873b444e10e0cb1b7510017bb917 Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Sat, 20 May 2023 16:26:29 -0700 Subject: [PATCH] [BugFix ]Fix emb shuffle bug in node_infer.py (#185) *Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Xiang Song --- python/graphstorm/gconstruct/utils.py | 6 +-- python/graphstorm/inference/ep_infer.py | 41 ++++++++------- python/graphstorm/inference/np_infer.py | 48 +++++++++++------ python/graphstorm/model/utils.py | 25 +++++++-- python/graphstorm/run/gsgnn_ep/gsgnn_ep.py | 10 +++- tests/end2end-tests/check_infer.py | 13 ++++- .../end2end-tests/graphstorm-ec/mgpu_test.sh | 2 +- .../end2end-tests/graphstorm-nc/mgpu_test.sh | 9 ++++ tests/unit-tests/test_utils.py | 51 +++++++++++++++---- tools/partition_graph.py | 23 +++++++-- training_scripts/gsgnn_lp/arxiv_lp.yaml | 45 ++++++++++++++++ 11 files changed, 216 insertions(+), 57 deletions(-) create mode 100644 training_scripts/gsgnn_lp/arxiv_lp.yaml diff --git a/python/graphstorm/gconstruct/utils.py b/python/graphstorm/gconstruct/utils.py index 6f9405e07b..7630a183bd 100644 --- a/python/graphstorm/gconstruct/utils.py +++ b/python/graphstorm/gconstruct/utils.py @@ -227,7 +227,7 @@ def __call__(self, arrs, name): em_arr[:] = arr[:] return em_arr -def _save_maps(output_dir, fname, map_data): +def save_maps(output_dir, fname, map_data): """ Save node id mapping or edge id mapping Parameters @@ -329,7 +329,7 @@ def partition_graph(g, node_data, edge_data, graph_name, num_partitions, output_ # the new_node_mapping contains per entity type on the ith row # the original node id for the ith node. - _save_maps(output_dir, "node_mapping", new_node_mapping) + save_maps(output_dir, "node_mapping", new_node_mapping) # the new_edge_mapping contains per edge type on the ith row # the original edge id for the ith edge. - _save_maps(output_dir, "edge_mapping", new_edge_mapping) + save_maps(output_dir, "edge_mapping", new_edge_mapping) diff --git a/python/graphstorm/inference/ep_infer.py b/python/graphstorm/inference/ep_infer.py index df5a106767..8aa74abf24 100644 --- a/python/graphstorm/inference/ep_infer.py +++ b/python/graphstorm/inference/ep_infer.py @@ -83,12 +83,25 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, assert len(infer_data.eval_etypes) == 1, \ "GraphStorm only support single target edge type for training and inference" - target_ntypes = set() - for etype in infer_data.eval_etypes: - target_ntypes.add(etype[0]) - target_ntypes.add(etype[2]) - embs = {ntype: embs[ntype] for ntype in target_ntypes} + # do evaluation first + if do_eval: + test_start = time.time() + val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0) + sys_tracker.check('run evaluation') + if self.rank == 0: + self.log_print_metrics(val_score=val_score, + test_score=test_score, + dur_eval=time.time() - test_start, + total_steps=0) + if save_embed_path is not None: + target_ntypes = set() + for etype in infer_data.eval_etypes: + target_ntypes.add(etype[0]) + target_ntypes.add(etype[2]) + + # The order of the ntypes must be sorted + embs = {ntype: embs[ntype] for ntype in sorted(target_ntypes)} device = th.device(f"cuda:{self.dev_id}") \ if self.dev_id >= 0 else th.device("cpu") save_gsgnn_embeddings(save_embed_path, embs, self.rank, @@ -102,27 +115,19 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, if edge_id_mapping_file is not None: g = loader.data.g etype = infer_data.eval_etypes[0] - pred_data = DistTensor((g.num_edges(etype), pred.shape[1]), + pred_shape = list(pred.shape) + pred_shape[0] = g.num_edges(etype) + pred_data = DistTensor(pred_shape, dtype=pred.dtype, name='predict-'+'-'.join(etype), part_policy=g.get_edge_partition_policy(etype), # TODO: this makes the tensor persistent in memory. persistent=True) # edges that have predictions may be just a subset of the # entire edge set. - pred_data[loader.target_eidx] = pred + pred_data[loader.target_eidx[etype]] = pred.cpu() - pred = shuffle_predict(pred_data, edge_id_mapping_file, self.rank, + pred = shuffle_predict(pred_data, edge_id_mapping_file, etype, self.rank, th.distributed.get_world_size(), device=device) save_prediction_results(pred, save_prediction_path, self.rank) th.distributed.barrier() sys_tracker.check('save predictions') - - if do_eval: - test_start = time.time() - val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0) - sys_tracker.check('run evaluation') - if self.rank == 0: - self.log_print_metrics(val_score=val_score, - test_score=test_score, - dur_eval=time.time() - test_start, - total_steps=0) diff --git a/python/graphstorm/inference/np_infer.py b/python/graphstorm/inference/np_infer.py index fbb4d469e3..c5122136a4 100644 --- a/python/graphstorm/inference/np_infer.py +++ b/python/graphstorm/inference/np_infer.py @@ -90,8 +90,33 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, label = res[1] if do_eval else None sys_tracker.check('compute embeddings') - embeddings = {ntype: embs[ntype]} + # do evaluation first + # do evaluation if any + if do_eval: + test_start = time.time() + val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0) + sys_tracker.check('run evaluation') + if self.rank == 0: + self.log_print_metrics(val_score=val_score, + test_score=test_score, + dur_eval=time.time() - test_start, + total_steps=0) + if save_embed_path is not None: + if use_mini_batch_infer: + g = loader.data.g + ntype_emb = DistTensor((g.num_nodes(ntype), embs[ntype].shape[1]), + dtype=embs[ntype].dtype, name=f'gen-emb-{ntype}', + part_policy=g.get_node_partition_policy(ntype), + # TODO: this makes the tensor persistent in memory. + persistent=True) + # nodes that do prediction in mini-batch may be just a subset of the + # entire node set. + ntype_emb[loader.target_nidx[ntype]] = embs[ntype] + else: + ntype_emb = embs[ntype] + embeddings = {ntype: ntype_emb} + device = th.device(f"cuda:{self.dev_id}") \ if self.dev_id >= 0 else th.device("cpu") save_gsgnn_embeddings(save_embed_path, @@ -99,34 +124,25 @@ def infer(self, loader, save_embed_path, save_prediction_path=None, device=device, node_id_mapping_file=node_id_mapping_file) th.distributed.barrier() - sys_tracker.check('save embeddings') + sys_tracker.check('save embeddings') if save_prediction_path is not None: # shuffle pred results according to node_id_mapping_file if node_id_mapping_file is not None: g = loader.data.g - pred_data = DistTensor((g.num_nodes(ntype), pred.shape[1]), + pred_shape = list(pred.shape) + pred_shape[0] = g.num_nodes(ntype) + pred_data = DistTensor(pred_shape, dtype=pred.dtype, name=f'predict-{ntype}', part_policy=g.get_node_partition_policy(ntype), # TODO: this makes the tensor persistent in memory. persistent=True) # nodes that have predictions may be just a subset of the # entire node set. - pred_data[loader.target_nidx] = pred - pred = shuffle_predict(pred_data, node_id_mapping_file, self.rank, + pred_data[loader.target_nidx[ntype]] = pred.cpu() + pred = shuffle_predict(pred_data, node_id_mapping_file, ntype, self.rank, th.distributed.get_world_size(), device=device) save_prediction_results(pred, save_prediction_path, self.rank) th.distributed.barrier() sys_tracker.check('save predictions') - - # do evaluation if any - if do_eval: - test_start = time.time() - val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0) - sys_tracker.check('run evaluation') - if self.rank == 0: - self.log_print_metrics(val_score=val_score, - test_score=test_score, - dur_eval=time.time() - test_start, - total_steps=0) diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 8531dd6740..612e7f3f7e 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -288,7 +288,6 @@ def _exchange_node_id_mapping(local_rank, world_size, device, start_idx, end_idx = _get_data_range(i, world_size, num_embs) data_tensors.append( node_id_mapping[start_idx:end_idx].to(device)) - else: data_tensors = [th.empty((0,), dtype=th.long, @@ -413,13 +412,33 @@ def save_embeddings(model_path, embeddings, local_rank, world_size, with open(os.path.join(model_path, "emb_info.json"), 'w', encoding='utf-8') as f: f.write(json.dumps(emb_info)) -def shuffle_predict(predictions, id_mapping_file, +def shuffle_predict(predictions, id_mapping_file, pred_type, local_rank, world_size, device): """ Shuffle prediction result according to id_mapping + + Parameters + ---------- + predictions: dgl DistTensor + prediction results + id_mapping_file: str + Path to the file storing node id mapping or edge id mapping generated by the + graph partition algorithm. + pred_type: str or tuple + Node type or edge type of the prediction target. + local_rank : int + Local rank + world_size : int + World size in a distributed env. + device : torch device + Device used to do data shuffling. """ id_mapping = th.load(id_mapping_file) if local_rank == 0 else None + # In most of cases, id_mapping is a dict for heterogeneous graph. + # For homogeneous graph, it is just a tensor. + id_mapping = id_mapping[pred_type] if isinstance(id_mapping, dict) else id_mapping local_id_mapping = _exchange_node_id_mapping( - local_rank, world_size, device, id_mapping, len(predictions)) + local_rank, world_size, device, id_mapping, + len(predictions)).cpu() # predictions are stored in CPU return predictions[local_id_mapping] def save_prediction_results(predictions, prediction_path, local_rank): diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index 6f63a57280..c8d6488495 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -127,7 +127,15 @@ def main(args): model.prepare_input_encoder(train_data) embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, task_tracker=tracker) - save_embeddings(config.save_embed_path, embeddings, gs.get_rank(), + # only save node embeddings of nodes with node types from target_etype + target_ntypes = set() + for etype in config.target_etype: + target_ntypes.add(etype[0]) + target_ntypes.add(etype[2]) + + # The order of the ntypes must be sorted + embs = {ntype: embeddings[ntype] for ntype in sorted(target_ntypes)} + save_embeddings(config.save_embed_path, embs, gs.get_rank(), th.distributed.get_world_size(), device=device, node_id_mapping_file=config.node_id_mapping_file) diff --git a/tests/end2end-tests/check_infer.py b/tests/end2end-tests/check_infer.py index 5e15c5b0e2..810c251f2d 100644 --- a/tests/end2end-tests/check_infer.py +++ b/tests/end2end-tests/check_infer.py @@ -29,6 +29,8 @@ help="Path to embedding saved by trainer") argparser.add_argument("--link_prediction", action='store_true', help="Path to embedding saved by trainer") + argparser.add_argument("--mini-batch-infer", action='store_true', + help="Inference use minibatch inference.") args = argparser.parse_args() with open(os.path.join(args.train_embout, "emb_info.json"), 'r', encoding='utf-8') as f: train_emb_info = json.load(f) @@ -69,4 +71,13 @@ assert train_emb.shape[0] == infer_emb.shape[0] assert train_emb.shape[1] == infer_emb.shape[1] - assert_almost_equal(train_emb.numpy(), infer_emb.numpy(), decimal=2) + + if args.mini_batch_infer: + # When inference is done with minibatch inference, only node + # embeddings of the test set are computed. + for i in range(len(train_emb)): + if th.all(infer_emb[i] == 0.): + continue + assert_almost_equal(train_emb[i].numpy(), infer_emb[i].numpy(), decimal=4) + else: + assert_almost_equal(train_emb.numpy(), infer_emb.numpy(), decimal=2) diff --git a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh index b3fa6ba237..08c9c650e8 100644 --- a/tests/end2end-tests/graphstorm-ec/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-ec/mgpu_test.sh @@ -75,7 +75,7 @@ best_epoch=$(grep "successfully save the model to" train_log.txt | tail -1 | tr echo "The best model is saved in epoch $best_epoch" echo "**************dataset: Generated multilabel MovieLens EC, do inference on saved model" -python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/inference_scripts/ep_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_infer.yaml --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer false --save-embed-path /data/gsgnn_ec/infer-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ | tee log.txt +python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/inference_scripts/ep_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_infer.yaml --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer false --save-embed-path /data/gsgnn_ec/infer-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_ec/prediction/ | tee log.txt error_and_exit ${PIPESTATUS[0]} diff --git a/tests/end2end-tests/graphstorm-nc/mgpu_test.sh b/tests/end2end-tests/graphstorm-nc/mgpu_test.sh index f22fa2d07c..bcf14bfb4c 100644 --- a/tests/end2end-tests/graphstorm-nc/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-nc/mgpu_test.sh @@ -133,6 +133,15 @@ python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_n error_and_exit $? +echo "**************dataset: Movielens, do inference on saved model with mini-batch-infer" +python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer true --save-embed-path /data/gsgnn_nc_ml/mini-infer-emb/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction/ | tee log.txt + +error_and_exit ${PIPESTATUS[0]} + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_nc_ml/emb/ --infer_embout /data/gsgnn_nc_ml/mini-infer-emb --mini-batch-infer + +error_and_exit $? + echo "**************dataset: Movielens, do inference on saved model, decoder: dot with a single process" python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers 1 --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml/infer-emb-1p/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction-1p/ | tee log.txt diff --git a/tests/unit-tests/test_utils.py b/tests/unit-tests/test_utils.py index 63dd0495b1..f32d80451f 100644 --- a/tests/unit-tests/test_utils.py +++ b/tests/unit-tests/test_utils.py @@ -25,7 +25,7 @@ from graphstorm.model.utils import _get_data_range from graphstorm.model.utils import _exchange_node_id_mapping from graphstorm.model.utils import shuffle_predict -from graphstorm.gconstruct.utils import _save_maps +from graphstorm.gconstruct.utils import save_maps from graphstorm import get_feat_size from data_utils import generate_dummy_dist_graph @@ -171,7 +171,7 @@ def run_dist_save_embeddings(model_path, emb, worker_rank, th.distributed.destroy_process_group() def run_dist_shuffle_predict(pred, worker_rank, - world_size, node_id_mapping_file, backend, conn): + world_size, node_id_mapping_file, type, backend, conn): dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port='12345') th.distributed.init_process_group(backend=backend, @@ -181,7 +181,7 @@ def run_dist_shuffle_predict(pred, worker_rank, th.cuda.set_device(worker_rank) device = 'cuda:%d' % worker_rank - pred = shuffle_predict(pred, node_id_mapping_file, worker_rank, world_size, device) + pred = shuffle_predict(pred, node_id_mapping_file, type, worker_rank, world_size, device) conn.send(pred.detach().cpu().numpy()) if worker_rank == 0: @@ -194,18 +194,18 @@ def run_dist_shuffle_predict(pred, worker_rank, def test_shuffle_predict(num_embs, backend): import tempfile - # single embedding + # node_mapping is tensor with tempfile.TemporaryDirectory() as tmpdirname: pred, nid_mapping = gen_predict_with_nid_mapping(num_embs) - _save_maps(tmpdirname, "node_mapping", nid_mapping) + save_maps(tmpdirname, "node_mapping", nid_mapping) nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt") ctx = mp.get_context('spawn') conn1, conn2 = mp.Pipe() p0 = ctx.Process(target=run_dist_shuffle_predict, - args=(pred, 0, 2, nid_mapping_file, backend, conn2)) + args=(pred, 0, 2, nid_mapping_file, None, backend, conn2)) conn3, conn4 = mp.Pipe() p1 = ctx.Process(target=run_dist_shuffle_predict, - args=(pred, 1, 2, nid_mapping_file, backend, conn4)) + args=(pred, 1, 2, nid_mapping_file, None, backend, conn4)) p0.start() p1.start() @@ -226,6 +226,39 @@ def test_shuffle_predict(num_embs, backend): # Load saved embeddings assert_equal(pred[nid_mapping].numpy(), shuffled_pred) + # node mapping is a dict + with tempfile.TemporaryDirectory() as tmpdirname: + pred, nid_mapping = gen_predict_with_nid_mapping(num_embs) + nid_mapping = {"node": nid_mapping} + save_maps(tmpdirname, "node_mapping", nid_mapping) + nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt") + ctx = mp.get_context('spawn') + conn1, conn2 = mp.Pipe() + p0 = ctx.Process(target=run_dist_shuffle_predict, + args=(pred, 0, 2, nid_mapping_file, "node", backend, conn2)) + conn3, conn4 = mp.Pipe() + p1 = ctx.Process(target=run_dist_shuffle_predict, + args=(pred, 1, 2, nid_mapping_file, "node", backend, conn4)) + + p0.start() + p1.start() + p0.join() + p1.join() + assert p0.exitcode == 0 + assert p1.exitcode == 0 + + shuffled_pred_1 = conn1.recv() + shuffled_pred_2 = conn3.recv() + conn1.close() + conn2.close() + conn3.close() + conn4.close() + + shuffled_pred = np.concatenate([shuffled_pred_1, shuffled_pred_2]) + + # Load saved embeddings + assert_equal(pred[nid_mapping["node"]].numpy(), shuffled_pred) + # TODO: Only test gloo now # Will add test for nccl once we enable nccl @pytest.mark.parametrize("num_embs", [16, 17]) @@ -236,7 +269,7 @@ def test_save_embeddings_with_id_mapping(num_embs, backend): # single embedding with tempfile.TemporaryDirectory() as tmpdirname: emb, nid_mapping = gen_embedding_with_nid_mapping(num_embs) - _save_maps(tmpdirname, "node_mapping", nid_mapping) + save_maps(tmpdirname, "node_mapping", nid_mapping) nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt") ctx = mp.get_context('spawn') p0 = ctx.Process(target=run_dist_save_embeddings, @@ -272,7 +305,7 @@ def test_save_embeddings_with_id_mapping(num_embs, backend): embs['n2'] = emb nid_mappings['n2'] = nid_mapping - _save_maps(tmpdirname, "node_mapping", nid_mappings) + save_maps(tmpdirname, "node_mapping", nid_mappings) nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt") ctx = mp.get_context('spawn') p0 = ctx.Process(target=run_dist_save_embeddings, diff --git a/tools/partition_graph.py b/tools/partition_graph.py index 4750f9c347..252133ead5 100644 --- a/tools/partition_graph.py +++ b/tools/partition_graph.py @@ -25,6 +25,7 @@ from graphstorm.data import OGBTextFeatDataset from graphstorm.data import MovieLens100kNCDataset from graphstorm.data import ConstructedGraphDataset +from graphstorm.gconstruct.utils import save_maps if __name__ == '__main__': argparser = argparse.ArgumentParser("Partition DGL graphs for node and edge classification " @@ -297,8 +298,20 @@ else: balance_ntypes = None - dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output, - part_method=args.part_method, - balance_ntypes=balance_ntypes, - balance_edges=args.balance_edges, - num_trainers_per_machine=args.num_trainers_per_machine) + mapping = \ + dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output, + part_method=args.part_method, + balance_ntypes=balance_ntypes, + balance_edges=args.balance_edges, + num_trainers_per_machine=args.num_trainers_per_machine, + return_mapping=True) + + new_node_mapping, new_edge_mapping = mapping + + # the new_node_mapping contains per entity type on the ith row + # the original node id for the ith node. + save_maps(args.output, "node_mapping", new_node_mapping) + # the new_edge_mapping contains per edge type on the ith row + # the original edge id for the ith edge. + save_maps(args.output, "edge_mapping", new_edge_mapping) + diff --git a/training_scripts/gsgnn_lp/arxiv_lp.yaml b/training_scripts/gsgnn_lp/arxiv_lp.yaml new file mode 100644 index 0000000000..30916cb666 --- /dev/null +++ b/training_scripts/gsgnn_lp/arxiv_lp.yaml @@ -0,0 +1,45 @@ +--- +version: 1.0 +gsf: + basic: + model_encoder_type: rgcn + backend: gloo + num_gpus: 4 + verbose: false + gnn: + fanout: "3" + n_layers: 1 + n_hidden: 128 + mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + save_model_per_iters: 1000 + hyperparam: + dropout: 0. + lr: 0.001 + bert_tune_lr: 0.0001 + num_epochs: 3 + batch_size: 128 + eval_batch_size: 1024 + bert_infer_bs: 128 + wd_l2norm: 0 + no_validation: false + rgcn: + n_bases: -1 + use_self_loop: true + lp_decoder_type: "dot_product" + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + link_prediction: + num_negative_edges: 4 + num_negative_edges_eval: 100 + train_negative_sampler: joint + eval_etype: + - "node,interacts,node" + train_etype: + - "node,interacts,node" + exclude_training_targets: false + reverse_edge_types_map: [] \ No newline at end of file