Skip to content

Commit

Permalink
Fixed bug in contrast normalization backpropagation code which caused…
Browse files Browse the repository at this point in the history
… wrong gradients to be computed near image borders.
  • Loading branch information
akrizhevsky committed Jul 17, 2012
1 parent 660ea54 commit a929ca2
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 32 deletions.
17 changes: 10 additions & 7 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, data_dir, batch_range=None, init_epoch=1, init_batchnum=None,
self.batch_meta = self.get_batch_meta(data_dir)
self.data_dic = None
self.test = test
self.batch_range_idx = batch_range.index(init_batchnum)
self.batch_idx = batch_range.index(init_batchnum)

def get_next_batch(self):
if self.data_dic is None or len(self.batch_range) > 1:
Expand Down Expand Up @@ -84,13 +84,13 @@ def get_data_dims(self):
return self.batch_meta['num_vis']

def advance_batch(self):
self.batch_range_idx = self.get_next_batch_idx()
self.curr_batchnum = self.batch_range[self.batch_range_idx]
if self.batch_range_idx == 0: # we wrapped
self.batch_idx = self.get_next_batch_idx()
self.curr_batchnum = self.batch_range[self.batch_idx]
if self.batch_idx == 0: # we wrapped
self.curr_epoch += 1

def get_next_batch_idx(self):
return (self.batch_range_idx + 1) % len(self.batch_range)
return (self.batch_idx + 1) % len(self.batch_range)

def get_next_batch_num(self):
return self.batch_range[self.get_next_batch_idx()]
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, data_dim):
self.batch_meta = {'num_vis': data_dim, 'data_in_rows':True}
self.curr_epoch = 1
self.curr_batchnum = 1
self.batch_range_idx = 0
self.batch_idx = 0

def get_next_batch(self):
epoch, batchnum = self.curr_epoch, self.curr_batchnum
Expand All @@ -170,7 +170,7 @@ def __init__(self, data_dim, num_classes=10, num_cases=512):
self.num_classes = num_classes
self.curr_epoch = 1
self.curr_batchnum = 1
self.batch_range_idx=0
self.batch_idx=0

def get_num_classes(self):
return self.num_classes
Expand All @@ -197,6 +197,9 @@ def get_next_batch(self):
return epoch, batchnum, self.data_dic[batchnum - self.batch_range[0]]

class LabeledDataProvider(DataProvider):
def __init__(self, data_dir, batch_range=None, init_epoch=1, init_batchnum=None, dp_params={}, test=False):
DataProvider.__init__(self, data_dir, batch_range, init_epoch, init_batchnum, dp_params, test)

def get_num_classes(self):
return len(self.batch_meta['label_names'])

Expand Down
14 changes: 9 additions & 5 deletions gpumodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,16 @@ def get_test_error(self):
while True:
data = next_data
self.start_batch(data, train=False)
if not self.test_one and data[1] < self.test_batch_range[-1]: # load next batch
load_next = not self.test_one and data[1] < self.test_batch_range[-1]
if load_next: # load next batch
next_data = self.get_next_batch(train=False)
test_outputs += [self.finish_batch()]
else:
test_outputs += [self.finish_batch()]
break
test_outputs += [self.finish_batch()]
if self.test_only: # Print the individual batch results for safety
print "batch %d: %s" % (data[1], str(test_outputs[-1]))
if not load_next:
break
sys.stdout.flush()

return self.aggregate_test_outputs(test_outputs)

def set_var(self, var_name, var_val):
Expand Down
22 changes: 10 additions & 12 deletions include/cudaconv2/conv_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,15 @@ void convResponseNormCrossMap(NVMatrix& images, NVMatrix& denoms, NVMatrix& targ
float powScale, bool blocked);

