Skip to content

Commit

Permalink
Mistral.rs GPTQ dev PR (#14)
Browse files Browse the repository at this point in the history
* Add i32 dtype for cpu and cuda, with kernels

* Fix cuda i32

* Fix cpu i32

* Add cuda map impls for i32

* Start to add to metal

* Add the kernels

* Oops

* Fix dtype cast in safetensors

* Oops

* Oops

* Add bf16 to i32 and vice versa casts
  • Loading branch information
EricLBuehler authored Aug 9, 2024
1 parent 412e9f4 commit 7ad6494
Show file tree
Hide file tree
Showing 35 changed files with 548 additions and 18 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": ["cuda"]
"rust-analyzer.cargo.features": [
"cuda"
],
}
5 changes: 5 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand Down
114 changes: 112 additions & 2 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const USE_IM2COL_CONV2D: bool = true;
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
I32(Vec<i32>),
I64(Vec<i64>),
BF16(Vec<bf16>),
F16(Vec<f16>),
Expand All @@ -33,6 +34,7 @@ pub enum CpuStorage {
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
I32(&'a [i32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
Expand Down Expand Up @@ -2285,6 +2287,17 @@ impl CpuStorage {
.concat();
Self::U32(storages)
}
Self::I32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::I32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::I32(storages)
}
Self::I64(_) => {
let storages = storages
.iter()
Expand Down Expand Up @@ -2352,6 +2365,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
Self::I32(_) => DType::I32,
Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Expand All @@ -2371,6 +2385,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I32(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I64(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
Expand Down Expand Up @@ -2399,6 +2417,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I32(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I64(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -2427,6 +2449,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I32(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I64(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -2471,6 +2497,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I64(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
Expand All @@ -2483,6 +2513,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data))
}
(Self::I32(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::I64(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand All @@ -2503,6 +2537,38 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I32(data))
}
(Self::I64(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::BF16(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i32);
Ok(Self::I32(data))
}
(Self::F16(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i32);
Ok(Self::I32(data))
}
(Self::F32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::F64(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand All @@ -2511,6 +2577,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2539,6 +2609,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I32(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I64(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -2674,6 +2748,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
Expand All @@ -2699,6 +2774,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
Expand Down Expand Up @@ -2749,6 +2825,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
Self::I32(storage) => {
let data = unary_map(storage, layout, B::i32);
Ok(Self::I32(data))
}
Self::I64(storage) => {
let data = unary_map(storage, layout, B::i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2803,6 +2883,14 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
(Self::I32(lhs), Self::I32(rhs)) => {
let data = if B::I32_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::i32)
};
Ok(Self::I32(data))
}
(Self::I64(lhs), Self::I64(rhs)) => {
let data = if B::I64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
Expand Down Expand Up @@ -2846,6 +2934,9 @@ impl BackendStorage for CpuStorage {
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I32(src), Self::I32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
Expand Down Expand Up @@ -2877,6 +2968,7 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
Expand Down Expand Up @@ -2906,6 +2998,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
Expand Down Expand Up @@ -3075,6 +3168,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
}
Expand All @@ -3084,6 +3178,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
}
Expand All @@ -3101,6 +3196,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
Expand Down Expand Up @@ -3130,6 +3226,13 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
Expand Down Expand Up @@ -3225,7 +3328,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::BF16 => {
Expand Down Expand Up @@ -3271,7 +3374,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
}
DType::BF16 => {
Expand Down Expand Up @@ -3330,6 +3433,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::U32(v)
}
DType::I32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::I32(v)
}
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
Expand Down Expand Up @@ -3364,6 +3472,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I32 => CpuStorage::I32(vec![1i32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
Expand All @@ -3378,6 +3487,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub trait Map1 {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
Expand All @@ -26,6 +27,7 @@ pub trait Map1Any {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
Expand Down
Loading

0 comments on commit 7ad6494

Please sign in to comment.