-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f3ea9d
commit 5d70ba6
Showing
19 changed files
with
698 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
import scipy.linalg | ||
import os,time | ||
import torch | ||
|
||
import warp,util | ||
|
||
# load MNIST data | ||
def loadMNIST(fname): | ||
if not os.path.exists(fname): | ||
# download and preprocess MNIST dataset | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) | ||
trainData,validData,testData = {},{},{} | ||
trainData["image"] = mnist.train.images.reshape([-1,28,28]).astype(np.float32) | ||
validData["image"] = mnist.validation.images.reshape([-1,28,28]).astype(np.float32) | ||
testData["image"] = mnist.test.images.reshape([-1,28,28]).astype(np.float32) | ||
trainData["label"] = np.argmax(mnist.train.labels.astype(np.float32),axis=1) | ||
validData["label"] = np.argmax(mnist.validation.labels.astype(np.float32),axis=1) | ||
testData["label"] = np.argmax(mnist.test.labels.astype(np.float32),axis=1) | ||
os.makedirs(os.path.dirname(fname)) | ||
np.savez(fname,train=trainData,valid=validData,test=testData) | ||
os.system("rm -rf MNIST_data") | ||
MNIST = np.load(fname) | ||
trainData = MNIST["train"].item() | ||
validData = MNIST["valid"].item() | ||
testData = MNIST["test"].item() | ||
return trainData,validData,testData | ||
|
||
# generate training batch | ||
def genPerturbations(opt): | ||
X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1]) | ||
Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1]) | ||
O = np.zeros([opt.batchSize,4],dtype=np.float32) | ||
I = np.ones([opt.batchSize,4],dtype=np.float32) | ||
dX = np.random.randn(opt.batchSize,4)*opt.pertScale \ | ||
+np.random.randn(opt.batchSize,1)*opt.transScale | ||
dY = np.random.randn(opt.batchSize,4)*opt.pertScale \ | ||
+np.random.randn(opt.batchSize,1)*opt.transScale | ||
dX,dY = dX.astype(np.float32),dY.astype(np.float32) | ||
# fit warp parameters to generated displacements | ||
if opt.warpType=="homography": | ||
A = np.concatenate([np.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1), | ||
np.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],axis=1) | ||
b = np.expand_dims(np.concatenate([X+dX,Y+dY],axis=1),axis=-1) | ||
pPert = np.matmul(np.linalg.inv(A),b).squeeze() | ||
pPert -= np.array([1,0,0,0,1,0,0,0]) | ||
else: | ||
if opt.warpType=="translation": | ||
J = np.concatenate([np.stack([I,O],axis=-1), | ||
np.stack([O,I],axis=-1)],axis=1) | ||
if opt.warpType=="similarity": | ||
J = np.concatenate([np.stack([X,Y,I,O],axis=-1), | ||
np.stack([-Y,X,O,I],axis=-1)],axis=1) | ||
if opt.warpType=="affine": | ||
J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1), | ||
np.stack([O,O,O,X,Y,I],axis=-1)],axis=1) | ||
dXY = np.expand_dims(np.concatenate([dX,dY],axis=1),axis=-1) | ||
Jtransp = np.transpose(J,axes=[0,2,1]) | ||
pPert = np.matmul(np.linalg.inv(np.matmul(Jtransp,J)),np.matmul(Jtransp,dXY)).squeeze() | ||
pInit = util.toTorch(pPert) | ||
return pInit | ||
|
||
# make training batch | ||
def makeBatch(opt,data): | ||
N = len(data["image"]) | ||
randIdx = np.random.randint(N,size=[opt.batchSize]) | ||
batch = { | ||
"image": util.toTorch(data["image"][randIdx]), | ||
"label": util.toTorch(data["label"][randIdx]), | ||
} | ||
return batch | ||
|
||
# evaluation on test set | ||
def evalTest(opt,data,geometric,classifier): | ||
geometric.eval() | ||
classifier.eval() | ||
N = len(data["image"]) | ||
batchN = int(np.ceil(N/opt.batchSize)) | ||
warped = [{},{}] | ||
count = 0 | ||
for b in range(batchN): | ||
# use some dummy data (0) as batch filler if necessary | ||
if b!=batchN-1: | ||
realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1)) | ||
else: | ||
realIdx = np.arange(opt.batchSize*b,N) | ||
idx = np.zeros([opt.batchSize],dtype=int) | ||
idx[:len(realIdx)] = realIdx | ||
# make training batch | ||
image = util.toTorch(data["image"][idx]) | ||
label = util.toTorch(data["label"][idx]) | ||
image.data.unsqueeze_(dim=1) | ||
# generate perturbation | ||
pInit = genPerturbations(opt) | ||
pInitMtrx = warp.vec2mtrx(opt,pInit) | ||
imagePert = warp.transformImage(opt,image,pInitMtrx) | ||
imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert) | ||
imageWarp = imageWarpAll[-1] | ||
output = classifier(opt,imageWarp) | ||
_,pred = output.max(dim=1) | ||
count += int(util.toNumpy((pred==label).sum())) | ||
if opt.netType=="STN" or opt.netType=="IC-STN": | ||
imgPert = util.toNumpy(imagePert) | ||
imgWarp = util.toNumpy(imageWarp) | ||
for i in range(len(realIdx)): | ||
l = data["label"][idx[i]] | ||
if l not in warped[0]: warped[0][l] = [] | ||
if l not in warped[1]: warped[1][l] = [] | ||
warped[0][l].append(imgPert[i]) | ||
warped[1][l].append(imgWarp[i]) | ||
accuracy = float(count)/N | ||
if opt.netType=="STN" or opt.netType=="IC-STN": | ||
mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]), | ||
np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])] | ||
var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]), | ||
np.array([np.var(warped[1][l],axis=0) for l in warped[1]])] | ||
else: mean,var = None,None | ||
geometric.train() | ||
classifier.train() | ||
return accuracy,mean,var |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
import time,os,sys | ||
import argparse | ||
import util | ||
|
||
print(util.toYellow("=======================================================")) | ||
print(util.toYellow("train.py (training on MNIST)")) | ||
print(util.toYellow("=======================================================")) | ||
|
||
import torch | ||
import data,graph,warp,util | ||
import options | ||
|
||
print(util.toMagenta("setting configurations...")) | ||
opt = options.set(training=True) | ||
|
||
# create directories for model output | ||
util.mkdir("models_{0}".format(opt.group)) | ||
|
||
print(util.toMagenta("building network...")) | ||
with torch.cuda.device(0): | ||
classifier = graph.FullCNN(opt) | ||
# ------ define loss ------ | ||
loss = torch.nn.CrossEntropyLoss() | ||
# ------ optimizer ------ | ||
optimList = [{ "params": classifier.parameters(), "lr": opt.lrC }] | ||
optim = torch.optim.SGD(optimList) | ||
|
||
# load data | ||
print(util.toMagenta("loading MNIST dataset...")) | ||
trainData,validData,testData = data.loadMNIST("data/MNIST.npz") | ||
|
||
print(util.toYellow("======= TRAINING START =======")) | ||
timeStart = time.time() | ||
# start session | ||
with torch.cuda.device(0): | ||
classifier.train() | ||
print(util.toMagenta("start training...")) | ||
|
||
# training loop | ||
for i in range(opt.fromIt,opt.toIt): | ||
# make training batch | ||
batch = data.makeBatch(opt,trainData) | ||
image = batch["image"].unsqueeze(dim=1) | ||
label = batch["label"] | ||
# generate perturbation | ||
pInit = util.toTorch(np.zeros([opt.batchSize,opt.warpDim],dtype=np.float32)) | ||
pInitMtrx = warp.vec2mtrx(opt,pInit) | ||
# forward/backprop through network | ||
optim.zero_grad() | ||
imagePert = warp.transformImage(opt,image,pInitMtrx) | ||
# print((imagePert-image).data.cpu().mean(),(imagePert-image).data.cpu().var()) | ||
# img = image.data.cpu().numpy() | ||
# imgpert = imagePert.data.cpu().numpy() | ||
# util.imsave("temp1.png",img[0].squeeze()) | ||
# util.imsave("temp2.png",imgpert[0].squeeze()) | ||
# for i in range(100): util.imsave("temp/{0}_1.png".format(i),1-img[i].squeeze()) | ||
# for i in range(100): util.imsave("temp/{0}_2.png".format(i),1-imgpert[i].squeeze()) | ||
# print(imagePert[0,0],image[0,0]) | ||
# assert(False) | ||
# forward/backprop through network | ||
optim.zero_grad() | ||
output = classifier(opt,imagePert) | ||
train_loss = loss(output,label) | ||
train_loss.backward() | ||
# run one step | ||
optim.step() | ||
if (i+1)%100==0: | ||
print("it. {0}/{1} loss={3}, time={2}" | ||
.format(util.toCyan("{0}".format(i+1)), | ||
opt.toIt, | ||
util.toGreen("{0:.2f}".format(time.time()-timeStart)), | ||
util.toRed("{0:.4f}".format(train_loss.data[0])))) | ||
assert(False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import numpy as np | ||
import torch | ||
import time | ||
import data,warp,util | ||
|
||
# build classification network | ||
class FullCNN(torch.nn.Module): | ||
def __init__(self,opt): | ||
super(FullCNN,self).__init__() | ||
self.inDim = 1 | ||
def conv2Layer(outDim): | ||
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[3,3],stride=1,padding=0) | ||
self.inDim = outDim | ||
return conv | ||
def linearLayer(outDim): | ||
fc = torch.nn.Linear(self.inDim,outDim) | ||
self.inDim = outDim | ||
return fc | ||
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) | ||
self.conv2Layers = torch.nn.Sequential( | ||
conv2Layer(3),torch.nn.ReLU(True), | ||
conv2Layer(6),torch.nn.ReLU(True),maxpoolLayer(), | ||
conv2Layer(9),torch.nn.ReLU(True), | ||
conv2Layer(12),torch.nn.ReLU(True) | ||
) | ||
self.inDim *= 8**2 | ||
self.linearLayers = torch.nn.Sequential( | ||
linearLayer(48),torch.nn.ReLU(True), | ||
linearLayer(opt.labelN) | ||
) | ||
initialize(opt,self,opt.stdC) | ||
def forward(self,opt,image): | ||
feat = image | ||
feat = self.conv2Layers(feat).view(opt.batchSize,-1) | ||
feat = self.linearLayers(feat) | ||
output = feat | ||
return output | ||
|
||
# build classification network | ||
class CNN(torch.nn.Module): | ||
def __init__(self,opt): | ||
super(CNN,self).__init__() | ||
self.inDim = 1 | ||
def conv2Layer(outDim): | ||
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[9,9],stride=1,padding=0) | ||
self.inDim = outDim | ||
return conv | ||
def linearLayer(outDim): | ||
fc = torch.nn.Linear(self.inDim,outDim) | ||
self.inDim = outDim | ||
return fc | ||
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) | ||
self.conv2Layers = torch.nn.Sequential( | ||
conv2Layer(3),torch.nn.ReLU(True) | ||
) | ||
self.inDim *= 20**2 | ||
self.linearLayers = torch.nn.Sequential( | ||
linearLayer(opt.labelN) | ||
) | ||
initialize(opt,self,opt.stdC) | ||
def forward(self,opt,image): | ||
feat = image | ||
feat = self.conv2Layers(feat).view(opt.batchSize,-1) | ||
feat = self.linearLayers(feat) | ||
output = feat | ||
return output | ||
|
||
# an identity class to skip geometric predictors | ||
class Identity(torch.nn.Module): | ||
def __init__(self): super(Identity,self).__init__() | ||
def forward(self,opt,feat): return [feat] | ||
|
||
# build Spatial Transformer Network | ||
class STN(torch.nn.Module): | ||
def __init__(self,opt): | ||
super(STN,self).__init__() | ||
self.inDim = 1 | ||
def conv2Layer(outDim): | ||
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0) | ||
self.inDim = outDim | ||
return conv | ||
def linearLayer(outDim): | ||
fc = torch.nn.Linear(self.inDim,outDim) | ||
self.inDim = outDim | ||
return fc | ||
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) | ||
self.conv2Layers = torch.nn.Sequential( | ||
conv2Layer(4),torch.nn.ReLU(True), | ||
conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer() | ||
) | ||
self.inDim *= 8**2 | ||
self.linearLayers = torch.nn.Sequential( | ||
linearLayer(48),torch.nn.ReLU(True), | ||
linearLayer(opt.warpDim) | ||
) | ||
initialize(opt,self,opt.stdGP,last0=True) | ||
def forward(self,opt,image): | ||
imageWarpAll = [image] | ||
feat = image | ||
feat = self.conv2Layers(feat).view(opt.batchSize,-1) | ||
feat = self.linearLayers(feat) | ||
p = feat | ||
pMtrx = warp.vec2mtrx(opt,p) | ||
imageWarp = warp.transformImage(opt,image,pMtrx) | ||
imageWarpAll.append(imageWarp) | ||
return imageWarpAll | ||
|
||
# build Inverse Compositional STN | ||
class ICSTN(torch.nn.Module): | ||
def __init__(self,opt): | ||
super(ICSTN,self).__init__() | ||
self.inDim = 1 | ||
def conv2Layer(outDim): | ||
conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0) | ||
self.inDim = outDim | ||
return conv | ||
def linearLayer(outDim): | ||
fc = torch.nn.Linear(self.inDim,outDim) | ||
self.inDim = outDim | ||
return fc | ||
def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) | ||
self.conv2Layers = torch.nn.Sequential( | ||
conv2Layer(4),torch.nn.ReLU(True), | ||
conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer() | ||
) | ||
self.inDim *= 8**2 | ||
self.linearLayers = torch.nn.Sequential( | ||
linearLayer(48),torch.nn.ReLU(True), | ||
linearLayer(opt.warpDim) | ||
) | ||
initialize(opt,self,opt.stdGP,last0=True) | ||
def forward(self,opt,image,p): | ||
imageWarpAll = [] | ||
for l in range(opt.warpN): | ||
pMtrx = warp.vec2mtrx(opt,p) | ||
imageWarp = warp.transformImage(opt,image,pMtrx) | ||
imageWarpAll.append(imageWarp) | ||
feat = imageWarp | ||
feat = self.conv2Layers(feat).view(opt.batchSize,-1) | ||
feat = self.linearLayers(feat) | ||
dp = feat | ||
p = warp.compose(opt,p,dp) | ||
pMtrx = warp.vec2mtrx(opt,p) | ||
imageWarp = warp.transformImage(opt,image,pMtrx) | ||
imageWarpAll.append(imageWarp) | ||
return imageWarpAll | ||
|
||
# initialize weights/biases | ||
def initialize(opt,model,stddev,last0=False): | ||
for m in model.conv2Layers: | ||
if isinstance(m,torch.nn.Conv2d): | ||
m.weight.data.normal_(0,stddev) | ||
m.bias.data.normal_(0,stddev) | ||
for m in model.linearLayers: | ||
if isinstance(m,torch.nn.Linear): | ||
m.weight.data.normal_(0,0.0 if last0 and m is model.linearLayers[-1] else stddev) | ||
m.bias.data.normal_(0,0.0 if last0 and m is model.linearLayers[-1] else stddev) |
Oops, something went wrong.