Skip to content

Commit

Permalink
Add initial f8 e4m3 dtype (#31)
Browse files Browse the repository at this point in the history
* Add initial f8 e4m3 type

* Fixes

* Update deps

* Implement CudaDType

* Add some cast kernels

* Add copy2d, other impls

* Fix zeros impl

* Error checking for metal

* Use isnanf
  • Loading branch information
EricLBuehler authored Oct 17, 2024
1 parent 20a57c4 commit fa4902f
Show file tree
Hide file tree
Showing 33 changed files with 662 additions and 347 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@
"rust-analyzer.cargo.features": [
"cuda",
],
"files.associations": {
"cmath": "cpp"
},
}
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
memmap2 = { workspace = true }
Expand All @@ -39,7 +40,7 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels", "float8/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -149,6 +150,11 @@ impl Tensor {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
f.write_u8(v.to_bits())?
}
}
}
Ok(())
}
Expand Down
133 changes: 133 additions & 0 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Deref;
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};
use rayon::prelude::*;

Expand All @@ -29,6 +30,7 @@ pub enum CpuStorage {
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
F8E4M3(Vec<F8E4M3>),
}

#[derive(Debug, Clone)]
Expand All @@ -42,6 +44,7 @@ pub enum CpuStorageRef<'a> {
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
F8E4M3(&'a [F8E4M3]),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -2366,6 +2369,17 @@ impl CpuStorage {
.concat();
Self::F64(storages)
}
Self::F8E4M3(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F8E4M3(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F8E4M3(storages)
}
};
Ok(s)
}
Expand All @@ -2385,6 +2399,7 @@ impl BackendStorage for CpuStorage {
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
Self::F8E4M3(_) => DType::F8E4M3,
}
}

Expand Down Expand Up @@ -2427,6 +2442,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::F8E4M3(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::U8(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -2463,6 +2482,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, f16::from_f64);
Ok(Self::F16(data))
}
(Self::F8E4M3(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::U8(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -2499,6 +2522,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F8E4M3(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::U8(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U8(data))
Expand Down Expand Up @@ -2535,6 +2562,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::F8E4M3(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand Down Expand Up @@ -2571,6 +2602,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F8E4M3(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I16) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2607,6 +2642,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i16);
Ok(Self::I16(data))
}
(Self::F8E4M3(storage), DType::I16) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i16);
Ok(Self::I16(data))
}
(Self::U8(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2643,6 +2682,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::F8E4M3(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i32);
Ok(Self::I32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2679,6 +2722,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F8E4M3(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -2715,6 +2762,50 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
(Self::F8E4M3(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::U8(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::U32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::BF16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f32);
Ok(Self::F8E4M3(data))
}
(Self::F64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f64);
Ok(Self::F8E4M3(data))
}
(Self::F8E4M3(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F8E4M3(data))
}
}
}

Expand Down Expand Up @@ -2828,6 +2919,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
Expand Down Expand Up @@ -2855,6 +2950,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
Expand Down Expand Up @@ -2901,6 +3000,15 @@ impl BackendStorage for CpuStorage {
Ok(Self::F64(data))
}
}
Self::F8E4M3(storage) => {
if B::F8E4M3_VEC {
let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec);
Ok(Self::F8E4M3(data))
} else {
let data = unary_map(storage, layout, B::f8e4m3);
Ok(Self::F8E4M3(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
Expand Down Expand Up @@ -3455,6 +3563,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max));
for _i in 0..elem_count {
data.push(rng.sample::<F8E4M3, _>(uniform))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
Expand Down Expand Up @@ -3501,6 +3618,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal =
Expand Down Expand Up @@ -3574,6 +3700,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::F64(v)
}
DType::F8E4M3 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F8E4M3(v)
}
};
Ok(storage)
}
Expand All @@ -3588,6 +3719,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Expand All @@ -3604,6 +3736,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
};
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub trait Map1 {
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
}
}
}
Expand All @@ -35,6 +36,7 @@ pub trait Map1Any {
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
}
}
}
Expand All @@ -52,6 +54,7 @@ pub trait Map2 {
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down Expand Up @@ -96,6 +99,7 @@ pub trait Map3 {
(C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
(C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?),
_ => Err(Error::DTypeMismatchBinaryOp3 {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down Expand Up @@ -129,6 +133,7 @@ pub trait Map2Alpha {
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2, s)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand All @@ -152,6 +157,7 @@ pub trait Map2U8 {
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down
Loading

0 comments on commit fa4902f

Please sign in to comment.