-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal: Improve reduce and softmax kernels #1614
Conversation
884ef6e
to
077e781
Compare
Hi thanks for this work. Could you provide a high level overview of WHAT made it faster and why? I couldn't run it, because this crashes on M3 (as in real crash, everything hangs and nothing responds. My attempt:
Could you provide more insights into your bench, and what it actually gives out (ideally both the pure softmax op, I'm guessing 30/40% but it'd be nice to have ideas of the throughputs AND a real model use case scenario, where it's likely to be a slimmer increase but still). Overall I find the code much less readable, with lots of spread macros/functions doing a different part of the work, but being tied together subtly, which is a recipe to create bugs imho. Edit for French bias: Sorry, I'm French, I point out what's wrong by default. It's actually super nice to potentially get 30/40% throughput, as this one is a regular bottleneck for performance, it could really yield something in real settings ! |
haha no worries. Norwegians can be fairly direct as well ;) I actually found the code easier to read after introducing templates (nice to get rid of all the Yes, |
benchmarks::matmul::benches, | ||
benchmarks::random::benches, | ||
benchmarks::where_cond::benches | ||
//benchmarks::affine::benches, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's bring these back before merging
a.argmin(2).unwrap(); | ||
} | ||
|
||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀
// Metal does not have simd_shuffle_down for bfloat16 | ||
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat | ||
bfloat simd_shuffle_down(bfloat value, ushort delta) { | ||
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised they don't have simd_shuffle_down
for bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
myself & @ivarflakstad dug into it - they don't!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And bfloat -> float simd shuffle -> bfloat is faster than the alternative method 🙃
Will continue work in #1819 |
Changes:
The main change is that instead of using shared memory reduction all the way down
we are now doing shared memory reduction until we can safely use faster techniques (like simd):
Simdgroup reduce looks like this:
You may also have noticed that instead of using
block_dim
we are now using the constantBLOCKSIZE
. It's the same thing, but using a switch to go from dynamic to constant allows the metal compiler to create specialized (and faster) versions of the kernel. For example if2 < BLOCKSIZE < 4
then the generated bytecode will only have the very last operation insimdgroup_reduce
.This increases compilation time as well, so I think it's time to pre-compile a
candle.metallib
.From a software design perspective I tried to unify the api so there is little difference between max and arg max etc.
Remaining work:
Metal does not support
simd_shuffle_down
with bfloat. To solve this restriction I simply cast bfloat -> float, do the shuffle down and then cast float -> bfloat again.It is probably faster to use a volatile threadgroup memory reduction without barriers.It was not. The cast + shuffle technique is faster.Load more than one element per thread. Should be relatively simple.
Results:
Phi 1.5 has gone from 18.2 t/s to 19.3 t/s on my M1 pro.
f32 and f16 have increased throughput of softmax by 60% and reduce 60-68%.
For bfloat the increase is 55% and 45% respectively, because metal does not have
simd_shuffle_down
support for bfloat yet.Benchmarks: