Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding iOS Support to Mistral.rs #63

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};

use super::MetalError;
use super::{MetalError, METAL_SHARED_BUFFER_STORAGE_MODE};

/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -201,7 +201,7 @@ impl MetalDevice {
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name)
self.allocate_buffer(size, METAL_SHARED_BUFFER_STORAGE_MODE, name)
}

/// Creates a new buffer (not necessarily zeroed).
Expand All @@ -210,7 +210,7 @@ impl MetalDevice {
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
self.allocate_buffer(size, METAL_SHARED_BUFFER_STORAGE_MODE, "managed")
}

/// Creates a new buffer from data.
Expand All @@ -223,12 +223,12 @@ impl MetalDevice {
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeShared,
METAL_SHARED_BUFFER_STORAGE_MODE,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;

let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeShared))
.entry((size, METAL_SHARED_BUFFER_STORAGE_MODE))
.or_insert(vec![]);

let new_buffer = Arc::new(new_buffer);
Expand All @@ -239,7 +239,7 @@ impl MetalDevice {
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModeShared,
METAL_SHARED_BUFFER_STORAGE_MODE,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
Expand Down
10 changes: 9 additions & 1 deletion candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};

#[cfg(target_os = "ios")]
pub const METAL_SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions =
MTLResourceOptions::StorageModeShared;

#[cfg(not(target_os = "ios"))]
pub const METAL_SHARED_BUFFER_STORAGE_MODE: MTLResourceOptions =
MTLResourceOptions::StorageModeManaged;

mod device;
pub use device::{DeviceId, MetalDevice};

Expand Down Expand Up @@ -2088,7 +2096,7 @@ impl BackendDevice for MetalDevice {
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeShared,
METAL_SHARED_BUFFER_STORAGE_MODE,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
Expand Down
4 changes: 4 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
#[cfg(not(target_os = "ios"))]
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
// Current source: https://github.com/philipturner/metal-flash-attention/releases/tag/v1.0.1
#[cfg(target_os = "ios")]
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.ios.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
Expand Down
Binary file not shown.
23 changes: 14 additions & 9 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl From<f64> for LayerNormConfig {
#[derive(Clone, Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
bias: Option<Tensor>,
remove_mean: bool,
eps: f64,
}
Expand All @@ -87,7 +87,7 @@ impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self {
weight,
bias,
bias: Some(bias),
remove_mean: true,
eps,
}
Expand All @@ -96,7 +96,7 @@ impl LayerNorm {
pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
Self {
weight: weight.clone(),
bias: Tensor::zeros_like(&weight).unwrap(),
bias: None,
remove_mean: true,
eps,
}
Expand All @@ -105,7 +105,7 @@ impl LayerNorm {
pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
Self {
weight: weight.clone(),
bias: Tensor::zeros_like(&weight).unwrap(),
bias: None,
remove_mean: false,
eps,
}
Expand All @@ -115,15 +115,17 @@ impl LayerNorm {
&self.weight
}

pub fn bias(&self) -> &Tensor {
&self.bias
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}

impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.is_contiguous() && self.remove_mean {
return crate::ops::layer_norm(x, &self.weight, &self.bias, self.eps as f32);
if let Some(bias) = self.bias.as_ref() {
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
}
}
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
Expand All @@ -141,7 +143,10 @@ impl Module for LayerNorm {
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
x.broadcast_add(&self.bias)
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}

Expand All @@ -159,7 +164,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
};
Ok(LayerNorm {
weight: weight.clone(),
bias: bias.unwrap_or(Tensor::zeros_like(&weight)?),
bias,
remove_mean: config.remove_mean,
eps: config.eps,
})
Expand Down