diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs index 59f10a52da..12740e6da2 100644 --- a/candle-metal-kernels/build.rs +++ b/candle-metal-kernels/build.rs @@ -1,13 +1,13 @@ +use std::path::PathBuf; use std::process::Command; use std::{env, str}; -use std::path::PathBuf; const METAL_SOURCES: [&str; 1] = ["reduce"]; fn main() -> Result<(), String> { println!("cargo:rerun-if-changed=build.rs"); - println!("cargo:rerun-if-changed=*.metal"); - println!("cargo:rerun-if-changed=*.m"); + println!("cargo:rerun-if-changed=reduce.metal"); + println!("cargo:rerun-if-changed=utils.metal"); let xcrun_output = Command::new("xcrun") .args(["--sdk", "macosx", "--show-sdk-path"]) @@ -18,44 +18,88 @@ fn main() -> Result<(), String> { .expect("Invalid UTF-8 from xcrun") .replace('\n', ""); - println!("cargo:rerun-if-changed={sdk_path}"); let current_dir = env::current_dir().expect("Failed to get current directory"); - let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_|"OUT_DIR not set")?); - - let sources = current_dir - .join("src") - .to_str() - .unwrap() - .to_string(); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?); + let working_directory = out_dir.to_string_lossy().to_string(); + let sources = current_dir.join("src"); // Compile metal to air let mut compile_air_cmd = Command::new("xcrun"); compile_air_cmd .arg("metal") - .arg(format!("-working-directory={}", out_dir.to_str().ok_or("")?)) + .arg(format!("-working-directory={working_directory}")) + .arg("-Wall") + .arg("-Wextra") + .arg("-O3") .arg("-c") - .arg("-frecord-sources") .arg("-w"); for metal_file in METAL_SOURCES { - compile_air_cmd.arg(format!("{sources}/{metal_file}.metal")); + compile_air_cmd.arg(sources.join(format!("{metal_file}.metal"))); } + compile_air_cmd.arg(sources.join("utils.metal")); compile_air_cmd.spawn().expect("Failed to compile air"); + let mut child = compile_air_cmd.spawn().expect("Failed to compile air"); + + match child.try_wait() { + Ok(Some(status)) => { + if !status.success() { + panic!( + "Compiling metal -> air failed. Exit with status: {}", + status + ) + } + } + Ok(None) => { + let status = child + .wait() + .expect("Compiling metal -> air failed while waiting for result"); + if !status.success() { + panic!( + "Compiling metal -> air failed. Exit with status: {}", + status + ) + } + } + Err(e) => panic!("Compiling metal -> air failed: {:?}", e), + } + // Compile air to metallib let metallib = out_dir.join("candle.metallib"); let mut compile_metallib_cmd = Command::new("xcrun"); - compile_metallib_cmd - .arg("metal") - .arg("-o") - .arg(&metallib); + compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib); for metal_file in METAL_SOURCES { compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air"))); } + compile_metallib_cmd.arg(out_dir.join("utils.air")); - compile_metallib_cmd + let mut child = compile_metallib_cmd .spawn() - .expect("Failed to compile metallib"); + .expect("Failed to compile air -> metallib"); + + match child.try_wait() { + Ok(Some(status)) => { + if !status.success() { + panic!( + "Compiling air -> metallib failed. Exit with status: {}", + status + ) + } + } + Ok(None) => { + let status = child + .wait() + .expect("Compiling air -> metallib failed while waiting for result"); + if !status.success() { + panic!( + "Compiling air -> metallib failed. Exit with status: {}", + status + ) + } + } + Err(e) => panic!("Compiling air -> metallib failed: {:?}", e), + } Ok(()) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5b1995dfa2..7ff1781ba9 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -269,7 +269,7 @@ impl Kernels { "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" )) })? - }, + } Source::Mfa => { let source_data = MFA; device.new_library_with_data(source_data).map_err(|e| { diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 6c7786535a..a08b768067 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,37 +1,8 @@ #include #include +#include "utils.metal" using namespace metal; - -METAL_FUNC uint nonzero(uint n) { - return n == 0 ? 1 : n; -} -template -constexpr uint nonzero() { - return N == 0 ? 1 : N; -} - -template -constexpr ushort granularity() { - return nonzero::value>(); -} - -METAL_FUNC uint next_p2(uint x) { - return 1 << (32 - clz(x - 1)); -} - - -METAL_FUNC uint prev_p2(uint x) { - return 1 << (31 - clz(x)); -} - -constant uint MAX_SHARED_MEM = 32767; - -template -METAL_FUNC uint max_shared_mem(uint n) { - return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); -} - struct Divide { template METAL_FUNC T operator()(T a, T b) { return a / b; } @@ -66,21 +37,6 @@ struct Exp { #endif }; -METAL_FUNC uint get_strided_index( - uint idx, - constant const uint &num_dims, - constant const size_t *dims, - constant const size_t *strides -) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; - } - return strided_i; -} - // Keeps track of the index of the value in the reduction operation (argmin, argmax, etc.) // and the value itself. The index is also used to break ties in the reduction operation. // There are two specializations of the indexed class, one for scalar values and one for vector values. @@ -1167,7 +1123,6 @@ struct finalize_softmax } }; - template struct finalize_softmax>> { using ST = make_scalar_t; diff --git a/candle-metal-kernels/src/utils.metal b/candle-metal-kernels/src/utils.metal new file mode 100644 index 0000000000..8ee6b4ad76 --- /dev/null +++ b/candle-metal-kernels/src/utils.metal @@ -0,0 +1,47 @@ +#pragma once +#include +using namespace metal; + +METAL_FUNC uint nonzero(uint n) { + return n == 0 ? 1 : n; +} + +template +constexpr uint nonzero() { + return N == 0 ? 1 : N; +} + +template +constexpr ushort granularity() { + return nonzero::value>(); +} + +METAL_FUNC uint next_p2(uint x) { + return 1 << (32 - clz(x - 1)); +} + +METAL_FUNC uint prev_p2(uint x) { + return 1 << (31 - clz(x)); +} + +constant uint MAX_SHARED_MEM = 32767; + +template +METAL_FUNC uint max_shared_mem(uint n) { + return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); +} + +METAL_FUNC uint get_strided_index( + uint idx, + constant const uint &num_dims, + constant const size_t *dims, + constant const size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +}