forked from tomgoldstein/loss-landscape
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
66 lines (58 loc) · 2.37 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
The calculation to be performed at each point (modified model), evaluating
the loss value, accuracy and eigen values of the hessian matrix
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from torch.autograd.variable import Variable
def eval_loss(net, criterion, loader, use_cuda=False):
"""
Evaluate the loss value for a given 'net' on the dataset provided by the loader.
Args:
net: the neural net model
criterion: loss function
loader: dataloader
use_cuda: use cuda or not
Returns:
loss value and accuracy
"""
correct = 0
total_loss = 0
total = 0 # number of samples
num_batch = len(loader)
if use_cuda:
net.cuda()
net.eval()
with torch.no_grad():
if isinstance(criterion, nn.CrossEntropyLoss):
for batch_idx, (inputs, targets) in enumerate(loader):
batch_size = inputs.size(0)
total += batch_size
inputs = Variable(inputs)
targets = Variable(targets)
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item()*batch_size
_, predicted = torch.max(outputs.data, 1)
correct += predicted.eq(targets).sum().item()
elif isinstance(criterion, nn.MSELoss):
for batch_idx, (inputs, targets) in enumerate(loader):
batch_size = inputs.size(0)
total += batch_size
inputs = Variable(inputs)
one_hot_targets = torch.FloatTensor(batch_size, 10).zero_()
one_hot_targets = one_hot_targets.scatter_(1, targets.view(batch_size, 1), 1.0)
one_hot_targets = one_hot_targets.float()
one_hot_targets = Variable(one_hot_targets)
if use_cuda:
inputs, one_hot_targets = inputs.cuda(), one_hot_targets.cuda()
outputs = F.softmax(net(inputs))
loss = criterion(outputs, one_hot_targets)
total_loss += loss.item()*batch_size
_, predicted = torch.max(outputs.data, 1)
correct += predicted.cpu().eq(targets).sum().item()
return total_loss/total, 100.*correct/total