Skip to content

Commit

Permalink
Add cmul in Metal.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 31, 2023
1 parent 7f8e9e8 commit 378ae93
Show file tree
Hide file tree
Showing 11 changed files with 1,251 additions and 676 deletions.
2 changes: 1 addition & 1 deletion lib/nnc/cmd/blas/ccv_nnc_blas.c
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ static int _ccv_nnc_cmul_back_bitmask(const ccv_nnc_cmd_param_t cmd, const int i
}

REGISTER_COMMAND(CCV_NNC_CMUL_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c, gpu/ccv_nnc_cmul_gpu_ref.cu)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c, gpu/ccv_nnc_cmul_gpu_ref.cu, mps/ccv_nnc_cmul_mps.m)
{
registry->bitmask = _ccv_nnc_cmul_forw_bitmask;
registry->tensor_auto = _ccv_nnc_broadcast_tensor_auto_forw;
Expand Down
10 changes: 5 additions & 5 deletions lib/nnc/cmd/blas/gpu/ccv_nnc_cmul_gpu_ref.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ __global__ void _ccv_nnc_cmul_kernel(const size_t count, const NUM1* const a, co
}

template<typename NUM1, typename NUM2, typename NUM3>
__global__ void _ccv_nnc_cmul_kernel_4d_0(const int astride2, const int astride1, const int astride0, const int bstride2, const int bstride1, const int bstride0, const int cstride2, const int cstride1, const int cstride0, const int dim2, const int dim1, const int dim0, const NUM1* const a, const NUM2* const b, NUM3* const c)
__global__ void _ccv_nnc_cmul_kernel_4d_0(const int astride2, const int astride1, const int astride0, const int bstride2, const int bstride1, const int bstride0, const int cstride2, const int cstride1, const int cstride0, const int dim1, const int dim0, const NUM1* const a, const NUM2* const b, NUM3* const c)
{
const int z = blockIdx.z * blockDim.z + threadIdx.z;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -161,13 +161,13 @@ static int _ccv_nnc_cmul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
{
if (a->info.datatype == CCV_32F && c->info.datatype == CCV_32F)
{
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, c->data.f32);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, c->data.f32);
} else if (a->info.datatype == CCV_32F && c->info.datatype == CCV_16F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, (__half*)c->data.f16);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, a->data.f32, b->data.f32, (__half*)c->data.f16);
} else if (a->info.datatype == CCV_16F && c->info.datatype == CCV_32F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, c->data.f32);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, c->data.f32);
} else if (a->info.datatype == CCV_16F && c->info.datatype == CCV_16F) {
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[2], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, (__half*)c->data.f16);
_ccv_nnc_cmul_kernel_4d_0<<<dim3((cdim[0] / 2 + 63) / 64, (cdim[1] + 7) / 8, cdim[2] * cdim[3]), dim3(64, 8, 1), 0, stream>>>(astride[3], astride[2], astride[1], bstride[3], bstride[2], bstride[1], cstride[3], cstride[2], cstride[1], cdim[1], cdim[0] / 2, (__half*)a->data.f16, (__half*)b->data.f16, (__half*)c->data.f16);
}
} else if (nd == 3) {
if (a->info.datatype == CCV_32F && c->info.datatype == CCV_32F)
Expand Down
168 changes: 168 additions & 0 deletions lib/nnc/cmd/blas/mps/ccv_nnc_cmul_mps.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#include <ccv.h>
#include <ccv_internal.h>
#include <nnc/ccv_nnc.h>
#include <nnc/ccv_nnc_easy.h>
#include <nnc/ccv_nnc_internal.h>
#include <nnc/mps/ccv_nnc_mps.h>

static int _ccv_nnc_cmul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
{
assert(input_size == 2);
const ccv_nnc_tensor_t* const a = inputs[0];
assert(CCV_IS_TENSOR_CONTIGUOUS(a));
const ccv_nnc_tensor_t* const b = inputs[1];
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
assert(output_size == 1);
ccv_nnc_tensor_t* const c = outputs[0];
assert(CCV_IS_TENSOR_CONTIGUOUS(c));
const size_t count = ccv_nnc_tensor_count(c->info) / 2;
@autoreleasepool {
bool use_mfa = true;
const char *fallback_reason = NULL;
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) {
use_mfa = false;
fallback_reason = "Disabled.";
}

uint32_t mtl_data_type = UINT32_MAX;
if (use_mfa) {
const int is_same_dtype =
(a->info.datatype == b->info.datatype) &&
(a->info.datatype == c->info.datatype);
if (!is_same_dtype) {
use_mfa = false;
fallback_reason = "Mixed precision.";
}

switch (a->info.datatype) {
case CCV_16F: {
mtl_data_type = 16;
break;
}
case CCV_32F: {
mtl_data_type = 3;
break;
}
default: {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}
}

if (use_mfa) {
if (!CCV_IS_TENSOR_CONTIGUOUS(a) ||
!CCV_IS_TENSOR_CONTIGUOUS(b) ||
!CCV_IS_TENSOR_CONTIGUOUS(c))
{
use_mfa = false;
fallback_reason = "Strided.";
}
}
if (use_mfa) {
ccv_nnc_mfa_cmul_params_t params = {
.data_type = mtl_data_type,
.astride = {0, 0, 0},
.bstride = {0, 0, 0},
.cstride = {0, 0, 0},
.dim = {0, 0, 0, 0}
};
const size_t count = ccv_nnc_tensor_count(c->info);
if (ccv_nnc_tensor_count(a->info) == count && ccv_nnc_tensor_count(b->info) == count) {
params.dim[0] = count;
} else {
int i;
int nd = ccv_nnc_tensor_nd(a->info.dim);
assert(nd = ccv_nnc_tensor_nd(b->info.dim));
assert(nd = ccv_nnc_tensor_nd(c->info.dim));
int adim[CCV_NNC_MAX_DIM_ALLOC];
int bdim[CCV_NNC_MAX_DIM_ALLOC];
int cdim[CCV_NNC_MAX_DIM_ALLOC];
int squeezed_dims = 0;
for (i = nd - 1; i >= 0; i--)
{
if (c->info.dim[i] == 1)
continue;
adim[squeezed_dims] = a->info.dim[i];
bdim[squeezed_dims] = b->info.dim[i];
cdim[squeezed_dims] = c->info.dim[i];
squeezed_dims += 1;
}
nd = squeezed_dims;
int astride[CCV_NNC_MAX_DIM_ALLOC];
int bstride[CCV_NNC_MAX_DIM_ALLOC];
int cstride[CCV_NNC_MAX_DIM_ALLOC];
astride[0] = 1;
bstride[0] = 1;
cstride[0] = 1;
for (i = 1; i < nd; i++)
{
astride[i] = adim[i - 1] * astride[i - 1];
bstride[i] = bdim[i - 1] * bstride[i - 1];
cstride[i] = cdim[i - 1] * cstride[i - 1];
}
for (i = 0; i < nd; i++)
{
if (cdim[i] == adim[i] && cdim[i] == bdim[i])
continue;
if (cdim[i] == adim[i])
{
assert(bdim[i] == 1);
bstride[i] = 0;
} else {
assert(cdim[i] == bdim[i]);
assert(adim[i] == 1);
astride[i] = 0;
}
}
assert(nd <= 4);
params.dim[0] = cdim[0];
params.dim[1] = cdim[1];
params.dim[2] = cdim[2];
params.dim[3] = cdim[3];
for (i = nd; i < 4; i++)
params.dim[i] = 0;
params.astride[0] = astride[1];
params.astride[1] = astride[2];
params.astride[2] = astride[3];
params.bstride[0] = bstride[1];
params.bstride[1] = bstride[2];
params.bstride[2] = bstride[3];
params.cstride[0] = cstride[1];
params.cstride[1] = cstride[2];
params.cstride[2] = cstride[3];
}
ccv_nnc_mfa_prepare_cmul(context, params);

mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
mtl_buffer_t* tensors[4] = {
mpgetbuffer(inputs[0]), // gradient
mpgetbuffer(inputs[1]), // source
mpgetbuffer(outputs[0]), // destination
NULL,
};
size_t tensor_offsets[3] = {
a->dataof,
b->dataof,
c->dataof
};
ccv_nnc_mfa_encode_cmul(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
assert(0);
}
}
return CCV_NNC_EXEC_SUCCESS;
}

REGISTER_COMMAND_BACKEND(CCV_NNC_CMUL_FORWARD, CCV_NNC_BACKEND_MPS)(ccv_nnc_cmd_backend_registry_t* const registry)
{
registry->tensor_formats = CCV_TENSOR_FORMAT_NHWC | CCV_TENSOR_FORMAT_NCHW | CCV_TENSOR_FORMAT_CHWN;
registry->tensor_datatypes = CCV_32F | CCV_16F;
registry->tensor_memory = CCV_TENSOR_GPU_MEMORY;
registry->algorithms = 1;
registry->exec = _ccv_nnc_cmul_forw;
}
Loading

0 comments on commit 378ae93

Please sign in to comment.