-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_mnist
executable file
·78 lines (63 loc) · 1.67 KB
/
test_mnist
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/bin/bash
INFO=$'
Usage:
\ttest_mnist <saved.model>
\ttest_mnist <saved.model> capture
\ttest_mnist <saved.model> <images dir>
'
if [ ! -d "$1" ]; then
echo "$INFO" 1>&2
exit 1
fi
SCRIPT_DIR=$(readlink -m "$0/..") || exit 1
export PATH="$SCRIPT_DIR:$PATH"
TEST_PY='
import sys
import numpy as np
import tensorflow.examples.tutorials.mnist as tf_mnist
import loadsave
from utils import *
mnist = tf_mnist.input_data.read_data_sets("__mnist__")
with loadsave.load(sys.argv[1]) as s:
data = np.reshape(mnist.test.images, [-1, 28, 28]), mnist.test.labels
total_ok = 0
for x, t in random_batches(data, 1000):
y = s.run(s.y, feed_dict = {s.x: x})
mistakes = np.not_equal(np.argmax(y, 1), t)
total_ok += len(x) - np.sum(mistakes)
write_array(sys.stdout.buffer, x[mistakes])
print_info("Accuracy: %.2f%% (%d/%d)" % (
100.0 * (total_ok / len(data[0])),
total_ok,
len(data[0])
))
'
function print_mistakes() {
echo "$TEST_PY" | python3 - "$@"
}
if [ "x$2" = xcapture ]; then
image_capture 200,200 | \
image_write 200,200 -resize 28x28 GRAY:- | \
model_run "$1" x:0 y:0 | \
image_show 10,1 || exit 1
exit 0
fi
if [ "x$2" != x -a -d "$2" ]; then
N=$(ls "$2" | wc -l)
if ! [ "$N" -ge 1 ]; then
echo "Error: no images found in '$2'" 1>&2
exit 1
fi
image_read 28,28 "$2" | \
model_run "$1" x:0 y:0 | \
od -v -An -t u1 -w10 || exit 1
exit 0
fi
if [ "x$2" != x ]; then
mkdir "$2" || exit 1
print_mistakes "$1" | \
image_write 28,28 "$2/%05d.png" || exit 1
exit 0
fi
print_mistakes "$1" >/dev/null || exit 1
exit 0