Skip to content

Commit

Permalink
add pytorch for MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhsuanlin committed Dec 14, 2017
1 parent 7f3ea9d commit 5d70ba6
Show file tree
Hide file tree
Showing 19 changed files with 698 additions and 6 deletions.
121 changes: 121 additions & 0 deletions MNIST-pytorch/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import scipy.linalg
import os,time
import torch

import warp,util

# load MNIST data
def loadMNIST(fname):
if not os.path.exists(fname):
# download and preprocess MNIST dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
trainData,validData,testData = {},{},{}
trainData["image"] = mnist.train.images.reshape([-1,28,28]).astype(np.float32)
validData["image"] = mnist.validation.images.reshape([-1,28,28]).astype(np.float32)
testData["image"] = mnist.test.images.reshape([-1,28,28]).astype(np.float32)
trainData["label"] = np.argmax(mnist.train.labels.astype(np.float32),axis=1)
validData["label"] = np.argmax(mnist.validation.labels.astype(np.float32),axis=1)
testData["label"] = np.argmax(mnist.test.labels.astype(np.float32),axis=1)
os.makedirs(os.path.dirname(fname))
np.savez(fname,train=trainData,valid=validData,test=testData)
os.system("rm -rf MNIST_data")
MNIST = np.load(fname)
trainData = MNIST["train"].item()
validData = MNIST["valid"].item()
testData = MNIST["test"].item()
return trainData,validData,testData

# generate training batch
def genPerturbations(opt):
X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1])
Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1])
O = np.zeros([opt.batchSize,4],dtype=np.float32)
I = np.ones([opt.batchSize,4],dtype=np.float32)
dX = np.random.randn(opt.batchSize,4)*opt.pertScale \
+np.random.randn(opt.batchSize,1)*opt.transScale
dY = np.random.randn(opt.batchSize,4)*opt.pertScale \
+np.random.randn(opt.batchSize,1)*opt.transScale
dX,dY = dX.astype(np.float32),dY.astype(np.float32)
# fit warp parameters to generated displacements
if opt.warpType=="homography":
A = np.concatenate([np.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1),
np.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],axis=1)
b = np.expand_dims(np.concatenate([X+dX,Y+dY],axis=1),axis=-1)
pPert = np.matmul(np.linalg.inv(A),b).squeeze()
pPert -= np.array([1,0,0,0,1,0,0,0])
else:
if opt.warpType=="translation":
J = np.concatenate([np.stack([I,O],axis=-1),
np.stack([O,I],axis=-1)],axis=1)
if opt.warpType=="similarity":
J = np.concatenate([np.stack([X,Y,I,O],axis=-1),
np.stack([-Y,X,O,I],axis=-1)],axis=1)
if opt.warpType=="affine":
J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1),
np.stack([O,O,O,X,Y,I],axis=-1)],axis=1)
dXY = np.expand_dims(np.concatenate([dX,dY],axis=1),axis=-1)
Jtransp = np.transpose(J,axes=[0,2,1])
pPert = np.matmul(np.linalg.inv(np.matmul(Jtransp,J)),np.matmul(Jtransp,dXY)).squeeze()
pInit = util.toTorch(pPert)
return pInit

# make training batch
def makeBatch(opt,data):
N = len(data["image"])
randIdx = np.random.randint(N,size=[opt.batchSize])
batch = {
"image": util.toTorch(data["image"][randIdx]),
"label": util.toTorch(data["label"][randIdx]),
}
return batch

# evaluation on test set
def evalTest(opt,data,geometric,classifier):
geometric.eval()
classifier.eval()
N = len(data["image"])
batchN = int(np.ceil(N/opt.batchSize))
warped = [{},{}]
count = 0
for b in range(batchN):
# use some dummy data (0) as batch filler if necessary
if b!=batchN-1:
realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1))
else:
realIdx = np.arange(opt.batchSize*b,N)
idx = np.zeros([opt.batchSize],dtype=int)
idx[:len(realIdx)] = realIdx
# make training batch
image = util.toTorch(data["image"][idx])
label = util.toTorch(data["label"][idx])
image.data.unsqueeze_(dim=1)
# generate perturbation
pInit = genPerturbations(opt)
pInitMtrx = warp.vec2mtrx(opt,pInit)
imagePert = warp.transformImage(opt,image,pInitMtrx)
imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert)
imageWarp = imageWarpAll[-1]
output = classifier(opt,imageWarp)
_,pred = output.max(dim=1)
count += int(util.toNumpy((pred==label).sum()))
if opt.netType=="STN" or opt.netType=="IC-STN":
imgPert = util.toNumpy(imagePert)
imgWarp = util.toNumpy(imageWarp)
for i in range(len(realIdx)):
l = data["label"][idx[i]]
if l not in warped[0]: warped[0][l] = []
if l not in warped[1]: warped[1][l] = []
warped[0][l].append(imgPert[i])
warped[1][l].append(imgWarp[i])
accuracy = float(count)/N
if opt.netType=="STN" or opt.netType=="IC-STN":
mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]),
np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])]
var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]),
np.array([np.var(warped[1][l],axis=0) for l in warped[1]])]
else: mean,var = None,None
geometric.train()
classifier.train()
return accuracy,mean,var
74 changes: 74 additions & 0 deletions MNIST-pytorch/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import time,os,sys
import argparse
import util

