diff --git a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp index f615e0777..0b0dc5626 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp @@ -84,7 +84,38 @@ mfa::depalettize::pipeline::pipeline(mfa::context* context, mfa::depalettize::ha auto* pool = NS::AutoreleasePool::alloc()->init(); std::string shader; - if (hash.qbits == 6) { + if (hash.qbits == 5) { + shader = R"( +#include +using namespace metal; + +kernel void depalettize( + device uchar *source [[buffer(0)]], + device real4 *destination [[buffer(1)]], + + uint3 tgid [[threadgroup_position_in_grid]], + ushort lid [[thread_index_in_threadgroup]] +) { + device const uchar *ui0 = source + (sizeof(real) * palette_size + number_in_blocks * 5) * tgid.y; + threadgroup real palette[palette_size]; + if (lid < palette_size) { + palette[lid] = ((device real*)ui0)[lid]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const uint x = tgid.x * threadgroup_size + lid; + device const uchar *ui1 = (device const uchar*)(ui0 + sizeof(real) * palette_size); + const uchar u0 = ui1[x * 5]; + const uchar u1 = ui1[x * 5 + 1]; + const uchar u2 = ui1[x * 5 + 2]; + const uchar u3 = ui1[x * 5 + 3]; + const uchar u4 = ui1[x * 5 + 4]; + const real4 d0 = real4(palette[u0 >> 3], palette[((u0 & 7) << 2) | (u1 >> 6)], palette[(u1 >> 1) & 31], palette[((u1 & 1) << 4) | (u2 >> 4)]); + const real4 d1 = real4(palette[((u2 & 15) << 1) | (u3 >> 7)], palette[(u3 >> 2) & 31], palette[((u3 & 3) << 3) | (u4 >> 5)], palette[u4 & 31]); + destination[(number_in_blocks * tgid.y + x) * 2] = d0; + destination[(number_in_blocks * tgid.y + x) * 2 + 1] = d1; +} + )"; + } else if (hash.qbits == 6) { shader = R"( #include using namespace metal; @@ -175,14 +206,30 @@ kernel void depalettize( defines += "\n"; } - uint16_t threadgroup_size = 256; - defines += "constant ushort threadgroup_size = "; - defines += std::to_string(threadgroup_size) + ";"; - defines += "\n"; - this->group_size = MTL::Size(threadgroup_size, 1, 1); - CCV_NNC_MFA_PRECONDITION(hash.qbits == 8 || hash.qbits == 6); - - if (hash.qbits == 6) { + CCV_NNC_MFA_PRECONDITION(hash.qbits == 8 || hash.qbits == 6 || hash.qbits == 5); + + if (hash.qbits == 5) { + const uint16_t threadgroup_size = 128; + defines += "constant ushort threadgroup_size = "; + defines += std::to_string(threadgroup_size) + ";"; + defines += "\n"; + this->group_size = MTL::Size(threadgroup_size, 1, 1); + CCV_NNC_MFA_PRECONDITION((hash.length % hash.number_in_blocks) == 0); + defines += "constant ushort palette_size = 32;\n"; + + defines += "constant uint number_in_blocks = "; + defines += std::to_string(hash.number_in_blocks / 8) + ";"; + defines += "\n"; + const int num_blocks = hash.length / hash.number_in_blocks; + CCV_NNC_MFA_PRECONDITION((hash.number_in_blocks % (128 * 8)) == 0); + const int repeat_4 = hash.number_in_blocks / (128 * 8); + this->grid_size = MTL::Size(repeat_4, num_blocks, 1); + } else if (hash.qbits == 6) { + const uint16_t threadgroup_size = 256; + defines += "constant ushort threadgroup_size = "; + defines += std::to_string(threadgroup_size) + ";"; + defines += "\n"; + this->group_size = MTL::Size(threadgroup_size, 1, 1); CCV_NNC_MFA_PRECONDITION((hash.length % hash.number_in_blocks) == 0); defines += "constant ushort palette_size = 64;\n"; @@ -194,8 +241,12 @@ kernel void depalettize( const int repeat_4 = hash.number_in_blocks / (256 * 4); this->grid_size = MTL::Size(repeat_4, num_blocks, 1); } else if (hash.qbits == 8) { + const uint16_t threadgroup_size = 256; + defines += "constant ushort threadgroup_size = "; + defines += std::to_string(threadgroup_size) + ";"; + defines += "\n"; + this->group_size = MTL::Size(threadgroup_size, 1, 1); defines += "constant ushort palette_size = 256;\n"; - defines += "constant uint number_in_blocks = "; defines += std::to_string(hash.number_in_blocks / 4) + ";"; defines += "\n";