Skip to content

Commit

Permalink
Q2B1: Add new quant with optimized performance
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseCarlosGarcia95 committed Jan 1, 2025
1 parent 49d44b5 commit eefc336
Show file tree
Hide file tree
Showing 11 changed files with 579 additions and 92 deletions.
3 changes: 3 additions & 0 deletions candle-core/src/quantized/ggml_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ pub fn qtensor_from_ggml(
GgmlDType::QI8 => {
from_raw_data::<k_quants::BlockQI8>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2b1 => {
from_raw_data::<k_quants::BlockQ2b1>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
Expand Down
203 changes: 183 additions & 20 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
pub const Q2B_0: usize = 32;
pub const Q2B_1: usize = 32;
pub const QI8: usize = 32;

pub trait GgmlType: Sized + Clone + Send + Sync {
Expand Down Expand Up @@ -168,7 +169,14 @@ const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::<BlockQ2b0>());

#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQI8{
pub struct BlockQ2b1 {
pub(crate) qs: [u8; Q2B_0 / 4], // Every single 2-bit represents {-1, 0, 1}
}
const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::<BlockQ2b1>());

#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQI8 {
pub(crate) qs: [i8; QI8],
}
const _: () = assert!(std::mem::size_of::<BlockQI8>() == QI8);
Expand Down Expand Up @@ -1857,7 +1865,6 @@ impl GgmlType for BlockQ8K {
}
}


impl GgmlType for BlockQ2b0 {
const DTYPE: GgmlDType = GgmlDType::Q2b0;
const BLCK_SIZE: usize = Q2B_0;
Expand All @@ -1869,7 +1876,7 @@ impl GgmlType for BlockQ2b0 {

Self::vec_dot_unopt(n, xs, ys)
}

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % Q2B_0 != 0 {
crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {Q2B_0}");
Expand All @@ -1881,20 +1888,29 @@ impl GgmlType for BlockQ2b0 {
let qs = x.qs[i];
let qd = x.qd[i];
let mut y_cache = [0i32; 8];
y_cache.copy_from_slice(&y.qs[i * 8..(i + 1) * 8].iter().map(|&x| x as i32).collect::<Vec<_>>()[..]);

let pos_sum: i32 = (0..8).map(|bit| {
let mask = 1 << bit;
let is_active = ((qs & mask) >> bit) as i32;
is_active * y_cache[bit]
}).sum();

let neg_sum: i32 = (0..8).map(|bit| {
let mask = 1 << bit;
let is_active = ((qd & mask) >> bit) as i32;
is_active * y_cache[bit]
}).sum();

y_cache.copy_from_slice(
&y.qs[i * 8..(i + 1) * 8]
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>()[..],
);

let pos_sum: i32 = (0..8)
.map(|bit| {
let mask = 1 << bit;
let is_active = ((qs & mask) >> bit) as i32;
is_active * y_cache[bit]
})
.sum();

let neg_sum: i32 = (0..8)
.map(|bit| {
let mask = 1 << bit;
let is_active = ((qd & mask) >> bit) as i32;
is_active * y_cache[bit]
})
.sum();

isum += pos_sum - neg_sum;
}
sumf += isum as f32;
Expand All @@ -1904,7 +1920,11 @@ impl GgmlType for BlockQ2b0 {

fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() % Q2B_0 != 0 {
crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), Q2B_0);
crate::bail!(
"quantize_row_q2b0: size mismatch {} not divisible by {}",
xs.len(),
Q2B_0
);
}

for (block, x) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) {
Expand All @@ -1928,7 +1948,11 @@ impl GgmlType for BlockQ2b0 {

fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if ys.len() % Q2B_0 != 0 {
crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), Q2B_0);
crate::bail!(
"dequantize_row_q2b0: size mismatch {} not divisible by {}",
ys.len(),
Q2B_0
);
}

for (block, y) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) {
Expand All @@ -1951,6 +1975,145 @@ impl GgmlType for BlockQ2b0 {
}
}

const fn build_decode_q2b1_lut_i8() -> [[i8; 4]; 256] {
let mut table = [[0i8; 4]; 256];
let mut i = 0;
while i < 256 {
let byte = i as u8;
let mut dec = [0i8; 4];
let mut b = 0;
while b < 4 {
let code = (byte >> (2 * b)) & 0b11;
dec[b as usize] = match code {
0b00 => 0,
0b01 => 1,
0b10 => -1,
0b11 => 0,
_ => unreachable!(),
};
b += 1;
}
table[i] = dec;
i += 1;
}
table
}

static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = build_decode_q2b1_lut_i8();
const fn build_decode_q2b1_lut_f32() -> [[f32; 4]; 256] {
let mut table = [[0.0_f32; 4]; 256];
let mut i = 0;
while i < 256 {
let byte = i as u8;
let mut dec = [0.0_f32; 4];
let mut b = 0;
while b < 4 {
let code = (byte >> (2 * b)) & 0b11;
dec[b as usize] = match code {
0b00 => 0.0,
0b01 => 1.0,
0b10 => -1.0,
0b11 => 0.0,
_ => unreachable!(),
};
b += 1;
}
table[i] = dec;
i += 1;
}
table
}

