diff --git a/lib/nnc/ccv_nnc.h b/lib/nnc/ccv_nnc.h index 17076e3fe..72b6b9c73 100644 --- a/lib/nnc/ccv_nnc.h +++ b/lib/nnc/ccv_nnc.h @@ -105,6 +105,7 @@ typedef struct { struct { int count; /**< [convolution.count] The number of filters for convolutional layer. */ int groups; /**< [convolution.groups] The number of groups for convolutional layer. */ + int dilation[CCV_NNC_MAX_DIM_ALLOC]; /**< [convolution.dilation[]] The dilation factor for convolutional layer. Default to 1. */ } convolution; struct { int hidden_size; /**< [rnn.hidden_size] The number of features in the hidden state h. */ diff --git a/lib/nnc/cmd/blas/ccv_nnc_blas.c b/lib/nnc/cmd/blas/ccv_nnc_blas.c index cb135de30..361d39533 100644 --- a/lib/nnc/cmd/blas/ccv_nnc_blas.c +++ b/lib/nnc/cmd/blas/ccv_nnc_blas.c @@ -234,3 +234,45 @@ REGISTER_COMMAND(CCV_NNC_SCALAR_MUL_BACKWARD)(ccv_nnc_cmd_registry_t* const regi #define CMD_SCALAR_MUL_FORWARD(_a) ccv_nnc_cmd(CCV_NNC_SCALAR_MUL_FORWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.blas={.a={_a,}}}, 0) //@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_SCALAR_MUL_BACKWARD) #define CMD_SCALAR_MUL_BACKWARD(_a) ccv_nnc_cmd(CCV_NNC_SCALAR_MUL_BACKWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}},.blas={.a={_a,}}}, 0) + +static int _ccv_nnc_cmul_forw_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) +{ + if ((input_bitmasks[0] & 3u) == ((1u << 0) | (1u << 1)) && output_bitmasks[0] == 1u) + return 1; + return 0; +} + +static int _ccv_nnc_cmul_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int input_size, const int output_size, const uint64_t* const input_bitmasks, const int input_bitmask_size, const uint64_t* const output_bitmasks, const int output_bitmask_size) +{ + // w.r.t. both x and y + if ((input_bitmasks[0] & 7u) == 7u && output_bitmasks[0] == ((1u << 0) | (1u << 1))) + return 1; + // w.r.t. x + if ((input_bitmasks[0] & 5u) == 5u && output_bitmasks[0] == ((1u << 0) | (0u << 1))) + return 1; + // w.r.t. y + if ((input_bitmasks[0] & 3u) == 3u && output_bitmasks[0] == ((0u << 0) | (1u << 1))) + return 1; + return 0; +} + +REGISTER_COMMAND(CCV_NNC_CMUL_FORWARD)(ccv_nnc_cmd_registry_t* const registry) + FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c) +{ + registry->bitmask = _ccv_nnc_cmul_forw_bitmask; + registry->tensor_auto = _ccv_nnc_broadcast_tensor_auto_forw; + registry->allow_inplace = _ccv_nnc_same_pos_inplace; +} + +REGISTER_COMMAND(CCV_NNC_CMUL_BACKWARD)(ccv_nnc_cmd_registry_t* const registry) + FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c) +{ + registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES; + registry->bitmask = _ccv_nnc_cmul_back_bitmask; + registry->tensor_auto = ccv_nnc_hint_tensor_auto_backward_from_inputs; +} + +//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CMUL_FORWARD) +#define CMD_CMUL_FORWARD() ccv_nnc_cmd(CCV_NNC_CMUL_FORWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}}}, 0) +//@REGISTER_EASY_COMMAND_MACRO(CCV_NNC_CMUL_BACKWARD) +#define CMD_CMUL_BACKWARD() ccv_nnc_cmd(CCV_NNC_CMUL_BACKWARD, 0, (ccv_nnc_cmd_param_t){.size={.dim={1,1,1}}}, 0) diff --git a/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_opt.c b/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_opt.c index f972375ba..ca29511bd 100644 --- a/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_opt.c +++ b/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_opt.c @@ -36,6 +36,8 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint assert(bdim[CCV_NNC_MAX_DIM] == cmd.info.convolution.count); if (cmd.info.convolution.groups != 1) return CCV_NNC_EXEC_INVALID; + if (cmd.info.convolution.dilation[0] > 1 || cmd.info.convolution.dilation[1] > 1) + return CCV_NNC_EXEC_INVALID; int i; // Make sure the weights dimension matches the network dimension for (i = 1; i < CCV_NNC_MAX_DIM_ALLOC; i++) diff --git a/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_ref.c b/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_ref.c index 9a9d08630..4f4c5bbcb 100644 --- a/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_ref.c +++ b/lib/nnc/cmd/convolution/ccv_nnc_conv_cpu_ref.c @@ -37,11 +37,19 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_tensor_view_get_stride(b, bstride); assert(!bias || bias->info.dim[0] == cmd.info.convolution.count); const int batch_size = (a_nd == CCV_NNC_MAX_DIM + 2) ? a->info.dim[0] : 1; + const int dilation[CCV_NNC_MAX_DIM] = { + ccv_max(cmd.info.convolution.dilation[0], 1), + ccv_max(cmd.info.convolution.dilation[1], 1) + }; if (a->info.format == CCV_TENSOR_FORMAT_NHWC) { // Make sure the weights dimension matches the network dimension assert(w->info.dim[1] == cmd.info.size.dim[0]); assert(w->info.dim[2] == cmd.info.size.dim[1]); + const int wdim[CCV_NNC_MAX_DIM] = { + (w->info.dim[1] - 1) * dilation[0] + 1, + (w->info.dim[2] - 1) * dilation[1] + 1 + }; assert(w->info.dim[CCV_NNC_MAX_DIM + 1] * groups == adim[CCV_NNC_MAX_DIM]); assert(b->info.format == CCV_TENSOR_FORMAT_NHWC); const int channel_size = w->info.dim[CCV_NNC_MAX_DIM + 1]; @@ -59,25 +67,36 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. int i[CCV_NNC_MAX_DIM]; int n[CCV_NNC_MAX_DIM]; + int d[CCV_NNC_MAX_DIM]; int m[CCV_NNC_MAX_DIM]; int j[CCV_NNC_MAX_DIM]; for (i[0] = 0; i[0] < bdim[0]; i[0]++) { - SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, adim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim, n, m); + m[0] = (m[0] + n[0] - 1) / dilation[0] + 1; + const int n0 = (n[0] + dilation[0] - 1) / dilation[0]; + d[0] = n0 * dilation[0] - n[0]; + n[0] = n0; + m[0] = m[0] - n[0]; float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size; for (i[1] = 0; i[1] < bdim[1]; i[1]++) { - SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, adim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim, n, m); + m[1] = (m[1] + n[1] - 1) / dilation[1] + 1; + const int n1 = (n[1] + dilation[1] - 1) / dilation[1]; + d[1] = n1 * dilation[1] - n[1]; + n[1] = n1; + m[1] = m[1] - n[1]; float p = biasval; float* wpz = wpu + n[1] * channel_size; - float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size; + float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size; for (j[0] = 0; j[0] < m[0]; j[0]++) { for (j[1] = 0; j[1] < m[1]; j[1]++) for (c = 0; c < channel_size; c++) - p += wpz[j[1] * channel_size + c] * apz[j[1] * astride[CCV_NNC_MAX_DIM] + c]; + p += wpz[j[1] * channel_size + c] * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM] + c]; wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size; - apz += astride[CCV_NNC_MAX_DIM - 1]; + apz += astride[CCV_NNC_MAX_DIM - 1] * dilation[0]; } bp[i[1] * bstride[CCV_NNC_MAX_DIM]] = p; } @@ -89,6 +108,10 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint // Make sure the weights dimension matches the network dimension assert(w->info.dim[2] == cmd.info.size.dim[0]); assert(w->info.dim[3] == cmd.info.size.dim[1]); + const int wdim[CCV_NNC_MAX_DIM] = { + (w->info.dim[2] - 1) * dilation[0] + 1, + (w->info.dim[3] - 1) * dilation[1] + 1 + }; assert(w->info.dim[1] * groups == adim[0]); assert(b->info.format == CCV_TENSOR_FORMAT_NCHW); const int channel_size = w->info.dim[1]; @@ -107,25 +130,36 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. int i[CCV_NNC_MAX_DIM]; int n[CCV_NNC_MAX_DIM]; + int d[CCV_NNC_MAX_DIM]; int m[CCV_NNC_MAX_DIM]; int j[CCV_NNC_MAX_DIM]; for (i[0] = 0; i[0] < bdim[1]; i[0]++) { - SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 2, adim + 1, n, m); + SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim + 1, n, m); + m[0] = (m[0] + n[0] - 1) / dilation[0] + 1; + const int n0 = (n[0] + dilation[0] - 1) / dilation[0]; + d[0] = n0 * dilation[0] - n[0]; + n[0] = n0; + m[0] = m[0] - n[0]; float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM + 1]; for (i[1] = 0; i[1] < bdim[2]; i[1]++) { - SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 2, adim + 1, n, m); + SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim + 1, n, m); + m[1] = (m[1] + n[1] - 1) / dilation[1] + 1; + const int n1 = (n[1] + dilation[1] - 1) / dilation[1]; + d[1] = n1 * dilation[1] - n[1]; + n[1] = n1; + m[1] = m[1] - n[1]; float p = biasval; float* wpz = wpu + n[1]; - float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM + 1] + gidx * channel_size * astride[1]; + float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM + 1] + gidx * channel_size * astride[1]; for (j[0] = 0; j[0] < m[0]; j[0]++) { for (j[1] = 0; j[1] < m[1]; j[1]++) for (c = 0; c < channel_size; c++) - p += wpz[j[1] + c * hw] * apz[j[1] * astride[CCV_NNC_MAX_DIM + 1] + c * astride[1]]; + p += wpz[j[1] + c * hw] * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM + 1] + c * astride[1]]; wpz += w->info.dim[CCV_NNC_MAX_DIM + 1]; - apz += astride[CCV_NNC_MAX_DIM]; + apz += astride[CCV_NNC_MAX_DIM] * dilation[0]; } bp[i[1]] = p; } @@ -173,6 +207,14 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint const int group_size = cmd.info.convolution.count / groups; const int channel_size = w ? w->info.dim[CCV_NNC_MAX_DIM + 1] : inputs[2]->info.dim[CCV_NNC_MAX_DIM + 1]; const int batch_size = (a_nd == CCV_NNC_MAX_DIM + 2) ? a->info.dim[0] : 1; + const int dilation[CCV_NNC_MAX_DIM] = { + ccv_max(cmd.info.convolution.dilation[0], 1), + ccv_max(cmd.info.convolution.dilation[1], 1) + }; + const int wdim[CCV_NNC_MAX_DIM] = { + (w->info.dim[1] - 1) * dilation[0] + 1, + (w->info.dim[2] - 1) * dilation[1] + 1 + }; if (w) { parallel_for(k, cmd.info.convolution.count) { @@ -183,6 +225,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint float biasval = 0; int i[CCV_NNC_MAX_DIM]; int n[CCV_NNC_MAX_DIM]; + int d[CCV_NNC_MAX_DIM]; int m[CCV_NNC_MAX_DIM]; int j[CCV_NNC_MAX_DIM]; int bidx; @@ -192,24 +235,34 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint const float* gp = g->data.f32 + bidx * gstride[0] + k; for (i[0] = 0; i[0] < gdim[0]; i[0]++) { - SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, adim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, adim, n, m); + m[0] = (m[0] + n[0] - 1) / dilation[0] + 1; + const int n0 = (n[0] + dilation[0] - 1) / dilation[0]; + d[0] = n0 * dilation[0] - n[0]; + n[0] = n0; + m[0] = m[0] - n[0]; float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size; for (i[1] = 0; i[1] < gdim[1]; i[1]++) { - SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, adim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, adim, n, m); + m[1] = (m[1] + n[1] - 1) / dilation[1] + 1; + const int n1 = (n[1] + dilation[1] - 1) / dilation[1]; + d[1] = n1 * dilation[1] - n[1]; + n[1] = n1; + m[1] = m[1] - n[1]; const float v = gp[i[1] * gstride[CCV_NNC_MAX_DIM]]; if (v == 0) // shortcut if v is zero continue; biasval += v; float* wpz = wpu + n[1] * channel_size; - const float* apz = ap + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size; + const float* apz = ap + d[0] * astride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * astride[CCV_NNC_MAX_DIM] + gidx * channel_size; for (j[0] = 0; j[0] < m[0]; j[0]++) { for (j[1] = 0; j[1] < m[1]; j[1]++) for (c = 0; c < channel_size; c++) - wpz[j[1] * channel_size + c] += v * apz[j[1] * astride[CCV_NNC_MAX_DIM] + c]; + wpz[j[1] * channel_size + c] += v * apz[j[1] * dilation[1] * astride[CCV_NNC_MAX_DIM] + c]; wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size; - apz += astride[CCV_NNC_MAX_DIM - 1]; + apz += astride[CCV_NNC_MAX_DIM - 1] * dilation[0]; } } gp += gstride[CCV_NNC_MAX_DIM - 1]; @@ -248,27 +301,38 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint // This block will be cause in each for-loop, therefore, you can use it to generate some temporary variables. int i[CCV_NNC_MAX_DIM]; int n[CCV_NNC_MAX_DIM]; + int d[CCV_NNC_MAX_DIM]; int m[CCV_NNC_MAX_DIM]; int j[CCV_NNC_MAX_DIM]; for (i[0] = 0; i[0] < gdim[0]; i[0]++) { - SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, w->info.dim + 1, hdim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(0, i, hint, wdim, hdim, n, m); + m[0] = (m[0] + n[0] - 1) / dilation[0] + 1; + const int n0 = (n[0] + dilation[0] - 1) / dilation[0]; + d[0] = n0 * dilation[0] - n[0]; + n[0] = n0; + m[0] = m[0] - n[0]; const float* wpu = wp + n[0] * w->info.dim[CCV_NNC_MAX_DIM] * channel_size; for (i[1] = 0; i[1] < gdim[1]; i[1]++) { - SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, w->info.dim + 1, hdim, n, m); + SET_BORDER_OFFSET_SIZE_FOR(1, i, hint, wdim, hdim, n, m); + m[1] = (m[1] + n[1] - 1) / dilation[1] + 1; + const int n1 = (n[1] + dilation[1] - 1) / dilation[1]; + d[1] = n1 * dilation[1] - n[1]; + n[1] = n1; + m[1] = m[1] - n[1]; const float v = gp[i[1] * gstride[CCV_NNC_MAX_DIM]]; if (v == 0) // shortcut if v is zero continue; const float* wpz = wpu + n[1] * channel_size; - float* hpz = hp + ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) * hstride[CCV_NNC_MAX_DIM] + gidx * channel_size; + float* hpz = hp + d[0] * hstride[CCV_NNC_MAX_DIM - 1] + (ccv_max(i[1] * hint.stride.dim[1] - hint.border.begin[1], 0) + d[1]) * hstride[CCV_NNC_MAX_DIM] + gidx * channel_size; for (j[0] = 0; j[0] < m[0]; j[0]++) { for (j[1] = 0; j[1] < m[1]; j[1]++) for (c = 0; c < channel_size; c++) - hpz[j[1] * hstride[CCV_NNC_MAX_DIM] + c] += v * wpz[j[1] * channel_size + c]; + hpz[j[1] * dilation[1] * hstride[CCV_NNC_MAX_DIM] + c] += v * wpz[j[1] * channel_size + c]; wpz += w->info.dim[CCV_NNC_MAX_DIM] * channel_size; - hpz += hstride[CCV_NNC_MAX_DIM - 1]; + hpz += hstride[CCV_NNC_MAX_DIM - 1] * dilation[0]; } } gp += gstride[CCV_NNC_MAX_DIM - 1]; diff --git a/lib/nnc/cmd/convolution/ccv_nnc_convolution.c b/lib/nnc/cmd/convolution/ccv_nnc_convolution.c index 8e3408350..d20fa7e12 100644 --- a/lib/nnc/cmd/convolution/ccv_nnc_convolution.c +++ b/lib/nnc/cmd/convolution/ccv_nnc_convolution.c @@ -47,7 +47,11 @@ static void _ccv_nnc_conv_tensor_auto_forw(const ccv_nnc_cmd_param_t cmd, const assert(count == cmd.convolution.count); ccv_nnc_tensor_set_c(outputs, ccv_nnc_tensor_nd(inputs[0].dim), count); ccv_nnc_tensor_set_n(outputs, ccv_nnc_tensor_get_n(inputs[0])); - ccv_nnc_hint_tensor_forward(cmd, inputs[0], hint, outputs); + ccv_nnc_cmd_param_t modified_cmd = cmd; + int i = 0; + for (i = 0; i < CCV_NNC_MAX_DIM; i++) + modified_cmd.size.dim[i] = (modified_cmd.size.dim[i] - 1) * ccv_max(cmd.convolution.dilation[i], 1) + 1; + ccv_nnc_hint_tensor_forward(modified_cmd, inputs[0], hint, outputs); } REGISTER_COMMAND(CCV_NNC_CONVOLUTION_FORWARD)(ccv_nnc_cmd_registry_t* const registry) diff --git a/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu b/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu index b241319c2..f2d6533e1 100644 --- a/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu +++ b/lib/nnc/cmd/convolution/gpu/ccv_nnc_conv_gpu_cudnn.cu @@ -29,7 +29,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]); const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]); const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]); - const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, inputs[1]->info.datatype); + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, inputs[1]->info.datatype); cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups); cudnnConvolutionFwdAlgo_t algo; @@ -124,7 +124,7 @@ static int _ccv_nnc_conv_forw_autotune(const ccv_nnc_cmd_t cmd, size_t max_works const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]); const ccv_nnc_cudnn_filter_descriptor_t w = ccv_nnc_cudnn_get_filter_descriptor(stream_context, (const ccv_nnc_tensor_t*)inputs[1]); const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[0]); - const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, inputs[1]->info.datatype); + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, inputs[1]->info.datatype); cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups); int count = 0; cudnnConvolutionFwdAlgoPerf_t perfs[CCV_NNC_CMD_CUDNN_CONV_FWD_ALGO_COUNT]; @@ -210,7 +210,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint const ccv_nnc_cudnn_tensor_view_descriptor_t g = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[0]); const int is_w_nhwc = (output_size > 1 && outputs[1]) ? outputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC : inputs[2]->info.format == CCV_TENSOR_FORMAT_NHWC; const int w_datatype = (output_size > 1 && outputs[1]) ? outputs[1]->info.datatype : inputs[2]->info.datatype; - const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups); static const float one = 1, zero = 0; @@ -370,7 +370,7 @@ static int _ccv_nnc_conv_back_autotune(const ccv_nnc_cmd_t cmd, size_t max_works int count = 0; const int is_w_nhwc = (output_size > 1 && outputs[1]) ? outputs[1]->info.format == CCV_TENSOR_FORMAT_NHWC : inputs[2]->info.format == CCV_TENSOR_FORMAT_NHWC; const int w_datatype = (output_size > 1 && outputs[1]) ? outputs[1]->info.datatype : inputs[2]->info.datatype; - const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); + const ccv_nnc_cudnn_convolution_descriptor_t conv = ccv_nnc_cudnn_get_convolution_descriptor(stream_context, cmd.info, hint, (is_w_nhwc && w_datatype == CCV_16F) ? CCV_32F : w_datatype); cudnnSetConvolutionGroupCount(conv.descriptor, cmd.info.convolution.groups); cudnnConvolutionBwdFilterAlgo_t filter_algorithm = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; if (output_size > 1 && outputs[1]) diff --git a/lib/nnc/gpu/ccv_nnc_compat.cu b/lib/nnc/gpu/ccv_nnc_compat.cu index 577eb4cee..03550be95 100644 --- a/lib/nnc/gpu/ccv_nnc_compat.cu +++ b/lib/nnc/gpu/ccv_nnc_compat.cu @@ -1338,7 +1338,7 @@ void ccv_nnc_cudnn_deinit_filter_descriptor(const ccv_nnc_cudnn_filter_descripto ccv_nnc_stream_context_return_filter_descriptor(filter_desc.stream_context, filter_desc.descriptor); } -ccv_nnc_cudnn_convolution_descriptor_t ccv_nnc_cudnn_get_convolution_descriptor(const ccv_nnc_stream_context_t* const stream_context, const ccv_nnc_hint_t hint, const int datatype) +ccv_nnc_cudnn_convolution_descriptor_t ccv_nnc_cudnn_get_convolution_descriptor(const ccv_nnc_stream_context_t* const stream_context, const ccv_nnc_cmd_param_t cmd, const ccv_nnc_hint_t hint, const int datatype) { ccv_nnc_cudnn_convolution_descriptor_t convolution_desc = { stream_context, @@ -1351,11 +1351,13 @@ ccv_nnc_cudnn_convolution_descriptor_t ccv_nnc_cudnn_get_convolution_descriptor( int v[CCV_NNC_MAX_DIM]; for (i = 0; i < CCV_NNC_MAX_DIM; i++) v[i] = hint.stride.dim[i]; + int u[CCV_NNC_MAX_DIM]; + for (i = 0; i < CCV_NNC_MAX_DIM; i++) + u[i] = ccv_max(cmd.convolution.dilation[i], 1); if (CCV_NNC_MAX_DIM == 2) { - CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(convolution_desc.descriptor, p[0], p[1], v[0], v[1], 1, 1, CUDNN_CROSS_CORRELATION, ccv_nnc_cudnn_datatype(datatype))); + CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(convolution_desc.descriptor, p[0], p[1], v[0], v[1], u[0], u[1], CUDNN_CROSS_CORRELATION, ccv_nnc_cudnn_datatype(datatype))); } else { - int u[CCV_NNC_MAX_DIM]; for (i = 0; i < CCV_NNC_MAX_DIM; i++) u[i] = 1; CUDNN_ENFORCE(cudnnSetConvolutionNdDescriptor(convolution_desc.descriptor, CCV_NNC_MAX_DIM, p, v, u, CUDNN_CROSS_CORRELATION, ccv_nnc_cudnn_datatype(datatype))); diff --git a/lib/nnc/gpu/ccv_nnc_compat.h b/lib/nnc/gpu/ccv_nnc_compat.h index 8203e8625..0f96a93ab 100644 --- a/lib/nnc/gpu/ccv_nnc_compat.h +++ b/lib/nnc/gpu/ccv_nnc_compat.h @@ -191,7 +191,7 @@ typedef struct { const ccv_nnc_stream_context_t* stream_context; cudnnConvolutionDescriptor_t descriptor; } ccv_nnc_cudnn_convolution_descriptor_t; -ccv_nnc_cudnn_convolution_descriptor_t ccv_nnc_cudnn_get_convolution_descriptor(const ccv_nnc_stream_context_t* const stream_context, const ccv_nnc_hint_t hint, const int datatype); +ccv_nnc_cudnn_convolution_descriptor_t ccv_nnc_cudnn_get_convolution_descriptor(const ccv_nnc_stream_context_t* const stream_context, const ccv_nnc_cmd_param_t cmd, const ccv_nnc_hint_t hint, const int datatype); void ccv_nnc_cudnn_deinit_convolution_descriptor(const ccv_nnc_cudnn_convolution_descriptor_t convolution_desc); #endif diff --git a/test/int/nnc/cudnn.tests.c b/test/int/nnc/cudnn.tests.c index f0c2efb0a..24fa55dfb 100644 --- a/test/int/nnc/cudnn.tests.c +++ b/test/int/nnc/cudnn.tests.c @@ -277,6 +277,74 @@ TEST_CASE("cudnn forward convolution in half precision with palettize weights") ccv_nnc_tensor_free(ga); } +TEST_CASE("cudnn forward convolution with dilation 2, 3") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN)); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_FORWARD(1, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM); + cmd.info.convolution.dilation[0] = 2; + cmd.info.convolution.dilation[1] = 3; + cmd.backend = CCV_NNC_BACKEND_CPU_REF; + assert(cmd.backend >= 0); + ccv_nnc_cmd_param_t modified_cmd = cmd.info; + modified_cmd.size.dim[0] = (cmd.info.size.dim[0] - 1) * ccv_max(cmd.info.convolution.dilation[0], 1) + 1; + modified_cmd.size.dim[1] = (cmd.info.size.dim[1] - 1) * ccv_max(cmd.info.convolution.dilation[1], 1) + 1; + ccv_nnc_hint_t hint = ccv_nnc_hint_auto(modified_cmd, a->info, b->info); + assert(ccv_nnc_hint_verify(hint, modified_cmd, a->info, b->info) == 0); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM), 0); + // configure the inlets. + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + int i; + for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++) + w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE); + for (i = 0; i < INPUT_SIZE * INPUT_SIZE * INPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < OUTPUT_DIM; i++) + bias->data.f32[i] = (float)i / OUTPUT_DIM; + // Copy generated matrix values over to GPU. + ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gwo = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0); + ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM), 0); + ccv_nnc_cmd_t move = CMD_DATA_TRANSFER_FORWARD(); + move.backend = CCV_NNC_BACKEND_GPU_REF; + assert(move.backend >= 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(ga, gw, gbias), 0); + ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0); + ccv_nnc_tensor_t* gc = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + + ccv_nnc_cmd_t transform = CMD_FORMAT_TRANSFORM_FORWARD(); + transform.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(transform.backend >= 0); + ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU); + ccv_nnc_cmd_exec(transform, ccv_nnc_no_hint, 0, TENSOR_LIST(gw), TENSOR_LIST(gwo), stream_context); + ccv_nnc_stream_context_wait(stream_context); + ccv_nnc_tensor_free(gw); + + cmd.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(cmd.backend >= 0); + cmd.algorithm = -1; + cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(ga, gwo, gbias), TENSOR_LIST(gc), stream_context); + assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ga, gwo, gbias), TENSOR_LIST(gc), stream_context)); + ccv_nnc_stream_context_wait(stream_context); + ccv_nnc_stream_context_free(stream_context); + ccv_nnc_tensor_t* c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(gc), TENSOR_LIST(c), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, b->data.f32, c->data.f32, BATCH_SIZE * OUTPUT_DIM * OUTPUT_SIZE * OUTPUT_SIZE, 1e-4, "output from cudnn should match from CPU"); + ccv_nnc_tensor_free(c); + ccv_nnc_tensor_free(gc); + ccv_nnc_tensor_free(bias); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(b); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(gbias); + ccv_nnc_tensor_free(gwo); + ccv_nnc_tensor_free(ga); +} + TEST_CASE("cudnn backward convolution") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_BACKWARD, CCV_NNC_BACKEND_GPU_CUDNN)); @@ -508,6 +576,80 @@ TEST_CASE("cudnn backward convolution in half precision") ccv_nnc_tensor_free(cdbias16); } +TEST_CASE("cudnn backward convolution with dilation 2, 3") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_BACKWARD, CCV_NNC_BACKEND_GPU_CUDNN)); + ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* g = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_BACKWARD(1, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM); + cmd.info.convolution.dilation[0] = 2; + cmd.info.convolution.dilation[1] = 3; + cmd.backend = CCV_NNC_BACKEND_CPU_REF; + assert(cmd.backend >= 0); + ccv_nnc_cmd_param_t modified_cmd = cmd.info; + modified_cmd.size.dim[0] = (modified_cmd.size.dim[0] - 1) * cmd.info.convolution.dilation[0] + 1; + modified_cmd.size.dim[1] = (modified_cmd.size.dim[1] - 1) * cmd.info.convolution.dilation[1] + 1; + ccv_nnc_hint_t hint = ccv_nnc_hint_auto(modified_cmd, a->info, g->info); + assert(ccv_nnc_hint_verify(hint, modified_cmd, a->info, g->info) == 0); + ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* dw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM), 0); + // configure the inlets. + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 0); + int i; + for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++) + w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE); + for (i = 0; i < INPUT_SIZE * INPUT_SIZE * INPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM * ccv_max(1, BATCH_SIZE); i++) + g->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / OUTPUT_DIM; // (OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM); + // Copy generated matrix values over to GPU. + ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gg = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* gh = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM), 0); + ccv_nnc_tensor_t* gdw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* gdbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM), 0); + ccv_nnc_cmd_t move = CMD_DATA_TRANSFER_FORWARD(); + move.backend = CCV_NNC_BACKEND_GPU_REF; + assert(move.backend >= 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, g), TENSOR_LIST(ga, gw, gg), 0); + ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(g, a, w), TENSOR_LIST(h, dw, dbias), 0); + + cmd.backend = CCV_NNC_BACKEND_GPU_CUDNN; + assert(cmd.backend >= 0); + cmd.algorithm = -1; + ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU); + cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(gg, ga, gw), TENSOR_LIST(gh, gdw, gdbias), stream_context); + assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(gg, ga, gw), TENSOR_LIST(gh, gdw, gdbias), stream_context)); + ccv_nnc_stream_context_wait(stream_context); + ccv_nnc_stream_context_free(stream_context); + ccv_nnc_tensor_t* ch = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* cdw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0); + ccv_nnc_tensor_t* cdbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM), 0); + ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(gh, gdw, gdbias), TENSOR_LIST(ch, cdw, cdbias), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dw->data.f32, cdw->data.f32, INPUT_DIM * OUTPUT_DIM * KERNEL_SIZE * KERNEL_SIZE, 5e-1, "output from cudnn should match from CPU"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dbias->data.f32, cdbias->data.f32, OUTPUT_DIM, 5e-1, "output from cudnn should match from CPU"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, h->data.f32, ch->data.f32, BATCH_SIZE * INPUT_DIM * INPUT_SIZE * INPUT_SIZE, 1e-4, "output from cudnn should match from CPU"); + ccv_nnc_tensor_free(h); + ccv_nnc_tensor_free(gh); + ccv_nnc_tensor_free(w); + ccv_nnc_tensor_free(g); + ccv_nnc_tensor_free(a); + ccv_nnc_tensor_free(gbias); + ccv_nnc_tensor_free(gdbias); + ccv_nnc_tensor_free(gdw); + ccv_nnc_tensor_free(gw); + ccv_nnc_tensor_free(gg); + ccv_nnc_tensor_free(ga); + ccv_nnc_tensor_free(ch); + ccv_nnc_tensor_free(cdw); + ccv_nnc_tensor_free(cdbias); +} + TEST_CASE("compare batch norm with cudnn") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_BATCH_NORM_FORWARD, CCV_NNC_BACKEND_GPU_CUDNN) &&