Skip to content

Commit

Permalink
Add conv3d support on MPS backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 24, 2025
1 parent 819aa2a commit 19f4960
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 64 deletions.
6 changes: 3 additions & 3 deletions lib/nnc/ccv_nnc_easy.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ static inline void ccv_nnc_tensor_view_get_dim(const ccv_nnc_tensor_view_t* cons
{
int x;
const int nd = ccv_nnc_tensor_nd(tv->info.dim);
const int offset = CCV_NNC_MAX_DIM + 2 - nd;
const int offset = ccv_max(CCV_NNC_MAX_DIM + 2 - nd, 0);
for (x = 0; x < offset; x++)
dim[x] = 1;
for (x = offset; x < CCV_NNC_MAX_DIM + 2; x++)
for (x = offset; x < ccv_max(CCV_NNC_MAX_DIM + 2, nd); x++)
dim[x] = tv->info.dim[x - offset];
dim[CCV_NNC_MAX_DIM + 2] = 0;
dim[ccv_max(CCV_NNC_MAX_DIM + 2, nd)] = 0;
}

static inline CCV_WARN_UNUSED(int) ccv_nnc_is_tensor_stride_packed(const int stride[CCV_NNC_MAX_DIM_ALLOC], const int dim[CCV_NNC_MAX_DIM_ALLOC])
Expand Down
86 changes: 58 additions & 28 deletions lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,28 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
assert(w->info.format == CCV_TENSOR_FORMAT_NCHW);
int biasdim[CCV_NNC_MAX_DIM_ALLOC] = {0};
int biasstride[CCV_NNC_MAX_DIM_ALLOC] = {0};
const int size_nd = ccv_nnc_tensor_nd(cmd.info.size.dim) - 1;
assert(size_nd == 2 || size_nd == 3);
if (bias)
{
assert(CCV_GET_DATA_TYPE(bias->info.datatype) != CCV_QX);
assert(ccv_nnc_tensor_nd(bias->info.dim) == 1);
int i;
for (i = 0; i < CCV_NNC_MAX_DIM + 2; i++)
for (i = 0; i < size_nd + 2; i++)
biasdim[i] = 1;
int c;
if (b->info.format == CCV_TENSOR_FORMAT_NCHW)
c = 1;
else if (b->info.format == CCV_TENSOR_FORMAT_NHWC)
c = CCV_NNC_MAX_DIM + 1;
c = size_nd + 1;
else
c = 0;
biasdim[c] = bias->info.dim[0];
if (CCV_IS_TENSOR_VIEW(bias))
{
for (i = 0; i < c; i++)
biasstride[i] = bias->info.dim[0] * bias->stride[0];
for (i = c; i < CCV_NNC_MAX_DIM + 2; i++)
for (i = c; i < size_nd + 2; i++)
biasstride[i] = bias->stride[0];
}
}
Expand Down Expand Up @@ -98,15 +100,15 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int w_batch_size;
int b_batch_size;
if (use_mfa) {
a_batch_size = a_nd < 4 ? 1 : adim[a_nd - 4];
a_batch_size = a_nd < size_nd + 2 ? 1 : adim[a_nd - size_nd - 2];
int i;
for (i = 0; i < a_nd - 4; i++)
for (i = 0; i < a_nd - size_nd - 2; i++)
a_batch_size *= adim[i];
w_batch_size = w_nd < 5 ? 1 : w->info.dim[w_nd - 5];
for (i = 0; i < w_nd - 5; i++)
w_batch_size = w_nd < size_nd + 3 ? 1 : w->info.dim[w_nd - size_nd - 3];
for (i = 0; i < w_nd - size_nd - 3; i++)
w_batch_size *= w->info.dim[i];
b_batch_size = b_nd < 4 ? 1 : b->info.dim[b_nd - 4];
for (i = 0; i < b_nd - 4; i++)
b_batch_size = b_nd < size_nd + 2 ? 1 : b->info.dim[b_nd - size_nd - 2];
for (i = 0; i < b_nd - size_nd - 2; i++)
b_batch_size *= b->info.dim[i];
assert(a_batch_size == b_batch_size || a_batch_size == 1);
assert(w_batch_size == a_batch_size || w_batch_size == 1);
Expand Down Expand Up @@ -151,17 +153,15 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
// Height and width of the filter, not the image.
const int W = wdim[w_nd - 1];
const int H = wdim[w_nd - 2];
const int D = size_nd > 2 ? wdim[w_nd - 3] : 1;

if ((H != 1) || (W != 1)) {
if ((H != 1) || (W != 1) || (D != 1)) {
use_mfa = false;
fallback_reason = "Kernel size not 1x1.";
} else if (hint.stride.dim[1] != 1 || hint.stride.dim[0] != 1) {
} else if (hint.stride.dim[1] != 1 || hint.stride.dim[0] != 1 || (size_nd == 3 && hint.stride.dim[2] != 1)) {
use_mfa = false;
fallback_reason = "Strided filter.";
} else if (hint.border.begin[1] != 0 ||
hint.border.end[1] != 0 ||
hint.border.begin[0] != 0 ||
hint.border.end[0] != 0) {
} else if (hint.border.begin[1] != 0 || hint.border.end[1] != 0 || hint.border.begin[0] != 0 || hint.border.end[0] != 0 || (size_nd == 3 && (hint.border.begin[2] != 0 || hint.border.end[2] != 0))) {
use_mfa = false;
fallback_reason = "Padded.";
} else if (cmd.info.convolution.groups != 1) {
Expand Down Expand Up @@ -225,6 +225,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int O;
int H;
int W;
int D = 1;

// Bypass a compilation error from a header.
int I_dim;
Expand All @@ -234,26 +235,33 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
I_dim = adim[a_nd - 1];
W = adim[a_nd - 2];
H = adim[a_nd - 3];
if (size_nd == 3)
D = adim[a_nd - 4];
} else if (a->info.format == CCV_TENSOR_FORMAT_NCHW) {
// IxHW -> KxM
W = adim[a_nd - 1];
H = adim[a_nd - 2];
I_dim = adim[a_nd - 3];
if (size_nd == 3)
{
D = adim[a_nd - 3];
I_dim = adim[a_nd - 4];
} else
I_dim = adim[a_nd - 3];
} else {
// This should never happen.
assert(false);
}

// OxI -> NxK
assert(I_dim == wdim[w_nd - 3]);
O = wdim[w_nd - 4];
assert(I_dim == wdim[w_nd - size_nd - 1]);
O = wdim[w_nd - size_nd - 2];

ccv_nnc_mfa_gemm_params_t params;
if (a->info.format == CCV_TENSOR_FORMAT_NHWC)
{
params = (ccv_nnc_mfa_gemm_params_t){
.data_type = mtl_data_type,
.M = (uint32_t)(H * W),
.M = (uint32_t)(H * W * D),
.N = (uint32_t)O,
.K = (uint32_t)I_dim,
.A_trans = 0,
Expand All @@ -263,16 +271,16 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
.register_float = 0,

.batch_dimension = b_batch_size,
.batch_stride_a = a_batch_size > 1 ? H * W * I_dim : 0,
.batch_stride_a = a_batch_size > 1 ? H * W * D * I_dim : 0,
.batch_stride_b = w_batch_size > 1 ? O * I_dim : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * O : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * D * O : 0,
.batch_stride_d = 0,
};
} else {
params = (ccv_nnc_mfa_gemm_params_t){
.data_type = mtl_data_type,
.M = (uint32_t)O,
.N = (uint32_t)(H * W),
.N = (uint32_t)(H * W * D),
.K = (uint32_t)I_dim,
.A_trans = 0,
.B_trans = 0,
Expand All @@ -282,8 +290,8 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

.batch_dimension = b_batch_size,
.batch_stride_a = w_batch_size > 1 ? O * I_dim : 0,
.batch_stride_b = a_batch_size > 1 ? H * W * I_dim : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * O : 0,
.batch_stride_b = a_batch_size > 1 ? H * W * D * I_dim : 0,
.batch_stride_c = b_batch_size > 1 ? H * W * D * O : 0,
.batch_stride_d = 0,
};
}
Expand Down Expand Up @@ -387,8 +395,9 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
int* biasdim_r = biasdim;
int* biasstride_r = biasstride;
int indices[3];
const int dilationX = ccv_max(cmd.info.convolution.dilation[1], 1);
const int dilationY = ccv_max(cmd.info.convolution.dilation[0], 1);
const int dilationZ = size_nd == 2 ? 1 : ccv_max(cmd.info.convolution.dilation[size_nd - 3], 1);
const int dilationY = ccv_max(cmd.info.convolution.dilation[size_nd - 2], 1);
const int dilationX = ccv_max(cmd.info.convolution.dilation[size_nd - 1], 1);
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, adim_r, astride_r, &mps_input_a);
Expand All @@ -400,8 +409,29 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
[inputTensors addObject:mps_input_w];
MPSGraphShapedType* mps_w_shape = ccv_nnc_mps_graph_tensor_input_shape(w, wdim_r, wstride_r);
[inputShapedTypes addObject:mps_w_shape];
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:dilationX dilationRateInY:dilationY groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(a->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
MPSGraphTensor* mps_b = [graph convolution2DWithSourceTensor:mps_a weightsTensor:mps_w descriptor:descriptor name:nil];
MPSGraphTensor* mps_b;
if (size_nd == 2)
{
MPSGraphConvolution2DOpDescriptor* descriptor = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:hint.stride.dim[1] strideInY:hint.stride.dim[0] dilationRateInX:dilationX dilationRateInY:dilationY groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[1] paddingRight:hint.border.end[1] paddingTop:hint.border.begin[0] paddingBottom:hint.border.end[0] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:ccv_nnc_mps_tensor_data_layout(a->info.format) weightsLayout:MPSGraphTensorNamedDataLayoutOIHW];
mps_b = [graph convolution2DWithSourceTensor:mps_a weightsTensor:mps_w descriptor:descriptor name:nil];
} else if (size_nd == 3) {
MPSGraphTensorNamedDataLayout data_layout;
switch (a->info.format)
{
case CCV_TENSOR_FORMAT_NCHW:
data_layout = MPSGraphTensorNamedDataLayoutNCDHW;
break;
case CCV_TENSOR_FORMAT_NHWC:
data_layout = MPSGraphTensorNamedDataLayoutNDHWC;
break;
case CCV_TENSOR_FORMAT_CHWN:
assert(0 && "doesn't support CHWN");
}
MPSGraphConvolution3DOpDescriptor* descriptor = [MPSGraphConvolution3DOpDescriptor descriptorWithStrideInX:hint.stride.dim[size_nd - 1] strideInY:hint.stride.dim[size_nd - 2] strideInZ:hint.stride.dim[size_nd - 3] dilationRateInX:dilationX dilationRateInY:dilationY dilationRateInZ:dilationZ groups:cmd.info.convolution.groups paddingLeft:hint.border.begin[size_nd - 1] paddingRight:hint.border.end[size_nd - 1] paddingTop:hint.border.begin[size_nd - 2] paddingBottom:hint.border.end[size_nd - 2] paddingFront:hint.border.begin[size_nd - 3] paddingBack:hint.border.end[size_nd - 3] paddingStyle:MPSGraphPaddingStyleExplicit dataLayout:data_layout weightsLayout:MPSGraphTensorNamedDataLayoutOIDHW];
mps_b = [graph convolution3DWithSourceTensor:mps_a weightsTensor:mps_w descriptor:descriptor name:nil];
} else {
assert(0);
}
if (bias)
{
MPSGraphTensor* mps_input_bias;
Expand Down
62 changes: 30 additions & 32 deletions lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -256,66 +256,64 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
ccv_nnc_tensor_view_get_stride(a, astride);
ccv_nnc_tensor_view_get_dim(&bt, bdim);
ccv_nnc_tensor_view_get_stride(&bt, bstride);
const int b_nd = ccv_nnc_tensor_nd(bdim);
int j;
if (a->info.format == CCV_TENSOR_FORMAT_NHWC)
{
if (bt.info.format == CCV_TENSOR_FORMAT_NCHW)
{
int c = bdim[1];
bdim[1] = bdim[2];
bdim[2] = bdim[3];
bdim[3] = c;
for (j = 1; j < b_nd - 1; j++)
bdim[j] = bdim[j + 1];
bdim[b_nd - 1] = c;
c = bstride[1];
bstride[1] = bstride[2];
bstride[2] = bstride[3];
bstride[3] = c;
for (j = 1; j < b_nd - 1; j++)
bstride[j] = bstride[j + 1];
bstride[b_nd - 1] = c;
} else {
assert(bt.info.format == CCV_TENSOR_FORMAT_CHWN);
int t;
CCV_SWAP(bdim[0], bdim[3], t);
CCV_SWAP(bstride[0], bstride[3], t);
CCV_SWAP(bdim[0], bdim[b_nd - 1], t);
CCV_SWAP(bstride[0], bstride[b_nd - 1], t);
}
} else if (a->info.format == CCV_TENSOR_FORMAT_NCHW) {
if (bt.info.format == CCV_TENSOR_FORMAT_NHWC)
{
int c = bdim[3];
bdim[3] = bdim[2];
bdim[2] = bdim[1];
int c = bdim[b_nd - 1];
for (j = b_nd - 1; j > 1; j--)
bdim[j] = bdim[j - 1];
bdim[1] = c;
c = bstride[3];
bstride[3] = bstride[2];
bstride[2] = bstride[1];
c = bstride[b_nd - 1];
for (j = b_nd - 1; j > 1; j--)
bstride[j] = bstride[j - 1];
bstride[1] = c;
} else {
assert(bt.info.format == CCV_TENSOR_FORMAT_CHWN);
int n = bdim[3];
bdim[3] = bdim[2];
bdim[2] = bdim[1];
bdim[1] = bdim[0];
int n = bdim[b_nd - 1];
for (j = b_nd - 1; j > 0; j--)
bdim[j] = bdim[j - 1];
bdim[0] = n;
n = bstride[3];
bstride[3] = bstride[2];
bstride[2] = bstride[1];
bstride[1] = bstride[0];
n = bstride[b_nd - 1];
for (j = b_nd - 1; j > 0; j--)
bstride[j] = bstride[j - 1];
bstride[0] = n;
}
} else if (a->info.format == CCV_TENSOR_FORMAT_CHWN) {
if (bt.info.format == CCV_TENSOR_FORMAT_NCHW)
{
int n = bdim[0];
bdim[0] = bdim[1];
bdim[1] = bdim[2];
bdim[2] = bdim[3];
bdim[3] = n;
for (j = 0; j < b_nd - 1; j++)
bdim[j] = bdim[j + 1];
bdim[b_nd - 1] = n;
n = bstride[0];
bstride[0] = bstride[1];
bstride[1] = bstride[2];
bstride[2] = bstride[3];
bstride[3] = n;
for (j = 0; j < b_nd - 1; j++)
bstride[j] = bstride[j + 1];
bstride[b_nd - 1] = n;
} else {
assert(bt.info.format == CCV_TENSOR_FORMAT_NHWC);
int t;
CCV_SWAP(bdim[0], bdim[3], t);
CCV_SWAP(bstride[0], bstride[3], t);
CCV_SWAP(bdim[0], bdim[b_nd - 1], t);
CCV_SWAP(bstride[0], bstride[b_nd - 1], t);
}
}
// Mark this as tensor view as we changed its stride and dim.
Expand Down
Loading

0 comments on commit 19f4960

Please sign in to comment.