-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
66 lines (57 loc) · 2.25 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import random
import torch
from fastNLP import logger, DataSet, Tester, AccuracyMetric, SequentialSampler
from revmux.utils import set_random_seed, set_input_target
def evaluation_test(testing_time, test_data, model, task_name, set_batch_size=None,
input_fields=['input_ids', 'attention_mask', 'decoder_input_ids',
'decoder_attention_mask', 'target_ids']):
set_random_seed(4141)
idx_list = [_ for _ in range(len(test_data))]
test_data_list = []
for _ in range(testing_time):
random.shuffle(idx_list)
test_data_list.append(DataSet())
for idx in idx_list:
test_data_list[-1].append(test_data[idx])
test_data_list[-1] = set_input_target(test_data_list[-1], input_fields=input_fields)
logger.info(f'successfully append test data {testing_time} times.')
if set_batch_size is not None:
batch_size = set_batch_size
else:
if task_name in ['rte']:
batch_size = 16
else:
batch_size = 32
model.mode = 'normal'
model.eval()
total_acc = 0.
for t in range(testing_time):
logger.info(f'{"=" * 30}')
logger.info(f'Testing {t + 1}-th time ({len(test_data_list[t])} testing examples):')
tester = Tester(
test_data_list[t],
model,
AccuracyMetric(pred='logits', target='labels'),
batch_size=batch_size,
device=[_ for _ in range(torch.cuda.device_count())],
sampler=SequentialSampler(),
verbose=10,
)
logger.info(f'Tester.verbose: {tester.verbose}')
eval_results = tester.test()
acc = eval_results['AccuracyMetric']['acc']
total_acc += acc
logger.info(f'average acc: {round(total_acc / (t + 1), 6)}')
model.mode = 'teacher_only'
tester = Tester(
test_data,
model,
AccuracyMetric(pred='logits', target='labels'),
batch_size=batch_size,
device=[_ for _ in range(torch.cuda.device_count())],
sampler=SequentialSampler(),
verbose=10
)
eval_results = tester.test()
teacher_acc = eval_results['AccuracyMetric']['acc']
logger.info(f'Accuracy of Teacher-Only mode is: {round(teacher_acc, 4)}')