Skip to content

Commit

Permalink
Improve build. Extract utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Mar 8, 2024
1 parent a9e26b5 commit 8186cee
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 67 deletions.
84 changes: 64 additions & 20 deletions candle-metal-kernels/build.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::path::PathBuf;
use std::process::Command;
use std::{env, str};
use std::path::PathBuf;

const METAL_SOURCES: [&str; 1] = ["reduce"];

fn main() -> Result<(), String> {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=*.metal");
println!("cargo:rerun-if-changed=*.m");
println!("cargo:rerun-if-changed=reduce.metal");
println!("cargo:rerun-if-changed=utils.metal");

let xcrun_output = Command::new("xcrun")
.args(["--sdk", "macosx", "--show-sdk-path"])
Expand All @@ -18,44 +18,88 @@ fn main() -> Result<(), String> {
.expect("Invalid UTF-8 from xcrun")
.replace('\n', "");

println!("cargo:rerun-if-changed={sdk_path}");
let current_dir = env::current_dir().expect("Failed to get current directory");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_|"OUT_DIR not set")?);

let sources = current_dir
.join("src")
.to_str()
.unwrap()
.to_string();
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?);
let working_directory = out_dir.to_string_lossy().to_string();
let sources = current_dir.join("src");

// Compile metal to air
let mut compile_air_cmd = Command::new("xcrun");
compile_air_cmd
.arg("metal")
.arg(format!("-working-directory={}", out_dir.to_str().ok_or("")?))
.arg(format!("-working-directory={working_directory}"))
.arg("-Wall")
.arg("-Wextra")
.arg("-O3")
.arg("-c")
.arg("-frecord-sources")
.arg("-w");
for metal_file in METAL_SOURCES {
compile_air_cmd.arg(format!("{sources}/{metal_file}.metal"));
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
}
compile_air_cmd.arg(sources.join("utils.metal"));
compile_air_cmd.spawn().expect("Failed to compile air");

let mut child = compile_air_cmd.spawn().expect("Failed to compile air");

match child.try_wait() {
Ok(Some(status)) => {
if !status.success() {
panic!(
"Compiling metal -> air failed. Exit with status: {}",
status
)
}
}
Ok(None) => {
let status = child
.wait()
.expect("Compiling metal -> air failed while waiting for result");
if !status.success() {
panic!(
"Compiling metal -> air failed. Exit with status: {}",
status
)
}
}
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
}

// Compile air to metallib
let metallib = out_dir.join("candle.metallib");
let mut compile_metallib_cmd = Command::new("xcrun");
compile_metallib_cmd
.arg("metal")
.arg("-o")
.arg(&metallib);
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);

for metal_file in METAL_SOURCES {
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
}
compile_metallib_cmd.arg(out_dir.join("utils.air"));

compile_metallib_cmd
let mut child = compile_metallib_cmd
.spawn()
.expect("Failed to compile metallib");
.expect("Failed to compile air -> metallib");

match child.try_wait() {
Ok(Some(status)) => {
if !status.success() {
panic!(
"Compiling air -> metallib failed. Exit with status: {}",
status
)
}
}
Ok(None) => {
let status = child
.wait()
.expect("Compiling air -> metallib failed while waiting for result");
if !status.success() {
panic!(
"Compiling air -> metallib failed. Exit with status: {}",
status
)
}
}
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
}

Ok(())
}
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ impl Kernels {
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
},
}
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
Expand Down
47 changes: 1 addition & 46 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
@@ -1,37 +1,8 @@
#include <metal_stdlib>
#include <metal_limits>
#include "utils.metal"
using namespace metal;


METAL_FUNC uint nonzero(uint n) {
return n == 0 ? 1 : n;
}
template<uint N>
constexpr uint nonzero() {
return N == 0 ? 1 : N;
}

template<typename T>
constexpr ushort granularity() {
return nonzero<vec_elements<T>::value>();
}

METAL_FUNC uint next_p2(uint x) {
return 1 << (32 - clz(x - 1));
}


METAL_FUNC uint prev_p2(uint x) {
return 1 << (31 - clz(x));
}

constant uint MAX_SHARED_MEM = 32767;

template<typename T>
METAL_FUNC uint max_shared_mem(uint n) {
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
}

struct Divide {
template<typename T>
METAL_FUNC T operator()(T a, T b) { return a / b; }
Expand Down Expand Up @@ -66,21 +37,6 @@ struct Exp {
#endif
};

METAL_FUNC uint get_strided_index(
uint idx,
constant const uint &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

// Keeps track of the index of the value in the reduction operation (argmin, argmax, etc.)
// and the value itself. The index is also used to break ties in the reduction operation.
// There are two specializations of the indexed class, one for scalar values and one for vector values.
Expand Down Expand Up @@ -1167,7 +1123,6 @@ struct finalize_softmax<T, BLOCKSIZE, typename metal::enable_if_t<is_scalar_v<T>
}
};


template<typename T, ushort BLOCKSIZE>
struct finalize_softmax<T, BLOCKSIZE, typename metal::enable_if_t<is_vector_v<T>>> {
using ST = make_scalar_t<T>;
Expand Down
47 changes: 47 additions & 0 deletions candle-metal-kernels/src/utils.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once
#include <metal_stdlib>
using namespace metal;

METAL_FUNC uint nonzero(uint n) {
return n == 0 ? 1 : n;
}

template<uint N>
constexpr uint nonzero() {
return N == 0 ? 1 : N;
}

template<typename T>
constexpr ushort granularity() {
return nonzero<vec_elements<T>::value>();
}

METAL_FUNC uint next_p2(uint x) {
return 1 << (32 - clz(x - 1));
}

METAL_FUNC uint prev_p2(uint x) {
return 1 << (31 - clz(x));
}

constant uint MAX_SHARED_MEM = 32767;

template<typename T>
METAL_FUNC uint max_shared_mem(uint n) {
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
}

METAL_FUNC uint get_strided_index(
uint idx,
constant const uint &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

0 comments on commit 8186cee

Please sign in to comment.