-
-
Notifications
You must be signed in to change notification settings - Fork 104
Adding a new element wise (unary) tensor operation
https://github.com/coreylowman/dfdx/pull/397 by @Narsil is a good example of all the pieces you'll need.
The first step is to create a new module under src/tensor_ops
, in the PR above the gelu
module was added:
src/
tensor_ops/
<your_new_op>/
mod.rs
cpu_kernel.rs
cuda_kernel.rs
<your_new_op>.cu
mod.rs
will contain the wrappers around the kernels to make it easy to use your op with any device.
Another way to think about it is that mod.rs contains an abstraction layer over different devices.
Additionally mod.rs
will contain device agnostic code to test your new op.
For element wise operations, you just need to implement the trait UnaryDerivative<f32>
for your op's kernel struct (created in mod.rs).
This trait requires you to implement both the forward function and the derivative of your operation.
See the PR for a good example.
Cuda kernels are a bit more complicated. There are two files needed:
- cuda_kernel.rs, which sets up some hooks that tell dfdx how to use the second file, which is
- <your_new_op>.cu, which contains the actual cuda kernel code.
Let's look at cuda_kernel.rs
first:
use crate::tensor_ops::cuda_kernels::UnaryOpCudaKernel;
unsafe impl cudarc::driver::AsKernelParam for super::GeLUKernelOp {}
impl UnaryOpCudaKernel for super::GeLUKernelOp {
const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/gelu.ptx"));
const MODULE_NAME: &'static str = "gelu";
const FWD_FN_NAME: &'static str = "gelu_forward";
const BWD_FN_NAME: &'static str = "gelu_backward";
}
unsafe impl cudarc::driver::AsKernelParam for super::GeLUKernelOp {}
means we can pass the kernel operation to the cuda kernel.
This is marked as unsafe because this interacts with FFI, so it can cause undefined behavior if you don't set up the GeLUKernelOp
struct correctly.
Next we have the very weird line that uses some special rust macros:
const PTX_SRC: &'static str = include_str!(concat!(env!("OUT_DIR"), "/gelu.ptx"));
-
env!()
pastes in an environment variable at build time.OUT_DIR
is the directory under target. -
concat!()
concatenates two strings together at compile time. -
"gelu.ptx"
is the name of the compiled cuda kernel that will be created bybuild.rs
. It's basically<your_new_op>.cu
with ptx replacing the cu. -
include_str!()
pastes the contents of a file into your binary at compile time
So all that to say is PTX_SRC
will contain the source code of the compiled cuda kernel!
const FWD_FN_NAME: &'static str = "gelu_forward";
const BWD_FN_NAME: &'static str = "gelu_backward";
Are the actual function names in your .cu
file that dfdx needs to load from the compiled cuda kernel. These should be identical to the function names in the .cu
file!.
See the PR for more clear example, but the functions defined there need to have the exact parameters the ones in the linked PR do.
The add for the nn
layer is very easy, you can just copy paste the macro invocation under src/nn/activations.rs
as shown in the PR, and everything will be set up for you.