Skip to content

Commit

Permalink
Support dilation in convolution op for MPS.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 20, 2023
1 parent 7d7e335 commit bb296cd
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int* biasdim_r = biasdim;
int* biasstride_r = biasstride;
int indices[3];
const int dilationX = ccv_max(cmd.info.convolution.dilation[1], 1);
const int dilationY = ccv_max(cmd.info.convolution.dilation[0], 1);
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, adim_r, astride_r, &mps_input_a);
Expand All @@ -376,7 +378,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
[inputTensors addObject:mps_input_w];
MPSGraphShapedType* mps_w_shape = ccv_nnc_mps_graph_tensor_input_shape(w, wdim_r, wstride_r);
[inputShapedTypes addObject:mps_w_shape];
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:1 dilationRateInY:1 groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(a->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:dilationX dilationRateInY:dilationY groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(a->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphTensor* mps_b = [graph convolution2DWithSourceTensor:mps_a weightsTensor:mps_w descriptor:descriptor name:nil];
if (bias)
{
Expand Down Expand Up @@ -419,6 +421,8 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_tensor_view_t* dw = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0; // weight_update
ccv_nnc_tensor_view_t* h = (ccv_nnc_tensor_view_t*)outputs[0]; // output gradients
ccv_nnc_tensor_view_t* db = output_size > 2 ? (ccv_nnc_tensor_view_t*)outputs[2] : 0;
const int dilationX = ccv_max(cmd.info.convolution.dilation[1], 1);
const int dilationY = ccv_max(cmd.info.convolution.dilation[0], 1);

@autoreleasepool {
MPSCommandBuffer* command_buffer = 0;
Expand Down Expand Up @@ -483,7 +487,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int i;
for (i = 0; i < h_nd; i++)
[h_shape addObject:@(h->info.dim[i])];
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:1 dilationRateInY:1 groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(g->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:dilationX dilationRateInY:dilationY groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(g->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphTensor* mps_h = [graph convolution2DDataGradientWithIncomingGradientTensor:mps_g
weightsTensor:mps_w
outputShape:h_shape
Expand Down Expand Up @@ -523,7 +527,7 @@ static int _ccv_nnc_conv_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int i;
for (i = 0; i < dw_nd; i++)
[dw_shape addObject:@(dw->info.dim[i])];
MPSGraphConvolution2DOpDescriptor* dw_descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:1 dilationRateInY:1 groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(g->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphConvolution2DOpDescriptor* dw_descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:dilationX dilationRateInY:dilationY groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(g->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];

MPSGraphTensor* mps_dw = [graph convolution2DWeightsGradientWithIncomingGradientTensor:mps_g
sourceTensor:mps_a
Expand Down
158 changes: 158 additions & 0 deletions test/int/nnc/mpsdnn.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,74 @@ TEST_CASE("mps forward convolution in half precision")
ccv_nnc_tensor_free(ga);
}

TEST_CASE("mps forward convolution with dilation 2, 3")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_FORWARD, CCV_NNC_BACKEND_MPS));
ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* b = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0);
ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_FORWARD(1, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM);
cmd.info.convolution.dilation[0] = 2;
cmd.info.convolution.dilation[1] = 3;
cmd.backend = CCV_NNC_BACKEND_CPU_REF;
assert(cmd.backend >= 0);
ccv_nnc_cmd_param_t modified_cmd = cmd.info;
modified_cmd.size.dim[0] = (cmd.info.size.dim[0] - 1) * ccv_max(cmd.info.convolution.dilation[0], 1) + 1;
modified_cmd.size.dim[1] = (cmd.info.size.dim[1] - 1) * ccv_max(cmd.info.convolution.dilation[1], 1) + 1;
ccv_nnc_hint_t hint = ccv_nnc_hint_auto(modified_cmd, a->info, b->info);
assert(ccv_nnc_hint_verify(hint, modified_cmd, a->info, b->info) == 0);
ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* bias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM), 0);
// configure the inlets.
dsfmt_t dsfmt;
dsfmt_init_gen_rand(&dsfmt, 0);
int i;
for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++)
w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE);
for (i = 0; i < INPUT_SIZE * INPUT_SIZE * INPUT_DIM * ccv_max(1, BATCH_SIZE); i++)
a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
for (i = 0; i < OUTPUT_DIM; i++)
bias->data.f32[i] = (float)i / OUTPUT_DIM;
// Copy generated matrix values over to GPU.
ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gwo = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0);
ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM), 0);
ccv_nnc_cmd_t move = CMD_DATA_TRANSFER_FORWARD();
move.backend = CCV_NNC_BACKEND_MPS;
assert(move.backend >= 0);
ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(ga, gw, gbias), 0);
ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(a, w, bias), TENSOR_LIST(b), 0);
ccv_nnc_tensor_t* gc = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0);