static LUT_DECODE_Q2B1_F32: [[f32; 4]; 256] = build_decode_q2b1_lut_f32();
impl GgmlType for BlockQ2b1 {
const DTYPE: GgmlDType = GgmlDType::Q2b1;
const BLCK_SIZE: usize = Q2B_0;
type VecDotType = BlockQI8;

fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q2b1_qi8(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % Q2B_0 != 0 {
crate::bail!("vec_dot_q2b1_qi8: n = {n} is not divisible by {Q2B_0}");
}

let mut sumf = 0.0;

for (block_x, block_y) in xs.iter().zip(ys.iter()) {
let mut isum = 0i32;

for i in 0..(Q2B_0 / 4) {
let enc_x = block_x.qs[i];
let y_slice = &block_y.qs[i * 4..(i + 1) * 4];

let dec_x = &LUT_DECODE_Q2B1_I8[enc_x as usize];

for b in 0..4 {
let x_val = dec_x[b] as i32;
let y_val = y_slice[b] as i32;
isum += x_val * y_val;
}
}
sumf += isum as f32;
}

Ok(sumf)
}

fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() % Q2B_0 != 0 {
crate::bail!(
"quantize_row_q2b1: size {} is not divisible by {}",
xs.len(),
Q2B_0
);
}

for (block, chunk) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) {
for (i, subchunk) in chunk.chunks_exact(4).enumerate() {
let mut encoded: u8 = 0;
for (b, &val) in subchunk.iter().enumerate() {
let bits = if val > 0.0 {
0b01
} else if val < 0.0 {
0b10
} else {
0b00
};
encoded |= bits << (2 * b);
}
block.qs[i] = encoded;
}
}

Ok(())
}

fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if ys.len() % Q2B_0 != 0 {
crate::bail!(
"dequantize_row_q2b1: size {} is not divisible by {}",
ys.len(),
Q2B_0
);
}

for (block, out_chunk) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) {
for (i, subchunk) in out_chunk.chunks_exact_mut(4).enumerate() {
let enc = block.qs[i];
let dec = &LUT_DECODE_Q2B1_F32[enc as usize];
subchunk.copy_from_slice(dec);
}
}

Ok(())
}
}

impl GgmlType for BlockQI8 {
const DTYPE: GgmlDType = GgmlDType::QI8;
const BLCK_SIZE: usize = QI8;
Expand Down Expand Up @@ -1990,7 +2153,7 @@ impl GgmlType for BlockQI8 {
for (i, ys) in ys.iter_mut().enumerate() {
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
*y = x as i8;
*y = x as i8;
}
}
Ok(())
Expand Down
5 changes: 5 additions & 0 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ impl QMetalStorage {
let vec: Vec<crate::quantized::BlockQ2b0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?;
}
GgmlDType::Q2b1 => {
let vec: Vec<crate::quantized::BlockQ2b1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2b1::to_float(&vec, &mut out)?;
}
GgmlDType::QI8 => {
let vec: Vec<crate::quantized::BlockQI8> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQI8::to_float(&vec, &mut out)?;
Expand Down Expand Up @@ -234,6 +238,7 @@ impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0,
GgmlDType::Q2b1 => candle_metal_kernels::GgmlDType::Q2b1,
GgmlDType::QI8 => todo!(),
}
}
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub enum GgmlDType {
Q6K,
Q8K,
Q2b0,
Q2b1,
QI8,
}

Expand All @@ -169,6 +170,7 @@ impl GgmlDType {
15 => Self::Q8K,
40 => Self::Q2b0,
41 => Self::QI8,
42 => Self::Q2b1,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
Expand All @@ -192,6 +194,7 @@ impl GgmlDType {
Self::Q8K => 15,
Self::Q2b0 => 40,
Self::QI8 => 41,
Self::Q2b1 => 42,
}
}

Expand All @@ -214,6 +217,7 @@ impl GgmlDType {
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]),
Self::QI8 => Box::new(vec![BlockQI8::zeros(); elem_count / BlockQI8::BLCK_SIZE]),
Self::Q2b1 => Box::new(vec![BlockQ2b1::zeros(); elem_count / BlockQ2b1::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes.
Expand All @@ -237,6 +241,7 @@ impl GgmlDType {
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
Self::Q2b0 => std::mem::size_of::<BlockQ2b0>(),
Self::QI8 => std::mem::size_of::<BlockQI8>(),
Self::Q2b1 => std::mem::size_of::<BlockQ2b1>(),
}
}

Expand All @@ -252,6 +257,7 @@ impl GgmlDType {
Self::Q8_0 => k_quants::QK8_0,
Self::Q8_1 => k_quants::QK8_1,
Self::Q2b0 => k_quants::Q2B_0,
Self::Q2b1 => k_quants::Q2B_1,
Self::QI8 => k_quants::QI8,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
}
Expand Down
Loading

0 comments on commit eefc336

Please sign in to comment.