forked from Snowdar/asv-subtools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomputeEER-like-Bosaris.py
executable file
·122 lines (97 loc) · 3.04 KB
/
computeEER-like-Bosaris.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright xmuspeech (Author:Snowdar 2019-01-10)
# It is a little different with Kaldi method.
# By this method, EER is estimated by avaraging the error rates of two points nearby center.
import sys
import argparse
def get_args():
# Start
parser = argparse.ArgumentParser(
description="""Compute EER.""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
conflict_handler='resolve')
# Main
parser.add_argument("trials_path", metavar="trials_path", type=str, help="The path of trials.")
parser.add_argument("score_path", metavar="score_path", type=str, help="The path of the scores.")
# End
print(' '.join(sys.argv))
args = parser.parse_args()
return args
def load_data(data_path,n):
list=[]
print("Load data from "+data_path+"...")
with open(data_path,'r') as f:
content=f.readlines()
for line in content:
line=line.strip()
data_list=line.split()
if(n!=len(data_list)):
print('[Error] The %s file has no %s fields'%(data_path,n))
exit(1)
list.append(data_list)
return list
def abs(x):
if(x<0):
return -x
else:
return x
def compute_eer(allScores):
numP=0
numN=0
for x in allScores:
if(x[1]=="target"):
x[1]=1
numP=numP+1
elif(x[1]=="nontarget"):
x[1]=0
numN=numN+1
else:
print("[Error in compute_eer()] %s is not target or nontarget in score"%(x[1]))
exit(1)
allScores=sorted(allScores,reverse=False)
numFA=numN
numFR=0
eer=0.0
threshold=0.0
memory=[]
for tuple in allScores:
if(tuple[1]==1):
numFR=numFR+1
else:
numFA=numFA-1
far=numFA*1.0/numN
frr=numFR*1.0/numP
if(far<=frr):
lnow=abs(far-frr)
lmemory=abs(memory[0]-memory[1])
if(lnow<=lmemory):
eer=(far+frr)/2
threshold=tuple[0]
else:
eer=(memory[0]+memory[1])/2
threshold=memory[2]
return eer, threshold
else:
memory=[far,frr,tuple[0]]
def main():
args = get_args()
try:
trials = load_data(args.trials_path, 3)
scores = load_data(args.score_path, 3)
allScores = []
label_dict = {}
for x in trials:
label_dict[x[0]+x[1]]=x[2]
for x in scores:
allScores.append([float(x[2]),label_dict[x[0]+x[1]]])
eer, threshold = compute_eer(allScores)
print("EER% {:.3f} (threshold = {:.5f})".format(eer*100, threshold))
except BaseException as e:
# Look for BaseException so we catch KeyboardInterrupt, which is
# what we get when a background thread dies.
if not isinstance(e, KeyboardInterrupt):
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()