-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal: Improved reduce and softmax #1819
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Christopher Fleetwood <[email protected]>
…c which may be suboptimal.
b6d251e
to
8186cee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @EricLBuehler for vis
Looks pretty good indeed, @ivarflakstad what is required before this can be merged / any reason while it's still marked as draft ? (also you may want to include the latest changes to main as there are some conflicting files) |
Well I guess it's ready now. If anyone wants to test it on some other models etc feel free 🙇 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is just a copy of the old reduce.metal
file and is not used, let's just remove the file and we can always get it from the git history.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How related is this to the reduce/softmax changes? If it's somewhat orthogonal, maybe this should be in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is only related in the sense that I wanted to see if I could further improve performance.
With a previous version of this PR it was a noticeable improvement, but I just tested and with the current kernels it's ~1% difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words not needed at the moment
This looks amazing @ivarflakstad! I tried Llama 3.2 3b on current
But on this branch (e8499c8), I'm getting an error which seems to indicate some numerical issues?
|
// NOTE: Should this be removed? Softmax impls live in candle-nn. | ||
fn softmax(a: &Tensor) -> candle_core::Result<()> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LaurentMazare
I'll have to remove this bench since this is metal specific (softmax lives in candle-nn ops so can't call it from candle-core), but I think softmax warrants a benchmark somehow. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think moving the benchmark to candle-nn
would be good, (do it with git mv
so as to preserve history).
Nice, thanks! |
build.rs
Improvements in throughput on my machine (150GiB/s) using f32 ops