diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index c0d97d670..3a85dba9f 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -375,110 +375,3 @@ impl Tensor { ) } } - -pub struct UgIOp1 { - name: &'static str, - #[cfg(feature = "cuda")] - func: cudarc::driver::CudaFunction, - #[cfg(feature = "metal")] - func: metal::ComputePipelineState, -} - -impl UgIOp1 { - #[allow(unused)] - pub fn new( - name: &'static str, - kernel: ug::lang::ssa::Kernel, - device: &crate::Device, - ) -> Result { - #[cfg(feature = "cuda")] - { - let device = device.as_cuda_device()?; - let func = device.compile(name, kernel)?; - Ok(Self { name, func }) - } - #[cfg(feature = "metal")] - { - let device = device.as_metal_device()?; - let func = device.compile(name, kernel)?; - Ok(Self { name, func }) - } - #[cfg(not(any(feature = "cuda", feature = "metal")))] - { - Ok(Self { name }) - } - } -} - -impl InplaceOp1 for UgIOp1 { - fn name(&self) -> &'static str { - self.name - } - - fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on metal/cuda at the moment") - } - - #[cfg(feature = "metal")] - fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { - use crate::backend::BackendStorage; - use candle_metal_kernels::utils::EncoderProvider; - - let elem_count = layout.shape().elem_count(); - if sto.dtype() != crate::DType::F32 { - // TODO: support more dtypes. - crate::bail!("input is not a f32 tensor") - } - let device = sto.device(); - println!("here"); - let command_buffer = device.command_buffer()?; - let command_buffer = &command_buffer; - let encoder = command_buffer.encoder(); - let encoder = encoder.as_ref(); - encoder.set_compute_pipeline_state(&self.func); - let (g, b) = if elem_count % 32 == 0 { - (elem_count / 32, 32) - } else { - (elem_count, 1) - }; - let grid_dims = metal::MTLSize { - width: g as u64, - height: 1, - depth: 1, - }; - let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); - candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); - - encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); - encoder.dispatch_threads(grid_dims, group_dims); - - Ok(()) - } - - #[cfg(feature = "cuda")] - fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { - use crate::cuda_backend::WrapErr; - use cudarc::driver::LaunchAsync; - - let elem_count = layout.shape().elem_count(); - // TODO: support more dtypes. - let sto = sto.as_cuda_slice::()?; - let sto = match layout.contiguous_offsets() { - None => crate::bail!("input has to be contiguous"), - Some((o1, o2)) => sto.slice(o1..o2), - }; - let params = (&sto,); - let (g, b) = if elem_count % 32 == 0 { - (elem_count / 32, 32) - } else { - (elem_count, 1) - }; - let cfg = cudarc::driver::LaunchConfig { - grid_dim: (g as u32, 1, 1), - block_dim: (b as u32, 1, 1), - shared_mem_bytes: 0, - }; - unsafe { self.func.clone().launch(cfg, params) }.w()?; - Ok(()) - } -} diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 506f85afc..ff75ce65b 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -186,9 +186,6 @@ pub enum Error { #[error("Metal error {0}")] Metal(#[from] MetalError), - #[error(transparent)] - Ug(#[from] ug::Error), - #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError),