From 7ad6494dc04024532dfdacdb2f9a09908c272a6b Mon Sep 17 00:00:00 2001
From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com>
Date: Fri, 9 Aug 2024 13:54:32 -0400
Subject: [PATCH] Mistral.rs GPTQ dev PR (#14)

* Add i32 dtype for cpu and cuda, with kernels

* Fix cuda i32

* Fix cpu i32

* Add cuda map impls for i32

* Start to add to metal

* Add the kernels

* Oops

* Fix dtype cast in safetensors

* Oops

* Oops

* Add bf16 to i32 and vice versa casts
---
 .vscode/settings.json                   |   4 +-
 candle-core/src/convert.rs              |   5 ++
 candle-core/src/cpu/kernels.rs          |  11 +++
 candle-core/src/cpu_backend/mod.rs      | 114 +++++++++++++++++++++++-
 candle-core/src/cpu_backend/utils.rs    |   2 +
 candle-core/src/cuda_backend/device.rs  |  32 ++++++-
 candle-core/src/cuda_backend/mod.rs     |  53 ++++++++++-
 candle-core/src/cuda_backend/utils.rs   |   2 +
 candle-core/src/display.rs              |   7 ++
 candle-core/src/dtype.rs                |  19 +++-
 candle-core/src/metal_backend/mod.rs    |  69 ++++++++++++++
 candle-core/src/npy.rs                  |   6 ++
 candle-core/src/op.rs                   |  56 ++++++++++++
 candle-core/src/safetensors.rs          |   8 +-
 candle-core/src/sort.rs                 |   1 +
 candle-core/tests/tensor_tests.rs       |   6 +-
 candle-kernels/src/affine.cu            |   1 +
 candle-kernels/src/binary.cu            |  12 +++
 candle-kernels/src/cast.cu              |  18 ++++
 candle-kernels/src/cuda_utils.cuh       |   2 +
 candle-kernels/src/fill.cu              |   2 +
 candle-kernels/src/indexing.cu          |  46 ++++++++++
 candle-kernels/src/reduce.cu            |   1 +
 candle-kernels/src/sort.cu              |   1 +
 candle-kernels/src/ternary.cu           |  10 +++
 candle-kernels/src/unary.cu             |   1 +
 candle-metal-kernels/src/binary.metal   |   6 +-
 candle-metal-kernels/src/cast.metal     |  16 ++++
 candle-metal-kernels/src/indexing.metal |  22 +++++
 candle-metal-kernels/src/lib.rs         |   7 ++
 candle-metal-kernels/src/reduce.metal   |   6 ++
 candle-metal-kernels/src/sort.metal     |   1 +
 candle-metal-kernels/src/ternary.metal  |  14 +++
 candle-metal-kernels/src/unary.metal    |   3 +
 candle-pyo3/src/lib.rs                  |   2 +
 35 files changed, 548 insertions(+), 18 deletions(-)

diff --git a/.vscode/settings.json b/.vscode/settings.json
index 646783a968..e510b688c4 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -8,5 +8,7 @@
     ],
     "python.testing.unittestEnabled": false,
     "python.testing.pytestEnabled": true,
-    "rust-analyzer.cargo.features": ["cuda"]
+    "rust-analyzer.cargo.features": [
+        "cuda"
+    ],
 }
\ No newline at end of file
diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs
index 5ea5612a7c..b29ff346f6 100644
--- a/candle-core/src/convert.rs
+++ b/candle-core/src/convert.rs
@@ -130,6 +130,11 @@ impl Tensor {
                     f.write_u32::<LittleEndian>(v)?
                 }
             }
