Skip to content

Commit

Permalink
Add mfa_cast & mfa_add impl.
Browse files Browse the repository at this point in the history
This added mfa_add which avoid MPS two step approach for tensors with
offsets (+ and then export).

Added an implementation of cast but turned off by default.
  • Loading branch information
liuliu committed Aug 3, 2024
1 parent 410d77e commit 1b568f4
Show file tree
Hide file tree
Showing 10 changed files with 728 additions and 105 deletions.
269 changes: 181 additions & 88 deletions lib/nnc/cmd/blas/mps/ccv_nnc_add_mps.m

Large diffs are not rendered by default.

125 changes: 110 additions & 15 deletions lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,125 @@ static int _ccv_nnc_datatype_conversion(const ccv_nnc_cmd_t cmd, const ccv_nnc_h
assert(output_size <= input_size);
int i;
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
for (i = 0; i < output_size; i++)
bool use_mfa = false;
const char *fallback_reason = NULL;
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();

if (!ccv_nnc_mfa_context_supported(context) || (ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION)) {
use_mfa = false;
fallback_reason = "Disabled.";
}
for (i = 0; i < output_size && use_mfa; i++)
{
const ccv_nnc_tensor_view_t* a = (ccv_nnc_tensor_view_t*)inputs[i];
ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[i];
assert(a != b); // Cannot do inplace transform.
assert(a->info.format == b->info.format);
assert(CCV_TENSOR_GET_DEVICE_ID(a->info.type) == CCV_TENSOR_GET_DEVICE_ID(b->info.type));
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
if (CCV_IS_TENSOR_VIEW(a)) // Only allocate on-demand MPSGraph if a is a tensor view.

if (use_mfa) {
if (a->info.datatype != CCV_16F && a->info.datatype != CCV_32F) {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
if (b->info.datatype != CCV_16F && b->info.datatype != CCV_32F) {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}

if (use_mfa) {
if (!CCV_IS_TENSOR_CONTIGUOUS(a) || !CCV_IS_TENSOR_CONTIGUOUS(b)) {
use_mfa = false;
fallback_reason = "Strided.";
}
}
}
if (use_mfa) {
mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
for (i = 0; i < output_size; i++)
{
MPSGraph *graph = [MPSGraph new];
graph.options = MPSGraphOptionsSynchronizeResults;
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
if (mps_a != mps_input_a)
ccv_nnc_mps_graph_result(graph, command_buffer, @{mps_input_a: data_a}, mps_a, b, b->info.dim, b->stride);
else
const ccv_nnc_tensor_view_t* a = (ccv_nnc_tensor_view_t*)inputs[i];
ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[i];
uint32_t mtl_original_data_type = UINT32_MAX;
uint32_t mtl_data_type = UINT32_MAX;
if (use_mfa) {
switch (a->info.datatype) {
case CCV_16F: {
mtl_original_data_type = 16;
break;
}
case CCV_32F: {
mtl_original_data_type = 3;
break;
}
default: {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}
switch (b->info.datatype) {
case CCV_16F: {
mtl_data_type = 16;
break;
}
case CCV_32F: {
mtl_data_type = 3;
break;
}
default: {
use_mfa = false;
fallback_reason = "Unsupported data type.";
break;
}
}
}
const size_t length = ccv_nnc_tensor_count(a->info);
ccv_nnc_mfa_cast_params_t params = {
.original_data_type = mtl_original_data_type,
.data_type = mtl_data_type,
.length = (uint32_t)length,
};
ccv_nnc_mfa_prepare_cast(context, params);

mtl_buffer_t* tensors[3] = {
mpgetbuffer(inputs[i]), // gradient
mpgetbuffer(outputs[i]), // destination
NULL
};
size_t tensor_offsets[2] = {
a->dataof,
b->dataof
};
ccv_nnc_mfa_encode_cast(context, params, command_batch, tensors, tensor_offsets);
}
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
for (i = 0; i < output_size; i++)
{
const ccv_nnc_tensor_view_t* a = (ccv_nnc_tensor_view_t*)inputs[i];
ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[i];
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
if (CCV_IS_TENSOR_VIEW(a)) // Only allocate on-demand MPSGraph if a is a tensor view.
{
MPSGraph *graph = [MPSGraph new];
graph.options = MPSGraphOptionsSynchronizeResults;
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
if (mps_a != mps_input_a)
ccv_nnc_mps_graph_result(graph, command_buffer, @{mps_input_a: data_a}, mps_a, b, b->info.dim, b->stride);
else
ccv_nnc_mps_export_data(data_a, command_buffer, b, b->info.dim, b->stride);
[graph release];
} else
ccv_nnc_mps_export_data(data_a, command_buffer, b, b->info.dim, b->stride);
[graph release];
} else
ccv_nnc_mps_export_data(data_a, command_buffer, b, b->info.dim, b->stride);
}
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
return CCV_NNC_EXEC_SUCCESS;
}
Expand Down
12 changes: 12 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ void mfa::cache<mfa::gemv::hash, mfa::gemv::pipeline>::prepare(mfa::context* con
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::cast::hash, mfa::cast::pipeline>::prepare(mfa::context* context, mfa::cast::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

template <>
void mfa::cache<mfa::add::hash, mfa::add::pipeline>::prepare(mfa::context* context, mfa::add::hash hash)
{
_mfa_cache_prepare(&map, context, hash);
}

mfa::context::context(MTL::Device* device)
{
auto* pool = NS::AutoreleasePool::alloc()->init();
Expand Down
4 changes: 4 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "ccv_nnc_mfa_adam.hpp"
#include "ccv_nnc_mfa_cmul.hpp"
#include "ccv_nnc_mfa_gemv.hpp"
#include "ccv_nnc_mfa_cast.hpp"
#include "ccv_nnc_mfa_add.hpp"

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
Expand Down Expand Up @@ -52,6 +54,8 @@ class context {
cache<adam::hash, adam::pipeline> adam_cache;
cache<cmul::hash, cmul::pipeline> cmul_cache;
cache<gemv::hash, gemv::pipeline> gemv_cache;
cache<cast::hash, cast::pipeline> cast_cache;
cache<add::hash, add::pipeline> add_cache;

MTL::Buffer* request_scratch(uint64_t size);
};
Expand Down
141 changes: 141 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_add.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "ccv_nnc_mfa.hpp"
#include "ccv_nnc_mfa_hash.hpp"
#include <simd/simd.h>
using namespace ccv::nnc;

#include <string>

// MARK: - C

void ccv_nnc_mfa_prepare_add(mfa::context* context, ccv_nnc_mfa_add_params_t params)
{
context->add_cache.prepare(context, mfa::add::hash(params));
}

void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets)
{
mfa::add::hash hash(params);
auto iterator = context->add_cache.map.find(hash);
if (iterator == context->add_cache.map.end()) {
mfa::precondition_failure("add hash not cached.", __LINE__, __FILE__, __FUNCTION__);
}

auto* pipeline = iterator->second;
auto encoder = command_batch->startCommand();

int num_tensors = 0;
while (tensors[num_tensors] != nullptr) {
encoder->setBuffer(tensors[num_tensors], tensor_offsets[num_tensors], NS::UInteger(num_tensors));
num_tensors += 1;
}
CCV_NNC_MFA_PRECONDITION(num_tensors == 3);

encoder->setComputePipelineState(pipeline->add_pso.get());
encoder->useResource(tensors[0], MTL::ResourceUsageRead);
encoder->useResource(tensors[1], MTL::ResourceUsageRead);
encoder->useResource(tensors[2], MTL::ResourceUsageWrite);

auto grid_size = pipeline->grid_size;
CCV_NNC_MFA_PRECONDITION(grid_size.depth > 0);
encoder->dispatchThreadgroups(grid_size, pipeline->group_size);
command_batch->finishCommand(encoder);
}

// MARK: - C++

mfa::add::hash::hash(ccv_nnc_mfa_add_params_t params) {
data_type = params.data_type;
length = params.length;
}

bool mfa::add::hash::operator==(const mfa::add::hash& hash) const {
return (data_type == hash.data_type) && (length == hash.length);
}

std::ostream& operator<<(std::ostream& os, const mfa::add::hash& hash) {
os << "mfa::add::hash {";
os << " .data_type = " << hash.data_type << ',';
os << " .length = " << hash.length << " ";
os << "}";
return os;
}

std::size_t std::hash<mfa::add::hash>::operator()(const mfa::add::hash& hash) const noexcept {
std::size_t seed = 0;
using namespace mfa::hash;
combine_64(seed, hash.data_type);
combine_32(seed, hash.length);
return seed;
}

mfa::add::pipeline::pipeline(mfa::context* context, mfa::add::hash hash) {
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))

auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void add(
device const real4 *src0 [[buffer(0)]],
device const real4 *src1 [[buffer(1)]],
device real4 *dst [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= count)
return;
dst[idx] = src0[idx] + src1[idx];
}
)";

std::string defines = "";
if (hash.data_type == MTL::DataTypeFloat) {
defines += std::string("typedef float4 real4;");
defines += "\n";
} else {
defines += std::string("typedef half4 real4;");
defines += "\n";
}

defines += "constant uint count = ";
CCV_NNC_MFA_PRECONDITION(hash.length % 4 == 0)
const unsigned int count = hash.length / 4;
defines += std::to_string(count) + ";";
defines += "\n";
this->group_size = MTL::Size(256, 1, 1);
const int num_blocks = (count + 255) / 256;
this->grid_size = MTL::Size(num_blocks, 1, 1);

auto constants = NS::TransferPtr(MTL::FunctionConstantValues::alloc()->init());
NS::SharedPtr<MTL::ComputePipelineState>* pso = &add_pso;

std::string source = defines;
if (METAL_LOG_LEVEL(context) >= 4) {
std::cerr << source << std::endl;
}
source += shader;

NS::Error *error = nullptr;
auto swift_source = NS::String::string(source.c_str(),
NS::UTF8StringEncoding);
auto library = NS::TransferPtr(context->device->newLibrary(swift_source, nullptr, &error));
if (!library) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

auto swift_name = NS::String::string("add", NS::UTF8StringEncoding);
auto function = NS::TransferPtr(library->newFunction(swift_name, constants.get(), &error));
if (!function) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

*pso = NS::TransferPtr(context->device->newComputePipelineState(function.get(), &error));
if (!*pso) {
CCV_NNC_MFA_CHECK_ERROR(error)
}

pool->drain();
}
62 changes: 62 additions & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_add.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef GUARD_ccv_nnc_mfa_add_hpp
#define GUARD_ccv_nnc_mfa_add_hpp

typedef struct {
uint64_t data_type;
uint32_t length;
} ccv_nnc_mfa_add_params_t;

#ifdef __cplusplus
#include "nnc/mfa/3rdparty/metal-cpp/Dispatch.hpp"
#include "nnc/mfa/3rdparty/metal-cpp/Metal.hpp"
#include <simd/simd.h>

namespace ccv {
namespace nnc {
namespace mfa {
namespace add {

class hash {
public:
uint64_t data_type;
uint32_t length;

hash(ccv_nnc_mfa_add_params_t);

bool operator==(const hash& rhs) const;
};

class pipeline {
public:
NS::SharedPtr<MTL::ComputePipelineState> add_pso;

MTL::Size grid_size;
MTL::Size group_size;

pipeline(context* context, hash hash);
};

} // namespace add
} // namespace mfa
} // namespace nnc
} // namespace ccv

std::ostream& operator<<(std::ostream& os, const ccv::nnc::mfa::add::hash& hash);

template<>
struct std::hash<ccv::nnc::mfa::add::hash>
{
std::size_t operator()(const ccv::nnc::mfa::add::hash& hash) const noexcept;
};

extern "C" {
#endif // __cplusplus

void ccv_nnc_mfa_prepare_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_params_t params);
void ccv_nnc_mfa_encode_add(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_add_params_t params, mtl_command_batch_t* command_batch, mtl_buffer_t** tensors, size_t* tensor_offsets);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif
Loading

0 comments on commit 1b568f4

Please sign in to comment.