Skip to content

Commit

Permalink
Fixed over xmas and then forgot about it
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jan 13, 2025
1 parent afa1ea1 commit c3e4da7
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 554 deletions.
3 changes: 1 addition & 2 deletions candle-core/benches/benchmarks/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn run_arg_min(a: &Tensor) {
a.argmin_keepdim(2).unwrap();
}

// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
// NOTE: Should this be removed? Softmax impls live in candle-nn.
fn softmax(a: &Tensor) -> candle_core::Result<()> {
use candle_core::{backend::BackendStorage, DType};
let (storage, layout) = a.storage_and_layout();
Expand Down Expand Up @@ -121,7 +121,6 @@ fn run_reduce<T: candle_core::FloatDType>(
strided: bool,
) {
let b = 1;

let m = 1024;
let k = 1024;

Expand Down
3 changes: 1 addition & 2 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};

Expand Down Expand Up @@ -175,7 +174,7 @@ impl MetalDevice {
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
data.as_ptr().cast(),
size,
MTLResourceOptions::StorageModeManaged,
);
Expand Down
16 changes: 9 additions & 7 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,14 @@ impl BackendStorage for MetalStorage {
}
}

if layout.is_contiguous() {
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}

let reduction_shape = Shape::from(dims.clone());

if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
Expand Down Expand Up @@ -326,7 +333,7 @@ impl BackendStorage for MetalStorage {
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
&src_dims,
dst_el,
src,
&buffer,
Expand All @@ -336,11 +343,6 @@ impl BackendStorage for MetalStorage {
return Ok(Self::new(buffer, device, dst_el, dtype));
}

for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}

let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
Expand Down
45 changes: 22 additions & 23 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,35 +561,39 @@ pub fn call_reduce_contiguous(
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
shape: &[usize],
out_length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length = shape.iter().product::<usize>();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;

let (name, granularity) = if work_per_threadgroup % 4 == 0 {
(format!("{kernel_name}x4").leak(), 4)
} else if work_per_threadgroup % 2 == 0 {
(format!("{kernel_name}x2").leak(), 2)
} else {
(format!("{kernel_name}").leak(), 1)
};
let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;

let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (length, work_per_threadgroup, &input, output));
set_params!(
encoder,
(
length,
num_dims,
shape,
work_per_threadgroup,
&input,
output
)
);

let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};

let work_split = work_per_threadgroup / (2 * granularity);
let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
Expand All @@ -605,6 +609,7 @@ pub fn call_reduce_contiguous(
height: 1,
depth: 1,
};

encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Expand All @@ -624,6 +629,7 @@ pub fn call_reduce_strided(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;

Expand All @@ -634,7 +640,8 @@ pub fn call_reduce_strided(
set_params!(
encoder,
(
shape.len(),
length,
num_dims,
shape,
strides,
work_per_threadgroup,
Expand Down Expand Up @@ -684,15 +691,8 @@ pub fn call_last_softmax(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let work_per_threadgroup = elements;
let (name, granularity) = if work_per_threadgroup % 4 == 0 {
(format!("{kernel_name}x4").leak(), 4)
} else if work_per_threadgroup % 2 == 0 {
(format!("{kernel_name}x2").leak(), 2)
} else {
(format!("{kernel_name}").leak(), 1)
};

let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand All @@ -710,12 +710,11 @@ pub fn call_last_softmax(
depth: 1,
};

let work_split = work_per_threadgroup / (2 * granularity);
let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
}
w *= 2;

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
Expand Down
Loading

0 comments on commit c3e4da7

Please sign in to comment.