ccv_nnc_cmd_t transform = CMD_FORMAT_TRANSFORM_FORWARD();
transform.backend = CCV_NNC_BACKEND_MPS;
assert(transform.backend >= 0);
ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU);
ccv_nnc_cmd_exec(transform, ccv_nnc_no_hint, 0, TENSOR_LIST(gw), TENSOR_LIST(gwo), stream_context);
ccv_nnc_stream_context_wait(stream_context);
ccv_nnc_tensor_free(gw);

cmd.backend = CCV_NNC_BACKEND_MPS;
assert(cmd.backend >= 0);
cmd.algorithm = -1;
cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(ga, gwo, gbias), TENSOR_LIST(gc), stream_context);
assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ga, gwo, gbias), TENSOR_LIST(gc), stream_context));
ccv_nnc_stream_context_wait(stream_context);
ccv_nnc_stream_context_free(stream_context);
ccv_nnc_tensor_t* c = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0);
ccv_nnc_cmd_exec(move, ccv_nnc_no_hint, 0, TENSOR_LIST(gc), TENSOR_LIST(c), 0);
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, b->data.f32, c->data.f32, OUTPUT_DIM * OUTPUT_SIZE * OUTPUT_SIZE, 1e-5, "output from mps should match from CPU");
ccv_nnc_tensor_free(c);
ccv_nnc_tensor_free(gc);
ccv_nnc_tensor_free(bias);
ccv_nnc_tensor_free(w);
ccv_nnc_tensor_free(b);
ccv_nnc_tensor_free(a);
ccv_nnc_tensor_free(gbias);
ccv_nnc_tensor_free(gwo);
ccv_nnc_tensor_free(ga);
}

TEST_CASE("compare softmax with mps")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SOFTMAX_FORWARD, CCV_NNC_BACKEND_MPS));
Expand Down Expand Up @@ -2277,6 +2345,96 @@ TEST_CASE("mps backward convolution in nhwc format")
ccv_nnc_tensor_free(cdbias);
}

TEST_CASE("mps backward convolution in nchw format with dilation 2, 3")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CONVOLUTION_BACKWARD, CCV_NNC_BACKEND_MPS));
ccv_nnc_tensor_t* a = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* h = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* g = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0);
ccv_nnc_cmd_t cmd = CMD_CONVOLUTION_BACKWARD(1, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM);
cmd.info.convolution.dilation[0] = 2;
cmd.info.convolution.dilation[1] = 3;
cmd.backend = CCV_NNC_BACKEND_CPU_REF;
assert(cmd.backend >= 0);
ccv_nnc_cmd_param_t modified_cmd = cmd.info;
modified_cmd.size.dim[0] = (cmd.info.size.dim[0] - 1) * ccv_max(cmd.info.convolution.dilation[0], 1) + 1;
modified_cmd.size.dim[1] = (cmd.info.size.dim[1] - 1) * ccv_max(cmd.info.convolution.dilation[1], 1) + 1;
ccv_nnc_hint_t hint = ccv_nnc_hint_auto(modified_cmd, a->info, g->info);
assert(ccv_nnc_hint_verify(hint, modified_cmd, a->info, g->info) == 0);
ccv_nnc_tensor_t* w = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* dw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* dbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, 1, 1, OUTPUT_DIM), 0);
// configure the inlets.
dsfmt_t dsfmt;
dsfmt_init_gen_rand(&dsfmt, 0);
int i;
for (i = 0; i < INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE * OUTPUT_DIM; i++)
w->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / (INPUT_DIM * KERNEL_SIZE * KERNEL_SIZE);
for (i = 0; i < INPUT_SIZE * INPUT_SIZE * INPUT_DIM * ccv_max(1, BATCH_SIZE); i++)
a->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
for (i = 0; i < OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM * ccv_max(1, BATCH_SIZE); i++)
g->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / OUTPUT_DIM; // (OUTPUT_SIZE * OUTPUT_SIZE * OUTPUT_DIM);
// Copy generated matrix values over to GPU.
ccv_nnc_tensor_t* ga = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gg = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, OUTPUT_SIZE, OUTPUT_SIZE, OUTPUT_DIM), 0);
ccv_nnc_tensor_t* gh = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 1, 1, 1, OUTPUT_DIM), 0);
ccv_nnc_tensor_t* gdw = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* gdbias = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, 1, 1, 1, OUTPUT_DIM), 0);
ccv_nnc_tensor_t* gao = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, BATCH_SIZE, INPUT_DIM, INPUT_SIZE, INPUT_SIZE), 0);
ccv_nnc_tensor_t* ggo = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, BATCH_SIZE, OUTPUT_DIM, OUTPUT_SIZE, OUTPUT_SIZE), 0);
ccv_nnc_tensor_t* gho = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, BATCH_SIZE, INPUT_DIM, INPUT_SIZE, INPUT_SIZE), 0);
ccv_nnc_tensor_t* gwo = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0);
ccv_nnc_tensor_t* gbiaso = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, 1, OUTPUT_DIM, 1, 1), 0);
ccv_nnc_tensor_t* gdwo = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, OUTPUT_DIM, INPUT_DIM, KERNEL_SIZE, KERNEL_SIZE), 0);
ccv_nnc_tensor_t* gdbiaso = ccv_nnc_tensor_new(0, GPU_TENSOR_NCHW(000, 32F, 1, OUTPUT_DIM, 1, 1), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(a, w, g), TENSOR_LIST(ga, gw, gg), 0);
ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(g, a, w), TENSOR_LIST(h, dw, dbias), 0);
ccv_nnc_cmd_exec(CMD_FORMAT_TRANSFORM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ga, gw, gg), TENSOR_LIST(gao, gwo, ggo), 0);
cmd.backend = CCV_NNC_BACKEND_MPS;

