-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_word_level.py
94 lines (75 loc) · 3.2 KB
/
eval_word_level.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
"""
Script to evaluate outputs of machine translation quality estimation
systems for the word level, in the WMT 2019 format.
"""
def read_tags(filename, only_gaps=False, only_words=False):
all_tags = []
with open(filename, 'r') as f:
for line in f:
tags = line.split()
if only_gaps:
tags = tags[::2]
elif only_words:
tags = tags[1::2]
all_tags.append(tags)
return all_tags
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('system', help='System file')
parser.add_argument('gold', help='Gold output file')
parser.add_argument('-v', help='Verbose, show additional metrics',
action='store_true', dest='verbose')
tag_type_group = parser.add_mutually_exclusive_group()
tag_type_group.add_argument('-w', help='Only evaluate word tags',
action='store_true', dest='only_words')
tag_type_group.add_argument('-g', help='Only evaluate gap tags',
action='store_true', dest='only_gaps')
args = parser.parse_args()
system_tags = read_tags(args.system, args.only_gaps, args.only_words)
gold_tags = read_tags(args.gold, args.only_gaps, args.only_words)
assert len(system_tags) == len(gold_tags), \
'Number of lines in system and gold file differ'
# true/false positives/negatives
tp = tn = fp = fn = 0
for i, (sys_sentence, gold_sentence) in enumerate(zip(system_tags,
gold_tags), 1):
assert len(sys_sentence) == len(gold_sentence), \
'Number of tags in system and gold file differ in line %d' % i
for sys_tag, gold_tag in zip(sys_sentence, gold_sentence):
if sys_tag == 'OK':
if sys_tag == gold_tag:
tp += 1
else:
fp += 1
else:
if sys_tag == gold_tag:
tn += 1
else:
fn += 1
total_tags = tp + tn + fp + fn
num_sys_ok = tp + fp
num_gold_ok = tp + fn
num_sys_bad = tn + fn
num_gold_bad = tn + fp
precision_ok = tp / num_sys_ok if num_sys_ok else 1.
recall_ok = tp / num_gold_ok if num_gold_ok else 0.
precision_bad = tn / num_sys_bad if num_sys_bad else 1.
recall_bad = tn / num_gold_bad if num_gold_bad else 0.
if precision_ok + recall_ok:
f1_ok = 2 * precision_ok * recall_ok / (precision_ok + recall_ok)
else:
f1_ok = 0.
if precision_bad + recall_bad:
f1_bad = 2 * precision_bad * recall_bad / (precision_bad + recall_bad)
else:
f1_bad = 0.
f1_mult = f1_ok * f1_bad
print('F1 OK: %.4f' % f1_ok)
print('F1 BAD: %.4f' % f1_bad)
print('F1 Mult: %.4f' % f1_mult)
if args.verbose:
mcc_numerator = (tp * tn) - (fp * fn)
mcc_denominator = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
mcc = mcc_numerator / mcc_denominator
print('Matthews correlation: %.4f' % mcc)