Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal: Improved reduce and softmax #1819

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d590284
Improve reduce perf and add contiguous impl
ivarflakstad Jan 21, 2024
1f4c544
Improve arg reduce and add contiguous impl
ivarflakstad Jan 21, 2024
2056866
Improve softmax kernel. 33%-39% higher thrpt
ivarflakstad Jan 22, 2024
086b6ef
Merge branch 'main' into ivarflakstad/metal-reduce-2
ivarflakstad Jan 22, 2024
077e781
fmt
ivarflakstad Jan 22, 2024
8babfe0
Fixed all bugs. Improved code quality. Added tests.
ivarflakstad Jan 30, 2024
cf92c96
Stash for debugging
ivarflakstad Feb 7, 2024
3bc4fcb
Stash for debugging 2
ivarflakstad Feb 7, 2024
2ee1a0c
Fixing argmax bug and improve performance
ivarflakstad Feb 7, 2024
a3fc6c4
Fix test and add is_valid_simgroup_reduce_type trait
ivarflakstad Feb 9, 2024
db08d66
Online softmax. Improved threadgroup reduce. Tidying up a bit.
ivarflakstad Feb 11, 2024
1f63401
Remove redundant threadgroup_barrier from arg reduce
ivarflakstad Feb 11, 2024
dec2db6
Mostly tidying up. Some improvements
ivarflakstad Feb 13, 2024
237369c
Simplify indexed struct
ivarflakstad Feb 13, 2024
bcdbcd1
tidying
ivarflakstad Feb 13, 2024
286598a
Reuse operation operator instead of passing it in as a parameter
ivarflakstad Feb 13, 2024
14652a6
Fix how operators are applied to indexed<vec<T,N>>
ivarflakstad Feb 13, 2024
c0d80b4
Vectorized load. Scalar block reduce. Hitting max throughput for f32 …
ivarflakstad Feb 18, 2024
2c78cab
Vectorized load for online softmax. Involves a reinterpret_cast of sr…
ivarflakstad Feb 20, 2024
847b34f
Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and …
ivarflakstad Feb 21, 2024
70dd35f
Use constant for input instead of const device. Fix strided reduce.
ivarflakstad Feb 25, 2024
8ca2c03
Use contiguous reduce in tests
ivarflakstad Feb 25, 2024
a3a5994
Rename finalize -> to_scalar
ivarflakstad Feb 25, 2024
9b54e41
Support integer types max/min (switch with trait-inferred impl later)
ivarflakstad Feb 28, 2024
3633135
Was worried I was skipping work -> shuffling the 1D test cases
ivarflakstad Mar 7, 2024
a9e26b5
Add build.rs to avoid metal kernel jit compile overhead
ivarflakstad Mar 7, 2024
8186cee
Improve build. Extract utils
ivarflakstad Mar 8, 2024
568b543
Compile metal kernels for both macos and ios
ivarflakstad Sep 2, 2024
afa1ea1
Merge branch 'main' into ivarflakstad/metal-reduce-3
ivarflakstad Sep 2, 2024
c3e4da7
Fixed over xmas and then forgot about it
ivarflakstad Jan 13, 2025
5c59669
Merge branch 'main' into ivarflakstad/metal-reduce-3
ivarflakstad Jan 13, 2025
bcbbad6
Add calculate_reduce_threads util
ivarflakstad Jan 13, 2025
e8499c8
Remove old reduce.metal
ivarflakstad Jan 13, 2025
4c94925
Improve f16/bf16 softmax precision by accumulating in f32
ivarflakstad Jan 20, 2025
eb6985e
Remove build.rs (for now)
ivarflakstad Jan 20, 2025
b094d09
Move softmax bench to candle-nn
ivarflakstad Jan 20, 2025
14cbd5e
Remove redundant thread calc util fn
ivarflakstad Jan 20, 2025
c174d5b
Merge remote-tracking branch 'origin/main' into ivarflakstad/metal-re…
LaurentMazare Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions candle-core/benches/bench_main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod benchmarks;

use criterion::criterion_main;

criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
Expand Down
1 change: 1 addition & 0 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod reduce;
pub(crate) mod unary;
pub(crate) mod where_cond;

Expand Down
158 changes: 158 additions & 0 deletions candle-core/benches/benchmarks/reduce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::time::Instant;

fn run_sum(a: &Tensor) {
a.sum_keepdim(2).unwrap();
}
fn run_arg_min(a: &Tensor) {
a.argmin_keepdim(2).unwrap();
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let (lo, up) = (-1000.0f32, 1000.0f32);
for device in handler.devices {
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);

run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);

run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);

run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}

fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();

let name = match T::DTYPE {
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "unknown",
};

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_sum(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;

let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};

let flops = b * m * k * T::DTYPE.size_in_bytes();

let name = match T::DTYPE {
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};

let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_arg_min(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

criterion_group!(benches, criterion_benchmark);
3 changes: 1 addition & 2 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};

Expand Down Expand Up @@ -241,7 +240,7 @@ impl MetalDevice {
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
data.as_ptr().cast(),
size,
MTLResourceOptions::StorageModeManaged,
);
Expand Down
66 changes: 63 additions & 3 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage {

fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();

let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
Expand All @@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]);
}
}

for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}

// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let reduction_shape = Shape::from(dims.clone());

if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
&src_dims,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;

return Ok(Self::new(buffer, device, dst_el, dtype));
}

let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
Expand Down Expand Up @@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
Expand Down
Loading
Loading