-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathevaluation.py
40 lines (29 loc) · 1.09 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import argparse
from eval_utils import DSTEvaluator
SLOT_META_PATH = './data/slot_meta.json'
def _evaluation(preds, labels, slot_meta):
evaluator = DSTEvaluator(slot_meta)
evaluator.init()
assert len(preds) == len(labels)
for k, l in labels.items():
p = preds.get(k)
if p is None:
raise Exception(f"{k} is not in the predictions!")
evaluator.update(l, p)
result = evaluator.compute()
print(result)
return result
def evaluation(gt_path, pred_path):
slot_meta = build_slot_meta(json.load(open(f"{args.data_dir}/wos-v1_train.json"))) # 45개의 slot
slot_meta = json.load(open(SLOT_META_PATH))
gts = json.load(open(gt_path))
preds = json.load(open(pred_path))
eval_result = _evaluation(preds, gts, slot_meta)
return eval_result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gt_path', type=str, required=True)
parser.add_argument('--pred_path', type=str, required=True)
args = parser.parse_args()
eval_result = evaluation(args.gt_path, args.pred_path)