Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the i16 dtype #24

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I16 => {
for v in vs.to_vec1::<i16>()? {
f.write_i16::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand Down
124 changes: 121 additions & 3 deletions candle-core/src/cpu_backend/mod.rs

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub trait Map1 {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)),
C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
Expand All @@ -27,6 +28,7 @@ pub trait Map1Any {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I16(vs) => Ok(self.f(vs, layout, C::I16)?),
C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
Expand Down
32 changes: 30 additions & 2 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i16", kernels::FILL)?;
let params = (&data, v as i16, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i32>(elem_count) }.w()?;
Expand Down Expand Up @@ -207,6 +215,10 @@ impl BackendDevice for CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc_zeros::<i16>(elem_count).w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc_zeros::<i32>(elem_count).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -244,7 +256,7 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => {
DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::I16 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
Expand Down Expand Up @@ -288,7 +300,7 @@ impl BackendDevice for CudaDevice {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => {
DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
Expand Down Expand Up @@ -330,6 +342,10 @@ impl BackendDevice for CudaDevice {
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I16 => {
let data = self.alloc::<i16>(elem_count).w()?;
CudaStorageSlice::I16(data)
}
DType::I32 => {
let data = self.alloc::<i32>(elem_count).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -371,6 +387,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorageRef::I32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -412,6 +432,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down Expand Up @@ -453,6 +477,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I16(data)
}
CpuStorage::I32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I32(data)
Expand Down
55 changes: 50 additions & 5 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl SlicePtrOrNull<usize> {
pub enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
I16(CudaSlice<i16>),
I32(CudaSlice<i32>),
I64(CudaSlice<i64>),
BF16(CudaSlice<bf16>),
Expand Down Expand Up @@ -364,14 +365,17 @@ impl<'a> Map1 for IndexSelect<'a> {
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I16(slice) => {
("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I32(slice) => {
("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8/u32/i32/i64",
msg: "index_select ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: self.0.dtype(),
})
Expand Down Expand Up @@ -431,14 +435,17 @@ impl<'a> Map1 for Gather<'a> {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => {
("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::I32(slice) => {
("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::I64(slice) => {
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8/u32/i32/i64",
msg: "gather ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -484,11 +491,12 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i32/i64",
msg: "index-add ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -533,11 +541,12 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i32/i64",
msg: "scatter-add ids should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
Expand Down Expand Up @@ -876,6 +885,10 @@ impl<'a> Map2 for WhereCond<'a> {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
CudaStorageSlice::I16(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i16")
}
CudaStorageSlice::I32(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_i32")
Expand All @@ -885,7 +898,7 @@ impl<'a> Map2 for WhereCond<'a> {
(ptr, "where_i64")
}
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u8/u32/i64",
msg: "where conditions should be u8/u32/i16/i32/i64",
expected: DType::U32,
got: self.0.dtype(),
})
Expand Down Expand Up @@ -1039,6 +1052,7 @@ macro_rules! cuda_dtype {
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(i16, I16);
cuda_dtype!(i32, I32);
cuda_dtype!(i64, I64);
cuda_dtype!(f16, F16);
Expand Down Expand Up @@ -1162,6 +1176,7 @@ impl BackendStorage for CudaStorage {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::I16(_) => DType::I16,
CudaStorageSlice::I32(_) => DType::I32,
CudaStorageSlice::I64(_) => DType::I64,
CudaStorageSlice::BF16(_) => DType::BF16,
Expand Down Expand Up @@ -1189,6 +1204,7 @@ impl BackendStorage for CudaStorage {
let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
Expand All @@ -1213,6 +1229,12 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
DType::I16 => {
let out = unsafe { dev.alloc::<i16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I16(out)
}
DType::I32 => {
let out = unsafe { dev.alloc::<i32>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
Expand Down Expand Up @@ -1315,6 +1337,11 @@ impl BackendStorage for CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::I16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::I16(cpu_storage))
}
CudaStorageSlice::I32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Expand Down Expand Up @@ -1587,6 +1614,7 @@ impl BackendStorage for CudaStorage {
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
(S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?,
(S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
Expand Down Expand Up @@ -1854,6 +1882,11 @@ impl BackendStorage for CudaStorage {
*d.slice(dst_o..).device_ptr(),
"copy2d_u32",
),
(S::I16(s), S::I16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_i16",
),
(S::I32(s), S::I32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
Expand Down Expand Up @@ -1965,6 +1998,18 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cuda_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub trait Map1 {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I16(s) => S::I16(self.f(s, d, l)?),
S::I32(s) => S::I32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
Expand Down Expand Up @@ -137,6 +138,7 @@ pub trait Map1Any {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I16(s) => self.f(s, d, l, S::I16)?,
S::I32(s) => self.f(s, d, l, S::I32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I16 => self.fmt_dt::<i16>(f),
DType::I32 => self.fmt_dt::<i32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
Expand Down Expand Up @@ -464,6 +465,12 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I16 => {
let tf: IntFormatter<i16> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I32 => {
let tf: IntFormatter<i32> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
Expand Down
Loading
Loading