assert(cmd.backend >= 0);
cmd.algorithm = -1;
ccv_nnc_stream_context_t* stream_context = ccv_nnc_stream_context_new(CCV_STREAM_CONTEXT_GPU);

cmd = ccv_nnc_cmd_autotune(cmd, 1 * 1024 * 1024 * 1024, hint, 0, TENSOR_LIST(ggo, gao, gwo), TENSOR_LIST(gho, gdwo, gdbiaso), stream_context);
assert(CCV_NNC_EXEC_SUCCESS == ccv_nnc_cmd_exec(cmd, hint, 0, TENSOR_LIST(ggo, gao, gwo), TENSOR_LIST(gho, gdwo, gdbiaso), stream_context));
ccv_nnc_stream_context_wait(stream_context);
ccv_nnc_stream_context_free(stream_context);
ccv_nnc_tensor_t* ch = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, BATCH_SIZE, INPUT_SIZE, INPUT_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* cdw = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, OUTPUT_DIM, KERNEL_SIZE, KERNEL_SIZE, INPUT_DIM), 0);
ccv_nnc_tensor_t* cdbias = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 1, OUTPUT_DIM, 1, 1), 0);

ccv_nnc_cmd_exec(CMD_FORMAT_TRANSFORM_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gho, gdwo, gdbiaso), TENSOR_LIST(gh, gdw, gdbias), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gh, gdw, gdbias), TENSOR_LIST(ch, cdw, cdbias), 0);

REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dw->data.f32, cdw->data.f32, INPUT_DIM * OUTPUT_DIM * KERNEL_SIZE * KERNEL_SIZE, 5e-1, "output from mps should match from CPU");
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, dbias->data.f32, cdbias->data.f32, OUTPUT_DIM, 5e-1, "output from mps should match from CPU");
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, h->data.f32, ch->data.f32, BATCH_SIZE * INPUT_DIM * INPUT_SIZE * INPUT_SIZE, 1e-4, "output from mps should match from CPU");
ccv_nnc_tensor_free(gao);
ccv_nnc_tensor_free(ggo);
ccv_nnc_tensor_free(gho);
ccv_nnc_tensor_free(gwo);
ccv_nnc_tensor_free(gbiaso);
ccv_nnc_tensor_free(gdwo);
ccv_nnc_tensor_free(gdbiaso);
ccv_nnc_tensor_free(h);
ccv_nnc_tensor_free(gh);
ccv_nnc_tensor_free(w);
ccv_nnc_tensor_free(g);
ccv_nnc_tensor_free(a);
ccv_nnc_tensor_free(gbias);
ccv_nnc_tensor_free(gdbias);
ccv_nnc_tensor_free(gdw);
ccv_nnc_tensor_free(gw);
ccv_nnc_tensor_free(gg);
ccv_nnc_tensor_free(ga);
ccv_nnc_tensor_free(ch);
ccv_nnc_tensor_free(cdw);
ccv_nnc_tensor_free(cdbias);
}

TEST_CASE("compare group norm gradient with mps")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GROUP_NORM_FORWARD, CCV_NNC_BACKEND_MPS) &&
Expand Down

0 comments on commit bb296cd

Please sign in to comment.