print(util.toYellow("======================================================="))
print(util.toYellow("train.py (training on MNIST)"))
print(util.toYellow("======================================================="))

import torch
import data,graph,warp,util
import options

print(util.toMagenta("setting configurations..."))
opt = options.set(training=True)

# create directories for model output
util.mkdir("models_{0}".format(opt.group))

print(util.toMagenta("building network..."))
with torch.cuda.device(0):
classifier = graph.FullCNN(opt)
# ------ define loss ------
loss = torch.nn.CrossEntropyLoss()
# ------ optimizer ------
optimList = [{ "params": classifier.parameters(), "lr": opt.lrC }]
optim = torch.optim.SGD(optimList)

# load data
print(util.toMagenta("loading MNIST dataset..."))
trainData,validData,testData = data.loadMNIST("data/MNIST.npz")

print(util.toYellow("======= TRAINING START ======="))
timeStart = time.time()
# start session
with torch.cuda.device(0):
classifier.train()
print(util.toMagenta("start training..."))

# training loop
for i in range(opt.fromIt,opt.toIt):
# make training batch
batch = data.makeBatch(opt,trainData)
image = batch["image"].unsqueeze(dim=1)
label = batch["label"]
# generate perturbation
pInit = util.toTorch(np.zeros([opt.batchSize,opt.warpDim],dtype=np.float32))
pInitMtrx = warp.vec2mtrx(opt,pInit)
# forward/backprop through network
optim.zero_grad()
imagePert = warp.transformImage(opt,image,pInitMtrx)
# print((imagePert-image).data.cpu().mean(),(imagePert-image).data.cpu().var())
# img = image.data.cpu().numpy()
# imgpert = imagePert.data.cpu().numpy()
# util.imsave("temp1.png",img[0].squeeze())
# util.imsave("temp2.png",imgpert[0].squeeze())
# for i in range(100): util.imsave("temp/{0}_1.png".format(i),1-img[i].squeeze())
# for i in range(100): util.imsave("temp/{0}_2.png".format(i),1-imgpert[i].squeeze())
# print(imagePert[0,0],image[0,0])
# assert(False)
# forward/backprop through network
optim.zero_grad()
output = classifier(opt,imagePert)
train_loss = loss(output,label)
train_loss.backward()
# run one step
optim.step()
if (i+1)%100==0:
print("it. {0}/{1} loss={3}, time={2}"
.format(util.toCyan("{0}".format(i+1)),
opt.toIt,
util.toGreen("{0:.2f}".format(time.time()-timeStart)),
util.toRed("{0:.4f}".format(train_loss.data[0]))))
assert(False)
157 changes: 157 additions & 0 deletions MNIST-pytorch/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
import torch
import time
import data,warp,util

# build classification network
class FullCNN(torch.nn.Module):
def __init__(self,opt):
super(FullCNN,self).__init__()
self.inDim = 1
def conv2Layer(outDim):
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[3,3],stride=1,padding=0)
self.inDim = outDim
return conv
def linearLayer(outDim):
fc = torch.nn.Linear(self.inDim,outDim)
self.inDim = outDim
return fc
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
self.conv2Layers = torch.nn.Sequential(
conv2Layer(3),torch.nn.ReLU(True),
conv2Layer(6),torch.nn.ReLU(True),maxpoolLayer(),
conv2Layer(9),torch.nn.ReLU(True),
conv2Layer(12),torch.nn.ReLU(True)
)
self.inDim *= 8**2
self.linearLayers = torch.nn.Sequential(
linearLayer(48),torch.nn.ReLU(True),
linearLayer(opt.labelN)
)
initialize(opt,self,opt.stdC)
def forward(self,opt,image):
feat = image
feat = self.conv2Layers(feat).view(opt.batchSize,-1)
feat = self.linearLayers(feat)
output = feat
return output

