Skip to content

Commit

Permalink
Add kernel to dequant 5-bit palettize weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jul 22, 2024
1 parent 885188a commit 7b58fae
Showing 1 changed file with 61 additions and 10 deletions.
71 changes: 61 additions & 10 deletions lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,38 @@ mfa::depalettize::pipeline::pipeline(mfa::context* context, mfa::depalettize::ha
auto* pool = NS::AutoreleasePool::alloc()->init();

std::string shader;
if (hash.qbits == 6) {
if (hash.qbits == 5) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
kernel void depalettize(
device uchar *source [[buffer(0)]],
device real4 *destination [[buffer(1)]],
uint3 tgid [[threadgroup_position_in_grid]],
ushort lid [[thread_index_in_threadgroup]]
) {
device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 5) * tgid.y;
threadgroup real palette[palette_size];
if (lid < palette_size) {
palette[lid] = ((device real*)ui0)[lid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint x = tgid.x * threadgroup_size + lid;
device const uchar *ui1 = (device const uchar*)(ui0 + sizeof(real) * palette_size);
const uchar u0 = ui1[x * 5];
const uchar u1 = ui1[x * 5 + 1];
const uchar u2 = ui1[x * 5 + 2];
const uchar u3 = ui1[x * 5 + 3];
const uchar u4 = ui1[x * 5 + 4];
const real4 d0 = real4(palette[u0 >> 3], palette[((u0 & 7) << 2) | (u1 >> 6)], palette[(u1 >> 1) & 31], palette[((u1 & 1) << 4) | (u2 >> 4)]);
const real4 d1 = real4(palette[((u2 & 15) << 1) | (u3 >> 7)], palette[(u3 >> 2) & 31], palette[((u3 & 3) << 3) | (u4 >> 5)], palette[u4 & 31]);
destination[(number_in_blocks * tgid.y + x) * 2] = d0;
destination[(number_in_blocks * tgid.y + x) * 2 + 1] = d1;
}
)";
} else if (hash.qbits == 6) {
shader = R"(
#include <metal_stdlib>
using namespace metal;
Expand Down Expand Up @@ -175,14 +206,30 @@ kernel void depalettize(
defines += "\n";
}

uint16_t threadgroup_size = 256;
defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);
CCV_NNC_MFA_PRECONDITION(hash.qbits == 8 || hash.qbits == 6);

if (hash.qbits == 6) {
CCV_NNC_MFA_PRECONDITION(hash.qbits == 8 || hash.qbits == 6 || hash.qbits == 5);

if (hash.qbits == 5) {
const uint16_t threadgroup_size = 128;
defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);
CCV_NNC_MFA_PRECONDITION((hash.length % hash.number_in_blocks) == 0);
defines += "constant ushort palette_size = 32;\n";

defines += "constant uint number_in_blocks = ";
defines += std::to_string(hash.number_in_blocks / 8) + ";";
defines += "\n";
const int num_blocks = hash.length / hash.number_in_blocks;
CCV_NNC_MFA_PRECONDITION((hash.number_in_blocks % (128 * 8)) == 0);
const int repeat_4 = hash.number_in_blocks / (128 * 8);
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
} else if (hash.qbits == 6) {
const uint16_t threadgroup_size = 256;
defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);
CCV_NNC_MFA_PRECONDITION((hash.length % hash.number_in_blocks) == 0);
defines += "constant ushort palette_size = 64;\n";

Expand All @@ -194,8 +241,12 @@ kernel void depalettize(
const int repeat_4 = hash.number_in_blocks / (256 * 4);
this->grid_size = MTL::Size(repeat_4, num_blocks, 1);
} else if (hash.qbits == 8) {
const uint16_t threadgroup_size = 256;
defines += "constant ushort threadgroup_size = ";
defines += std::to_string(threadgroup_size) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);
defines += "constant ushort palette_size = 256;\n";

defines += "constant uint number_in_blocks = ";
defines += std::to_string(hash.number_in_blocks / 4) + ";";
defines += "\n";
Expand Down

0 comments on commit 7b58fae

Please sign in to comment.