Skip to content

Commit

Permalink
Begin to remove ug
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 7, 2025
1 parent e264723 commit b28fc7b
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 110 deletions.
107 changes: 0 additions & 107 deletions candle-core/src/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,110 +375,3 @@ impl Tensor {
)
}
}

pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}

impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}

impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}

fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}

#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;

let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));

encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);

Ok(())
}

#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;

let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}
3 changes: 0 additions & 3 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),

#[error(transparent)]
Ug(#[from] ug::Error),

#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

Expand Down

0 comments on commit b28fc7b

Please sign in to comment.