-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_sentence_level.py
49 lines (39 loc) · 1.59 KB
/
eval_sentence_level.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import numpy as np
from scipy.stats import pearsonr, spearmanr
import pdb
"""
Script to evaluate outputs of machine translation quality estimation
systems for the sentence level, in the WMT 2019 format.
The system output and gold files should have one HTER value per line.
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('system', help='System file')
parser.add_argument('gold', help='Gold output file')
parser.add_argument('-v', action='store_true', dest='verbose',
help='Show all metrics (Pearson r, Spearman r, MAE, '
'RMSE). By default, it only computes Pearson r.')
args = parser.parse_args()
system = np.loadtxt(args.system)
# def zero_to_one(pred):
# if pred < 0:
# return 0
# elif pred > 1:
# return 1
# else:
# return pred
# system = np.array([zero_to_one(pred) for pred in system.tolist()])
gold = np.loadtxt(args.gold)
assert len(system) == len(gold), 'Number of gold and system values differ'
# pearsonr and spearmanr return (correlation, p_value)
pearson = pearsonr(gold, system)[0]
print('Pearson correlation: %.4f' % pearson)
if args.verbose:
spearman = spearmanr(gold, system)[0]
diff = gold - system
mae = np.abs(diff).mean()
rmse = (diff ** 2).mean() ** 0.5
print('Spearman correlation: %.4f' % spearman)
print('MAE: %.4f' % mae)
print('RMSE: %.4f' % rmse)