diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 7cde3eaa7..4cbddfb9c 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -192,6 +192,9 @@ pub fn qtensor_from_ggml( GgmlDType::QI8 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Q2b1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index e9c237986..5dd959577 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -20,6 +20,7 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; pub const Q2B_0: usize = 32; +pub const Q2B_1: usize = 32; pub const QI8: usize = 32; pub trait GgmlType: Sized + Clone + Send + Sync { @@ -168,7 +169,14 @@ const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); #[derive(Debug, Clone, PartialEq)] #[repr(C)] -pub struct BlockQI8{ +pub struct BlockQ2b1 { + pub(crate) qs: [u8; Q2B_0 / 4], // Every single 2-bit represents {-1, 0, 1} +} +const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQI8 { pub(crate) qs: [i8; QI8], } const _: () = assert!(std::mem::size_of::() == QI8); @@ -1857,7 +1865,6 @@ impl GgmlType for BlockQ8K { } } - impl GgmlType for BlockQ2b0 { const DTYPE: GgmlDType = GgmlDType::Q2b0; const BLCK_SIZE: usize = Q2B_0; @@ -1869,7 +1876,7 @@ impl GgmlType for BlockQ2b0 { Self::vec_dot_unopt(n, xs, ys) } - + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { if n % Q2B_0 != 0 { crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {Q2B_0}"); @@ -1881,20 +1888,29 @@ impl GgmlType for BlockQ2b0 { let qs = x.qs[i]; let qd = x.qd[i]; let mut y_cache = [0i32; 8]; - y_cache.copy_from_slice(&y.qs[i * 8..(i + 1) * 8].iter().map(|&x| x as i32).collect::>()[..]); - - let pos_sum: i32 = (0..8).map(|bit| { - let mask = 1 << bit; - let is_active = ((qs & mask) >> bit) as i32; - is_active * y_cache[bit] - }).sum(); - - let neg_sum: i32 = (0..8).map(|bit| { - let mask = 1 << bit; - let is_active = ((qd & mask) >> bit) as i32; - is_active * y_cache[bit] - }).sum(); - + y_cache.copy_from_slice( + &y.qs[i * 8..(i + 1) * 8] + .iter() + .map(|&x| x as i32) + .collect::>()[..], + ); + + let pos_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qs & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + + let neg_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qd & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + isum += pos_sum - neg_sum; } sumf += isum as f32; @@ -1904,7 +1920,11 @@ impl GgmlType for BlockQ2b0 { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { if xs.len() % Q2B_0 != 0 { - crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), Q2B_0); + crate::bail!( + "quantize_row_q2b0: size mismatch {} not divisible by {}", + xs.len(), + Q2B_0 + ); } for (block, x) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { @@ -1928,7 +1948,11 @@ impl GgmlType for BlockQ2b0 { fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { if ys.len() % Q2B_0 != 0 { - crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), Q2B_0); + crate::bail!( + "dequantize_row_q2b0: size mismatch {} not divisible by {}", + ys.len(), + Q2B_0 + ); } for (block, y) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { @@ -1951,6 +1975,145 @@ impl GgmlType for BlockQ2b0 { } } +const fn build_decode_q2b1_lut_i8() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = build_decode_q2b1_lut_i8(); +const fn build_decode_q2b1_lut_f32() -> [[f32; 4]; 256] { + let mut table = [[0.0_f32; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0.0_f32; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0.0, + 0b01 => 1.0, + 0b10 => -1.0, + 0b11 => 0.0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_F32: [[f32; 4]; 256] = build_decode_q2b1_lut_f32(); +impl GgmlType for BlockQ2b1 { + const DTYPE: GgmlDType = GgmlDType::Q2b1; + const BLCK_SIZE: usize = Q2B_0; + type VecDotType = BlockQI8; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2b1_qi8(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b1_qi8: n = {n} is not divisible by {Q2B_0}"); + } + + let mut sumf = 0.0; + + for (block_x, block_y) in xs.iter().zip(ys.iter()) { + let mut isum = 0i32; + + for i in 0..(Q2B_0 / 4) { + let enc_x = block_x.qs[i]; + let y_slice = &block_y.qs[i * 4..(i + 1) * 4]; + + let dec_x = &LUT_DECODE_Q2B1_I8[enc_x as usize]; + + for b in 0..4 { + let x_val = dec_x[b] as i32; + let y_val = y_slice[b] as i32; + isum += x_val * y_val; + } + } + sumf += isum as f32; + } + + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() % Q2B_0 != 0 { + crate::bail!( + "quantize_row_q2b1: size {} is not divisible by {}", + xs.len(), + Q2B_0 + ); + } + + for (block, chunk) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { + for (i, subchunk) in chunk.chunks_exact(4).enumerate() { + let mut encoded: u8 = 0; + for (b, &val) in subchunk.iter().enumerate() { + let bits = if val > 0.0 { + 0b01 + } else if val < 0.0 { + 0b10 + } else { + 0b00 + }; + encoded |= bits << (2 * b); + } + block.qs[i] = encoded; + } + } + + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if ys.len() % Q2B_0 != 0 { + crate::bail!( + "dequantize_row_q2b1: size {} is not divisible by {}", + ys.len(), + Q2B_0 + ); + } + + for (block, out_chunk) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { + for (i, subchunk) in out_chunk.chunks_exact_mut(4).enumerate() { + let enc = block.qs[i]; + let dec = &LUT_DECODE_Q2B1_F32[enc as usize]; + subchunk.copy_from_slice(dec); + } + } + + Ok(()) + } +} + impl GgmlType for BlockQI8 { const DTYPE: GgmlDType = GgmlDType::QI8; const BLCK_SIZE: usize = QI8; @@ -1990,7 +2153,7 @@ impl GgmlType for BlockQI8 { for (i, ys) in ys.iter_mut().enumerate() { let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { - *y = x as i8; + *y = x as i8; } } Ok(()) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 4ed1941ed..426721db8 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -107,6 +107,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?; } + GgmlDType::Q2b1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2b1::to_float(&vec, &mut out)?; + } GgmlDType::QI8 => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQI8::to_float(&vec, &mut out)?; @@ -234,6 +238,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0, + GgmlDType::Q2b1 => candle_metal_kernels::GgmlDType::Q2b1, GgmlDType::QI8 => todo!(), } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 728466e87..4912f0f33 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -147,6 +147,7 @@ pub enum GgmlDType { Q6K, Q8K, Q2b0, + Q2b1, QI8, } @@ -169,6 +170,7 @@ impl GgmlDType { 15 => Self::Q8K, 40 => Self::Q2b0, 41 => Self::QI8, + 42 => Self::Q2b1, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -192,6 +194,7 @@ impl GgmlDType { Self::Q8K => 15, Self::Q2b0 => 40, Self::QI8 => 41, + Self::Q2b1 => 42, } } @@ -214,6 +217,7 @@ impl GgmlDType { Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]), Self::QI8 => Box::new(vec![BlockQI8::zeros(); elem_count / BlockQI8::BLCK_SIZE]), + Self::Q2b1 => Box::new(vec![BlockQ2b1::zeros(); elem_count / BlockQ2b1::BLCK_SIZE]), } } /// The type size for blocks in bytes. @@ -237,6 +241,7 @@ impl GgmlDType { Self::Q8K => std::mem::size_of::(), Self::Q2b0 => std::mem::size_of::(), Self::QI8 => std::mem::size_of::(), + Self::Q2b1 => std::mem::size_of::(), } } @@ -252,6 +257,7 @@ impl GgmlDType { Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, Self::Q2b0 => k_quants::Q2B_0, + Self::Q2b1 => k_quants::Q2B_1, Self::QI8 => k_quants::QI8, Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 838755594..53de2b900 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,6 +1,10 @@ -use super::{k_quants::{ - BlockQ2K, BlockQ2b0, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K -}, BlockQI8}; +use super::{ + k_quants::{ + BlockQ2K, BlockQ2b0, BlockQ2b1, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, + BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K, + }, + BlockQI8, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -11,6 +15,7 @@ use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +use std::ptr; #[inline(always)] unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { @@ -575,6 +580,78 @@ pub(crate) fn vec_dot_q2b0_qi8(n: usize, xs: &[BlockQ2b0], ys: &[BlockQI8]) -> c Ok(sumf) } +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = { + const fn build_decode_table() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => 0, + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table + } + build_decode_table() +}; + +unsafe fn decode_q2b1_16(input: &[u8]) -> int8x16_t { + debug_assert_eq!(input.len(), 4, "input must be 4 bytes long"); + let mut tmp = [0i8; 16]; + + for (i, &byte) in input.iter().enumerate() { + let decoded4 = LUT_DECODE_Q2B1_I8[byte as usize]; + tmp[i * 4..i * 4 + 4].copy_from_slice(&decoded4); + } + + vld1q_s8(tmp.as_ptr()) +} + +#[inline(always)] +pub fn vec_dot_q2b1_qi8(n: usize, xs: &[BlockQ2b1], ys: &[BlockQI8]) -> crate::Result { + let blocks = n / 32; + + let mut total_sum = 0i32; + + unsafe { + for i in 0..blocks { + let x_block = &xs[i]; + let y_block = &ys[i]; + + let x_dec_lo = decode_q2b1_16(&x_block.qs[0..4]); + let x_dec_hi = decode_q2b1_16(&x_block.qs[4..8]); + + let y_lo = vld1q_s8(y_block.qs[0..16].as_ptr()); + let y_hi = vld1q_s8(y_block.qs[16..32].as_ptr()); + + let mut acc0 = vdupq_n_s32(0); + let mut acc1 = vdupq_n_s32(0); + + acc0 = vaddq_s32(acc0, vdotq_s32(x_dec_lo, y_lo)); + acc1 = vaddq_s32(acc1, vdotq_s32(x_dec_hi, y_hi)); + + let sum0 = vaddvq_s32(acc0); + let sum1 = vaddvq_s32(acc1); + + total_sum += sum0 + sum1; + } + } + + Ok(total_sum as f32) +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 2b024b50b..a40f3a874 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -5,9 +5,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use clap::{Parser, ValueEnum}; -use tracing_subscriber::fmt::time::FormatTime; use std::io::Write; -use tokenizers::{Tokenizer, AddedToken}; +use tokenizers::{AddedToken, Tokenizer}; +use tracing_subscriber::fmt::time::FormatTime; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; @@ -46,7 +46,7 @@ enum Which { Llama3_8b1_58, } -impl Which { +impl Which { fn tokenizer_repo(&self) -> &'static str { match self { Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", @@ -303,9 +303,7 @@ fn main() -> anyhow::Result<()> { let mut pre_prompt_tokens = vec![]; for prompt_index in 0.. { let prompt_str = match &prompt { - Prompt::One(prompt) => { - prompt.clone() - } + Prompt::One(prompt) => prompt.clone(), Prompt::Interactive | Prompt::Chat => { let is_interactive = matches!(prompt, Prompt::Interactive); print!("> "); @@ -318,11 +316,11 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - - prompt + + prompt.clone() } }; - + print!("{}", &prompt_str); let tokens = tos .tokenizer() @@ -383,13 +381,13 @@ fn main() -> anyhow::Result<()> { } let eos_tokens = match args.which { - Which::Falcon3_10b1_58 | - Which::Falcon3_10bInstruct1_58 | - Which::Falcon3_7bInstruct1_58 | - Which::Falcon3_7b1_58 | - Which::Falcon3_3bInstruct1_58 | - Which::Falcon3_3b1_58 | - Which::Falcon3_1bInstruct1_58 => { + Which::Falcon3_10b1_58 + | Which::Falcon3_10bInstruct1_58 + | Which::Falcon3_7bInstruct1_58 + | Which::Falcon3_7b1_58 + | Which::Falcon3_3bInstruct1_58 + | Which::Falcon3_3b1_58 + | Which::Falcon3_1bInstruct1_58 => { vec!["<|endoftext|>"] } Which::Llama3_8b1_58 => { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ddf0c6bb2..7e683596c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2165,6 +2165,7 @@ pub enum GgmlDType { F16, F32, Q2b0, + Q2b1, } #[allow(clippy::too_many_arguments)] @@ -2230,7 +2231,7 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } - GgmlDType::Q2b0 => { + GgmlDType::Q2b1 | GgmlDType::Q2b0 => { let nth0 = 8; let nth1 = 8; let align = 8; @@ -2287,7 +2288,8 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", - GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32" + GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32", + GgmlDType::Q2b1 => "kernel_mul_mv_q2b1_f32" }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index 6972d3825..73a374455 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -48,6 +48,11 @@ typedef struct { uint8_t qd[Q2B_0 / 8]; // Every single bit represents negative values, is a vector of {0, 1} } block_q2b_0; +#define Q2B_1 32 +typedef struct { + uint8_t qs[Q2B_1 / 4]; // Every single 2-bit represents {-1, 0, 1} +} block_q2b_1; + #define QI8 32 typedef struct { int8_t qs[QI8]; // quants @@ -3594,6 +3599,153 @@ kernel void kernel_mul_mv_q2b0_f32( kernel_mul_mv_q2b0_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +#define NB_Q2B_1 8 +constant float code_lut[4] = { 0.0f, 1.0f, -1.0f, 0.0f }; + +inline void kernel_mul_mv_q2b1_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]] +) { + // These come from your headers or #defines + const int nr = N_DST; // number of "rows" each thread processes + const int nsg = N_SIMDGROUP; // number of simdgroups per dimension + const int nw = N_SIMDWIDTH; // simd width + + const int nb = ne00 / Q2B_0; // number of Q2B_0 blocks in a row of X + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + // Each simdgroup processes 'nr' rows, so figure out which "chunk" we do: + const int first_row = (r0 * nsg + sgitg) * nr; + + // Flatten z index using ne12 + const uint i12 = im % ne12; + const uint i13 = im / ne12; + + // Compute offset into src0 (the quantized blocks array) + const uint offset0 = + first_row * nb + + (i12 / r2) * (nb * ne01) + + (i13 / r3) * (nb * ne01 * ne02); + + // Pointer to the first quantized block + device const block_q2b_1 * x = (device const block_q2b_1 *)src0 + offset0; + + // Pointer to the appropriate row of src1 + device const float * y = src1 + + r1 * ne10 // stride in y dimension + + im * ne00 * ne1; // stride in z dimension + + // Accumulators for partial sums, one per row + float sumf[nr]; + for (int i = 0; i < nr; i++) { + sumf[i] = 0.0f; + } + + // Figure out which quarter of the thread is ours + const int ix = tiisg / 4; + const int il = tiisg % 4; + + // This pointer yb will move through src1 in steps of NB_Q2B_1*nw + device const float * yb = y + ix * Q2B_0 + NB_Q2B_1 * il; + + // Main loop: each thread processes some subset of 'nb' blocks + for (int ib = ix; ib < nb; ib += (nw / 4)) { + + // Load 8 floats (NB_Q2B_0) into local array to keep them in registers + float yl[NB_Q2B_0]; + { + // Compiler usually unrolls such a small loop automatically + // but you can force it: + #pragma unroll 8 + for (int i = 0; i < NB_Q2B_0; i++) { + yl[i] = yb[i]; + } + } + + // For each row in [0..nr), compute partial dot-product + // with quantized data from 'x + ib + row * nb' + for (int row = 0; row < nr; row++) { + device const block_q2b_1 * bq = x + ib + row * nb; + + float sumq = 0.0f; + + // Each Q2B_0 = 8 bits, but we do them in steps of 2 + // 'startBit' is the first bit for the code. + // We unroll this loop as well. + const int startBit = NB_Q2B_1 * il; + #pragma unroll 8 + for (int iBit = 0; iBit < NB_Q2B_0; iBit++) { + const int bit = startBit + iBit; + const int bByte = bit >> 2; // bit / 4 + const int shift = 2 * (bit & 3); // (bit % 4)*2 + const int code = (bq->qs[bByte] >> shift) & 0x3; + + // Use the LUT to get +1 / -1 / 0 + sumq += code_lut[code] * yl[iBit]; + } + + sumf[row] += sumq; + } + + // Advance yb to the next group of 8 floats + yb += NB_Q2B_1 * nw; + } + + // Reduction across the simdgroup: each row's sum -> simd_sum(...) + // Then store to output if we're the "first lane" (tiisg == 0) + for (int row = 0; row < nr; row++) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && (first_row + row) < ne01) { + dst[r1 * ne0 + im * ne0 * ne1 + (first_row + row)] = tot; + } + } +} + + +[[host_name("kernel_mul_mv_q2b1_f32")]] +kernel void kernel_mul_mv_q2b1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2b1_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 34fc8b57e..0fca20e2d 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -75,6 +75,7 @@ pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_llama2_c; +pub mod quantized_llama_bitnet; pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; @@ -110,4 +111,3 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; -pub mod quantized_llama_bitnet; \ No newline at end of file diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index bfc3113ce..ed745f622 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -59,15 +59,22 @@ impl BitQMatMul { let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); let weight_scale = weight_scale.dequantize(&weight_scale.device())?; - Ok(Self { inner, span, weight_scale }) + Ok(Self { + inner, + span, + weight_scale, + }) } fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { - let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = x + .abs()? + .max_keepdim(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?; let scale = (127.0 / scale)?; let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?; - + Ok((y, scale)) } @@ -75,12 +82,10 @@ impl BitQMatMul { let (x, xscale) = self.activation_quant(x)?; let _enter = self.span.enter(); let scale = self.weight_scale.broadcast_mul(&xscale)?; - self.inner.forward(&x)? - .broadcast_div(&scale) + self.inner.forward(&x)?.broadcast_div(&scale) } } - #[derive(Debug, Clone)] struct Mlp { feed_forward_w1: BitQMatMul, @@ -337,11 +342,14 @@ impl ModelWeights { let attention_wo_ws = ct.remove(&format!("{prefix}.attention.wo.weight_scale"))?; let mlp_or_moe = { let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w1_ws = ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; + let feed_forward_w1_ws = + ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w2_ws = ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; + let feed_forward_w2_ws = + ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - let feed_forward_w3_ws = ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; + let feed_forward_w3_ws = + ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, @@ -431,13 +439,21 @@ impl ModelWeights { for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wq_ws = ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; + let attention_wq_ws = + ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wk_ws = ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; + let attention_wk_ws = + ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wv_ws = ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; - let attention_wo_ws = ct.tensor(reader, &format!("{prefix}.attn_output.weight_scale"), device)?; + let attention_wv_ws = + ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wo_ws = ct.tensor( + reader, + &format!("{prefix}.attn_output.weight_scale"), + device, + )?; let mlp_or_moe = if n_expert <= 1 { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; @@ -463,21 +479,36 @@ impl ModelWeights { for i in 0..n_expert { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; - let feed_forward_w1_ws = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight_scale"), device)?; + let feed_forward_w1_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_gate.{i}.weight_scale"), + device, + )?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; - let feed_forward_w2_ws = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight_scale"), device)?; + let feed_forward_w2_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_down.{i}.weight_scale"), + device, + )?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; let feed_forward_w3_ws = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight_scale"), device)?; - + experts.push(Mlp { - feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, - feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, - feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, + feed_forward_w1: BitQMatMul::from_qtensor( + feed_forward_w1, + feed_forward_w1_ws, + )?, + feed_forward_w2: BitQMatMul::from_qtensor( + feed_forward_w2, + feed_forward_w2_ws, + )?, + feed_forward_w3: BitQMatMul::from_qtensor( + feed_forward_w3, + feed_forward_w3_ws, + )?, }) } MlpOrMoe::MoE { @@ -567,4 +598,4 @@ impl ModelWeights { let _enter = self.span_output.enter(); self.output.forward(&x) } -} \ No newline at end of file +} diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 4a1e9bc03..78ba68840 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -14,7 +14,13 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType, bitnet_mode: bool) -> Result { + fn quantize( + &self, + name: &str, + tensor: QTensor, + dtype: GgmlDType, + bitnet_mode: bool, + ) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -50,6 +56,8 @@ enum Quantization { Q8_1, #[value(name = "q2b0")] Q2b0, + #[value(name = "q2b1")] + Q2b1, #[value(name = "qi8")] QI8, Q2k, @@ -81,6 +89,7 @@ impl Quantization { Quantization::F32 => GgmlDType::F32, Quantization::Q2b0 => GgmlDType::Q2b0, Quantization::QI8 => GgmlDType::QI8, + Quantization::Q2b1 => GgmlDType::Q2b1, } } } @@ -441,11 +450,11 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result { } use core::num; +use rayon::prelude::*; +use serde_json::Value; use std::collections::HashMap; use std::fs::File; use std::path::PathBuf; -use rayon::prelude::*; -use serde_json::Value; fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result { let n_head = match n_head_kv { @@ -464,7 +473,7 @@ fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result< let permuted = weights .reshape(new_shape)? - .transpose(1, 2)? + .transpose(1, 2)? .reshape(weights.shape())?; Ok(permuted) @@ -481,7 +490,9 @@ fn run_quantize_safetensors( let dtype = q.dtype(); let block_size = dtype.block_size(); - let metadata_file = in_files.iter().find(|f| f.to_string_lossy().ends_with("config.json")); + let metadata_file = in_files + .iter() + .find(|f| f.to_string_lossy().ends_with("config.json")); let mut qtensors = Vec::new(); @@ -489,7 +500,6 @@ fn run_quantize_safetensors( let mut num_key_value_heads = 0; let mut architecture = String::new(); - let gguf_metadata = if let Some(metadata_file) = metadata_file { let metadata_content = std::fs::read_to_string(metadata_file)?; let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); @@ -499,16 +509,41 @@ fn run_quantize_safetensors( architecture = metadata["model_type"].as_str().unwrap().to_string(); vec![ - ("llama.attention.head_count", gguf_file::Value::from_u32(num_attention_heads as u32)), - ("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)), - ("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)), - ("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)), - ("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)), - ("llama.rope.dimension_count", gguf_file::Value::from_u32( - (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32), - )), - ("llama.rope.freq_base", gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32)), - ("general.architecture", gguf_file::Value::from_string(architecture.clone())), + ( + "llama.attention.head_count", + gguf_file::Value::from_u32(num_attention_heads as u32), + ), + ( + "llama.attention.head_count_kv", + gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32), + ), + ( + "llama.block_count", + gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32), + ), + ( + "llama.embedding_length", + gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32), + ), + ( + "llama.attention.layer_norm_rms_epsilon", + gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32), + ), + ( + "llama.rope.dimension_count", + gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) + / (metadata["num_attention_heads"].as_u64().unwrap() as u32), + ), + ), + ( + "llama.rope.freq_base", + gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32), + ), + ( + "general.architecture", + gguf_file::Value::from_string(architecture.clone()), + ), ] } else { vec![] @@ -531,14 +566,13 @@ fn run_quantize_safetensors( let mut tensor = tensor; if should_quantize && bitnet_mode { - let is_bitnet_weight = - name.contains("self_attn.v_proj") || - name.contains("self_attn.q_proj") || - name.contains("self_attn.o_proj") || - name.contains("self_attn.k_proj") || - name.contains("mlp.down_proj") || - name.contains("mlp.up_proj") || - name.contains("mlp.gate_proj"); + let is_bitnet_weight = name.contains("self_attn.v_proj") + || name.contains("self_attn.q_proj") + || name.contains("self_attn.o_proj") + || name.contains("self_attn.k_proj") + || name.contains("mlp.down_proj") + || name.contains("mlp.up_proj") + || name.contains("mlp.gate_proj"); if is_bitnet_weight { println!(" unpacking {name} {tensor:?} {should_quantize}"); @@ -555,10 +589,18 @@ fn run_quantize_safetensors( match architecture.as_str() { "llama" => { if name.ends_with("self_attn.q_proj.weight") { - tensor = permute(&tensor, num_attention_heads as usize, Some(num_attention_heads as usize))?; + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_attention_heads as usize), + )?; } if name.ends_with("self_attn.k_proj.weight") { - tensor = permute(&tensor, num_attention_heads as usize, Some(num_key_value_heads as usize))?; + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_key_value_heads as usize), + )?; } } _ => {} @@ -716,7 +758,15 @@ fn main() -> anyhow::Result<()> { bitnet_quantization, mode, bitnet_mode, - } => run_quantize(&in_file, out_file, quantization, mode, bitnet_quantization, bitnet_mode, &device)?, + } => run_quantize( + &in_file, + out_file, + quantization, + mode, + bitnet_quantization, + bitnet_mode, + &device, + )?, Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(())