class AvgPooler {
private:
float _num;
public:
AvgPooler(float num) : _num(num) {
}
__device__ inline float operator()(const float a, const float b) const {
return a + b;
}
__device__ inline float getBaseValue() const {
return 0;
}
__device__ inline float output(const float a) const {
return a / _num;
__device__ inline float output(const float a, const int regionSize) const {
return a / regionSize;
}
};

Expand All @@ -93,7 +89,7 @@ public:
__device__ inline float getBaseValue() const {
return -2e38;
}
__device__ inline float output(const float a) const {
__device__ inline float output(const float a, const int regionSize) const {
return a;
}
};
Expand All @@ -106,7 +102,7 @@ public:
__device__ inline float getBaseValue() const {
return 0.0f;
}
__device__ inline float output(const float a) const {
__device__ inline float output(const float a, const int regionSize) const {
return a;
}
};
Expand Down Expand Up @@ -166,6 +162,7 @@ __global__ void kLocalPool(float* imgs, float* target, const int imgSize, const
const int loopStartX = MAX(0, startImgPxX);
const int loopEndY = MIN(imgSize, startImgPxY + subsX);
const int loopEndX = MIN(imgSize, startImgPxX + subsX);
const int regionSize = (loopEndY - loopStartY) * (loopEndX - loopStartX);
for (int y = loopStartY; y < loopEndY; y++) {
for (int x = loopStartX; x < loopEndX; x++) {
const int imgPx = y * imgSize + x;
Expand All @@ -186,7 +183,7 @@ __global__ void kLocalPool(float* imgs, float* target, const int imgSize, const
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]);
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
}
}
}
Expand Down Expand Up @@ -259,7 +256,7 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
const int loopStartX = MAX(startImgPxX, 0);
const int loopEndY = MIN(imgSize, endImgPxY + 3);
const int loopEndX = MIN(imgSize, endImgPxX + 3);

const int imgIdx = blockImgIdx + threadIdx.x;

imgs += (blockFilterIdx + loadY) * imgPixels * numImages + blockImgIdx + loadX;
Expand All @@ -273,7 +270,7 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
prod[f][i] = agg.getBaseValue();
}
}

int regionSize = 0;
for (int y = loopStartY; y < loopEndY; y++) {
const bool isInY = y >= myStartImgPxY && y < myEndImgPxY ;
for (int x = loopStartX; x < loopEndX; x++) {
Expand Down Expand Up @@ -303,6 +300,7 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
}
}
}
++regionSize;
}
__syncthreads();

Expand All @@ -314,7 +312,7 @@ __global__ void kLocalPool2(float* imgs, float* target, const int imgSize, const
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i]);
target[f * numOutputs * numImages + i * B_X] = agg.output(prod[f][i], regionSize);
}
}
}
Expand Down
18 changes: 13 additions & 5 deletions src/cudaconv2/conv_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1423,18 +1423,27 @@ __global__ void kLocalAvgUndo(float* avgGrads, float* target, const int imgSize,
}
}

