-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcollect_results.py
79 lines (61 loc) · 3.09 KB
/
collect_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import json
import numpy as np
def process_each_file(file_path):
with open(file_path, 'r') as textFile:
context = textFile.readlines()
val_line, test_line = context[-2:]
val_line, test_line = val_line.strip(), test_line.strip()
val_line_start = val_line.index('Val Mean&Var') + len('Val Mean&Var') + 2
cur_context = val_line[val_line_start+1: -1]
vaL_res = list(eval(cur_context))
test_line_start = test_line.index('Test Mean&Var') + len('Test Mean&Var') + 2
cur_context = test_line[test_line_start+1: -1]
test_res = list(eval(cur_context))
return vaL_res, test_res
if __name__ == '__main__':
file_root = './logs/ToyCircle/MMD_LSAE/'
# file_root = './logs/ToyCircle_C/MMD_LSAE/'
# file_root = './logs/RMNIST/MMD_LSAE/'
files = os.listdir(file_root)
files.sort()
precision = 2
val_res_list, test_res_list = [], []
file_nums = 0
for file_name in files:
# filter out non-relative files
if file_name.startswith('train') and file_name.endswith('.txt'):
file_path = os.path.join(file_root, file_name)
print('Reading file:{}'.format(file_path))
file_nums += 1
val_res, test_res = process_each_file(file_path)
val_res_list.append(val_res)
test_res_list.append(test_res)
# if file_nums == 2:
# break
val_mean, val_var, test_mean, test_var = [], [], [], []
val_iterator = zip(*val_res_list)
test_iterator = zip(*test_res_list)
for items in val_iterator:
mean_list = [mean_value for (mean_value, _) in items]
var_list = [var_value for (_, var_value) in items]
val_mean.append(mean_list)
val_var.append(var_list)
val_mean_each_domain = np.around(np.mean(val_mean, axis=1), decimals=precision)
val_var_each_domain = np.around(np.mean(val_var, axis=1), decimals=precision)
val_res_each_domain = [(mean_value, val_value) for (mean_value, val_value) in zip(val_mean_each_domain, val_var_each_domain)]
for items in test_iterator:
mean_list = [mean_value for (mean_value, _) in items]
var_list = [var_value for (_, var_value) in items]
test_mean.append(mean_list)
test_var.append(var_list)
test_mean_each_domain = np.around(np.mean(test_mean, axis=1), decimals=precision)
test_var_each_domain = np.around(np.mean(test_var, axis=1), decimals=precision)
test_res_each_domain = [(mean_value, val_value) for (mean_value, val_value) in zip(test_mean_each_domain, test_var_each_domain)]
print('=========================================')
print('Total Val results:{:.2f} ± {:.2f} %'.format(np.mean(val_mean), np.mean(val_var)))
print('Total Val each domain:{}'.format(val_res_each_domain))
print('Total Test results:{:.2f} ± {:.2f} %'.format(np.mean(test_mean), np.mean(test_var)))
test_mean_each_seed = np.around(np.mean(test_mean, axis=0), decimals=precision)
print('Total Test each seed:{}'.format(test_mean_each_seed))
print('Total Test each domain:{}'.format(test_res_each_domain))