# build classification network
class CNN(torch.nn.Module):
def __init__(self,opt):
super(CNN,self).__init__()
self.inDim = 1
def conv2Layer(outDim):
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[9,9],stride=1,padding=0)
self.inDim = outDim
return conv
def linearLayer(outDim):
fc = torch.nn.Linear(self.inDim,outDim)
self.inDim = outDim
return fc
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
self.conv2Layers = torch.nn.Sequential(
conv2Layer(3),torch.nn.ReLU(True)
)
self.inDim *= 20**2
self.linearLayers = torch.nn.Sequential(
linearLayer(opt.labelN)
)
initialize(opt,self,opt.stdC)
def forward(self,opt,image):
feat = image
feat = self.conv2Layers(feat).view(opt.batchSize,-1)
feat = self.linearLayers(feat)
output = feat
return output

# an identity class to skip geometric predictors
class Identity(torch.nn.Module):
def __init__(self): super(Identity,self).__init__()
def forward(self,opt,feat): return [feat]

# build Spatial Transformer Network
class STN(torch.nn.Module):
def __init__(self,opt):
super(STN,self).__init__()
self.inDim = 1
def conv2Layer(outDim):
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0)
self.inDim = outDim
return conv
def linearLayer(outDim):
fc = torch.nn.Linear(self.inDim,outDim)
self.inDim = outDim
return fc
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
self.conv2Layers = torch.nn.Sequential(
conv2Layer(4),torch.nn.ReLU(True),
conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer()
)
self.inDim *= 8**2
self.linearLayers = torch.nn.Sequential(
linearLayer(48),torch.nn.ReLU(True),
linearLayer(opt.warpDim)
)
initialize(opt,self,opt.stdGP,last0=True)
def forward(self,opt,image):
imageWarpAll = [image]
feat = image
feat = self.conv2Layers(feat).view(opt.batchSize,-1)
feat = self.linearLayers(feat)
p = feat
pMtrx = warp.vec2mtrx(opt,p)
imageWarp = warp.transformImage(opt,image,pMtrx)
imageWarpAll.append(imageWarp)
return imageWarpAll

# build Inverse Compositional STN
class ICSTN(torch.nn.Module):
def __init__(self,opt):
super(ICSTN,self).__init__()
self.inDim = 1
def conv2Layer(outDim):
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0)
self.inDim = outDim
return conv
def linearLayer(outDim):
fc = torch.nn.Linear(self.inDim,outDim)
self.inDim = outDim
return fc
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
self.conv2Layers = torch.nn.Sequential(
conv2Layer(4),torch.nn.ReLU(True),
conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer()
)
self.inDim *= 8**2
self.linearLayers = torch.nn.Sequential(
linearLayer(48),torch.nn.ReLU(True),
linearLayer(opt.warpDim)
)
initialize(opt,self,opt.stdGP,last0=True)
def forward(self,opt,image,p):
imageWarpAll = []
for l in range(opt.warpN):
pMtrx = warp.vec2mtrx(opt,p)
imageWarp = warp.transformImage(opt,image,pMtrx)
imageWarpAll.append(imageWarp)
feat = imageWarp
feat = self.conv2Layers(feat).view(opt.batchSize,-1)
feat = self.linearLayers(feat)
dp = feat
p = warp.compose(opt,p,dp)
pMtrx = warp.vec2mtrx(opt,p)
imageWarp = warp.transformImage(opt,image,pMtrx)
imageWarpAll.append(imageWarp)
return imageWarpAll

# initialize weights/biases
def initialize(opt,model,stddev,last0=False):
for m in model.conv2Layers:
if isinstance(m,torch.nn.Conv2d):
m.weight.data.normal_(0,stddev)
m.bias.data.normal_(0,stddev)
for m in model.linearLayers:
if isinstance(m,torch.nn.Linear):
m.weight.data.normal_(0,0.0 if last0 and m is model.linearLayers[-1] else stddev)
m.bias.data.normal_(0,0.0 if last0 and m is model.linearLayers[-1] else stddev)
Loading

0 comments on commit 5d70ba6

Please sign in to comment.