if (blockPxX >= startX && blockPxX < startX + strideX * (outputsX-1) + subsX
if (blockPxX >= startX && blockPxX < startX + strideX * (outputsX-1) + subsX
&& blockPxY >= startX && blockPxY < startX + strideX * (outputsX-1) + subsX) {

for (int my = startOutputY; my < endOutputY; my++) {
const float regionStartY = fmaxf(0, startX + my * strideX);
const float regionEndY = fminf(imgSize, startX + my * strideX + subsX);
const float regionSizeY = regionEndY - regionStartY;
for (int mx = startOutputX; mx < endOutputX; mx++) {
const int outputIdx = my * outputsX + mx;
const float regionStartX = fmaxf(0, startX + mx * strideX);
const float regionEndX = fminf(imgSize, startX + mx * strideX + subsX);
const float regionSizeX = regionEndX - regionStartX;
// It's important to do the division here, because pushing division into the below
// loops makes the code 4x slower.
const float regionSizeInv = 1.0f / (regionSizeX * regionSizeY);
#pragma unroll
for (int i = 0; i < imgsPerThread; i++) {
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
prod[f][i] += avgGrads[(f * B_Y * numOutputs + outputIdx) * numImages + i * B_X];
prod[f][i] += avgGrads[(f * B_Y * numOutputs + outputIdx) * numImages + i * B_X] * regionSizeInv;
}
}
}
Expand All @@ -1448,7 +1457,7 @@ __global__ void kLocalAvgUndo(float* avgGrads, float* target, const int imgSize,
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * B_Y * imgPixels * numImages + i * B_X] = prod[f][i] / (subsX * subsX);
target[f * B_Y * imgPixels * numImages + i * B_X] = prod[f][i];
}
}
}
Expand All @@ -1458,7 +1467,7 @@ __global__ void kLocalAvgUndo(float* avgGrads, float* target, const int imgSize,
if (!checkCaseBounds || imgIdx + i * B_X < numImages) {
#pragma unroll
for (int f = 0; f < filtersPerThread; f++) {
target[f * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * target[f * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[f][i] / (subsX * subsX);
target[f * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * target[f * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[f][i];
}
}
}
Expand Down Expand Up @@ -2335,7 +2344,6 @@ void convResponseNormUndo(NVMatrix& outGrads, NVMatrix& denoms, NVMatrix& inputs
}
}
}

} else {
int imgsPerThread = numImages % 64 == 0 ? 2 : 1;
bool checkCaseBounds = numImages % (32*imgsPerThread) != 0;
Expand Down
4 changes: 2 additions & 2 deletions src/layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ AvgPoolLayer::AvgPoolLayer(ConvNet* convNet, PyObject* paramsDict) : PoolLayer(c
}

void AvgPoolLayer::fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType) {
convLocalPool(*_inputs[0], getActs(), _channels, _sizeX, _start, _stride, _outputsX, AvgPooler(_sizeX*_sizeX));
convLocalPool(*_inputs[0], getActs(), _channels, _sizeX, _start, _stride, _outputsX, AvgPooler());
}

void AvgPoolLayer::bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType) {
Expand Down Expand Up @@ -896,7 +896,7 @@ ContrastNormLayer::ContrastNormLayer(ConvNet* convNet, PyObject* paramsDict) : R

void ContrastNormLayer::fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType) {
NVMatrix& images = *_inputs[0];
convLocalPool(images, _meanDiffs, _channels, _size, -_size/2, 1, _imgSize, AvgPooler(_size*_size));
convLocalPool(images, _meanDiffs, _channels, _size, -_size/2, 1, _imgSize, AvgPooler());
_meanDiffs.add(images, -1, 1);
convContrastNorm(images, _meanDiffs, _denoms, getActs(), _channels, _size, _scale, _pow);
}
Expand Down
3 changes: 2 additions & 1 deletion src/nvmatrix/nvmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ void NVMatrix::initRandom(unsigned long long seed) {
rndDevStates[d] = NULL;
CUDA_CALL(cudaMalloc((void **)&rndDevStates[d], NUM_RND_STREAMS * sizeof(curandState)));
pthread_mutex_unlock(_rndMutex);
printf("initialized random for %d\n", d);
kSetupCurand<<<NUM_RND_BLOCKS, NUM_RND_THREADS_PER_BLOCK>>>(getCurandState(), 1 + seed*2); // so there's no chance it'll be correlated with the other one
cutilCheckMsg("initRandom: Kernel execution failed");
}
Expand Down Expand Up @@ -518,6 +517,8 @@ NVMatrix& NVMatrix::reshaped(int numRows, int numCols) {
void NVMatrix::copy(NVMatrix &dest, int srcStartRow, int srcEndRow,
int srcStartCol, int srcEndCol,
int destStartRow, int destStartCol) const {
srcEndRow = srcEndRow < 0 ? _numRows : srcEndRow;
srcEndCol = srcEndCol < 0 ? _numCols : srcEndCol;
NVMatrix* srcSlice = &slice(srcStartRow, srcEndRow, srcStartCol, srcEndCol);
NVMatrix* destSlice = &dest.slice(destStartRow, destStartRow + srcEndRow - srcStartRow, destStartCol, destStartCol + srcEndCol - srcStartCol);
srcSlice->apply(NVMatrixOps::Identity(), *destSlice);
Expand Down

0 comments on commit a929ca2

Please sign in to comment.