Skip to content

Adding a new element wise (unary) tensor operation

Corey Lowman edited this page Jan 24, 2023 · 2 revisions

https://github.com/coreylowman/dfdx/pull/397 by @Narsil is a good example of all the pieces you'll need.

src/tensor_ops modifications

The first step is to create a new module under src/tensor_ops, in the PR above the gelu module was added:

src/
  tensor_ops/
    <your_new_op>/
      mod.rs
      cpu_kernel.rs
      cuda_kernel.rs
      <your_new_op>.cu

mod.rs will contain the wrappers around the kernels to make it easy to use your op with any device.

Another way to think about it is that mod.rs contains an abstraction layer over different devices.

Additionally mod.rs will contain device agnostic code to test your new op.

Cpu Kernel

For element wise operations, you just need to implement the trait UnaryDerivative<f32> for your op's kernel struct (created in mod.rs).

This trait requires you to implement both the forward function and the derivative of your operation.

See the PR for a good example.

Cuda Kernel

Cuda kernels are a bit more complicated. There are two files needed:

  1. cuda_kernel.rs, which sets up some hooks that tell dfdx how to use the second file, which is
  2. <your_new_op>.cu, which contains the actual cuda kernel code.

Let's look at cuda_kernel.rs first:

use crate::tensor_ops::cuda_kernels::UnaryOpCudaKernel;

unsafe impl cudarc::driver::AsKernelParam for super::GeLUKernelOp {}

impl UnaryOpCudaKernel for super::GeLUKernelOp {
    const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/gelu.ptx"));
    const MODULE_NAME: &'static str = "gelu";
    const FWD_FN_NAME: &'static str = "gelu_forward";
    const BWD_FN_NAME: &'static str = "gelu_backward";
}

AsKernelParam

unsafe impl cudarc::driver::AsKernelParam for super::GeLUKernelOp {} means we can pass the kernel operation to the cuda kernel. This is marked as unsafe because this interacts with FFI, so it can cause undefined behavior if you don't set up the GeLUKernelOp struct correctly.

PTX_SRC

Next we have the very weird line that uses some special rust macros:

    const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/gelu.ptx"));
  1. env!() pastes in an environment variable at build time. OUT_DIR is the directory under target.
  2. concat!() concatenates two strings together at compile time.
  3. "gelu.ptx" is the name of the compiled cuda kernel that will be created by build.rs. It's basically <your_new_op>.cu with ptx replacing the cu.
  4. include_str!() pastes the contents of a file into your binary at compile time

So all that to say is PTX_SRC will contain the source code of the compiled cuda kernel!

FWD_FN_NAME & BWD_FN_NAME

    const FWD_FN_NAME: &'static str = "gelu_forward";
    const BWD_FN_NAME: &'static str = "gelu_backward";

Are the actual function names in your .cu file that dfdx needs to load from the compiled cuda kernel. These should be identical to the function names in the .cu file!.

Actual cuda kernel

See the PR for more clear example, but the functions defined there need to have the exact parameters the ones in the linked PR do.

src/nn modifications

The add for the nn layer is very easy, you can just copy paste the macro invocation under src/nn/activations.rs as shown in the PR, and everything will be set up for you.