+            DType::I32 => {
+                for v in vs.to_vec1::<i32>()? {
+                    f.write_i32::<LittleEndian>(v)?
+                }
+            }
             DType::I64 => {
                 for v in vs.to_vec1::<i64>()? {
                     f.write_i64::<LittleEndian>(v)?
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 527646d62b..fe0e241622 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -144,6 +144,17 @@ impl VecOps for u32 {
         <Self as Ord>::max(self, other)
     }
 }
+impl VecOps for i32 {
+    #[inline(always)]
+    fn min(self, other: Self) -> Self {
+        <Self as Ord>::min(self, other)
+    }
+
+    #[inline(always)]
+    fn max(self, other: Self) -> Self {
+        <Self as Ord>::max(self, other)
+    }
+}
 impl VecOps for i64 {
     #[inline(always)]
     fn min(self, other: Self) -> Self {
diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs
index 3c4c47305a..2bb8991c0c 100644
--- a/candle-core/src/cpu_backend/mod.rs
+++ b/candle-core/src/cpu_backend/mod.rs
@@ -22,6 +22,7 @@ const USE_IM2COL_CONV2D: bool = true;
 pub enum CpuStorage {
     U8(Vec<u8>),
     U32(Vec<u32>),
+    I32(Vec<i32>),
     I64(Vec<i64>),
     BF16(Vec<bf16>),
     F16(Vec<f16>),
@@ -33,6 +34,7 @@ pub enum CpuStorage {
 pub enum CpuStorageRef<'a> {
     U8(&'a [u8]),
     U32(&'a [u32]),
+    I32(&'a [i32]),
     I64(&'a [i64]),
     BF16(&'a [bf16]),
     F16(&'a [f16]),
@@ -2285,6 +2287,17 @@ impl CpuStorage {
                     .concat();
                 Self::U32(storages)
             }
+            Self::I32(_) => {
+                let storages = storages
+                    .iter()
+                    .map(|s| match s {
+                        Self::I32(s) => Ok(s.as_slice()),
+                        _ => crate::bail!("dtype mismatch"),
+                    })
+                    .collect::<Result<Vec<_>>>()?
+                    .concat();
+                Self::I32(storages)
+            }
             Self::I64(_) => {
                 let storages = storages
                     .iter()
@@ -2352,6 +2365,7 @@ impl BackendStorage for CpuStorage {
         match self {
             Self::U8(_) => DType::U8,
             Self::U32(_) => DType::U32,
+            Self::I32(_) => DType::I32,
             Self::I64(_) => DType::I64,
             Self::BF16(_) => DType::BF16,
             Self::F16(_) => DType::F16,
@@ -2371,6 +2385,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
                 Ok(Self::BF16(data))
             }
+            (Self::I32(storage), DType::BF16) => {
+                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
+                Ok(Self::BF16(data))
+            }
             (Self::I64(storage), DType::BF16) => {
                 let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
                 Ok(Self::BF16(data))
@@ -2399,6 +2417,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
                 Ok(Self::F16(data))
             }
+            (Self::I32(storage), DType::F16) => {
+                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
+                Ok(Self::F16(data))
+            }
             (Self::I64(storage), DType::F16) => {
                 let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
                 Ok(Self::F16(data))
@@ -2427,6 +2449,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v as f32);
                 Ok(Self::F32(data))
             }
+            (Self::I32(storage), DType::F32) => {
+                let data = unary_map(storage, layout, |v| v as f32);
+                Ok(Self::F32(data))
+            }
             (Self::I64(storage), DType::F32) => {
                 let data = unary_map(storage, layout, |v| v as f32);
                 Ok(Self::F32(data))
@@ -2471,6 +2497,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v as u8);
                 Ok(Self::U8(data))
             }
+            (Self::I32(storage), DType::U8) => {
+                let data = unary_map(storage, layout, |v| v as u8);
+                Ok(Self::U8(data))
+            }
             (Self::I64(storage), DType::U8) => {
                 let data = unary_map(storage, layout, |v| v as u8);
                 Ok(Self::U8(data))
@@ -2483,6 +2513,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v);
                 Ok(Self::U32(data))
             }
+            (Self::I32(storage), DType::U32) => {
+                let data = unary_map(storage, layout, |v| v as u32);
+                Ok(Self::U32(data))
+            }
             (Self::I64(storage), DType::U32) => {
                 let data = unary_map(storage, layout, |v| v as u32);
                 Ok(Self::U32(data))
@@ -2503,6 +2537,38 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v as u32);
                 Ok(Self::U32(data))
             }
+            (Self::U8(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v as i64);
+                Ok(Self::I64(data))
+            }
+            (Self::U32(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v as i64);
+                Ok(Self::I64(data))
+            }
+            (Self::I32(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v);
+                Ok(Self::I32(data))
+            }
+            (Self::I64(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v as i32);
+                Ok(Self::I32(data))
+            }
+            (Self::BF16(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
+                Ok(Self::I32(data))
+            }
+            (Self::F16(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
+                Ok(Self::I32(data))
+            }
+            (Self::F32(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v as i32);
+                Ok(Self::I32(data))
+            }
+            (Self::F64(storage), DType::I32) => {
+                let data = unary_map(storage, layout, |v| v as i32);
+                Ok(Self::I32(data))
+            }
             (Self::U8(storage), DType::I64) => {
                 let data = unary_map(storage, layout, |v| v as i64);
                 Ok(Self::I64(data))
@@ -2511,6 +2577,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v as i64);
                 Ok(Self::I64(data))
             }
+            (Self::I32(storage), DType::I64) => {
+                let data = unary_map(storage, layout, |v| v as i64);
+                Ok(Self::I64(data))
+            }
             (Self::I64(storage), DType::I64) => {
                 let data = unary_map(storage, layout, |v| v);
                 Ok(Self::I64(data))
@@ -2539,6 +2609,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, |v| v as f64);
                 Ok(Self::F64(data))
             }
+            (Self::I32(storage), DType::F64) => {
+                let data = unary_map(storage, layout, |v| v as f64);
+                Ok(Self::F64(data))
+            }
             (Self::I64(storage), DType::F64) => {
                 let data = unary_map(storage, layout, |v| v as f64);
                 Ok(Self::F64(data))
@@ -2674,6 +2748,7 @@ impl BackendStorage for CpuStorage {
             }
             Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
             Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
+            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
             Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
         }
     }
@@ -2699,6 +2774,7 @@ impl BackendStorage for CpuStorage {
             }
             Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
             Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
+            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
             Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
         }
     }
@@ -2749,6 +2825,10 @@ impl BackendStorage for CpuStorage {
                 let data = unary_map(storage, layout, B::u32);
                 Ok(Self::U32(data))
             }
+            Self::I32(storage) => {
+                let data = unary_map(storage, layout, B::i32);
+                Ok(Self::I32(data))
+            }
             Self::I64(storage) => {
                 let data = unary_map(storage, layout, B::i64);
                 Ok(Self::I64(data))
@@ -2803,6 +2883,14 @@ impl BackendStorage for CpuStorage {
                 };
                 Ok(Self::U32(data))
             }
+            (Self::I32(lhs), Self::I32(rhs)) => {
+                let data = if B::I32_VEC {
+                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec)
+                } else {
+                    binary_map(lhs_l, rhs_l, lhs, rhs, B::i32)
+                };
+                Ok(Self::I32(data))
+            }
             (Self::I64(lhs), Self::I64(rhs)) => {
                 let data = if B::I64_VEC {
                     binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
@@ -2846,6 +2934,9 @@ impl BackendStorage for CpuStorage {
             (Self::U32(src), Self::U32(dst)) => {
                 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
             }
+            (Self::I32(src), Self::I32(dst)) => {
+                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+            }
             (Self::I64(src), Self::I64(dst)) => {
                 copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
             }
@@ -2877,6 +2968,7 @@ impl BackendStorage for CpuStorage {
         match (self, dst) {
             (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
             (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+            (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
             (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
             (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
             (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@@ -2906,6 +2998,7 @@ impl BackendStorage for CpuStorage {
         match self {
             Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
             Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+            Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
             Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
             _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
         }
@@ -3075,6 +3168,7 @@ impl BackendStorage for CpuStorage {
         match ids {
             Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
             Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
+            Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
             Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
             _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
         }
@@ -3084,6 +3178,7 @@ impl BackendStorage for CpuStorage {
         match ids {
             Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
             Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
+            Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l),
             Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
             _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
         }
@@ -3101,6 +3196,7 @@ impl BackendStorage for CpuStorage {
         match ids {
             Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
             Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
+            Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
             Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
             _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
         }
@@ -3130,6 +3226,13 @@ impl BackendStorage for CpuStorage {
                 };
                 IndexAdd { ids, dim }.map(self, l, src, src_l)
             }
+            Self::I32(ids) => {
+                let ids = match ids_l.contiguous_offsets() {
+                    Some((a, b)) => &ids[a..b],
+                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
+                };
+                IndexAdd { ids, dim }.map(self, l, src, src_l)
+            }
             Self::I64(ids) => {
                 let ids = match ids_l.contiguous_offsets() {
                     Some((a, b)) => &ids[a..b],
@@ -3225,7 +3328,7 @@ impl BackendDevice for CpuDevice {
         let elem_count = shape.elem_count();
         let mut rng = rand::thread_rng();
         match dtype {
-            DType::U8 | DType::U32 | DType::I64 => {
+            DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
                 Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
             }
             DType::BF16 => {
@@ -3271,7 +3374,7 @@ impl BackendDevice for CpuDevice {
         let elem_count = shape.elem_count();
         let mut rng = rand::thread_rng();
         match dtype {
-            DType::U8 | DType::U32 | DType::I64 => {
+            DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
                 Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
             }
             DType::BF16 => {
@@ -3330,6 +3433,11 @@ impl BackendDevice for CpuDevice {
                 v.set_len(elem_count);
                 CpuStorage::U32(v)
             }
+            DType::I32 => {
+                let mut v = Vec::with_capacity(elem_count);
+                v.set_len(elem_count);
+                CpuStorage::I32(v)
+            }
             DType::I64 => {
                 let mut v = Vec::with_capacity(elem_count);
                 v.set_len(elem_count);
@@ -3364,6 +3472,7 @@ impl BackendDevice for CpuDevice {
         let storage = match dtype {
             DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
             DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
+            DType::I32 => CpuStorage::I32(vec![1i32; elem_count]),
             DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
             DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
             DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
@@ -3378,6 +3487,7 @@ impl BackendDevice for CpuDevice {
         let storage = match dtype {
             DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
             DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
+            DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
             DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
             DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
             DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs
index 9d38729145..297ccd3de6 100644
--- a/candle-core/src/cpu_backend/utils.rs
+++ b/candle-core/src/cpu_backend/utils.rs
@@ -10,6 +10,7 @@ pub trait Map1 {
         match vs {
             C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
             C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
+            C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
             C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
             C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
             C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
@@ -26,6 +27,7 @@ pub trait Map1Any {
         match vs {
             C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
             C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
+            C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
             C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
             C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
             C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs
index 352bae9442..9e0b64067b 100644
--- a/candle-core/src/cuda_backend/device.rs
+++ b/candle-core/src/cuda_backend/device.rs
@@ -79,6 +79,14 @@ impl CudaDevice {
                 unsafe { func.launch(cfg, params) }.w()?;
                 CudaStorageSlice::U32(data)
             }
+            DType::I32 => {
+                // SAFETY: Set later by running the fill kernel.
+                let data = unsafe { self.alloc::<i32>(elem_count) }.w()?;
+                let func = self.get_or_load_func("fill_i32", kernels::FILL)?;
+                let params = (&data, v as i32, elem_count);
+                unsafe { func.launch(cfg, params) }.w()?;
+                CudaStorageSlice::I32(data)
+            }
             DType::I64 => {
                 // SAFETY: Set later by running the fill kernel.
                 let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
@@ -192,6 +200,10 @@ impl BackendDevice for CudaDevice {
                 let data = self.alloc_zeros::<u32>(elem_count).w()?;
                 CudaStorageSlice::U32(data)
             }
+            DType::I32 => {
+                let data = self.alloc_zeros::<i32>(elem_count).w()?;
+                CudaStorageSlice::I32(data)
+            }
             DType::I64 => {
                 let data = self.alloc_zeros::<i64>(elem_count).w()?;
                 CudaStorageSlice::I64(data)
@@ -225,7 +237,7 @@ impl BackendDevice for CudaDevice {
         let slice = match dtype {
             // TODO: Add support for F16 and BF16 though this is likely to require some upstream
             // cudarc changes.
-            DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
+            DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => {
                 Err(CudaError::UnsupportedDtype {
                     dtype,
                     op: "rand_uniform",
@@ -269,7 +281,7 @@ impl BackendDevice for CudaDevice {
             elem_count
         };
         let slice = match dtype {
-            DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
+            DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => {
                 Err(CudaError::UnsupportedDtype {
                     dtype,
                     op: "rand_normal",
@@ -311,6 +323,10 @@ impl BackendDevice for CudaDevice {
                 let data = self.alloc::<u32>(elem_count).w()?;
                 CudaStorageSlice::U32(data)
             }
+            DType::I32 => {
+                let data = self.alloc::<i32>(elem_count).w()?;
+                CudaStorageSlice::I32(data)
+            }
             DType::I64 => {
                 let data = self.alloc::<i64>(elem_count).w()?;
                 CudaStorageSlice::I64(data)
@@ -348,6 +364,10 @@ impl BackendDevice for CudaDevice {
                 let data = self.htod_sync_copy(storage).w()?;
                 CudaStorageSlice::U32(data)
             }
+            CpuStorageRef::I32(storage) => {
+                let data = self.htod_sync_copy(storage).w()?;
+                CudaStorageSlice::I32(data)
+            }
             CpuStorageRef::I64(storage) => {
                 let data = self.htod_sync_copy(storage).w()?;
                 CudaStorageSlice::I64(data)
@@ -385,6 +405,10 @@ impl BackendDevice for CudaDevice {
                 let data = self.htod_sync_copy(storage).w()?;
                 CudaStorageSlice::U32(data)
             }
+            CpuStorage::I32(storage) => {
+                let data = self.htod_sync_copy(storage).w()?;
+                CudaStorageSlice::I32(data)
+            }
             CpuStorage::I64(storage) => {
                 let data = self.htod_sync_copy(storage).w()?;
                 CudaStorageSlice::I64(data)
@@ -422,6 +446,10 @@ impl BackendDevice for CudaDevice {
                 let data = self.htod_copy(storage).w()?;
                 CudaStorageSlice::U32(data)
             }
+            CpuStorage::I32(storage) => {
+                let data = self.htod_copy(storage).w()?;
+                CudaStorageSlice::I32(data)
+            }
             CpuStorage::I64(storage) => {
                 let data = self.htod_copy(storage).w()?;
                 CudaStorageSlice::I64(data)
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 827e22e797..aedcbdd7cb 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -47,6 +47,7 @@ impl SlicePtrOrNull<usize> {
 pub enum CudaStorageSlice {
     U8(CudaSlice<u8>),
     U32(CudaSlice<u32>),
+    I32(CudaSlice<i32>),
     I64(CudaSlice<i64>),
     BF16(CudaSlice<bf16>),
     F16(CudaSlice<f16>),
@@ -361,11 +362,14 @@ impl<'a> Map1 for IndexSelect<'a> {
             CudaStorageSlice::U8(slice) => {
                 ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
             }
+            CudaStorageSlice::I32(slice) => {
+                ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr())
+            }
             CudaStorageSlice::I64(slice) => {
                 ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
             }
             _ => Err(CudaError::UnexpectedDType {
-                msg: "index_select ids should be u8 or u32",
+                msg: "index_select ids should be u8/u32/i32/i64",
                 expected: DType::U32,
                 got: self.0.dtype(),
             })
@@ -425,11 +429,14 @@ impl<'a> Map1 for Gather<'a> {
                 ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
             }
             CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
+            CudaStorageSlice::I32(slice) => {
+                ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr())
+            }
             CudaStorageSlice::I64(slice) => {
                 ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
             }
             _ => Err(CudaError::UnexpectedDType {
-                msg: "gather ids should be u8/u32/i64",
+                msg: "gather ids should be u8/u32/i32/i64",
                 expected: DType::U32,
                 got: ids.dtype(),
             })?,
@@ -475,10 +482,11 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
         };
         let (name, ids) = match &ids.slice {
             CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
+            CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
             CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
             CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
             _ => Err(CudaError::UnexpectedDType {
-                msg: "index-add ids should be u8/u32/i64",
+                msg: "index-add ids should be u8/u32/i32/i64",
                 expected: DType::U32,
                 got: ids.dtype(),
             })?,
@@ -523,10 +531,11 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
         };
         let (name, ids) = match &ids.slice {
             CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
+            CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()),
             CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
             CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
             _ => Err(CudaError::UnexpectedDType {
-                msg: "scatter-add ids should be u8/u32/i64",
+                msg: "scatter-add ids should be u8/u32/i32/i64",
                 expected: DType::U32,
                 got: ids.dtype(),
             })?,
@@ -865,6 +874,10 @@ impl<'a> Map2 for WhereCond<'a> {
                 let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
                 (ptr, "where_u32")
             }
+            CudaStorageSlice::I32(slice) => {
+                let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+                (ptr, "where_i32")
+            }
             CudaStorageSlice::I64(slice) => {
                 let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
                 (ptr, "where_i64")
@@ -1024,6 +1037,7 @@ macro_rules! cuda_dtype {
 }
 cuda_dtype!(u8, U8);
 cuda_dtype!(u32, U32);
+cuda_dtype!(i32, I32);
 cuda_dtype!(i64, I64);
 cuda_dtype!(f16, F16);
 cuda_dtype!(bf16, BF16);
@@ -1146,6 +1160,7 @@ impl BackendStorage for CudaStorage {
         match self.slice {
             CudaStorageSlice::U8(_) => DType::U8,
             CudaStorageSlice::U32(_) => DType::U32,
+            CudaStorageSlice::I32(_) => DType::I32,
             CudaStorageSlice::I64(_) => DType::I64,
             CudaStorageSlice::BF16(_) => DType::BF16,
             CudaStorageSlice::F16(_) => DType::F16,
@@ -1172,6 +1187,7 @@ impl BackendStorage for CudaStorage {
         let inp = match &self.slice {
             CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
             CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
+            CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(),
             CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
             CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
             CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
@@ -1195,6 +1211,12 @@ impl BackendStorage for CudaStorage {
                 unsafe { func.launch(cfg, params) }.w()?;
                 CudaStorageSlice::U32(out)
             }
+            DType::I32 => {
+                let out = unsafe { dev.alloc::<i32>(el) }.w()?;
+                let params = (el, dims.len(), &ds, *inp, &out);
+                unsafe { func.launch(cfg, params) }.w()?;
+                CudaStorageSlice::I32(out)
+            }
             DType::I64 => {
                 let out = unsafe { dev.alloc::<i64>(el) }.w()?;
                 let params = (el, dims.len(), &ds, *inp, &out);
@@ -1291,6 +1313,11 @@ impl BackendStorage for CudaStorage {
                 let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
                 Ok(CpuStorage::U32(cpu_storage))
             }
+            CudaStorageSlice::I32(slice) => {
+                let dev = slice.device();
+                let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
+                Ok(CpuStorage::I32(cpu_storage))
+            }
             CudaStorageSlice::I64(slice) => {
                 let dev = slice.device();
                 let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
@@ -1557,6 +1584,7 @@ impl BackendStorage for CudaStorage {
                 S::F64(out)
             }
             (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
+            (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?,
             (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
             _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
         };
@@ -1879,6 +1907,11 @@ impl BackendStorage for CudaStorage {
                 *d.slice(dst_o..).device_ptr(),
                 "copy2d_u32",
             ),
+            (S::I32(s), S::I32(d)) => (
+                *s.slice(src_o..).device_ptr(),
+                *d.slice(dst_o..).device_ptr(),
+                "copy2d_i32",
+            ),
             (S::I64(s), S::I64(d)) => (
                 *s.slice(src_o..).device_ptr(),
                 *d.slice(dst_o..).device_ptr(),
@@ -1985,6 +2018,18 @@ impl BackendStorage for CudaStorage {
                     unsafe { func.launch(cfg, params) }.w()?
                 }
             }
+            (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => {
+                let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+                if src_l.is_contiguous() {
+                    dev.dtod_copy(&src, &mut dst).w()?
+                } else {
+                    let func = dev.get_or_load_func("ucopy_i32", kernels::UNARY)?;
+                    // SAFETY: Set later by running the kernel.
+                    let params = (el_count, dims.len(), &ds, &src, &mut dst);
+                    // SAFETY: ffi.
+                    unsafe { func.launch(cfg, params) }.w()?
+                }
+            }
             (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
                 let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
                 if src_l.is_contiguous() {
diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs
index c1210727ad..ae009b26ab 100644
--- a/candle-core/src/cuda_backend/utils.rs
+++ b/candle-core/src/cuda_backend/utils.rs
@@ -19,6 +19,7 @@ pub trait Map1 {
         let out = match s {
             S::U8(s) => S::U8(self.f(s, d, l)?),
             S::U32(s) => S::U32(self.f(s, d, l)?),
+            S::I32(s) => S::I32(self.f(s, d, l)?),
             S::I64(s) => S::I64(self.f(s, d, l)?),
             S::BF16(s) => S::BF16(self.f(s, d, l)?),
             S::F16(s) => S::F16(self.f(s, d, l)?),
@@ -136,6 +137,7 @@ pub trait Map1Any {
         let out = match s {
             S::U8(s) => self.f(s, d, l, S::U8)?,
             S::U32(s) => self.f(s, d, l, S::U32)?,
+            S::I32(s) => self.f(s, d, l, S::I32)?,
             S::I64(s) => self.f(s, d, l, S::I64)?,
             S::BF16(s) => self.f(s, d, l, S::BF16)?,
             S::F16(s) => self.f(s, d, l, S::F16)?,
diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs
index 7e6e3cf8f1..5fb370b696 100644
--- a/candle-core/src/display.rs
+++ b/candle-core/src/display.rs
@@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor {
         match self.dtype() {
             DType::U8 => self.fmt_dt::<u8>(f),
             DType::U32 => self.fmt_dt::<u32>(f),
+            DType::I32 => self.fmt_dt::<i32>(f),
             DType::I64 => self.fmt_dt::<i64>(f),
             DType::BF16 => self.fmt_dt::<bf16>(f),
             DType::F16 => self.fmt_dt::<f16>(f),
@@ -463,6 +464,12 @@ impl std::fmt::Display for Tensor {
                 tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
                 writeln!(f)?;
             }
+            DType::I32 => {
+                let tf: IntFormatter<i32> = IntFormatter::new();
+                let max_w = tf.max_width(&to_display);
+                tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
+                writeln!(f)?;
+            }
             DType::I64 => {
                 let tf: IntFormatter<i64> = IntFormatter::new();
                 let max_w = tf.max_width(&to_display);
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index de6cddc3a3..c6a0800b24 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -10,6 +10,8 @@ pub enum DType {
     U8,
     // Unsigned 32 bits integer.
     U32,
+    // Signed 32 bits integer.
+    I32,
     // Signed 64 bits integer.
     I64,
     // Brain floating-point using half precision (16 bits).
@@ -39,6 +41,7 @@ impl std::str::FromStr for DType {
         match s {
             "u8" => Ok(Self::U8),
             "u32" => Ok(Self::U32),
+            "i32" => Ok(Self::I32),
             "i64" => Ok(Self::I64),
             "bf16" => Ok(Self::BF16),
             "f16" => Ok(Self::F16),
@@ -55,6 +58,7 @@ impl DType {
         match self {
             Self::U8 => "u8",
             Self::U32 => "u32",
+            Self::I32 => "i32",
             Self::I64 => "i64",
             Self::BF16 => "bf16",
             Self::F16 => "f16",
@@ -68,6 +72,7 @@ impl DType {
         match self {
             Self::U8 => 1,
             Self::U32 => 4,
+            Self::I32 => 4,
             Self::I64 => 8,
             Self::BF16 => 2,
             Self::F16 => 2,
@@ -78,14 +83,14 @@ impl DType {
 
     pub fn is_int(&self) -> bool {
         match self {
-            Self::U8 | Self::U32 | Self::I64 => true,
+            Self::U8 | Self::U32 | Self::I32 | Self::I64 => true,
             Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
         }
     }
 
     pub fn is_float(&self) -> bool {
         match self {
-            Self::U8 | Self::U32 | Self::I64 => false,
+            Self::U8 | Self::U32 | Self::I32 | Self::I64 => false,
             Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
         }
     }
@@ -169,6 +174,7 @@ use half::{bf16, f16};
 
 with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
 with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
+with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64);
 with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
 with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
 with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
@@ -180,6 +186,15 @@ pub trait IntDType: WithDType {
     fn as_usize(&self) -> usize;
 }
 
+impl IntDType for i32 {
+    fn is_true(&self) -> bool {
+        *self != 0
+    }
+    fn as_usize(&self) -> usize {
+        *self as usize
+    }
+}
+
 impl IntDType for i64 {
     fn is_true(&self) -> bool {
         *self != 0
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 039bb62d95..cdcec45f75 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -96,6 +96,7 @@ impl BackendStorage for MetalStorage {
         match self.dtype {
             DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
             DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
+            DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)),
             DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
             DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
             DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
@@ -304,6 +305,11 @@ impl BackendStorage for MetalStorage {
             (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
             (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
             (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
+            (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false),
+            (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false),
+            (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false),
+            (ReduceOp::ArgMin, DType::I32) => ("fast_argmin_i32_strided", true, true),
+            (ReduceOp::ArgMax, DType::I32) => ("fast_argmax_i32_strided", true, true),
             (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false),
             (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false),
             (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
@@ -363,21 +369,30 @@ impl BackendStorage for MetalStorage {
                 (DType::U32, DType::BF16) => "cast_u32_bf16",
                 (DType::U32, DType::F16) => "cast_u32_f16",
                 (DType::U32, DType::F32) => "cast_u32_f32",
+                (DType::U32, DType::I32) => "cast_u32_i32",
                 (DType::U32, DType::I64) => "cast_u32_i64",
                 (DType::U32, DType::U8) => "cast_u32_u8",
 
                 (DType::U8, DType::BF16) => "cast_u8_bf16",
                 (DType::U8, DType::F16) => "cast_u8_f16",
                 (DType::U8, DType::F32) => "cast_u8_f32",
+                (DType::U8, DType::I32) => "cast_u8_i32",
                 (DType::U8, DType::I64) => "cast_u8_i64",
                 (DType::U8, DType::U32) => "cast_u8_u32",
 
                 (DType::F32, DType::BF16) => "cast_f32_bf16",
                 (DType::F32, DType::F16) => "cast_f32_f16",
+                (DType::F32, DType::I32) => "cast_f32_i32",
                 (DType::F32, DType::I64) => "cast_f32_i64",
                 (DType::F32, DType::U32) => "cast_f32_u32",
                 (DType::F32, DType::U8) => "cast_f32_u8",
 
+                (DType::I32, DType::BF16) => "cast_i32_bf16",
+                (DType::I32, DType::F16) => "cast_i32_f16",
+                (DType::I32, DType::F32) => "cast_i32_f32",
+                (DType::I32, DType::U32) => "cast_i32_u32",
+                (DType::I32, DType::U8) => "cast_i32_u8",
+
                 (DType::I64, DType::BF16) => "cast_i64_bf16",
                 (DType::I64, DType::F16) => "cast_i64_f16",
                 (DType::I64, DType::F32) => "cast_i64_f32",
@@ -386,12 +401,14 @@ impl BackendStorage for MetalStorage {
 
                 (DType::F16, DType::BF16) => "cast_f16_bf16",
                 (DType::F16, DType::F32) => "cast_f16_f32",
+                (DType::F16, DType::I32) => "cast_f16_i32",
                 (DType::F16, DType::I64) => "cast_f16_i64",
                 (DType::F16, DType::U32) => "cast_f16_u32",
                 (DType::F16, DType::U8) => "cast_f16_u8",
 
                 (DType::BF16, DType::F16) => "cast_bf16_f16",
                 (DType::BF16, DType::F32) => "cast_bf16_f32",
+                (DType::BF16, DType::I32) => "cast_bf16_i32",
                 (DType::BF16, DType::I64) => "cast_bf16_i64",
                 (DType::BF16, DType::U32) => "cast_bf16_u32",
                 (DType::BF16, DType::U8) => "cast_bf16_u8",
@@ -414,12 +431,15 @@ impl BackendStorage for MetalStorage {
             let kernel_name = match (self.dtype, dtype) {
                 (DType::U32, DType::F32) => "cast_u32_f32_strided",
                 (DType::U32, DType::U8) => "cast_u32_u8_strided",
+                (DType::U32, DType::I32) => "cast_u32_i32_strided",
                 (DType::U32, DType::I64) => "cast_u32_i64_strided",
                 (DType::U8, DType::U32) => "cast_u8_u32_strided",
                 (DType::U8, DType::F32) => "cast_u8_f32_strided",
+                (DType::U8, DType::I32) => "cast_u8_i32_strided",
                 (DType::U8, DType::I64) => "cast_u8_i64_strided",
                 (DType::F32, DType::F16) => "cast_f32_f16_strided",
                 (DType::F16, DType::F32) => "cast_f16_f32_strided",
+                (DType::I32, DType::F32) => "cast_i32_f32_strided",
                 (DType::I64, DType::F32) => "cast_i64_f32_strided",
                 (DType::F32, DType::BF16) => "cast_f32_bf16_strided",
                 (DType::BF16, DType::F32) => "cast_bf16_f32_strided",
@@ -514,6 +534,7 @@ impl BackendStorage for MetalStorage {
                     ("usign", DType::F16) => contiguous_tiled::sign::HALF,
                     ("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
                     ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
+                    ("usign", DType::I32) => contiguous_tiled::sign::I32,
                     ("usign", DType::I64) => contiguous_tiled::sign::I64,
                     (name, dtype) => {
                         crate::bail!(
@@ -592,6 +613,7 @@ impl BackendStorage for MetalStorage {
                     ("usign", DType::F16) => contiguous::sign::HALF,
                     ("usign", DType::F32) => contiguous::sign::FLOAT,
                     ("usign", DType::BF16) => contiguous::sign::BFLOAT,
+                    ("usign", DType::I32) => contiguous::sign::I32,
                     ("usign", DType::I64) => contiguous::sign::I64,
                     (name, dtype) => {
                         crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
@@ -723,6 +745,7 @@ impl BackendStorage for MetalStorage {
             (DType::U32, DType::F32) => "where_u32_f32",
             (DType::U8, DType::BF16) => "where_u8_bf16",
             (DType::U8, DType::F16) => "where_u8_f16",
+            (DType::U8, DType::I32) => "where_u8_i32",
             (DType::U8, DType::I64) => "where_u8_i64",
             (DType::U8, DType::U32) => "where_u8_u32",
             (DType::U8, DType::U8) => "where_u8_u8",
@@ -1259,6 +1282,9 @@ impl BackendStorage for MetalStorage {
             (DType::U32, DType::F32) => "sa_u32_f32",
             (DType::U32, DType::F16) => "sa_u32_f16",
             (DType::U32, DType::BF16) => "sa_u32_bf16",
+            (DType::I32, DType::F32) => "sa_i32_f32",
+            (DType::I32, DType::F16) => "sa_i32_f16",
+            (DType::I32, DType::BF16) => "sa_i32_bf16",
             (DType::I64, DType::F32) => "sa_i64_f32",
             (DType::I64, DType::F16) => "sa_i64_f16",
             (DType::I64, DType::BF16) => "sa_i64_bf16",
@@ -1307,6 +1333,10 @@ impl BackendStorage for MetalStorage {
             (DType::U32, DType::F16) => "is_u32_f16",
             (DType::U32, DType::BF16) => "is_u32_bf16",
 
+            (DType::I32, DType::F32) => "is_i32_f32",
+            (DType::I32, DType::F16) => "is_i32_f16",
+            (DType::I32, DType::BF16) => "is_i32_bf16",
+
             (DType::I64, DType::F32) => "is_i64_f32",
             (DType::I64, DType::F16) => "is_i64_f16",
             (DType::I64, DType::BF16) => "is_i64_bf16",
@@ -1352,9 +1382,18 @@ impl BackendStorage for MetalStorage {
             return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
         };
         let name = match (ids.dtype, self.dtype) {
+            (DType::I32, DType::BF16) => "ia_i32_bf16",
+            (DType::I32, DType::F16) => "ia_i32_f16",
+            (DType::I32, DType::F32) => "ia_i32_f32",
+            (DType::I32, DType::I32) => "ia_i32_i32",
+            (DType::I32, DType::I64) => "ia_i32_i64",
+            (DType::I32, DType::U32) => "ia_i32_u32",
+            (DType::I32, DType::U8) => "ia_i32_u8",
+
             (DType::I64, DType::BF16) => "ia_i64_bf16",
             (DType::I64, DType::F16) => "ia_i64_f16",
             (DType::I64, DType::F32) => "ia_i64_f32",
+            (DType::I64, DType::I32) => "ia_i64_i32",
             (DType::I64, DType::I64) => "ia_i64_i64",
             (DType::I64, DType::U32) => "ia_i64_u32",
             (DType::I64, DType::U8) => "ia_i64_u8",
@@ -1362,6 +1401,7 @@ impl BackendStorage for MetalStorage {
             (DType::U32, DType::BF16) => "ia_u32_bf16",
             (DType::U32, DType::F16) => "ia_u32_f16",
             (DType::U32, DType::F32) => "ia_u32_f32",
+            (DType::U32, DType::I32) => "ia_u32_i32",
             (DType::U32, DType::I64) => "ia_u32_i64",
             (DType::U32, DType::U32) => "ia_u32_u32",
             (DType::U32, DType::U8) => "ia_u32_u8",
@@ -1369,6 +1409,7 @@ impl BackendStorage for MetalStorage {
             (DType::U8, DType::BF16) => "ia_u8_bf16",
             (DType::U8, DType::F16) => "ia_u8_f16",
             (DType::U8, DType::F32) => "ia_u8_f32",
+            (DType::U8, DType::I32) => "ia_u8_i32",
             (DType::U8, DType::I64) => "ia_u8_i64",
             (DType::U8, DType::U32) => "ia_u8_u32",
             (DType::U8, DType::U8) => "ia_u8_u8",
@@ -1579,6 +1620,7 @@ impl BackendStorage for MetalStorage {
                 DType::F32 => candle_metal_kernels::copy2d::FLOAT,
                 DType::F16 => candle_metal_kernels::copy2d::HALF,
                 DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
+                DType::I32 => candle_metal_kernels::copy2d::I32,
                 DType::I64 => candle_metal_kernels::copy2d::I64,
                 DType::U32 => candle_metal_kernels::copy2d::U32,
                 DType::U8 => candle_metal_kernels::copy2d::U8,
@@ -1625,6 +1667,7 @@ impl BackendStorage for MetalStorage {
                 DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
                 DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
                 DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
+                DType::I32 => candle_metal_kernels::unary::strided::copy::I32,
                 DType::I64 => candle_metal_kernels::unary::strided::copy::I64,
                 DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
                 DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
@@ -1716,6 +1759,17 @@ impl MetalStorage {
                 ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
                 ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
 
+                ("add", DType::I32) => (contiguous::add::I32, self.dtype),
+                ("sub", DType::I32) => (contiguous::sub::I32, self.dtype),
+                ("mul", DType::I32) => (contiguous::mul::I32, self.dtype),
+                ("div", DType::I32) => (contiguous::div::I32, self.dtype),
+                ("eq", DType::I32) => (contiguous::eq::I32, DType::U8),
+                ("ne", DType::I32) => (contiguous::ne::I32, DType::U8),
+                ("le", DType::I32) => (contiguous::le::I32, DType::U8),
+                ("lt", DType::I32) => (contiguous::lt::I32, DType::U8),
+                ("ge", DType::I32) => (contiguous::ge::I32, DType::U8),
+                ("gt", DType::I32) => (contiguous::gt::I32, DType::U8),
+
                 ("add", DType::I64) => (contiguous::add::I64, self.dtype),
                 ("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
                 ("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@@ -1809,6 +1863,19 @@ impl MetalStorage {
                 ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
                 ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
 
+                ("badd", DType::I32) => (strided::add::I32, self.dtype),
+                ("bsub", DType::I32) => (strided::sub::I32, self.dtype),
+                ("bmul", DType::I32) => (strided::mul::I32, self.dtype),
+                ("bdiv", DType::I32) => (strided::div::I32, self.dtype),
+                ("bminimum", DType::I32) => (strided::min::I32, self.dtype),
+                ("bmaximum", DType::I32) => (strided::max::I32, self.dtype),
+                ("eq", DType::I32) => (strided::eq::I32, DType::U8),
+                ("ne", DType::I32) => (strided::ne::I32, DType::U8),
+                ("le", DType::I32) => (strided::le::I32, DType::U8),
+                ("lt", DType::I32) => (strided::lt::I32, DType::U8),
+                ("ge", DType::I32) => (strided::ge::I32, DType::U8),
+                ("gt", DType::I32) => (strided::gt::I32, DType::U8),
+
                 ("badd", DType::I64) => (strided::add::I64, self.dtype),
                 ("bsub", DType::I64) => (strided::sub::I64, self.dtype),
                 ("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@@ -1964,6 +2031,7 @@ impl BackendDevice for MetalDevice {
         let (count, buffer) = match T::cpu_storage_ref(s) {
             CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+            CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
@@ -1977,6 +2045,7 @@ impl BackendDevice for MetalDevice {
         let (count, buffer) = match storage {
             CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+            CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
             CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 83e4f6527f..b321a619f8 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -85,6 +85,7 @@ impl Header {
             DType::F16 => "f2",
             DType::F32 => "f4",
             DType::F64 => "f8",
+            DType::I32 => "i4",
             DType::I64 => "i8",
             DType::U32 => "u4",
             DType::U8 => "u1",
@@ -234,6 +235,11 @@ impl Tensor {
                 reader.read_u32_into::<LittleEndian>(&mut data_t)?;
                 Tensor::from_vec(data_t, shape, &Device::Cpu)
             }
+            DType::I32 => {
+                let mut data_t = vec![0i32; elem_count];
+                reader.read_i32_into::<LittleEndian>(&mut data_t)?;
+                Tensor::from_vec(data_t, shape, &Device::Cpu)
+            }
             DType::I64 => {
                 let mut data_t = vec![0i64; elem_count];
                 reader.read_i64_into::<LittleEndian>(&mut data_t)?;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 49ba44be89..75931ee2fe 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -189,6 +189,7 @@ pub trait UnaryOpT {
     fn f64(v1: f64) -> f64;
     fn u8(v1: u8) -> u8;
     fn u32(v1: u32) -> u32;
+    fn i32(v1: i32) -> i32;
     fn i64(v1: i64) -> i64;
 
     // There is no very good way to represent optional function in traits so we go for an explicit
@@ -213,6 +214,7 @@ pub trait BinaryOpT {
     fn f64(v1: f64, v2: f64) -> f64;
     fn u8(v1: u8, v2: u8) -> u8;
     fn u32(v1: u32, v2: u32) -> u32;
+    fn i32(v1: i32, v2: i32) -> i32;
     fn i64(v1: i64, v2: i64) -> i64;
 
     const BF16_VEC: bool = false;
@@ -229,6 +231,8 @@ pub trait BinaryOpT {
     fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
     const I64_VEC: bool = false;
     fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
+    const I32_VEC: bool = false;
+    fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {}
 }
 
 pub(crate) struct Add;
@@ -288,6 +292,10 @@ macro_rules! bin_op {
                 $e(v1, v2)
             }
             #[inline(always)]
+            fn i32(v1: i32, v2: i32) -> i32 {
+                $e(v1, v2)
+            }
+            #[inline(always)]
             fn i64(v1: i64, v2: i64) -> i64 {
                 $e(v1, v2)
             }
@@ -379,6 +387,10 @@ macro_rules! unary_op {
             fn i64(_: i64) -> i64 {
                 todo!("no unary function for i64")
             }
+            #[inline(always)]
+            fn i32(_: i32) -> i32 {
+                todo!("no unary function for i32")
+            }
         }
     };
 
@@ -415,6 +427,10 @@ macro_rules! unary_op {
             fn i64(_: i64) -> i64 {
                 todo!("no unary function for i64")
             }
+            #[inline(always)]
+            fn i32(_: i32) -> i32 {
+                todo!("no unary function for i32")
+            }
 
             #[cfg(feature = "mkl")]
             const F32_VEC: bool = true;
@@ -514,6 +530,10 @@ impl UnaryOpT for Gelu {
     fn i64(_: i64) -> i64 {
         0
     }
+    #[inline(always)]
+    fn i32(_: i32) -> i32 {
+        0
+    }
     const KERNEL: &'static str = "ugelu";
 
     #[cfg(feature = "mkl")]
@@ -587,6 +607,10 @@ impl UnaryOpT for Erf {
     fn i64(_: i64) -> i64 {
         0
     }
+    #[inline(always)]
+    fn i32(_: i32) -> i32 {
+        0
+    }
 }
 
 /// Silu operation
@@ -621,6 +645,10 @@ impl UnaryOpT for Silu {
     fn i64(_: i64) -> i64 {
         0
     }
+    #[inline(always)]
+    fn i32(_: i32) -> i32 {
+        0
+    }
     const KERNEL: &'static str = "usilu";
 
     #[cfg(feature = "mkl")]
@@ -692,6 +720,10 @@ impl UnaryOpT for Abs {
     fn i64(v: i64) -> i64 {
         v.abs()
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        v.abs()
+    }
 }
 
 impl UnaryOpT for Ceil {
@@ -726,6 +758,10 @@ impl UnaryOpT for Ceil {
     fn i64(v: i64) -> i64 {
         v
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        v
+    }
 }
 
 impl UnaryOpT for Floor {
@@ -760,6 +796,10 @@ impl UnaryOpT for Floor {
     fn i64(v: i64) -> i64 {
         v
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        v
+    }
 }
 
 impl UnaryOpT for Round {
@@ -794,6 +834,10 @@ impl UnaryOpT for Round {
     fn i64(v: i64) -> i64 {
         v
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        v
+    }
 }
 
 impl UnaryOpT for GeluErf {
@@ -828,6 +872,10 @@ impl UnaryOpT for GeluErf {
     fn i64(_: i64) -> i64 {
         0
     }
+    #[inline(always)]
+    fn i32(_: i32) -> i32 {
+        0
+    }
 }
 
 impl UnaryOpT for Relu {
@@ -862,6 +910,10 @@ impl UnaryOpT for Relu {
     fn i64(v: i64) -> i64 {
         v
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        v
+    }
 }
 
 /// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
@@ -960,4 +1012,8 @@ impl UnaryOpT for Sign {
     fn i64(v: i64) -> i64 {
         (v > 0) as i64 - (v < 0) as i64
     }
+    #[inline(always)]
+    fn i32(v: i32) -> i32 {
+        (v > 0) as i32 - (v < 0) as i32
+    }
 }
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 5ea1f192b3..162928ec7d 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -11,6 +11,7 @@ impl From<DType> for st::Dtype {
             DType::U8 => st::Dtype::U8,
             DType::U32 => st::Dtype::U32,
             DType::I64 => st::Dtype::I64,
+            DType::I32 => st::Dtype::I32,
             DType::BF16 => st::Dtype::BF16,
             DType::F16 => st::Dtype::F16,
             DType::F32 => st::Dtype::F32,
@@ -187,6 +188,7 @@ impl Tensor {
         match dtype {
             DType::U8 => convert_slice::<u8>(data, shape, device),
             DType::U32 => convert_slice::<u32>(data, shape, device),
+            DType::I32 => convert_slice::<i32>(data, shape, device),
             DType::I64 => convert_slice::<i64>(data, shape, device),
             DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
             DType::F16 => convert_slice::<half::f16>(data, shape, device),
@@ -204,10 +206,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
             convert_with_cast_::<u16, u32, _>(view, device, conv)
         }
         st::Dtype::U32 => convert_::<u32>(view, device),
-        st::Dtype::I32 => {
-            let conv = |x| Ok(i64::from(x));
-            convert_with_cast_::<i32, i64, _>(view, device, conv)
-        }
+        st::Dtype::I32 => convert_::<i32>(view, device),
         st::Dtype::I64 => convert_::<i64>(view, device),
         st::Dtype::BF16 => convert_::<half::bf16>(view, device),
         st::Dtype::F16 => convert_::<half::f16>(view, device),
@@ -223,6 +222,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
     match tensor.dtype() {
         DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
         DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
+        DType::I32 => Ok(convert_back_::<i32>(tensor.to_vec1()?)),
         DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
         DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
         DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs
index 614a37fe65..92ad1d5adc 100644
--- a/candle-core/src/sort.rs
+++ b/candle-core/src/sort.rs
@@ -65,6 +65,7 @@ impl crate::CustomOp1 for ArgSort {
         let sort_indexes = match storage {
             crate::CpuStorage::U8(vs) => self.asort(vs, layout),
             crate::CpuStorage::U32(vs) => self.asort(vs, layout),
+            crate::CpuStorage::I32(vs) => self.asort(vs, layout),
             crate::CpuStorage::I64(vs) => self.asort(vs, layout),
             crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
             crate::CpuStorage::F16(vs) => self.asort(vs, layout),
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 6dd21cc1ab..bff8f36042 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -17,6 +17,10 @@ fn ones(device: &Device) -> Result<()> {
         Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
         [[1, 1, 1], [1, 1, 1]],
     );
+    assert_eq!(
+        Tensor::ones((2, 3), DType::I32, device)?.to_vec2::<i32>()?,
+        [[1, 1, 1], [1, 1, 1]],
+    );
     assert_eq!(
         Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
         [[1, 1, 1], [1, 1, 1]],
@@ -805,7 +809,7 @@ fn index_select(device: &Device) -> Result<()> {
             [9.0, 10.0, 11.0]
         ]
     );
-    for dtype in [DType::U8, DType::U32, DType::I64] {
+    for dtype in [DType::U8, DType::U32, DType::I32, DType::I64] {
         let ids = ids.to_dtype(dtype)?;
         let hs = t.index_select(&ids, 1)?;
         assert_eq!(
diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu
index 540d0819f5..c3ff5b8753 100644
--- a/candle-kernels/src/affine.cu
+++ b/candle-kernels/src/affine.cu
@@ -40,4 +40,5 @@ AFFINE_OP(float, affine_f32)
 AFFINE_OP(double, affine_f64)
 AFFINE_OP(uint8_t, affine_u8)
 AFFINE_OP(uint32_t, affine_u32)
+AFFINE_OP(int32_t, affine_i32)
 AFFINE_OP(int64_t, affine_i64)
diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu
index d44e3b20ee..f534fc76ad 100644
--- a/candle-kernels/src/binary.cu
+++ b/candle-kernels/src/binary.cu
@@ -35,65 +35,77 @@ BINARY_OP(float, badd_f32, x + y)
 BINARY_OP(double, badd_f64, x + y);
 BINARY_OP(uint8_t, badd_u8, x + y);
 BINARY_OP(uint32_t, badd_u32, x + y);
+BINARY_OP(int32_t, badd_i32, x + y);
 BINARY_OP(int64_t, badd_i64, x + y);
 BINARY_OP(float, bdiv_f32, x / y)
 BINARY_OP(double, bdiv_f64, x / y);
 BINARY_OP(uint8_t, bdiv_u8, x / y);
 BINARY_OP(uint32_t, bdiv_u32, x / y);
+BINARY_OP(int32_t, bdiv_i32, x / y);
 BINARY_OP(int64_t, bdiv_i64, x / y);
 BINARY_OP(float, bmul_f32, x * y)
 BINARY_OP(double, bmul_f64, x * y);
 BINARY_OP(uint8_t, bmul_u8, x * y);
 BINARY_OP(uint32_t, bmul_u32, x * y);
+BINARY_OP(int32_t, bmul_i32, x * y);
 BINARY_OP(int64_t, bmul_i64, x * y);
 BINARY_OP(float, bsub_f32, x - y)
 BINARY_OP(double, bsub_f64, x - y);
 BINARY_OP(uint8_t, bsub_u8, x - y);
 BINARY_OP(uint32_t, bsub_u32, x - y);
+BINARY_OP(int32_t, bsub_i32, x - y);
 BINARY_OP(int64_t, bsub_i64, x - y);
 BINARY_OP(float, bminimum_f32, ming(x, y));
 BINARY_OP(double, bminimum_f64, ming(x, y));
 BINARY_OP(uint8_t, bminimum_u8, ming(x, y));
 BINARY_OP(uint32_t, bminimum_u32, ming(x, y));
+BINARY_OP(int32_t, bminimum_i32, ming(x, y));
 BINARY_OP(int64_t, bminimum_i64, ming(x, y));
 BINARY_OP(float, bmaximum_f32, maxg(x, y));
 BINARY_OP(double, bmaximum_f64, maxg(x, y));
 BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y));
 BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y));
+BINARY_OP(int32_t, bmaximum_i32, maxg(x, y));
 BINARY_OP(int64_t, bmaximum_i64, maxg(x, y));
 
 BINARY_OP_OUT(float, uint8_t, eq_f32, x == y)
 BINARY_OP_OUT(double, uint8_t, eq_f64, x == y)
 BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y)
 BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y)
+BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y)
 BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y)
 
 BINARY_OP_OUT(float, uint8_t, ne_f32, x != y)
 BINARY_OP_OUT(double, uint8_t, ne_f64, x != y)
 BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y)
 BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y)
+BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y)
 BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y)
 
 BINARY_OP_OUT(float, uint8_t, lt_f32, x < y)
 BINARY_OP_OUT(double, uint8_t, lt_f64, x < y)
 BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y)
 BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y)
+BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y)
 BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y)
 
 BINARY_OP_OUT(float, uint8_t, le_f32, x <= y)
 BINARY_OP_OUT(double, uint8_t, le_f64, x <= y)
 BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y)
 BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y)
+BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y)
 BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y)
 
 BINARY_OP_OUT(float, uint8_t, gt_f32, x > y)
 BINARY_OP_OUT(double, uint8_t, gt_f64, x > y)
 BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y)
 BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y)
+BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y)
 BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y)
 
 BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y)
 BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y)
 BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y)
 BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y)
+BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y)
 BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y)
diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu
index 90f5e7ba48..f92ac0cbf9 100644
--- a/candle-kernels/src/cast.cu
+++ b/candle-kernels/src/cast.cu
@@ -83,6 +83,8 @@ CAST_OP(double,   __nv_bfloat16, cast_f64_bf16)
 CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
 CAST_THROUGH_OP(__nv_bfloat16, __half,   float, cast_bf16_f16)
 CAST_THROUGH_OP(__half,   __nv_bfloat16, float, cast_f16_bf16)
+CAST_THROUGH_OP(int32_t,   __nv_bfloat16, float, cast_i32_bf16)
+CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32)
 #else
 #include <cuda.h>
 #if CUDA_VERSION >= 11000
@@ -94,6 +96,8 @@ CAST_THROUGH_OP(__nv_bfloat16, double,  float, cast_bf16_f64)
 CAST_THROUGH_OP(__half,   __nv_bfloat16, float, cast_f16_bf16)
 CAST_THROUGH_OP(double,   __nv_bfloat16, float, cast_f64_bf16)
 CAST_THROUGH_OP(uint8_t,   __nv_bfloat16, float, cast_u8_bf16)
+CAST_THROUGH_OP(int32_t,   __nv_bfloat16, float, cast_i32_bf16)
+CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32)
 #endif
 #endif
 
@@ -108,34 +112,48 @@ CAST_OP(uint8_t,  __half, cast_u8_f16 )
 CAST_OP(uint32_t, __half, cast_u32_f16)
 CAST_OP(float,    __half, cast_f32_f16)
 CAST_OP(double,   __half, cast_f64_f16)
+CAST_OP(int32_t,  __half, cast_i32_f16 )
+CAST_THROUGH_OP(__half, int32_t,  float, cast_f16_i32)
 #endif
 
 CAST_OP(uint32_t, uint32_t, cast_u32_u32)
 CAST_OP(uint32_t, uint8_t,  cast_u32_u8 )
 CAST_OP(uint32_t, int64_t,  cast_u32_i64 )
+CAST_OP(uint32_t, int32_t,  cast_u32_i32 )
 CAST_OP(uint32_t, float,    cast_u32_f32)
 CAST_OP(uint32_t, double,   cast_u32_f64)
 
 CAST_OP(uint8_t, uint32_t, cast_u8_u32)
 CAST_OP(uint8_t, uint8_t,  cast_u8_u8 )
+CAST_OP(uint8_t, int32_t,  cast_u8_i32 )
 CAST_OP(uint8_t, int64_t,  cast_u8_i64 )
 CAST_OP(uint8_t, float,    cast_u8_f32)
 CAST_OP(uint8_t, double,   cast_u8_f64)
 
 CAST_OP(int64_t, uint32_t, cast_i64_u32)
 CAST_OP(int64_t, uint8_t,  cast_i64_u8 )
+CAST_OP(int64_t, int32_t,  cast_i64_i32 )
 CAST_OP(int64_t, int64_t,  cast_i64_i64 )
 CAST_OP(int64_t, float,    cast_i64_f32)
 CAST_OP(int64_t, double,   cast_i64_f64)
 
+CAST_OP(int32_t, uint32_t, cast_i32_u32)
+CAST_OP(int32_t, uint8_t,  cast_i32_u8 )
+CAST_OP(int32_t, int64_t,  cast_i32_i64 )
+CAST_OP(int32_t, int32_t,  cast_i32_i32 )
+CAST_OP(int32_t, float,    cast_i32_f32)
+CAST_OP(int32_t, double,   cast_i32_f64)
+
 CAST_OP(float, uint8_t,  cast_f32_u8 )
 CAST_OP(float, uint32_t, cast_f32_u32)
+CAST_OP(float, int32_t,  cast_f32_i32 )
 CAST_OP(float, int64_t,  cast_f32_i64 )
 CAST_OP(float, float,    cast_f32_f32)
 CAST_OP(float, double,   cast_f32_f64)
 
 CAST_OP(double, uint8_t,  cast_f64_u8 )
 CAST_OP(double, uint32_t, cast_f64_u32)
+CAST_OP(double, int32_t,  cast_f64_i32 )
 CAST_OP(double, int64_t,  cast_f64_i64 )
 CAST_OP(double, float,    cast_f64_f32)
 CAST_OP(double, double,   cast_f64_f64)
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh
index f7a2506d0e..08aa2b089a 100644
--- a/candle-kernels/src/cuda_utils.cuh
+++ b/candle-kernels/src/cuda_utils.cuh
@@ -181,6 +181,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); }
 __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
 __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
 
+__device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); }
+__device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); }
 __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); }
 __device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); }
 __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu
index ca448d989f..42bfddfd9f 100644
--- a/candle-kernels/src/fill.cu
+++ b/candle-kernels/src/fill.cu
@@ -9,6 +9,7 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
 }
 extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
 extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); }
 extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
 extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
 extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
@@ -34,6 +35,7 @@ COPY2D_OP(float, copy2d_f32)
 COPY2D_OP(double, copy2d_f64)
 COPY2D_OP(uint8_t, copy2d_u8)
 COPY2D_OP(uint32_t, copy2d_u32)
+COPY2D_OP(int32_t, copy2d_i32)
 COPY2D_OP(int64_t, copy2d_i64)
 
 #if __CUDA_ARCH__ >= 530
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu
index 8af2954d13..2f3df4de1b 100644
--- a/candle-kernels/src/indexing.cu
+++ b/candle-kernels/src/indexing.cu
@@ -147,44 +147,61 @@ extern "C" __global__ void FN_NAME(  \
 
 
 #if __CUDA_ARCH__ >= 800
+IS_OP(__nv_bfloat16, int32_t, is_i32_bf16)
 IS_OP(__nv_bfloat16, int64_t, is_i64_bf16)
 IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
 IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
+GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16)
 GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16)
 GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
 GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
+IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16)
 IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16)
 IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16)
 IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
+SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16)
 SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
 SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
 SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
 #endif
 
 #if __CUDA_ARCH__ >= 530
+IS_OP(__half, int32_t, is_i32_f16)
 IS_OP(__half, int64_t, is_i64_f16)
 IS_OP(__half, uint32_t, is_u32_f16)
 IS_OP(__half, uint8_t, is_u8_f16)
+GATHER_OP(__half, int32_t, gather_i32_f16)
 GATHER_OP(__half, int64_t, gather_i64_f16)
 GATHER_OP(__half, uint32_t, gather_u32_f16)
 GATHER_OP(__half, uint8_t, gather_u8_f16)
+IA_OP(__half, int32_t, ia_i32_f16)
 IA_OP(__half, int64_t, ia_i64_f16)
 IA_OP(__half, uint32_t, ia_u32_f16)
 IA_OP(__half, uint8_t, ia_u8_f16)
+SA_OP(__half, int32_t, sa_i32_f16)
 SA_OP(__half, int64_t, sa_i64_f16)
 SA_OP(__half, uint32_t, sa_u32_f16)
 SA_OP(__half, uint8_t, sa_u8_f16)
 #endif
 
+IS_OP(float, int32_t, is_i32_f32)
+IS_OP(double, int32_t, is_i32_f64)
+IS_OP(uint8_t, int32_t, is_i32_u8)
+IS_OP(uint32_t, int32_t, is_i32_u32)
+IS_OP(int32_t, int32_t, is_i32_i32)
+IS_OP(int64_t, int32_t, is_i32_i64)
+
 IS_OP(float, int64_t, is_i64_f32)
 IS_OP(double, int64_t, is_i64_f64)
 IS_OP(uint8_t, int64_t, is_i64_u8)
 IS_OP(uint32_t, int64_t, is_i64_u32)
 IS_OP(int64_t, int64_t, is_i64_i64)
+IS_OP(int32_t, int64_t, is_i64_i32)
 
 IS_OP(float, uint32_t, is_u32_f32)
 IS_OP(double, uint32_t, is_u32_f64)
 IS_OP(uint8_t, uint32_t, is_u32_u8)
+IS_OP(int32_t, uint32_t, is_u32_i32)
 IS_OP(int64_t, uint32_t, is_u32_i64)
 IS_OP(uint32_t, uint32_t, is_u32_u32)
 
@@ -192,17 +209,27 @@ IS_OP(float, uint8_t, is_u8_f32)
 IS_OP(double, uint8_t, is_u8_f64)
 IS_OP(uint8_t, uint8_t, is_u8_u8)
 IS_OP(uint32_t, uint8_t, is_u8_u32)
+IS_OP(int32_t, uint8_t, is_u8_i32)
 IS_OP(int64_t, uint8_t, is_u8_i64)
 
+GATHER_OP(float, int32_t, gather_i32_f32)
+GATHER_OP(double, int32_t, gather_i32_f64)
+GATHER_OP(uint8_t, int32_t, gather_i32_u8)
+GATHER_OP(uint32_t, int32_t, gather_i32_u32)
+GATHER_OP(int32_t, int32_t, gather_i32_i32)
+GATHER_OP(int64_t, int32_t, gather_i32_i64)
+
 GATHER_OP(float, int64_t, gather_i64_f32)
 GATHER_OP(double, int64_t, gather_i64_f64)
 GATHER_OP(uint8_t, int64_t, gather_i64_u8)
 GATHER_OP(uint32_t, int64_t, gather_i64_u32)
 GATHER_OP(int64_t, int64_t, gather_i64_i64)
+GATHER_OP(int32_t, int64_t, gather_i64_i32)
 
 GATHER_OP(float, uint32_t, gather_u32_f32)
 GATHER_OP(double, uint32_t, gather_u32_f64)
 GATHER_OP(uint8_t, uint32_t, gather_u32_u8)
+GATHER_OP(int32_t, uint32_t, gather_u32_i32)
 GATHER_OP(int64_t, uint32_t, gather_u32_i64)
 GATHER_OP(uint32_t, uint32_t, gather_u32_u32)
 
@@ -210,17 +237,26 @@ GATHER_OP(float, uint8_t, gather_u8_f32)
 GATHER_OP(double, uint8_t, gather_u8_f64)
 GATHER_OP(uint8_t, uint8_t, gather_u8_u8)
 GATHER_OP(uint32_t, uint8_t, gather_u8_u32)
+GATHER_OP(int32_t, uint8_t, gather_u8_i32)
 GATHER_OP(int64_t, uint8_t, gather_u8_i64)
 
+IA_OP(float, int32_t, ia_i32_f32)
+IA_OP(double, int32_t, ia_i32_f64)
+IA_OP(uint8_t, int32_t, ia_i32_u8)
+IA_OP(int32_t, int32_t, ia_i32_i32)
+IA_OP(uint32_t, int32_t, ia_i32_u32)
+
 IA_OP(float, int64_t, ia_i64_f32)
 IA_OP(double, int64_t, ia_i64_f64)
 IA_OP(uint8_t, int64_t, ia_i64_u8)
 IA_OP(int64_t, int64_t, ia_i64_i64)
 IA_OP(uint32_t, int64_t, ia_i64_u32)
+IA_OP(int32_t, int64_t, ia_i64_i32)
 
 IA_OP(float, uint32_t, ia_u32_f32)
 IA_OP(double, uint32_t, ia_u32_f64)
 IA_OP(uint8_t, uint32_t, ia_u32_u8)
+IA_OP(int32_t, uint32_t, ia_u32_i32)
 IA_OP(int64_t, uint32_t, ia_u32_i64)
 IA_OP(uint32_t, uint32_t, ia_u32_u32)
 
@@ -228,17 +264,26 @@ IA_OP(float, uint8_t, ia_u8_f32)
 IA_OP(double, uint8_t, ia_u8_f64)
 IA_OP(uint8_t, uint8_t, ia_u8_u8)
 IA_OP(uint32_t, uint8_t, ia_u8_u32)
+IA_OP(int32_t, uint8_t, ia_u8_i32)
 IA_OP(int64_t, uint8_t, ia_u8_i64)
 
+SA_OP(float, int32_t, sa_i32_f32)
+SA_OP(double, int32_t, sa_i32_f64)
+SA_OP(uint8_t, int32_t, sa_i32_u8)
+SA_OP(int32_t, int32_t, sa_i32_i32)
+SA_OP(uint32_t, int32_t, sa_i32_u32)
+
 SA_OP(float, int64_t, sa_i64_f32)
 SA_OP(double, int64_t, sa_i64_f64)
 SA_OP(uint8_t, int64_t, sa_i64_u8)
+SA_OP(int32_t, int64_t, sa_i64_i32)
 SA_OP(int64_t, int64_t, sa_i64_i64)
 SA_OP(uint32_t, int64_t, sa_i64_u32)
 
 SA_OP(float, uint32_t, sa_u32_f32)
 SA_OP(double, uint32_t, sa_u32_f64)
 SA_OP(uint8_t, uint32_t, sa_u32_u8)
+SA_OP(int32_t, uint32_t, sa_u32_i32)
 SA_OP(int64_t, uint32_t, sa_u32_i64)
 SA_OP(uint32_t, uint32_t, sa_u32_u32)
 
@@ -246,4 +291,5 @@ SA_OP(float, uint8_t, sa_u8_f32)
 SA_OP(double, uint8_t, sa_u8_f64)
 SA_OP(uint8_t, uint8_t, sa_u8_u8)
 SA_OP(uint32_t, uint8_t, sa_u8_u32)
+SA_OP(int32_t, uint8_t, sa_u8_i32)
 SA_OP(int64_t, uint8_t, sa_u8_i64)
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index aaac24a146..9a1354a8dc 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -606,5 +606,6 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
 FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
 FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
 FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32)
+FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32)
 FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64)
 FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8)
diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu
index 08f1f9fc29..7fecf8413e 100644
--- a/candle-kernels/src/sort.cu
+++ b/candle-kernels/src/sort.cu
@@ -85,4 +85,5 @@ ASORT_OP(float, f32)
 ASORT_OP(double, f64)
 ASORT_OP(uint8_t, u8)
 ASORT_OP(uint32_t, u32)
+ASORT_OP(int32_t, i32)
 ASORT_OP(int64_t, i64)
diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu
index aaa8a881fb..4617c08fbe 100644
--- a/candle-kernels/src/ternary.cu
+++ b/candle-kernels/src/ternary.cu
@@ -33,17 +33,25 @@ extern "C" __global__ void FN_NAME(  \
 } \
 
 #if __CUDA_ARCH__ >= 800
+WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16)
 WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
 WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
 WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)
 #endif
 
 #if __CUDA_ARCH__ >= 530
+WHERE_OP(__half, int32_t, where_i32_f16)
 WHERE_OP(__half, int64_t, where_i64_f16)
 WHERE_OP(__half, uint32_t, where_u32_f16)
 WHERE_OP(__half, uint8_t, where_u8_f16)
 #endif
 
+WHERE_OP(float, int32_t, where_i32_f32)
+WHERE_OP(double, int32_t, where_i32_f64)
+WHERE_OP(uint8_t, int32_t, where_i32_u8)
+WHERE_OP(uint32_t, int32_t, where_i32_u32)
+WHERE_OP(int32_t, int32_t, where_i32_i64)
+
 WHERE_OP(float, int64_t, where_i64_f32)
 WHERE_OP(double, int64_t, where_i64_f64)
 WHERE_OP(uint8_t, int64_t, where_i64_u8)
@@ -54,10 +62,12 @@ WHERE_OP(float, uint32_t, where_u32_f32)
 WHERE_OP(double, uint32_t, where_u32_f64)
 WHERE_OP(uint8_t, uint32_t, where_u32_u8)
 WHERE_OP(uint32_t, uint32_t, where_u32_u32)
+WHERE_OP(int32_t, uint32_t, where_u32_i32)
 WHERE_OP(int64_t, uint32_t, where_u32_i64)
 
 WHERE_OP(float, uint8_t, where_u8_f32)
 WHERE_OP(double, uint8_t, where_u8_f64)
 WHERE_OP(uint8_t, uint8_t, where_u8_u8)
 WHERE_OP(uint32_t, uint8_t, where_u8_u32)
+WHERE_OP(int32_t, uint8_t, where_u8_i32)
 WHERE_OP(int64_t, uint8_t, where_u8_i64)
diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu
index c82a88375d..21d3d995c0 100644
--- a/candle-kernels/src/unary.cu
+++ b/candle-kernels/src/unary.cu
@@ -153,6 +153,7 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
 
 UNARY_OP(uint8_t, ucopy_u8, x)
 UNARY_OP(uint32_t, ucopy_u32, x)
+UNARY_OP(int32_t, ucopy_i32, x)
 UNARY_OP(int64_t, ucopy_i64, x)
 UNARY_OP(float, ucopy_f32, x)
 UNARY_OP(double, ucopy_f64, x)
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index e83498e40d..a9b8129c3a 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -58,13 +58,15 @@ kernel void FN_NAME_STRIDED( \
 BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
 BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
 BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
-BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \
+BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided);
 
 #define BINARY_OP_OUT(NAME, FN) \
 BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
 BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
 BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
-BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \
+BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided);
 
 #define INT64_BINARY_OP(NAME, FN) \
 BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 2af3fdceb0..c8122ccf0a 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -76,6 +76,7 @@ kernel void FN_NAME_STRIDED( \
 CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
 CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
 CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
+CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t)
 #if __METAL_VERSION__ >= 220
 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
 #endif
@@ -87,6 +88,7 @@ CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
 CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
 CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
 CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
+CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t)
 #if __METAL_VERSION__ >= 220
 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
 #endif
@@ -98,6 +100,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
 CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
 CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
 CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
+CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t)
 CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
 #if defined(__HAVE_BFLOAT__)
 CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
@@ -107,15 +110,27 @@ CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
 CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
 CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
 CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
+CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t)
 CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
 #if defined(__HAVE_BFLOAT__)
 CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
 #endif
 
+// i32
+CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float)
+CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t)
+CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t)
+CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t)
+CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half)
+#if defined(__HAVE_BFLOAT__)
+CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float)
+#endif
+
 // f32
 CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
 CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
 CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
+CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t)
 CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
 #if defined(__HAVE_BFLOAT__)
 CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
@@ -124,6 +139,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
 // bf16
 #if defined(__HAVE_BFLOAT__)
 CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
+CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t)
 CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
 CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
 CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 9eee97ca0a..eaa78d7b73 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -193,6 +193,12 @@ INDEX_OP(is_i64_f16, int64_t, half)
 INDEX_OP(is_i64_bf16, int64_t, bfloat)
 #endif
 
+INDEX_OP(is_i32_f32, int32_t, float)
+INDEX_OP(is_i32_f16, int32_t, half)
+#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_i32_bf16, int32_t, bfloat)
+#endif
+
 INDEX_OP(is_u32_f32, uint32_t, float)
 INDEX_OP(is_u32_f16, uint32_t, half)
 #if defined(__HAVE_BFLOAT__)
@@ -213,9 +219,11 @@ GATHER_OP(gather_u32_bf16, uint, bfloat)
 
 SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
 SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
+SCATTER_ADD_OP(sa_i32_f32, int32_t, float)
 SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
 SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
 SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
+SCATTER_ADD_OP(sa_i32_f16, int32_t, half)
 SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
 #if defined(__HAVE_BFLOAT__)
 SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
@@ -226,6 +234,7 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
 // i64
 INDEX_ADD_OP(ia_i64_f16, int64_t, half)
 INDEX_ADD_OP(ia_i64_f32, int64_t, float)
+INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t)
 INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
 INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
 INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
@@ -233,9 +242,21 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
 INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
 #endif
 
+// i64
+INDEX_ADD_OP(ia_i32_f16, int32_t, half)
+INDEX_ADD_OP(ia_i32_f32, int32_t, float)
+INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t)
+INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t)
+INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t)
+INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t)
+#if defined(__HAVE_BFLOAT__)
+INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat)
+#endif
+
 // u32
 INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
 INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
+INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t)
 INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
 INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
 INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
@@ -246,6 +267,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
 // u8
 INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
 INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
+INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t)
 INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
 INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
 INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a97d327468..d6e6dd69b8 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -46,6 +46,7 @@ pub mod copy2d {
     pub const HALF: Kernel = Kernel("copy2d_f16");
     pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
     pub const I64: Kernel = Kernel("copy2d_i64");
+    pub const I32: Kernel = Kernel("copy2d_i32");
     pub const U32: Kernel = Kernel("copy2d_u32");
     pub const U8: Kernel = Kernel("copy2d_u8");
 }
@@ -62,6 +63,7 @@ macro_rules! ops{
             pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
             pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
             pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
+            pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32"));
             pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
             pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
         }
@@ -72,6 +74,7 @@ macro_rules! ops{
                 pub const HALF: Kernel = Kernel("copy_f16");
                 pub const BFLOAT: Kernel = Kernel("copy_bf16");
                 pub const I64: Kernel = Kernel("copy_i64");
+                pub const I32: Kernel = Kernel("copy_i32");
                 pub const U32: Kernel = Kernel("copy_u32");
                 pub const U8: Kernel = Kernel("copy_u8");
             }
@@ -86,6 +89,7 @@ macro_rules! ops{
             pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
             pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
             pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
+            pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled"));
             pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
             pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
         }
@@ -96,6 +100,7 @@ macro_rules! ops{
                 pub const HALF: Kernel = Kernel("copy_f16_tiled");
                 pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
                 pub const I64: Kernel = Kernel("copy_i64_tiled");
+                pub const I32: Kernel = Kernel("copy_i32_tiled");
                 pub const U32: Kernel = Kernel("copy_u32_tiled");
                 pub const U8: Kernel = Kernel("copy_u8_tiled");
             }
@@ -110,6 +115,7 @@ macro_rules! ops{
             pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
             pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
             pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
+            pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided"));
             pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
             pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
         }
@@ -120,6 +126,7 @@ macro_rules! ops{
                 pub const HALF: Kernel = Kernel("copy_f16_strided");
                 pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
                 pub const I64: Kernel = Kernel("copy_i64_strided");
+                pub const I32: Kernel = Kernel("copy_i32_strided");
                 pub const U32: Kernel = Kernel("copy_u32_strided");
                 pub const U8: Kernel = Kernel("copy_u8_strided");
             }
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index e009ca1d6a..484fa0a1b1 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -602,6 +602,12 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
 ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
 #endif
 
+REDUCE(x + y, fast_sum_i32_strided, int32_t, 0)
+REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX)
+REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN)
+ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX)
+ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN)
+
 #if defined(__HAVE_BFLOAT__)
 REDUCE(x + y, fast_sum_bf16, bfloat, 0)
 REDUCE(x + y, fast_sum_bf16_strided, half, 0)
diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal
index d71ab82234..b7cf71bb58 100644
--- a/candle-metal-kernels/src/sort.metal
+++ b/candle-metal-kernels/src/sort.metal
@@ -88,6 +88,7 @@ ARGSORT(float, f32)
 ARGSORT(half, f16)
 ARGSORT(uint8_t, u8)
 ARGSORT(uint32_t, u32)
+ARGSORT(int32_t, i32)
 
 #if __METAL_VERSION__ >= 220
 ARGSORT(int64_t, i64)
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index fe04f2378f..0e043332fe 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -75,11 +75,25 @@ WHERE_OP(float, int64_t, where_i64_f32)
 WHERE_OP(uint8_t, int64_t, where_i64_u8)
 WHERE_OP(uint32_t, int64_t, where_i64_u32)
 WHERE_OP(int64_t, int64_t, where_i64_i64)
+WHERE_OP(int64_t, int32_t, where_i64_i32)
 #if defined(__HAVE_BFLOAT__)
 WHERE_OP(bfloat, int64_t, where_i64_bf16)
 #endif
 #endif
 
+WHERE_OP(int64_t, uint8_t, where_u8_i32)
+WHERE_OP(int64_t, uint32_t, where_u32_i32)
+
+WHERE_OP(half, int32_t, where_i32_f16)
+WHERE_OP(float, int32_t, where_i32_f32)
+WHERE_OP(uint8_t, int32_t, where_i32_u8)
+WHERE_OP(uint32_t, int32_t, where_i32_u32)
+WHERE_OP(int64_t, int32_t, where_i32_i64)
+WHERE_OP(int32_t, int32_t, where_i32_i32)
+#if defined(__HAVE_BFLOAT__)
+WHERE_OP(bfloat, int32_t, where_i32_bf16)
+#endif
+
 #if defined(__HAVE_BFLOAT__)
 WHERE_OP(bfloat, uint8_t, where_u8_bf16)
 WHERE_OP(bfloat, uint32_t, where_u32_bf16)
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index a82bfdbdd6..0c5a2736ee 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -169,6 +169,9 @@ UNARY(id, int64_t, copy_i64, copy_i64_strided)
 COPY2D(copy2d_i64, int64_t)
 #endif
 
+UNARY(id, int32_t, copy_i32, copy_i32_strided)
+COPY2D(copy2d_i32, int32_t)
+
 #if defined(__HAVE_BFLOAT__)
 BFLOAT_UNARY_OP(cos)
 BFLOAT_UNARY_OP(sin)
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 0da2c70028..55b5542ed8 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -151,6 +151,7 @@ macro_rules! pydtype {
     };
 }
 
+pydtype!(i32, |v| v);
 pydtype!(i64, |v| v);
 pydtype!(u8, |v| v);
 pydtype!(u32, |v| v);
@@ -200,6 +201,7 @@ trait MapDType {
         match t.dtype() {
             DType::U8 => self.f::<u8>(t),
             DType::U32 => self.f::<u32>(t),
+            DType::I32 => self.f::<i32>(t),
             DType::I64 => self.f::<i64>(t),
             DType::BF16 => self.f::<bf16>(t),
             DType::F16 => self.f::<f16>(t),