forked from snap-stanford/stark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
139 lines (111 loc) · 4.92 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import json
import os
import os.path as osp
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from stark_qa import load_qa, load_skb
from models import get_model
from stark_qa.tools.args import load_args, merge_args
def parse_args():
parser = argparse.ArgumentParser()
# Dataset and model selection
parser.add_argument("--dataset", default="amazon", choices=['amazon', 'prime', 'mag'])
parser.add_argument("--model", default="VSS", choices=["VSS", "MultiVSS", "LLMReranker"])
parser.add_argument("--split", default="test", choices=["train", "val", "test", "human_generated_eval"])
# Path settings
parser.add_argument("--emb_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
# Evaluation settings
parser.add_argument("--test_ratio", type=float, default=1.0)
# MultiVSS specific settings
parser.add_argument("--chunk_size", type=int, default=None)
parser.add_argument("--multi_vss_topk", type=int, default=None)
parser.add_argument("--aggregate", type=str, default="max")
# VSS, MultiVSS, and LLMReranker settings
parser.add_argument("--emb_model", type=str, default="text-embedding-ada-002")
# LLMReranker specific settings
parser.add_argument("--llm_model", type=str, default="gpt-4-1106-preview", help='the LLM to rerank candidates.')
parser.add_argument("--llm_topk", type=int, default=10)
parser.add_argument("--max_retry", type=int, default=3)
# Prediction saving settings
parser.add_argument("--save_pred", action="store_true")
parser.add_argument("--save_topk", type=int, default=500, help="topk predicted indices to save")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
default_args = load_args(
json.load(open("config/default_args.json", "r"))[args.dataset]
)
args = merge_args(args, default_args)
query_emb_surfix = f'_{args.split}' if args.split == 'human_generated_eval' else ''
args.query_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f"query{query_emb_surfix}")
args.node_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "doc")
args.chunk_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "chunk")
surfix = args.llm_model if args.model == 'LLMReranker' else args.emb_model
output_dir = osp.join(args.output_dir, "eval", args.dataset, args.model, surfix)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(args.query_emb_dir, exist_ok=True)
os.makedirs(args.chunk_emb_dir, exist_ok=True)
os.makedirs(args.node_emb_dir, exist_ok=True)
json.dump(vars(args), open(osp.join(output_dir, "args.json"), "w"), indent=4)
eval_csv_path = osp.join(output_dir, f"eval_results_{args.split}.csv")
final_eval_path = (
osp.join(output_dir, f"eval_metrics_{args.split}.json")
if args.test_ratio == 1.0
else osp.join(output_dir, f"eval_metrics_{args.split}_{args.test_ratio}.json")
)
kb = load_skb(args.dataset)
qa_dataset = load_qa(args.dataset, human_generated_eval=args.split == 'human_generated_eval')
model = get_model(args, kb)
split_idx = qa_dataset.get_idx_split(test_ratio=args.test_ratio)
eval_metrics = [
"mrr",
"map",
"rprecision",
"recall@5",
"recall@10",
"recall@20",
"recall@50",
"recall@100",
"hit@1",
"hit@3",
"hit@5",
"hit@10",
"hit@20",
"hit@50",
]
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
existing_idx = []
if osp.exists(eval_csv_path):
eval_csv = pd.read_csv(eval_csv_path)
existing_idx = eval_csv["idx"].tolist()
indices = split_idx[args.split].tolist()
for idx in tqdm(indices):
if idx in existing_idx:
continue
query, query_id, answer_ids, meta_info = qa_dataset[idx]
pred_dict = model.forward(query, query_id)
answer_ids = torch.LongTensor(answer_ids)
result = model.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
result["idx"], result["query_id"] = idx, query_id
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
:args.save_topk
]
].tolist()
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
if args.save_pred:
eval_csv.to_csv(eval_csv_path, index=False)
for metric in eval_metrics:
print(
f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(indices)][metric])}"
)
if args.save_pred:
eval_csv.to_csv(eval_csv_path, index=False)
final_metrics = (
eval_csv[eval_csv["idx"].isin(indices)][eval_metrics].mean().to_dict()
)
json.dump(final_metrics, open(final_eval_path, "w"), indent=4)