Skip to content

Commit

Permalink
[BugFix ]Fix emb shuffle bug in node_infer.py (#185)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored May 20, 2023
1 parent 8b742be commit 6c31846
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 57 deletions.
6 changes: 3 additions & 3 deletions python/graphstorm/gconstruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __call__(self, arrs, name):
em_arr[:] = arr[:]
return em_arr

def _save_maps(output_dir, fname, map_data):
def save_maps(output_dir, fname, map_data):
""" Save node id mapping or edge id mapping
Parameters
Expand Down Expand Up @@ -329,7 +329,7 @@ def partition_graph(g, node_data, edge_data, graph_name, num_partitions, output_

# the new_node_mapping contains per entity type on the ith row
# the original node id for the ith node.
_save_maps(output_dir, "node_mapping", new_node_mapping)
save_maps(output_dir, "node_mapping", new_node_mapping)
# the new_edge_mapping contains per edge type on the ith row
# the original edge id for the ith edge.
_save_maps(output_dir, "edge_mapping", new_edge_mapping)
save_maps(output_dir, "edge_mapping", new_edge_mapping)
41 changes: 23 additions & 18 deletions python/graphstorm/inference/ep_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,25 @@ def infer(self, loader, save_embed_path, save_prediction_path=None,
assert len(infer_data.eval_etypes) == 1, \
"GraphStorm only support single target edge type for training and inference"

target_ntypes = set()
for etype in infer_data.eval_etypes:
target_ntypes.add(etype[0])
target_ntypes.add(etype[2])
embs = {ntype: embs[ntype] for ntype in target_ntypes}
# do evaluation first
if do_eval:
test_start = time.time()
val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0)
sys_tracker.check('run evaluation')
if self.rank == 0:
self.log_print_metrics(val_score=val_score,
test_score=test_score,
dur_eval=time.time() - test_start,
total_steps=0)

if save_embed_path is not None:
target_ntypes = set()
for etype in infer_data.eval_etypes:
target_ntypes.add(etype[0])
target_ntypes.add(etype[2])

# The order of the ntypes must be sorted
embs = {ntype: embs[ntype] for ntype in sorted(target_ntypes)}
device = th.device(f"cuda:{self.dev_id}") \
if self.dev_id >= 0 else th.device("cpu")
save_gsgnn_embeddings(save_embed_path, embs, self.rank,
Expand All @@ -102,27 +115,19 @@ def infer(self, loader, save_embed_path, save_prediction_path=None,
if edge_id_mapping_file is not None:
g = loader.data.g
etype = infer_data.eval_etypes[0]
pred_data = DistTensor((g.num_edges(etype), pred.shape[1]),
pred_shape = list(pred.shape)
pred_shape[0] = g.num_edges(etype)
pred_data = DistTensor(pred_shape,
dtype=pred.dtype, name='predict-'+'-'.join(etype),
part_policy=g.get_edge_partition_policy(etype),
# TODO: this makes the tensor persistent in memory.
persistent=True)
# edges that have predictions may be just a subset of the
# entire edge set.
pred_data[loader.target_eidx] = pred
pred_data[loader.target_eidx[etype]] = pred.cpu()

pred = shuffle_predict(pred_data, edge_id_mapping_file, self.rank,
pred = shuffle_predict(pred_data, edge_id_mapping_file, etype, self.rank,
th.distributed.get_world_size(), device=device)
save_prediction_results(pred, save_prediction_path, self.rank)
th.distributed.barrier()
sys_tracker.check('save predictions')

if do_eval:
test_start = time.time()
val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0)
sys_tracker.check('run evaluation')
if self.rank == 0:
self.log_print_metrics(val_score=val_score,
test_score=test_score,
dur_eval=time.time() - test_start,
total_steps=0)
48 changes: 32 additions & 16 deletions python/graphstorm/inference/np_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,43 +90,59 @@ def infer(self, loader, save_embed_path, save_prediction_path=None,
label = res[1] if do_eval else None
sys_tracker.check('compute embeddings')

embeddings = {ntype: embs[ntype]}
# do evaluation first
# do evaluation if any
if do_eval:
test_start = time.time()
val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0)
sys_tracker.check('run evaluation')
if self.rank == 0:
self.log_print_metrics(val_score=val_score,
test_score=test_score,
dur_eval=time.time() - test_start,
total_steps=0)

if save_embed_path is not None:
if use_mini_batch_infer:
g = loader.data.g
ntype_emb = DistTensor((g.num_nodes(ntype), embs[ntype].shape[1]),
dtype=embs[ntype].dtype, name=f'gen-emb-{ntype}',
part_policy=g.get_node_partition_policy(ntype),
# TODO: this makes the tensor persistent in memory.
persistent=True)
# nodes that do prediction in mini-batch may be just a subset of the
# entire node set.
ntype_emb[loader.target_nidx[ntype]] = embs[ntype]
else:
ntype_emb = embs[ntype]
embeddings = {ntype: ntype_emb}

device = th.device(f"cuda:{self.dev_id}") \
if self.dev_id >= 0 else th.device("cpu")
save_gsgnn_embeddings(save_embed_path,
embeddings, self.rank, th.distributed.get_world_size(),
device=device,
node_id_mapping_file=node_id_mapping_file)
th.distributed.barrier()
sys_tracker.check('save embeddings')
sys_tracker.check('save embeddings')

if save_prediction_path is not None:
# shuffle pred results according to node_id_mapping_file
if node_id_mapping_file is not None:
g = loader.data.g

pred_data = DistTensor((g.num_nodes(ntype), pred.shape[1]),
pred_shape = list(pred.shape)
pred_shape[0] = g.num_nodes(ntype)
pred_data = DistTensor(pred_shape,
dtype=pred.dtype, name=f'predict-{ntype}',
part_policy=g.get_node_partition_policy(ntype),
# TODO: this makes the tensor persistent in memory.
persistent=True)
# nodes that have predictions may be just a subset of the
# entire node set.
pred_data[loader.target_nidx] = pred
pred = shuffle_predict(pred_data, node_id_mapping_file, self.rank,
pred_data[loader.target_nidx[ntype]] = pred.cpu()
pred = shuffle_predict(pred_data, node_id_mapping_file, ntype, self.rank,
th.distributed.get_world_size(), device=device)
save_prediction_results(pred, save_prediction_path, self.rank)
th.distributed.barrier()
sys_tracker.check('save predictions')

# do evaluation if any
if do_eval:
test_start = time.time()
val_score, test_score = self.evaluator.evaluate(pred, pred, label, label, 0)
sys_tracker.check('run evaluation')
if self.rank == 0:
self.log_print_metrics(val_score=val_score,
test_score=test_score,
dur_eval=time.time() - test_start,
total_steps=0)
25 changes: 22 additions & 3 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def _exchange_node_id_mapping(local_rank, world_size, device,
start_idx, end_idx = _get_data_range(i, world_size, num_embs)
data_tensors.append(
node_id_mapping[start_idx:end_idx].to(device))

else:
data_tensors = [th.empty((0,),
dtype=th.long,
Expand Down Expand Up @@ -413,13 +412,33 @@ def save_embeddings(model_path, embeddings, local_rank, world_size,
with open(os.path.join(model_path, "emb_info.json"), 'w', encoding='utf-8') as f:
f.write(json.dumps(emb_info))

def shuffle_predict(predictions, id_mapping_file,
def shuffle_predict(predictions, id_mapping_file, pred_type,
local_rank, world_size, device):
""" Shuffle prediction result according to id_mapping
Parameters
----------
predictions: dgl DistTensor
prediction results
id_mapping_file: str
Path to the file storing node id mapping or edge id mapping generated by the
graph partition algorithm.
pred_type: str or tuple
Node type or edge type of the prediction target.
local_rank : int
Local rank
world_size : int
World size in a distributed env.
device : torch device
Device used to do data shuffling.
"""
id_mapping = th.load(id_mapping_file) if local_rank == 0 else None
# In most of cases, id_mapping is a dict for heterogeneous graph.
# For homogeneous graph, it is just a tensor.
id_mapping = id_mapping[pred_type] if isinstance(id_mapping, dict) else id_mapping
local_id_mapping = _exchange_node_id_mapping(
local_rank, world_size, device, id_mapping, len(predictions))
local_rank, world_size, device, id_mapping,
len(predictions)).cpu() # predictions are stored in CPU
return predictions[local_id_mapping]

def save_prediction_results(predictions, prediction_path, local_rank):
Expand Down
10 changes: 9 additions & 1 deletion python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,15 @@ def main(args):
model.prepare_input_encoder(train_data)
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
# only save node embeddings of nodes with node types from target_etype
target_ntypes = set()
for etype in config.target_etype:
target_ntypes.add(etype[0])
target_ntypes.add(etype[2])

# The order of the ntypes must be sorted
embs = {ntype: embeddings[ntype] for ntype in sorted(target_ntypes)}
save_embeddings(config.save_embed_path, embs, gs.get_rank(),
th.distributed.get_world_size(),
device=device,
node_id_mapping_file=config.node_id_mapping_file)
Expand Down
13 changes: 12 additions & 1 deletion tests/end2end-tests/check_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
help="Path to embedding saved by trainer")
argparser.add_argument("--link_prediction", action='store_true',
help="Path to embedding saved by trainer")
argparser.add_argument("--mini-batch-infer", action='store_true',
help="Inference use minibatch inference.")
args = argparser.parse_args()
with open(os.path.join(args.train_embout, "emb_info.json"), 'r', encoding='utf-8') as f:
train_emb_info = json.load(f)
Expand Down Expand Up @@ -69,4 +71,13 @@

assert train_emb.shape[0] == infer_emb.shape[0]
assert train_emb.shape[1] == infer_emb.shape[1]
assert_almost_equal(train_emb.numpy(), infer_emb.numpy(), decimal=2)

if args.mini_batch_infer:
# When inference is done with minibatch inference, only node
# embeddings of the test set are computed.
for i in range(len(train_emb)):
if th.all(infer_emb[i] == 0.):
continue
assert_almost_equal(train_emb[i].numpy(), infer_emb[i].numpy(), decimal=4)
else:
assert_almost_equal(train_emb.numpy(), infer_emb.numpy(), decimal=2)
2 changes: 1 addition & 1 deletion tests/end2end-tests/graphstorm-ec/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ best_epoch=$(grep "successfully save the model to" train_log.txt | tail -1 | tr
echo "The best model is saved in epoch $best_epoch"

echo "**************dataset: Generated multilabel MovieLens EC, do inference on saved model"
python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/inference_scripts/ep_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_infer.yaml --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer false --save-embed-path /data/gsgnn_ec/infer-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ | tee log.txt
python3 -m graphstorm.run.gs_edge_classification --inference --workspace $GS_HOME/inference_scripts/ep_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_label_ec/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_ec_infer.yaml --multilabel true --num-classes 6 --node-feat-name feat --use-mini-batch-infer false --save-embed-path /data/gsgnn_ec/infer-emb/ --restore-model-path /data/gsgnn_ec/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_ec/prediction/ | tee log.txt

error_and_exit ${PIPESTATUS[0]}

Expand Down
9 changes: 9 additions & 0 deletions tests/end2end-tests/graphstorm-nc/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_n

error_and_exit $?

echo "**************dataset: Movielens, do inference on saved model with mini-batch-infer"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer true --save-embed-path /data/gsgnn_nc_ml/mini-infer-emb/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction/ | tee log.txt

error_and_exit ${PIPESTATUS[0]}

python3 $GS_HOME/tests/end2end-tests/check_infer.py --train_embout /data/gsgnn_nc_ml/emb/ --infer_embout /data/gsgnn_nc_ml/mini-infer-emb --mini-batch-infer

error_and_exit $?

echo "**************dataset: Movielens, do inference on saved model, decoder: dot with a single process"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers 1 --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml/infer-emb-1p/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction-1p/ | tee log.txt

Expand Down
51 changes: 42 additions & 9 deletions tests/unit-tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from graphstorm.model.utils import _get_data_range
from graphstorm.model.utils import _exchange_node_id_mapping
from graphstorm.model.utils import shuffle_predict
from graphstorm.gconstruct.utils import _save_maps
from graphstorm.gconstruct.utils import save_maps
from graphstorm import get_feat_size

from data_utils import generate_dummy_dist_graph
Expand Down Expand Up @@ -171,7 +171,7 @@ def run_dist_save_embeddings(model_path, emb, worker_rank,
th.distributed.destroy_process_group()

def run_dist_shuffle_predict(pred, worker_rank,
world_size, node_id_mapping_file, backend, conn):
world_size, node_id_mapping_file, type, backend, conn):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
th.distributed.init_process_group(backend=backend,
Expand All @@ -181,7 +181,7 @@ def run_dist_shuffle_predict(pred, worker_rank,
th.cuda.set_device(worker_rank)
device = 'cuda:%d' % worker_rank

pred = shuffle_predict(pred, node_id_mapping_file, worker_rank, world_size, device)
pred = shuffle_predict(pred, node_id_mapping_file, type, worker_rank, world_size, device)
conn.send(pred.detach().cpu().numpy())

if worker_rank == 0:
Expand All @@ -194,18 +194,18 @@ def run_dist_shuffle_predict(pred, worker_rank,
def test_shuffle_predict(num_embs, backend):
import tempfile

# single embedding
# node_mapping is tensor
with tempfile.TemporaryDirectory() as tmpdirname:
pred, nid_mapping = gen_predict_with_nid_mapping(num_embs)
_save_maps(tmpdirname, "node_mapping", nid_mapping)
save_maps(tmpdirname, "node_mapping", nid_mapping)
nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt")
ctx = mp.get_context('spawn')
conn1, conn2 = mp.Pipe()
p0 = ctx.Process(target=run_dist_shuffle_predict,
args=(pred, 0, 2, nid_mapping_file, backend, conn2))
args=(pred, 0, 2, nid_mapping_file, None, backend, conn2))
conn3, conn4 = mp.Pipe()
p1 = ctx.Process(target=run_dist_shuffle_predict,
args=(pred, 1, 2, nid_mapping_file, backend, conn4))
args=(pred, 1, 2, nid_mapping_file, None, backend, conn4))

p0.start()
p1.start()
Expand All @@ -226,6 +226,39 @@ def test_shuffle_predict(num_embs, backend):
# Load saved embeddings
assert_equal(pred[nid_mapping].numpy(), shuffled_pred)

# node mapping is a dict
with tempfile.TemporaryDirectory() as tmpdirname:
pred, nid_mapping = gen_predict_with_nid_mapping(num_embs)
nid_mapping = {"node": nid_mapping}
save_maps(tmpdirname, "node_mapping", nid_mapping)
nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt")
ctx = mp.get_context('spawn')
conn1, conn2 = mp.Pipe()
p0 = ctx.Process(target=run_dist_shuffle_predict,
args=(pred, 0, 2, nid_mapping_file, "node", backend, conn2))
conn3, conn4 = mp.Pipe()
p1 = ctx.Process(target=run_dist_shuffle_predict,
args=(pred, 1, 2, nid_mapping_file, "node", backend, conn4))

p0.start()
p1.start()
p0.join()
p1.join()
assert p0.exitcode == 0
assert p1.exitcode == 0

shuffled_pred_1 = conn1.recv()
shuffled_pred_2 = conn3.recv()
conn1.close()
conn2.close()
conn3.close()
conn4.close()

shuffled_pred = np.concatenate([shuffled_pred_1, shuffled_pred_2])

# Load saved embeddings
assert_equal(pred[nid_mapping["node"]].numpy(), shuffled_pred)

# TODO: Only test gloo now
# Will add test for nccl once we enable nccl
@pytest.mark.parametrize("num_embs", [16, 17])
Expand All @@ -236,7 +269,7 @@ def test_save_embeddings_with_id_mapping(num_embs, backend):
# single embedding
with tempfile.TemporaryDirectory() as tmpdirname:
emb, nid_mapping = gen_embedding_with_nid_mapping(num_embs)
_save_maps(tmpdirname, "node_mapping", nid_mapping)
save_maps(tmpdirname, "node_mapping", nid_mapping)
nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt")
ctx = mp.get_context('spawn')
p0 = ctx.Process(target=run_dist_save_embeddings,
Expand Down Expand Up @@ -272,7 +305,7 @@ def test_save_embeddings_with_id_mapping(num_embs, backend):
embs['n2'] = emb
nid_mappings['n2'] = nid_mapping

_save_maps(tmpdirname, "node_mapping", nid_mappings)
save_maps(tmpdirname, "node_mapping", nid_mappings)
nid_mapping_file = os.path.join(tmpdirname, "node_mapping.pt")
ctx = mp.get_context('spawn')
p0 = ctx.Process(target=run_dist_save_embeddings,
Expand Down
Loading

0 comments on commit 6c31846

Please sign in to comment.