Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal: Improve reduce and softmax kernels #1614

Closed
wants to merge 6 commits into from

Conversation

ivarflakstad
Copy link
Member

@ivarflakstad ivarflakstad commented Jan 22, 2024

Changes:
The main change is that instead of using shared memory reduction all the way down

for (uint s = block_dim / 2; s > 0; s >>= 1) {
   if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
       shared_indices[tid] = shared_indices[tid + s];
       shared_memory[tid] = shared_memory[tid + s];
   }
   threadgroup_barrier(mem_flags::mem_none);
}

we are now doing shared memory reduction until we can safely use faster techniques (like simd):

#pragma clang loop unroll(full)
for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
    if (tid < s) {
        shared[tid] = op(shared[tid], shared[tid + s]);
    }
    threadgroup_barrier(mem_flags::mem_none);
}

if (tid < 32) {
    // Last shared memory reduce can be done without tid < s check.
    if (BLOCKSIZE >= 64) {
        shared[tid] = op(shared[tid], shared[tid + 32]);
        simdgroup_barrier(mem_flags::mem_none);
    }
    // Remaining 32 threads can be reduced with simdgroup_reduce.
    shared[tid] = simdgroup_reduce<ReductionOp, BLOCKSIZE>(shared[tid]);
}

Simdgroup reduce looks like this:

template<typename ReductionOp, ushort BLOCKSIZE, typename T>
METAL_FUNC T simdgroup_reduce(T value) {
    ReductionOp op;
    if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
    if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value,  8));
    if (BLOCKSIZE >=  8) value = op(value, simd_shuffle_down(value,  4));
    if (BLOCKSIZE >=  4) value = op(value, simd_shuffle_down(value,  2));
    if (BLOCKSIZE >=  2) value = op(value, simd_shuffle_down(value,  1));
    return value;
}

You may also have noticed that instead of using block_dim we are now using the constant BLOCKSIZE. It's the same thing, but using a switch to go from dynamic to constant allows the metal compiler to create specialized (and faster) versions of the kernel. For example if 2 < BLOCKSIZE < 4 then the generated bytecode will only have the very last operation in simdgroup_reduce.
This increases compilation time as well, so I think it's time to pre-compile a candle.metallib.

From a software design perspective I tried to unify the api so there is little difference between max and arg max etc.

Remaining work:

  • Metal does not support simd_shuffle_down with bfloat. To solve this restriction I simply cast bfloat -> float, do the shuffle down and then cast float -> bfloat again. It is probably faster to use a volatile threadgroup memory reduction without barriers. It was not. The cast + shuffle technique is faster.

  • Load more than one element per thread. Should be relatively simple.

Results:
Phi 1.5 has gone from 18.2 t/s to 19.3 t/s on my M1 pro.

f32 and f16 have increased throughput of softmax by 60% and reduce 60-68%.

For bfloat the increase is 55% and 45% respectively, because metal does not have simd_shuffle_down support for bfloat yet.

Benchmarks:
Softmax
Reduce
Arg reduce

@ivarflakstad ivarflakstad self-assigned this Jan 22, 2024
@ivarflakstad ivarflakstad force-pushed the ivarflakstad/metal-reduce-2 branch from 884ef6e to 077e781 Compare January 22, 2024 20:25
@Narsil
Copy link
Collaborator

Narsil commented Jan 23, 2024

Hi thanks for this work.

Could you provide a high level overview of WHAT made it faster and why?

I couldn't run it, because this crashes on M3 (as in real crash, everything hangs and nothing responds.

My attempt:

cargo run --example phi --release  --features metal  -- --prompt "What is Deep Learning?" --sample-len 10  --model 1.5

Could you provide more insights into your bench, and what it actually gives out (ideally both the pure softmax op, I'm guessing 30/40% but it'd be nice to have ideas of the throughputs AND a real model use case scenario, where it's likely to be a slimmer increase but still).

Overall I find the code much less readable, with lots of spread macros/functions doing a different part of the work, but being tied together subtly, which is a recipe to create bugs imho.
For instance here, since it crashes the entire computer it's kind of hard to figure out where (especially in the macros).
#pragma unroll should be available I think if it's only about unrolling loops for instance.

Edit for French bias: Sorry, I'm French, I point out what's wrong by default. It's actually super nice to potentially get 30/40% throughput, as this one is a regular bottleneck for performance, it could really yield something in real settings !

@ivarflakstad
Copy link
Member Author

ivarflakstad commented Jan 23, 2024

haha no worries. Norwegians can be fairly direct as well ;)
And finding out it's broken is pretty gosh darn important hehe.

I actually found the code easier to read after introducing templates (nice to get rid of all the \s), but that's probably just because I wrote it.
I have some improvements coming in already which will 1. make it not broken and 2. improve readability.

Yes, #pragma unroll is available. You can see I tried a variant in get_strided_index with #pragma clang loop unroll(full).

benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
//benchmarks::affine::benches,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's bring these back before merging

a.argmin(2).unwrap();
}

// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

// Metal does not have simd_shuffle_down for bfloat16
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat
bfloat simd_shuffle_down(bfloat value, ushort delta) {
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised they don't have simd_shuffle_down for bf16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

myself & @ivarflakstad dug into it - they don't!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And bfloat -> float simd shuffle -> bfloat is faster than the alternative method 🙃

@ivarflakstad
Copy link
Member Author

Will continue work in #1819

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants