Skip to content

Commit

Permalink
modernized MarginCriterion
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko authored and soumith committed Oct 27, 2014
1 parent a1cddee commit 9176533
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 15 deletions.
24 changes: 9 additions & 15 deletions MarginCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@ local MarginCriterion, parent = torch.class('nn.MarginCriterion', 'nn.Criterion'

function MarginCriterion:__init(margin)
parent.__init(self)
margin=margin or 1
self.margin = margin
self.gradInput = torch.Tensor(1)
end

function MarginCriterion:updateOutput(input,y)
self.output=math.max(0, self.margin- y* input[1])
return self.output
self.sizeAverage = true
self.margin = margin or 1
end

function MarginCriterion:updateGradInput(input, y)
if (y*input[1])<self.margin then
self.gradInput[1]=-y
else
self.gradInput[1]=0;
end
return self.gradInput
function MarginCriterion:updateOutput(input, target)
return input.nn.MarginCriterion_updateOutput(self, input, target)
end

function MarginCriterion:updateGradInput(input, target)
input.nn.MarginCriterion_updateGradInput(self, input, target)
return self.gradInput
end
56 changes: 56 additions & 0 deletions generic/MarginCriterion.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/MarginCriterion.c"
#else

static int nn_(MarginCriterion_updateOutput)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
real margin = luaT_getfieldchecknumber(L, 1, "margin");
real sum;

sum = 0;
TH_TENSOR_APPLY2(real, input, real, target,
real z = (margin - *input_data* *target_data);
sum += z>0 ? z : 0;)

if(sizeAverage)
sum /= THTensor_(nElement)(input);

lua_pushnumber(L, sum);
lua_setfield(L, 1, "output");

lua_pushnumber(L, sum);
return 1;
}

static int nn_(MarginCriterion_updateGradInput)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
THTensor *target = luaT_checkudata(L, 3, torch_Tensor);
int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
real margin = luaT_getfieldchecknumber(L, 1, "margin");
real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);

THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
*gradInput_data = (*input_data * *target_data) < margin ? -norm* *target_data : 0;)
return 1;
}

static const struct luaL_Reg nn_(MarginCriterion__) [] = {
{"MarginCriterion_updateOutput", nn_(MarginCriterion_updateOutput)},
{"MarginCriterion_updateGradInput", nn_(MarginCriterion_updateGradInput)},
{NULL, NULL}
};

static void nn_(MarginCriterion_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(MarginCriterion__), "nn");
lua_pop(L,1);
}

#endif
5 changes: 5 additions & 0 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
#include "generic/MSECriterion.c"
#include "THGenerateFloatTypes.h"

#include "generic/MarginCriterion.c"
#include "THGenerateFloatTypes.h"

#include "generic/AbsCriterion.c"
#include "THGenerateFloatTypes.h"

Expand Down Expand Up @@ -129,6 +132,7 @@ int luaopen_libnn(lua_State *L)
nn_FloatHardTanh_init(L);
nn_FloatLogSoftMax_init(L);
nn_FloatMSECriterion_init(L);
nn_FloatMarginCriterion_init(L);
nn_FloatAbsCriterion_init(L);
nn_FloatDistKLDivCriterion_init(L);
nn_FloatLogSigmoid_init(L);
Expand Down Expand Up @@ -166,6 +170,7 @@ int luaopen_libnn(lua_State *L)
nn_DoubleHardTanh_init(L);
nn_DoubleLogSoftMax_init(L);
nn_DoubleMSECriterion_init(L);
nn_DoubleMarginCriterion_init(L);
nn_DoubleAbsCriterion_init(L);
nn_DoubleDistKLDivCriterion_init(L);
nn_DoubleLogSigmoid_init(L);
Expand Down
7 changes: 7 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,13 @@ function nntest.MSECriterion()
criterionJacobianTest1D(cri, input, target)
end

function nntest.MarginCriterion()
local input = torch.rand(100)
local target = input:clone():add(torch.rand(100))
local cri = nn.MarginCriterion()
criterionJacobianTest1D(cri, input, target)
end

function nntest.WeightedMSECriterion()
local input = torch.rand(100)
local target = input:clone():add(torch.rand(100))
Expand Down

0 comments on commit 9176533

Please sign in to comment.