Skip to content

Commit

Permalink
Add build.rs to avoid metal kernel jit compile overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Mar 7, 2024
1 parent 3633135 commit a9e26b5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 7 deletions.
61 changes: 61 additions & 0 deletions candle-metal-kernels/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::process::Command;
use std::{env, str};
use std::path::PathBuf;

const METAL_SOURCES: [&str; 1] = ["reduce"];

fn main() -> Result<(), String> {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=*.metal");
println!("cargo:rerun-if-changed=*.m");

let xcrun_output = Command::new("xcrun")
.args(["--sdk", "macosx", "--show-sdk-path"])
.output()
.expect("xcrun command failed to start");

let sdk_path = str::from_utf8(&xcrun_output.stdout)
.expect("Invalid UTF-8 from xcrun")
.replace('\n', "");

println!("cargo:rerun-if-changed={sdk_path}");
let current_dir = env::current_dir().expect("Failed to get current directory");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_|"OUT_DIR not set")?);

let sources = current_dir
.join("src")
.to_str()
.unwrap()
.to_string();

// Compile metal to air
let mut compile_air_cmd = Command::new("xcrun");
compile_air_cmd
.arg("metal")
.arg(format!("-working-directory={}", out_dir.to_str().ok_or("")?))
.arg("-c")
.arg("-frecord-sources")
.arg("-w");
for metal_file in METAL_SOURCES {
compile_air_cmd.arg(format!("{sources}/{metal_file}.metal"));
}
compile_air_cmd.spawn().expect("Failed to compile air");

// Compile air to metallib
let metallib = out_dir.join("candle.metallib");
let mut compile_metallib_cmd = Command::new("xcrun");
compile_metallib_cmd
.arg("metal")
.arg("-o")
.arg(&metallib);

for metal_file in METAL_SOURCES {
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
}

compile_metallib_cmd
.spawn()
.expect("Failed to compile metallib");

Ok(())
}
21 changes: 14 additions & 7 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;

const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle.metallib"));
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
Expand Down Expand Up @@ -114,13 +114,13 @@ macro_rules! set_params {

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Candle,
Affine,
Indexing,
Unary,
Binary,
Ternary,
Cast,
Reduce,
Mfa,
Conv,
Random,
Expand Down Expand Up @@ -243,11 +243,10 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
_ => panic!("Invalid lib"),
}
}

Expand All @@ -263,6 +262,14 @@ impl Kernels {
Ok(lib.clone())
} else {
let lib = match source {
Source::Candle => {
let source_data = CANDLE;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
},
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
Expand Down Expand Up @@ -569,7 +576,7 @@ pub fn call_reduce_contiguous(
} else {
(format!("{kernel_name}").leak(), 1)
};
let pipeline = kernels.load_pipeline(device, Source::Reduce, name)?;
let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;

let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down Expand Up @@ -628,7 +635,7 @@ pub fn call_reduce_strided(
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;

let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down Expand Up @@ -697,7 +704,7 @@ pub fn call_last_softmax(
(format!("{kernel_name}").leak(), 1)
};

let pipeline = kernels.load_pipeline(device, Source::Reduce, name)?;
let pipeline = kernels.load_pipeline(device, Source::Candle, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

Expand Down
2 changes: 2 additions & 0 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ constexpr ushort granularity() {
METAL_FUNC uint next_p2(uint x) {
return 1 << (32 - clz(x - 1));
}


METAL_FUNC uint prev_p2(uint x) {
return 1 << (31 - clz(x));
}
Expand Down

0 comments on commit a9e26b5

Please sign in to comment.