From 9e31a192642b4048e3df75173efddebaf663fef2 Mon Sep 17 00:00:00 2001 From: Rodrigo Date: Sat, 14 Sep 2024 21:08:54 -0300 Subject: [PATCH] Add the i16 dtype (2) (#26) * Add the i16 dtype * Added I16 and I32 to fix the missing arms issue (candle-onnx/eval) * Update rust-ci.yml * Update ci_cuda.yaml * fmt adjustment * Revert "Update rust-ci.yml" This reverts commit f659d36aed9e6e7ab7377c408a3859b8c8b94908. * Revert "Update ci_cuda.yaml" This reverts commit 62a4b3977e24bc7ac60195a6ae4363df36127125. --- candle-core/src/convert.rs | 5 + candle-core/src/cpu/kernels.rs | 11 +++ candle-core/src/cpu_backend/mod.rs | 124 +++++++++++++++++++++++- candle-core/src/cpu_backend/utils.rs | 2 + candle-core/src/cuda_backend/device.rs | 64 +++++++++--- candle-core/src/cuda_backend/mod.rs | 55 ++++++++++- candle-core/src/cuda_backend/utils.rs | 2 + candle-core/src/display.rs | 7 ++ candle-core/src/dtype.rs | 19 +++- candle-core/src/metal_backend/mod.rs | 69 +++++++++++++ candle-core/src/npy.rs | 6 ++ candle-core/src/op.rs | 56 +++++++++++ candle-core/src/safetensors.rs | 4 + candle-core/src/sort.rs | 1 + candle-core/tests/tensor_tests.rs | 6 +- candle-kernels/src/affine.cu | 1 + candle-kernels/src/binary.cu | 12 +++ candle-kernels/src/cast.cu | 14 +++ candle-kernels/src/cuda_utils.cuh | 2 + candle-kernels/src/fill.cu | 2 + candle-kernels/src/indexing.cu | 52 ++++++++++ candle-kernels/src/reduce.cu | 1 + candle-kernels/src/sort.cu | 1 + candle-kernels/src/ternary.cu | 12 +++ candle-kernels/src/unary.cu | 1 + candle-metal-kernels/src/binary.metal | 2 + candle-metal-kernels/src/cast.metal | 18 ++++ candle-metal-kernels/src/indexing.metal | 26 ++++- candle-metal-kernels/src/lib.rs | 7 ++ candle-metal-kernels/src/reduce.metal | 6 ++ candle-metal-kernels/src/sort.metal | 1 + candle-metal-kernels/src/ternary.metal | 15 +++ candle-metal-kernels/src/unary.metal | 3 + candle-onnx/src/eval.rs | 4 +- candle-pyo3/src/lib.rs | 2 + 35 files changed, 586 insertions(+), 27 deletions(-) diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index b29ff346f6..3e19d970c3 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -130,6 +130,11 @@ impl Tensor { f.write_u32::(v)? } } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } DType::I32 => { for v in vs.to_vec1::()? { f.write_i32::(v)? diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index fd6da1f1ff..f81ad625d3 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -151,6 +151,17 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i32 { #[inline(always)] fn min(self, other: Self) -> Self { diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 54d2da7d12..24ce83581c 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -22,6 +22,7 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I16(Vec), I32(Vec), I64(Vec), BF16(Vec), @@ -34,6 +35,7 @@ pub enum CpuStorage { pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I16(&'a [i16]), I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), @@ -2287,6 +2289,17 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } Self::I32(_) => { let storages = storages .iter() @@ -2365,6 +2378,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, @@ -2385,6 +2399,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::I32(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) @@ -2417,6 +2435,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::I32(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -2449,6 +2471,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::I32(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -2497,6 +2523,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::I32(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) @@ -2513,6 +2543,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::I32(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -2537,6 +2571,42 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } (Self::U8(storage), DType::I32) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2545,6 +2615,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } (Self::I32(storage), DType::I32) => { let data = unary_map(storage, layout, |v| v); Ok(Self::I32(data)) @@ -2577,6 +2651,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::I32(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -2609,6 +2687,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::I32(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -2748,6 +2830,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } @@ -2774,7 +2857,8 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), - Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2825,6 +2909,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } Self::I32(storage) => { let data = unary_map(storage, layout, B::i32); Ok(Self::I32(data)) @@ -2883,6 +2971,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = if B::I16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i16, B::i16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i16) + }; + Ok(Self::I16(data)) + } (Self::I32(lhs), Self::I32(rhs)) => { let data = if B::I32_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) @@ -2934,6 +3030,9 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I32(src), Self::I32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2968,6 +3067,7 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2998,6 +3098,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), @@ -3169,6 +3270,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), @@ -3179,6 +3281,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), @@ -3197,6 +3300,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I16(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), @@ -3227,6 +3331,13 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I32(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -3323,7 +3434,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { @@ -3369,7 +3480,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::BF16 => { @@ -3428,6 +3539,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } DType::I32 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -3467,6 +3583,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![1i16; elem_count]), DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), @@ -3482,6 +3599,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 297ccd3de6..20f362e8c4 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,6 +10,7 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), @@ -27,6 +28,7 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 9dd4477639..ccca8c039c 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -80,6 +80,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i16", kernels::FILL)?; + let params = (&data, v as i16, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; @@ -207,6 +215,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::I32(data) @@ -244,13 +256,17 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; @@ -288,13 +304,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand @@ -330,6 +350,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I16(data) + } DType::I32 => { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::I32(data) @@ -371,6 +395,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorageRef::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorageRef::I32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I32(data) @@ -412,6 +440,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorage::I32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I32(data) @@ -453,6 +485,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I16(data) + } CpuStorage::I32(storage) => { let data = self.htod_copy(storage).w()?; CudaStorageSlice::I32(data) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 231e24715c..1a394d4b58 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -47,6 +47,7 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I16(CudaSlice), I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), @@ -364,6 +365,9 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I16(slice) => { + ("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr()) + } CudaStorageSlice::I32(slice) => { ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) } @@ -371,7 +375,7 @@ impl<'a> Map1 for IndexSelect<'a> { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8/u32/i32/i64", + msg: "index_select ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -431,6 +435,9 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => { + ("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr()) + } CudaStorageSlice::I32(slice) => { ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) } @@ -438,7 +445,7 @@ impl<'a> Map1 for Gather<'a> { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8/u32/i32/i64", + msg: "gather ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -484,11 +491,12 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i32/i64", + msg: "index-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -533,11 +541,12 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8/u32/i32/i64", + msg: "scatter-add ids should be u8/u32/i16/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -876,6 +885,10 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I16(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i16") + } CudaStorageSlice::I32(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i32") @@ -885,7 +898,7 @@ impl<'a> Map2 for WhereCond<'a> { (ptr, "where_i64") } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u8/u32/i64", + msg: "where conditions should be u8/u32/i16/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -1039,6 +1052,7 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); @@ -1162,6 +1176,7 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, @@ -1189,6 +1204,7 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), @@ -1213,6 +1229,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(out) + } DType::I32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1315,6 +1337,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } CudaStorageSlice::I32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1587,6 +1614,7 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, @@ -1854,6 +1882,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), + (S::I16(s), S::I16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i16", + ), (S::I32(s), S::I32(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), @@ -1965,6 +1998,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index ae009b26ab..df06756d78 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,6 +19,7 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), @@ -137,6 +138,7 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 5fb370b696..50e0129aeb 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), @@ -464,6 +465,12 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I32 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index c6a0800b24..42d3b1eef9 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,8 @@ pub enum DType { U8, // Unsigned 32 bits integer. U32, + // Signed 16 bits integer. + I16, // Signed 32 bits integer. I32, // Signed 64 bits integer. @@ -41,6 +43,7 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), @@ -58,6 +61,7 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I16 => "i16", Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", @@ -72,6 +76,7 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I16 => 2, Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, @@ -83,14 +88,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I32 | Self::I64 => true, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I32 | Self::I64 => false, + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, } } @@ -174,6 +179,7 @@ use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); @@ -186,6 +192,15 @@ pub trait IntDType: WithDType { fn as_usize(&self) -> usize; } +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for i32 { fn is_true(&self) -> bool { *self != 0 diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 19bac09e15..bf641eb86a 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -96,6 +96,7 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), @@ -305,6 +306,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I16) => ("fast_sum_i16_strided", false, false), + (ReduceOp::Min, DType::I16) => ("fast_min_i16_strided", true, false), + (ReduceOp::Max, DType::I16) => ("fast_max_i16_strided", true, false), + (ReduceOp::ArgMin, DType::I16) => ("fast_argmin_i16_strided", true, true), + (ReduceOp::ArgMax, DType::I16) => ("fast_argmax_i16_strided", true, true), (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), @@ -369,6 +375,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I16) => "cast_u32_i16", (DType::U32, DType::I32) => "cast_u32_i32", (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::U8) => "cast_u32_u8", @@ -376,17 +383,25 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I16) => "cast_u8_i16", (DType::U8, DType::I32) => "cast_u8_i32", (DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I16) => "cast_f32_i16", (DType::F32, DType::I32) => "cast_f32_i32", (DType::F32, DType::I64) => "cast_f32_i64", (DType::F32, DType::U32) => "cast_f32_u32", (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I16, DType::BF16) => "cast_i16_bf16", + (DType::I16, DType::F16) => "cast_i16_f16", + (DType::I16, DType::F32) => "cast_i16_f32", + (DType::I16, DType::U32) => "cast_i16_u32", + (DType::I16, DType::U8) => "cast_i16_u8", + (DType::I32, DType::BF16) => "cast_i32_bf16", (DType::I32, DType::F16) => "cast_i32_f16", (DType::I32, DType::F32) => "cast_i32_f32", @@ -401,6 +416,7 @@ impl BackendStorage for MetalStorage { (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I16) => "cast_f16_i16", (DType::F16, DType::I32) => "cast_f16_i32", (DType::F16, DType::I64) => "cast_f16_i64", (DType::F16, DType::U32) => "cast_f16_u32", @@ -408,6 +424,7 @@ impl BackendStorage for MetalStorage { (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I16) => "cast_bf16_i16", (DType::BF16, DType::I32) => "cast_bf16_i32", (DType::BF16, DType::I64) => "cast_bf16_i64", (DType::BF16, DType::U32) => "cast_bf16_u32", @@ -431,14 +448,17 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I16) => "cast_u32_i16_strided", (DType::U32, DType::I32) => "cast_u32_i32_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I16) => "cast_u8_i16_strided", (DType::U8, DType::I32) => "cast_u8_i32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::I16, DType::F32) => "cast_i16_f32_strided", (DType::I32, DType::F32) => "cast_i32_f32_strided", (DType::I64, DType::F32) => "cast_i64_f32_strided", (DType::F32, DType::BF16) => "cast_f32_bf16_strided", @@ -534,6 +554,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous_tiled::sign::HALF, ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I16) => contiguous_tiled::sign::I16, ("usign", DType::I32) => contiguous_tiled::sign::I32, ("usign", DType::I64) => contiguous_tiled::sign::I64, (name, dtype) => { @@ -613,6 +634,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous::sign::HALF, ("usign", DType::F32) => contiguous::sign::FLOAT, ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I16) => contiguous::sign::I16, ("usign", DType::I32) => contiguous::sign::I32, ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { @@ -745,6 +767,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I16) => "where_u8_i16", (DType::U8, DType::I32) => "where_u8_i32", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", @@ -1283,6 +1306,9 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I16, DType::F32) => "sa_i16_f32", + (DType::I16, DType::F16) => "sa_i16_f16", + (DType::I16, DType::BF16) => "sa_i16_bf16", (DType::I32, DType::F32) => "sa_i32_f32", (DType::I32, DType::F16) => "sa_i32_f16", (DType::I32, DType::BF16) => "sa_i32_bf16", @@ -1334,6 +1360,10 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I16, DType::F32) => "is_i16_f32", + (DType::I16, DType::F16) => "is_i16_f16", + (DType::I16, DType::BF16) => "is_i16_bf16", + (DType::I32, DType::F32) => "is_i32_f32", (DType::I32, DType::F16) => "is_i32_f16", (DType::I32, DType::BF16) => "is_i32_bf16", @@ -1383,6 +1413,14 @@ impl BackendStorage for MetalStorage { return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I16, DType::BF16) => "ia_i16_bf16", + (DType::I16, DType::F16) => "ia_i16_f16", + (DType::I16, DType::F32) => "ia_i16_f32", + (DType::I16, DType::I32) => "ia_i16_i32", + (DType::I16, DType::I64) => "ia_i16_i64", + (DType::I16, DType::U32) => "ia_i16_u32", + (DType::I16, DType::U8) => "ia_i16_u8", + (DType::I32, DType::BF16) => "ia_i32_bf16", (DType::I32, DType::F16) => "ia_i32_f16", (DType::I32, DType::F32) => "ia_i32_f32", @@ -1394,6 +1432,7 @@ impl BackendStorage for MetalStorage { (DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::F16) => "ia_i64_f16", (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I16) => "ia_i64_i16", (DType::I64, DType::I32) => "ia_i64_i32", (DType::I64, DType::I64) => "ia_i64_i64", (DType::I64, DType::U32) => "ia_i64_u32", @@ -1402,6 +1441,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "ia_u32_bf16", (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I16) => "ia_u32_i16", (DType::U32, DType::I32) => "ia_u32_i32", (DType::U32, DType::I64) => "ia_u32_i64", (DType::U32, DType::U32) => "ia_u32_u32", @@ -1410,6 +1450,7 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "ia_u8_bf16", (DType::U8, DType::F16) => "ia_u8_f16", (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I16) => "ia_u8_i16", (DType::U8, DType::I32) => "ia_u8_i32", (DType::U8, DType::I64) => "ia_u8_i64", (DType::U8, DType::U32) => "ia_u8_u32", @@ -1577,6 +1618,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::copy2d::FLOAT, DType::F16 => candle_metal_kernels::copy2d::HALF, DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I16 => candle_metal_kernels::copy2d::I16, DType::I32 => candle_metal_kernels::copy2d::I32, DType::I64 => candle_metal_kernels::copy2d::I64, DType::U32 => candle_metal_kernels::copy2d::U32, @@ -1624,6 +1666,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I16 => candle_metal_kernels::unary::strided::copy::I16, DType::I32 => candle_metal_kernels::unary::strided::copy::I32, DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, @@ -1716,6 +1759,17 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I16) => (contiguous::add::I16, self.dtype), + ("sub", DType::I16) => (contiguous::sub::I16, self.dtype), + ("mul", DType::I16) => (contiguous::mul::I16, self.dtype), + ("div", DType::I16) => (contiguous::div::I16, self.dtype), + ("eq", DType::I16) => (contiguous::eq::I16, DType::U8), + ("ne", DType::I16) => (contiguous::ne::I16, DType::U8), + ("le", DType::I16) => (contiguous::le::I16, DType::U8), + ("lt", DType::I16) => (contiguous::lt::I16, DType::U8), + ("ge", DType::I16) => (contiguous::ge::I16, DType::U8), + ("gt", DType::I16) => (contiguous::gt::I16, DType::U8), + ("add", DType::I32) => (contiguous::add::I32, self.dtype), ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), @@ -1820,6 +1874,19 @@ impl MetalStorage { ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I16) => (strided::add::I16, self.dtype), + ("bsub", DType::I16) => (strided::sub::I16, self.dtype), + ("bmul", DType::I16) => (strided::mul::I16, self.dtype), + ("bdiv", DType::I16) => (strided::div::I16, self.dtype), + ("bminimum", DType::I16) => (strided::min::I16, self.dtype), + ("bmaximum", DType::I16) => (strided::max::I16, self.dtype), + ("eq", DType::I16) => (strided::eq::I16, DType::U8), + ("ne", DType::I16) => (strided::ne::I16, DType::U8), + ("le", DType::I16) => (strided::le::I16, DType::U8), + ("lt", DType::I16) => (strided::lt::I16, DType::U8), + ("ge", DType::I16) => (strided::ge::I16, DType::U8), + ("gt", DType::I16) => (strided::gt::I16, DType::U8), + ("badd", DType::I32) => (strided::add::I32, self.dtype), ("bsub", DType::I32) => (strided::sub::I32, self.dtype), ("bmul", DType::I32) => (strided::mul::I32, self.dtype), @@ -1989,6 +2056,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -2003,6 +2071,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index b321a619f8..33a4f4c728 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,6 +85,7 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I16 => "i2", DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", @@ -235,6 +236,11 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I32 => { let mut data_t = vec![0i32; elem_count]; reader.read_i32_into::(&mut data_t)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 75931ee2fe..3786a82aaf 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -189,6 +189,7 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; @@ -214,6 +215,7 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; @@ -233,6 +235,8 @@ pub trait BinaryOpT { fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} const I32_VEC: bool = false; fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} + const I16_VEC: bool = false; + fn i16_vec(_xs1: &[i16], _xs2: &[i16], _ys: &mut [i16]) {} } pub(crate) struct Add; @@ -292,6 +296,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] fn i32(v1: i32, v2: i32) -> i32 { $e(v1, v2) } @@ -391,6 +399,10 @@ macro_rules! unary_op { fn i32(_: i32) -> i32 { todo!("no unary function for i32") } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } } }; @@ -431,6 +443,10 @@ macro_rules! unary_op { fn i32(_: i32) -> i32 { todo!("no unary function for i32") } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -534,6 +550,10 @@ impl UnaryOpT for Gelu { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -611,6 +631,10 @@ impl UnaryOpT for Erf { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } /// Silu operation @@ -649,6 +673,10 @@ impl UnaryOpT for Silu { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -724,6 +752,10 @@ impl UnaryOpT for Abs { fn i32(v: i32) -> i32 { v.abs() } + #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -762,6 +794,10 @@ impl UnaryOpT for Ceil { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Floor { @@ -800,6 +836,10 @@ impl UnaryOpT for Floor { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for Round { @@ -838,6 +878,10 @@ impl UnaryOpT for Round { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } impl UnaryOpT for GeluErf { @@ -876,6 +920,10 @@ impl UnaryOpT for GeluErf { fn i32(_: i32) -> i32 { 0 } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } } impl UnaryOpT for Relu { @@ -914,6 +962,10 @@ impl UnaryOpT for Relu { fn i32(v: i32) -> i32 { v } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are @@ -1016,4 +1068,8 @@ impl UnaryOpT for Sign { fn i32(v: i32) -> i32 { (v > 0) as i32 - (v < 0) as i32 } + #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 162928ec7d..12436a0903 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -11,6 +11,7 @@ impl From for st::Dtype { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, + DType::I16 => st::Dtype::I16, DType::I32 => st::Dtype::I32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, @@ -188,6 +189,7 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), @@ -206,6 +208,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), + st::Dtype::I16 => convert_::(view, device), st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), @@ -222,6 +225,7 @@ fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 9d9fd59634..14e3417138 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -65,6 +65,7 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index bff8f36042..ede9f3c708 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -17,6 +17,10 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], ); + assert_eq!( + Tensor::ones((2, 3), DType::I16, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); assert_eq!( Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], @@ -809,7 +813,7 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - for dtype in [DType::U8, DType::U32, DType::I32, DType::I64] { + for dtype in [DType::U8, DType::U32, DType::I16, DType::I32, DType::I64] { let ids = ids.to_dtype(dtype)?; let hs = t.index_select(&ids, 1)?; assert_eq!( diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index c3ff5b8753..301bcd5a64 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -40,5 +40,6 @@ AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) +AFFINE_OP(int16_t, affine_i16) AFFINE_OP(int32_t, affine_i32) AFFINE_OP(int64_t, affine_i64) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index f534fc76ad..99ab23b875 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,36 +35,42 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int16_t, badd_i16, x + y); BINARY_OP(int32_t, badd_i32, x + y); BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int16_t, bdiv_i16, x / y); BINARY_OP(int32_t, bdiv_i32, x / y); BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int16_t, bmul_i16, x * y); BINARY_OP(int32_t, bmul_i32, x * y); BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int16_t, bsub_i16, x - y); BINARY_OP(int32_t, bsub_i32, x - y); BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int16_t, bminimum_i16, ming(x, y)); BINARY_OP(int32_t, bminimum_i32, ming(x, y)); BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int16_t, bmaximum_i16, maxg(x, y)); BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); @@ -72,6 +78,7 @@ BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int16_t, uint8_t, eq_i16, x == y) BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) @@ -79,6 +86,7 @@ BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int16_t, uint8_t, ne_i16, x != y) BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) @@ -86,6 +94,7 @@ BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int16_t, uint8_t, lt_i16, x < y) BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) @@ -93,6 +102,7 @@ BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int16_t, uint8_t, le_i16, x <= y) BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) @@ -100,6 +110,7 @@ BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int16_t, uint8_t, gt_i16, x > y) BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) @@ -107,5 +118,6 @@ BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int16_t, uint8_t, ge_i16, x >= y) BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index f92ac0cbf9..e288bf1812 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -120,11 +120,13 @@ CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, int64_t, cast_u32_i64 ) CAST_OP(uint32_t, int32_t, cast_u32_i32 ) +CAST_OP(uint32_t, int16_t, cast_u32_i16 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int16_t, cast_u8_i16 ) CAST_OP(uint8_t, int32_t, cast_u8_i32 ) CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) @@ -132,6 +134,7 @@ CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(int64_t, uint32_t, cast_i64_u32) CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int16_t, cast_i64_i16 ) CAST_OP(int64_t, int32_t, cast_i64_i32 ) CAST_OP(int64_t, int64_t, cast_i64_i64 ) CAST_OP(int64_t, float, cast_i64_f32) @@ -141,11 +144,21 @@ CAST_OP(int32_t, uint32_t, cast_i32_u32) CAST_OP(int32_t, uint8_t, cast_i32_u8 ) CAST_OP(int32_t, int64_t, cast_i32_i64 ) CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, int16_t, cast_i32_i16 ) CAST_OP(int32_t, float, cast_i32_f32) CAST_OP(int32_t, double, cast_i32_f64) +CAST_OP(int16_t, uint32_t, cast_i16_u32) +CAST_OP(int16_t, uint8_t, cast_i16_u8 ) +CAST_OP(int16_t, int64_t, cast_i16_i64 ) +CAST_OP(int16_t, int32_t, cast_i16_i32 ) +CAST_OP(int16_t, int16_t, cast_i16_i16 ) +CAST_OP(int16_t, float, cast_i16_f32) +CAST_OP(int16_t, double, cast_i16_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int16_t, cast_f32_i16 ) CAST_OP(float, int32_t, cast_f32_i32 ) CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) @@ -153,6 +166,7 @@ CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int16_t, cast_f64_i16 ) CAST_OP(double, int32_t, cast_f64_i32 ) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 08aa2b089a..df1497f672 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -181,6 +181,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int16_t ming(int16_t a, int16_t b) { return min(a, b); } +__device__ __forceinline__ int16_t maxg(int16_t a, int16_t b) { return max(a, b); } __device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } __device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 42bfddfd9f..0654c2631b 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -9,6 +9,7 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { } extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i16(int16_t *buf, int16_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } @@ -35,6 +36,7 @@ COPY2D_OP(float, copy2d_f32) COPY2D_OP(double, copy2d_f64) COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int16_t, copy2d_i16) COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 2f3df4de1b..df0e3a071d 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -147,18 +147,22 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int16_t, is_i16_bf16) IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int16_t, gather_i16_bf16) GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int16_t, ia_i16_bf16) IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int16_t, sa_i16_bf16) SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) @@ -166,28 +170,41 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int16_t, is_i16_f16) IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int16_t, gather_i16_f16) GATHER_OP(__half, int32_t, gather_i32_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int16_t, ia_i16_f16) IA_OP(__half, int32_t, ia_i32_f16) IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int16_t, sa_i16_f16) SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int16_t, is_i16_f32) +IS_OP(double, int16_t, is_i16_f64) +IS_OP(uint8_t, int16_t, is_i16_u8) +IS_OP(uint32_t, int16_t, is_i16_u32) +IS_OP(int16_t, int16_t, is_i16_i16) +IS_OP(int32_t, int16_t, is_i16_i32) +IS_OP(int64_t, int16_t, is_i16_i64) + IS_OP(float, int32_t, is_i32_f32) IS_OP(double, int32_t, is_i32_f64) IS_OP(uint8_t, int32_t, is_i32_u8) IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int16_t, int32_t, is_i32_i16) IS_OP(int32_t, int32_t, is_i32_i32) IS_OP(int64_t, int32_t, is_i32_i64) @@ -197,10 +214,12 @@ IS_OP(uint8_t, int64_t, is_i64_u8) IS_OP(uint32_t, int64_t, is_i64_u32) IS_OP(int64_t, int64_t, is_i64_i64) IS_OP(int32_t, int64_t, is_i64_i32) +IS_OP(int16_t, int64_t, is_i64_i16) IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int16_t, uint32_t, is_u32_i16) IS_OP(int32_t, uint32_t, is_u32_i32) IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) @@ -209,13 +228,23 @@ IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int16_t, uint8_t, is_u8_i16) IS_OP(int32_t, uint8_t, is_u8_i32) IS_OP(int64_t, uint8_t, is_u8_i64) +GATHER_OP(float, int16_t, gather_i16_f32) +GATHER_OP(double, int16_t, gather_i16_f64) +GATHER_OP(uint8_t, int16_t, gather_i16_u8) +GATHER_OP(uint32_t, int16_t, gather_i16_u32) +GATHER_OP(int16_t, int16_t, gather_i16_i16) +GATHER_OP(int32_t, int16_t, gather_i16_i32) +GATHER_OP(int64_t, int16_t, gather_i16_i64) + GATHER_OP(float, int32_t, gather_i32_f32) GATHER_OP(double, int32_t, gather_i32_f64) GATHER_OP(uint8_t, int32_t, gather_i32_u8) GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int16_t, int32_t, gather_i32_i16) GATHER_OP(int32_t, int32_t, gather_i32_i32) GATHER_OP(int64_t, int32_t, gather_i32_i64) @@ -225,10 +254,12 @@ GATHER_OP(uint8_t, int64_t, gather_i64_u8) GATHER_OP(uint32_t, int64_t, gather_i64_u32) GATHER_OP(int64_t, int64_t, gather_i64_i64) GATHER_OP(int32_t, int64_t, gather_i64_i32) +GATHER_OP(int16_t, int64_t, gather_i64_i16) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int16_t, uint32_t, gather_u32_i16) GATHER_OP(int32_t, uint32_t, gather_u32_i32) GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) @@ -237,9 +268,16 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int16_t, uint8_t, gather_u8_i16) GATHER_OP(int32_t, uint8_t, gather_u8_i32) GATHER_OP(int64_t, uint8_t, gather_u8_i64) +IA_OP(float, int16_t, ia_i16_f32) +IA_OP(double, int16_t, ia_i16_f64) +IA_OP(uint8_t, int16_t, ia_i16_u8) +IA_OP(int16_t, int16_t, ia_i16_i16) +IA_OP(uint16_t, int16_t, ia_i16_u16) + IA_OP(float, int32_t, ia_i32_f32) IA_OP(double, int32_t, ia_i32_f64) IA_OP(uint8_t, int32_t, ia_i32_u8) @@ -252,10 +290,12 @@ IA_OP(uint8_t, int64_t, ia_i64_u8) IA_OP(int64_t, int64_t, ia_i64_i64) IA_OP(uint32_t, int64_t, ia_i64_u32) IA_OP(int32_t, int64_t, ia_i64_i32) +IA_OP(int16_t, int64_t, ia_i64_i16) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int16_t, uint32_t, ia_u32_i16) IA_OP(int32_t, uint32_t, ia_u32_i32) IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) @@ -264,18 +304,28 @@ IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int16_t, uint8_t, ia_u8_i16) IA_OP(int32_t, uint8_t, ia_u8_i32) IA_OP(int64_t, uint8_t, ia_u8_i64) +SA_OP(float, int16_t, sa_i16_f32) +SA_OP(double, int16_t, sa_i16_f64) +SA_OP(uint8_t, int16_t, sa_i16_u8) +SA_OP(int16_t, int16_t, sa_i16_i16) +SA_OP(int32_t, int16_t, sa_i16_i32) +SA_OP(uint32_t, int16_t, sa_i16_u32) + SA_OP(float, int32_t, sa_i32_f32) SA_OP(double, int32_t, sa_i32_f64) SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int16_t, int32_t, sa_i32_i16) SA_OP(int32_t, int32_t, sa_i32_i32) SA_OP(uint32_t, int32_t, sa_i32_u32) SA_OP(float, int64_t, sa_i64_f32) SA_OP(double, int64_t, sa_i64_f64) SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int16_t, int64_t, sa_i64_i16) SA_OP(int32_t, int64_t, sa_i64_i32) SA_OP(int64_t, int64_t, sa_i64_i64) SA_OP(uint32_t, int64_t, sa_i64_u32) @@ -283,6 +333,7 @@ SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int16_t, uint32_t, sa_u32_i16) SA_OP(int32_t, uint32_t, sa_u32_i32) SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) @@ -291,5 +342,6 @@ SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int16_t, uint8_t, sa_u8_i16) SA_OP(int32_t, uint8_t, sa_u8_i32) SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 9a1354a8dc..fe2e30160a 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -606,6 +606,7 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int16_t, fast_min_i16, fast_max_i16, fast_argmin_i16, fast_argmax_i16, fast_sum_i16) FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 7fecf8413e..f2b2e9d458 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -85,5 +85,6 @@ ASORT_OP(float, f32) ASORT_OP(double, f64) ASORT_OP(uint8_t, u8) ASORT_OP(uint32_t, u32) +ASORT_OP(int16_t, i16) ASORT_OP(int32_t, i32) ASORT_OP(int64_t, i64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index 4617c08fbe..18beede021 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,6 +33,7 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int16_t, where_i16_bf16) WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) @@ -40,12 +41,21 @@ WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int16_t, where_i16_f16) WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(double, int16_t, where_i16_f64) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int64_t, int16_t, where_i16_i64) + WHERE_OP(float, int32_t, where_i32_f32) WHERE_OP(double, int32_t, where_i32_f64) WHERE_OP(uint8_t, int32_t, where_i32_u8) @@ -62,6 +72,7 @@ WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int16_t, uint32_t, where_u32_i16) WHERE_OP(int32_t, uint32_t, where_u32_i32) WHERE_OP(int64_t, uint32_t, where_u32_i64) @@ -69,5 +80,6 @@ WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int16_t, uint8_t, where_u8_i16) WHERE_OP(int32_t, uint8_t, where_u8_i32) WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 21d3d995c0..bfd60de0b1 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -153,6 +153,7 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int16_t, ucopy_i16, x) UNARY_OP(int32_t, ucopy_i32, x) UNARY_OP(int64_t, ucopy_i64, x) UNARY_OP(float, ucopy_f32, x) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index a9b8129c3a..4c558c2cdb 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -59,6 +59,7 @@ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, int16_t, NAME##_i16, NAME##_i16_strided); \ BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); #define BINARY_OP_OUT(NAME, FN) \ @@ -66,6 +67,7 @@ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, uint8_t, NAME##_i16, NAME##_i16_strided); \ BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); #define INT64_BINARY_OP(NAME, FN) \ diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index c8122ccf0a..5a8324bf11 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -77,6 +77,7 @@ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) +CAST(cast_u32_i16, cast_u32_i16_strided, uint32_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) #endif @@ -89,6 +90,7 @@ CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) +CAST(cast_u8_i16, cast_u8_i16_strided, uint8_t, int16_t) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) #endif @@ -100,6 +102,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i16, cast_f16_i16_strided, half, int16_t) CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) #if defined(__HAVE_BFLOAT__) @@ -111,6 +114,7 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) +CAST(cast_i64_i16, cast_i64_i16_strided, int64_t, int16_t) CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) @@ -121,15 +125,28 @@ CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_i16, cast_i32_i16_strided, int32_t, int16_t) CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) #endif +// i16 +CAST(cast_i16_f32, cast_i16_f32_strided, int16_t, float) +CAST(cast_i16_u8, cast_i16_u8_strided, int16_t, uint8_t) +CAST(cast_i16_u32, cast_i16_u32_strided, int16_t, uint32_t) +CAST(cast_i16_i32, cast_i16_i32_strided, int16_t, int32_t) +CAST(cast_i16_i64, cast_i16_i64_strided, int16_t, int64_t) +CAST(cast_i16_f16, cast_i16_f16_strided, int16_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i16_bf16, cast_i16_bf16_strided, int16_t, bfloat, float) +#endif + // f32 CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i16, cast_f32_i16_strided, float, int16_t) CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) @@ -139,6 +156,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) // bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i16, cast_bf16_i16_strided, bfloat, int16_t) CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index eaa78d7b73..f01d4795d8 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -199,6 +199,12 @@ INDEX_OP(is_i32_f16, int32_t, half) INDEX_OP(is_i32_bf16, int32_t, bfloat) #endif +INDEX_OP(is_i16_f32, int16_t, float) +INDEX_OP(is_i16_f16, int16_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i16_bf16, int16_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -219,10 +225,12 @@ GATHER_OP(gather_u32_bf16, uint, bfloat) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i16_f32, int16_t, float) SCATTER_ADD_OP(sa_i32_f32, int32_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i16_f16, int16_t, half) SCATTER_ADD_OP(sa_i32_f16, int32_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) @@ -234,6 +242,7 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i16, int64_t, int16_t) INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) @@ -242,7 +251,7 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) #endif -// i64 +// i32 INDEX_ADD_OP(ia_i32_f16, int32_t, half) INDEX_ADD_OP(ia_i32_f32, int32_t, float) INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) @@ -253,9 +262,23 @@ INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) #endif +// i16 +INDEX_ADD_OP(ia_i16_f16, int16_t, half) +INDEX_ADD_OP(ia_i16_f32, int16_t, float) +INDEX_ADD_OP(ia_i16_i16, int16_t, int16_t) +INDEX_ADD_OP(ia_i16_i32, int16_t, int32_t) +INDEX_ADD_OP(ia_i16_i64, int16_t, int64_t) +INDEX_ADD_OP(ia_i16_u32, int16_t, uint32_t) +INDEX_ADD_OP(ia_i16_u8, int16_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i16_bf16, int16_t, bfloat) +#endif + + // u32 INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i16, uint32_t, int16_t) INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) @@ -267,6 +290,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) // u8 INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i16, uint8_t, int16_t) INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d5e5b8eb66..ea1656193f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -47,6 +47,7 @@ pub mod copy2d { pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); pub const I32: Kernel = Kernel("copy2d_i32"); + pub const I16: Kernel = Kernel("copy2d_i16"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } @@ -64,6 +65,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } @@ -75,6 +77,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); pub const I32: Kernel = Kernel("copy_i32"); + pub const I16: Kernel = Kernel("copy_i16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -90,6 +93,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } @@ -101,6 +105,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); pub const I32: Kernel = Kernel("copy_i32_tiled"); + pub const I16: Kernel = Kernel("copy_i16_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } @@ -116,6 +121,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } @@ -127,6 +133,7 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); pub const I32: Kernel = Kernel("copy_i32_strided"); + pub const I16: Kernel = Kernel("copy_i16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 484fa0a1b1..56ef56f7e0 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -608,6 +608,12 @@ REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) +REDUCE(x + y, fast_sum_i16_strided, int16_t, 0) +REDUCE(MIN(x, y), fast_min_i16_strided, int16_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i16_strided, int16_t, INT_MIN) +ARGMIN(fast_argmin_i16_strided, int16_t, INT_MAX) +ARGMAX(fast_argmax_i16_strided, int16_t, INT_MIN) + #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x + y, fast_sum_bf16_strided, half, 0) diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal index b7cf71bb58..9f001d8fb6 100644 --- a/candle-metal-kernels/src/sort.metal +++ b/candle-metal-kernels/src/sort.metal @@ -89,6 +89,7 @@ ARGSORT(half, f16) ARGSORT(uint8_t, u8) ARGSORT(uint32_t, u32) ARGSORT(int32_t, i32) +ARGSORT(int16_t, i16) #if __METAL_VERSION__ >= 220 ARGSORT(int64_t, i64) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0e043332fe..98aacd0036 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -76,6 +76,7 @@ WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(int64_t, int64_t, where_i64_i64) WHERE_OP(int64_t, int32_t, where_i64_i32) +WHERE_OP(int64_t, int16_t, where_i64_i16) #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, int64_t, where_i64_bf16) #endif @@ -94,6 +95,20 @@ WHERE_OP(int32_t, int32_t, where_i32_i32) WHERE_OP(bfloat, int32_t, where_i32_bf16) #endif +WHERE_OP(int64_t, uint8_t, where_u8_i16) +WHERE_OP(int64_t, uint32_t, where_u32_i16) + +WHERE_OP(half, int16_t, where_i16_f16) +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int64_t, int16_t, where_i16_i64) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int16_t, where_i16_bf16) +#endif + #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint32_t, where_u32_bf16) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 0c5a2736ee..a76c311a3a 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -172,6 +172,9 @@ COPY2D(copy2d_i64, int64_t) UNARY(id, int32_t, copy_i32, copy_i32_strided) COPY2D(copy2d_i32, int32_t) +UNARY(id, int16_t, copy_i16, copy_i16_strided) +COPY2D(copy2d_i16, int16_t) + #if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 5b66a743c3..d8fcc77769 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -712,6 +712,8 @@ fn simple_eval_( let output = match start.dtype() { DType::U8 => arange_step!(u8), DType::U32 => arange_step!(u32), + DType::I16 => arange_step!(i16), + DType::I32 => arange_step!(i32), DType::I64 => arange_step!(i64), DType::BF16 => arange_step!(f32), DType::F16 => arange_step!(f32), @@ -1305,7 +1307,7 @@ fn simple_eval_( let input = get(&node.input[0])?; let dt = input.dtype(); match dt { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I64 | DType::I16 | DType::I32 => { bail!( "unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str() diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 55b5542ed8..d2179d577f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -151,6 +151,7 @@ macro_rules! pydtype { }; } +pydtype!(i16, |v| v); pydtype!(i32, |v| v); pydtype!(i64, |v| v); pydtype!(u8, |v| v); @@ -201,6 +202,7 @@ trait MapDType { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), + DType::I16 => self.f::(t), DType::I32 => self.f::(t), DType::I64 => self.f::(t), DType::BF16 => self.f::(t),