-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
knn retrieval evaluation #770
base: main
Are you sure you want to change the base?
Changes from 3 commits
b9bad61
ae6b91f
597b9a8
63a8091
2e462f0
942e8c6
884876b
29b60eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch as th | ||
import time | ||
import graphstorm as gs | ||
from graphstorm.utils import is_distributed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to move import graphstorm related code together. import system/builtin libraries like os, time, etc. |
||
import faiss | ||
import dgl | ||
import numpy as np | ||
from collections import defaultdict | ||
from graphstorm.config import get_argument_parser | ||
from graphstorm.config import GSConfig | ||
from graphstorm.dataloading import GSgnnNodeDataLoader | ||
from graphstorm.dataloading import GSgnnNodeTrainData | ||
from graphstorm.utils import setup_device | ||
from graphstorm.model.utils import load_gsgnn_embeddings | ||
|
||
def calculate_recall(pred, ground_truth): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you give a description of how you compute recall in the function doc? |
||
# Convert list_data to a set if it's not already a set | ||
if not isinstance(pred, set): | ||
pred = set(pred) | ||
|
||
overlap = len(pred & ground_truth) | ||
#if overlap > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the comments. |
||
# return 1 | ||
#else: | ||
# return 0 | ||
return overlap / len(ground_truth) | ||
|
||
def main(config_args): | ||
""" main function | ||
""" | ||
config = GSConfig(config_args) | ||
embs = load_gsgnn_embeddings(config.save_embed_path) | ||
|
||
index_dimension = embs[config.target_ntype].size(1) | ||
# Number of clusters (higher values lead to better recall but slower search) | ||
#nlist = 750 | ||
#quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the commented codes. |
||
#index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT) | ||
#index.train(embs[config.target_ntype]) | ||
index = faiss.IndexFlatIP(index_dimension) | ||
index.add(embs[config.target_ntype]) | ||
|
||
#print(scores.abs().mean()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
||
gs.initialize(ip_config=config.ip_config, backend=config.backend) | ||
device = setup_device(config.local_rank) | ||
#index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(embedding_size)) | ||
# Define the training dataset | ||
train_data = GSgnnNodeTrainData( | ||
config.graph_name, | ||
config.part_config, | ||
train_ntypes=config.target_ntype, | ||
eval_ntypes=config.eval_target_ntype, | ||
label_field=None, | ||
node_feat_field=None, | ||
) | ||
#for i in range(embs[config.target_ntype].shape[0]): | ||
# print(embs[config.target_ntype][i,:].sum(), train_data.g.ndata['bert_h'][i].sum()) | ||
# breakpoint() | ||
# embs[config.target_ntype][i,:] = train_data.g.ndata['bert_h'][i] | ||
|
||
#print( train_data.g.ndata['bert_h'][0,:], embs[config.target_ntype][0,:]) | ||
#print(train_data.g.ndata['bert_h']) | ||
|
||
# TODO: devise a dataloader that can exclude targets and add train_mask like LP Loader | ||
test_dataloader = GSgnnNodeDataLoader( | ||
train_data, | ||
train_data.train_idxs, | ||
fanout=[-1], | ||
batch_size=config.eval_batch_size, | ||
device=device, | ||
train_task=False, | ||
) | ||
dataloader_iter = iter(test_dataloader) | ||
len_dataloader = max_num_batch = len(test_dataloader) | ||
tensor = th.tensor([len_dataloader], device=device) | ||
if is_distributed(): | ||
th.distributed.all_reduce(tensor, op=th.distributed.ReduceOp.MAX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to make it distributed? |
||
max_num_batch = tensor[0] | ||
recall = [] | ||
max_ = [] | ||
for iter_l in range(max_num_batch): | ||
ground_truth = defaultdict(set) | ||
input_nodes, seeds, blocks = next(dataloader_iter) | ||
#block_graph = dgl.block_to_graph(blocks[0]) | ||
src_id = blocks[0].srcdata[dgl.NID].tolist() | ||
dst_id = blocks[0].dstdata[dgl.NID].tolist() | ||
#print(blocks[0].edges(form='uv', etype='also_buy')) | ||
#breakpoint() | ||
# print(dgl.NID) | ||
if 'also_buy' in blocks[0].etypes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this implemented specifically for amazon review? |
||
#src, dst = block_graph.edges(form='uv', etype='also_buy') | ||
src, dst = blocks[0].edges(form='uv', etype='also_buy') | ||
for s,d in zip(src.tolist(),dst.tolist()): | ||
ground_truth[dst_id[d]].add(src_id[s]) | ||
#ground_truth[src_id[s]].add(dst_id[d]) | ||
if 'also_buy-rev' in blocks[0].etypes: | ||
#src, dst = block_graph.edges(form='uv', etype='also_buy-rev') | ||
src, dst = blocks[0].edges(form='uv', etype='also_buy-rev') | ||
for s,d in zip(src.tolist(),dst.tolist()): | ||
ground_truth[dst_id[d]].add(src_id[s]) | ||
#ground_truth[src_id[s]].add(dst_id[d]) | ||
query_idx = list(ground_truth.keys()) | ||
#print(ground_truth) | ||
#breakpoint() | ||
ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1) | ||
#knn_result = lll.tolist() | ||
|
||
for idx,query in enumerate(query_idx): | ||
recall.append(calculate_recall(lll[idx, 1:], ground_truth[query])) | ||
max_.append(query) | ||
#print(recall) | ||
if gs.get_rank() == 0: | ||
#print(query_idx, lll) | ||
#print(max_num_batch, len(recall), np.mean(recall)) | ||
print(f'recall@100: {np.mean(recall)}') | ||
|
||
def generate_parser(): | ||
"""Generate an argument parser""" | ||
parser = get_argument_parser() | ||
return parser | ||
|
||
if __name__ == "__main__": | ||
arg_parser = generate_parser() | ||
|
||
args = arg_parser.parse_args() | ||
print(args) | ||
main(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
gsf: | ||
basic: | ||
backend: gloo | ||
verbose: false | ||
save_perf_results_path: null | ||
gnn: | ||
model_encoder_type: mlp | ||
fanout: "5,5" | ||
node_feat_name: | ||
- item:bert_h | ||
num_layers: 2 | ||
hidden_size: 768 | ||
use_mini_batch_infer: true | ||
input: | ||
restore_model_path: null | ||
output: | ||
save_model_path: null | ||
save_embed_path: /shared_data/graphstorm/examples/peft_llm_gnn/results/lp/Video_Games | ||
hyperparam: | ||
dropout: 0. | ||
lr: 0.001 | ||
num_epochs: 1 | ||
batch_size: 512 | ||
eval_batch_size: 512 | ||
wd_l2norm: 0.00001 | ||
no_validation: false | ||
rgcn: | ||
num_bases: -1 | ||
use_self_loop: true | ||
lp_decoder_type: dot_product | ||
sparse_optimizer_lr: 1e-2 | ||
use_node_embeddings: false | ||
link_prediction: | ||
num_negative_edges: 1 | ||
num_negative_edges_eval: 100 | ||
contrastive_loss_temperature: 0.1 | ||
lp_loss_func: contrastive | ||
lp_embed_normalizer: l2_norm | ||
train_negative_sampler: inbatch_joint | ||
target_ntype: item | ||
eval_etype: | ||
- "item,also_buy,item" | ||
train_etype: | ||
- "item,also_buy,item" | ||
exclude_training_targets: true | ||
reverse_edge_types_map: ["item,also_buy,also_buy-rev,item"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
WORKSPACE=/shared_data/graphstorm/examples/knn_retriever/ | ||
DATASPACE=/shared_data/graphstorm/examples/peft_llm_gnn/ | ||
dataset=amazon_review | ||
domain=$1 | ||
|
||
python -m graphstorm.run.launch \ | ||
--workspace "$WORKSPACE" \ | ||
--part-config "$DATASPACE"/datasets/amazon_review_"$domain"/amazon_review.json \ | ||
--ip-config "$DATASPACE"/ip_list.txt \ | ||
--num-trainers 1 \ | ||
--num-servers 1 \ | ||
--num-samplers 0 \ | ||
--ssh-port 22 \ | ||
--do-nid-remap False \ | ||
build_index.py \ | ||
--cf "$WORKSPACE"/embedding_config.yaml \ | ||
--save-model-path "$DATASPACE"/model/lp/"$domain"/ \ | ||
--save-embed-path "$DATASPACE"/results/lp/"$domain"/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel | ||
import dgl | ||
import pandas as pd | ||
|
||
from ..config import GRAPHSTORM_LP_EMB_L2_NORMALIZATION | ||
from ..gconstruct.file_io import stream_dist_tensors_to_hdf5 | ||
|
@@ -1065,6 +1066,32 @@ def save_full_node_embeddings(g, save_embed_path, | |
|
||
save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format) | ||
|
||
def load_gsgnn_embeddings(emb_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you check if load_pytorch_embedding is useful? |
||
'''Load from `save_full_node_embeddings` to a dict of DistTensor's | ||
''' | ||
with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f: | ||
emb_info = json.load(f) | ||
embs = {} | ||
for ntype in emb_info["emb_name"]: | ||
path = os.path.join(emb_path, ntype) | ||
ntype_emb_files = os.listdir(path) | ||
nid_files = [fname for fname in ntype_emb_files \ | ||
if fname.startswith("embed_nids-") and fname.endswith("pt")] | ||
emb_files = [fname for fname in ntype_emb_files \ | ||
if fname.startswith("embed-") and fname.endswith("pt")] | ||
num_parts = len(emb_files) | ||
embeddings_list = [] | ||
nid_list = [] | ||
for i in range(num_parts): | ||
embeddings_list.append(th.load(os.path.join(path, emb_files[i]))) | ||
nid_list.append(th.load(os.path.join(path, nid_files[i]))) | ||
# Convert the list of embeddings to a PyTorch tensor | ||
embeddings_tensor = th.cat(embeddings_list, dim=0) | ||
nids_tensor = th.cat(nid_list, dim=0) | ||
result_tensor = th.zeros_like(embeddings_tensor) | ||
result_tensor[nids_tensor] = embeddings_tensor | ||
embs[ntype] = result_tensor | ||
return embs | ||
|
||
def save_embeddings(emb_path, embeddings, rank, world_size, | ||
device=th.device('cpu'), node_id_mapping_file=None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a license head for each python code.