-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal: Improved reduce and softmax #1819
Open
ivarflakstad
wants to merge
38
commits into
main
Choose a base branch
from
ivarflakstad/metal-reduce-3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,502
−334
Open
Changes from 30 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
d590284
Improve reduce perf and add contiguous impl
ivarflakstad 1f4c544
Improve arg reduce and add contiguous impl
ivarflakstad 2056866
Improve softmax kernel. 33%-39% higher thrpt
ivarflakstad 086b6ef
Merge branch 'main' into ivarflakstad/metal-reduce-2
ivarflakstad 077e781
fmt
ivarflakstad 8babfe0
Fixed all bugs. Improved code quality. Added tests.
ivarflakstad cf92c96
Stash for debugging
ivarflakstad 3bc4fcb
Stash for debugging 2
ivarflakstad 2ee1a0c
Fixing argmax bug and improve performance
ivarflakstad a3fc6c4
Fix test and add is_valid_simgroup_reduce_type trait
ivarflakstad db08d66
Online softmax. Improved threadgroup reduce. Tidying up a bit.
ivarflakstad 1f63401
Remove redundant threadgroup_barrier from arg reduce
ivarflakstad dec2db6
Mostly tidying up. Some improvements
ivarflakstad 237369c
Simplify indexed struct
ivarflakstad bcdbcd1
tidying
ivarflakstad 286598a
Reuse operation operator instead of passing it in as a parameter
ivarflakstad 14652a6
Fix how operators are applied to indexed<vec<T,N>>
ivarflakstad c0d80b4
Vectorized load. Scalar block reduce. Hitting max throughput for f32 …
ivarflakstad 2c78cab
Vectorized load for online softmax. Involves a reinterpret_cast of sr…
ivarflakstad 847b34f
Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and …
ivarflakstad 70dd35f
Use constant for input instead of const device. Fix strided reduce.
ivarflakstad 8ca2c03
Use contiguous reduce in tests
ivarflakstad a3a5994
Rename finalize -> to_scalar
ivarflakstad 9b54e41
Support integer types max/min (switch with trait-inferred impl later)
ivarflakstad 3633135
Was worried I was skipping work -> shuffling the 1D test cases
ivarflakstad a9e26b5
Add build.rs to avoid metal kernel jit compile overhead
ivarflakstad 8186cee
Improve build. Extract utils
ivarflakstad 568b543
Compile metal kernels for both macos and ios
ivarflakstad afa1ea1
Merge branch 'main' into ivarflakstad/metal-reduce-3
ivarflakstad c3e4da7
Fixed over xmas and then forgot about it
ivarflakstad 5c59669
Merge branch 'main' into ivarflakstad/metal-reduce-3
ivarflakstad bcbbad6
Add calculate_reduce_threads util
ivarflakstad e8499c8
Remove old reduce.metal
ivarflakstad 4c94925
Improve f16/bf16 softmax precision by accumulating in f32
ivarflakstad eb6985e
Remove build.rs (for now)
ivarflakstad b094d09
Move softmax bench to candle-nn
ivarflakstad 14cbd5e
Remove redundant thread calc util fn
ivarflakstad c174d5b
Merge remote-tracking branch 'origin/main' into ivarflakstad/metal-re…
LaurentMazare File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; | ||
use candle_core::{DType, Device, Storage, Tensor}; | ||
use criterion::{black_box, criterion_group, Criterion, Throughput}; | ||
use half::{bf16, f16}; | ||
use std::ops::Deref; | ||
use std::time::Instant; | ||
|
||
fn run_sum(a: &Tensor) { | ||
a.sum_keepdim(2).unwrap(); | ||
} | ||
fn run_arg_min(a: &Tensor) { | ||
a.argmin_keepdim(2).unwrap(); | ||
} | ||
|
||
// NOTE: Should this be removed? Softmax impls live in candle-nn. | ||
fn softmax(a: &Tensor) -> candle_core::Result<()> { | ||
use candle_core::{backend::BackendStorage, DType}; | ||
let (storage, layout) = a.storage_and_layout(); | ||
|
||
let device = a.device(); | ||
|
||
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) { | ||
let command_buffer = device.command_buffer()?; | ||
let kernels = device.kernels(); | ||
let name = match a.dtype() { | ||
DType::F32 => "softmax_f32", | ||
DType::F16 => "softmax_f16", | ||
DType::BF16 => "softmax_bf16", | ||
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"), | ||
}; | ||
|
||
let n = layout.stride().len(); | ||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { | ||
candle_core::bail!("Non contiguous softmax-last-dim is not implemented"); | ||
} | ||
|
||
let last_dim = layout.dims()[layout.shape().rank() - 1]; | ||
let elem_count = layout.shape().elem_count(); | ||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; | ||
candle_metal_kernels::call_last_softmax( | ||
device.metal_device(), | ||
&command_buffer, | ||
kernels, | ||
name, | ||
elem_count, | ||
last_dim, | ||
storage.buffer(), | ||
layout.start_offset() * storage.dtype().size_in_bytes(), | ||
&output, | ||
) | ||
.unwrap(); | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn criterion_benchmark(c: &mut Criterion) { | ||
let handler = BenchDeviceHandler::new().unwrap(); | ||
let (lo, up) = (-1000.0f32, 1000.0f32); | ||
for device in handler.devices { | ||
run_softmax(c, &device, (lo, up)); | ||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up))); | ||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); | ||
|
||
run_reduce(c, &device, (lo, up), false); | ||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); | ||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); | ||
|
||
run_arg_reduce(c, &device, (lo, up), false); | ||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false); | ||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false); | ||
|
||
run_reduce(c, &device, (lo, up), true); | ||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); | ||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); | ||
|
||
run_arg_reduce(c, &device, (lo, up), true); | ||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true); | ||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true); | ||
} | ||
} | ||
|
||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) { | ||
if !device.is_metal() { | ||
return; | ||
} | ||
|
||
let b = 1; | ||
let m = 1024; | ||
let k = 1024; | ||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); | ||
|
||
let flops = b * m * k * T::DTYPE.size_in_bytes(); | ||
|
||
let name = match T::DTYPE { | ||
DType::F32 => "softmax_f32", | ||
DType::F16 => "softmax_f16", | ||
DType::BF16 => "softmax_bf16", | ||
_ => "softmax", | ||
}; | ||
softmax(&a).unwrap(); | ||
|
||
let mut group = c.benchmark_group(device.bench_name(name)); | ||
group.throughput(Throughput::Bytes(flops as u64)); | ||
group.bench_function("iter", move |b| { | ||
b.iter_custom(|iters| { | ||
let start = Instant::now(); | ||
for _i in 0..iters { | ||
softmax(black_box(&a)).unwrap(); | ||
} | ||
device.sync().unwrap(); | ||
start.elapsed() | ||
}) | ||
}); | ||
group.finish(); | ||
} | ||
|
||
fn run_reduce<T: candle_core::FloatDType>( | ||
c: &mut Criterion, | ||
device: &Device, | ||
(lo, up): (T, T), | ||
strided: bool, | ||
) { | ||
let b = 1; | ||
let m = 1024; | ||
let k = 1024; | ||
|
||
let a = if strided { | ||
Tensor::rand(lo, up, (b, m, k), &device) | ||
.unwrap() | ||
.transpose(0, 2) | ||
.unwrap() | ||
} else { | ||
Tensor::rand(lo, up, (b, m, k), &device).unwrap() | ||
}; | ||
|
||
let flops = b * m * k * T::DTYPE.size_in_bytes(); | ||
|
||
let name = match T::DTYPE { | ||
DType::F32 => { | ||
if strided { | ||
"reduce_f32_strided" | ||
} else { | ||
"reduce_f32" | ||
} | ||
} | ||
DType::F16 => { | ||
if strided { | ||
"reduce_f16_strided" | ||
} else { | ||
"reduce_f16" | ||
} | ||
} | ||
DType::BF16 => { | ||
if strided { | ||
"reduce_bf16_strided" | ||
} else { | ||
"reduce_bf16" | ||
} | ||
} | ||
_ => "unknown", | ||
}; | ||
|
||
let mut group = c.benchmark_group(device.bench_name(name)); | ||
group.throughput(Throughput::Bytes(flops as u64)); | ||
group.bench_function("iter", move |b| { | ||
b.iter_custom(|iters| { | ||
let start = Instant::now(); | ||
for _i in 0..iters { | ||
run_sum(black_box(&a)); | ||
} | ||
device.sync().unwrap(); | ||
start.elapsed() | ||
}) | ||
}); | ||
group.finish(); | ||
} | ||
|
||
fn run_arg_reduce<T: candle_core::FloatDType>( | ||
c: &mut Criterion, | ||
device: &Device, | ||
(lo, up): (T, T), | ||
strided: bool, | ||
) { | ||
let b = 1; | ||
let m = 1024; | ||
let k = 1024; | ||
|
||
let a = if strided { | ||
Tensor::rand(lo, up, (b, m, k), &device) | ||
.unwrap() | ||
.transpose(0, 2) | ||
.unwrap() | ||
} else { | ||
Tensor::rand(lo, up, (b, m, k), &device).unwrap() | ||
}; | ||
|
||
let flops = b * m * k * T::DTYPE.size_in_bytes(); | ||
|
||
let name = match T::DTYPE { | ||
DType::F32 => { | ||
if strided { | ||
"arg_reduce_f32_strided" | ||
} else { | ||
"arg_reduce_f32" | ||
} | ||
} | ||
DType::F16 => { | ||
if strided { | ||
"arg_reduce_f16_strided" | ||
} else { | ||
"arg_reduce_f16" | ||
} | ||
} | ||
DType::BF16 => { | ||
if strided { | ||
"arg_reduce_bf16_strided" | ||
} else { | ||
"arg_reduce_bf16" | ||
} | ||
} | ||
_ => "unknown", | ||
}; | ||
|
||
let mut group = c.benchmark_group(device.bench_name(name)); | ||
group.throughput(Throughput::Bytes(flops as u64)); | ||
group.bench_function("iter", move |b| { | ||
b.iter_custom(|iters| { | ||
let start = Instant::now(); | ||
for _i in 0..iters { | ||
run_arg_min(black_box(&a)); | ||
} | ||
device.sync().unwrap(); | ||
start.elapsed() | ||
}) | ||
}); | ||
group.finish(); | ||
} | ||
|
||
criterion_group!(benches, criterion_benchmark); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LaurentMazare
I'll have to remove this bench since this is metal specific (softmax lives in candle-nn ops so can't call it from candle-core), but I think softmax warrants a benchmark somehow. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think moving the benchmark to
candle-nn
would be good, (do it withgit mv
so as to preserve history).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved. Moving only part of a file with git alone seemed more complex than it was valuable.