Skip to content

Tapes and tracking gradients

Corey Lowman edited this page Aug 14, 2022 · 1 revision

NoneTape (& OwnedTape)

These two structs are concrete implementations of the Tape trait. NoneTape just ignores all backward ops & gradients during the forward pass, while OwnedTape stores backward ops inside a Gradient Tape. By default when you create a tensor it has a NoneTape in it.

They indicate whether the tensor CURRENTLY has a tape.

Tracking gradients

To actually insert a tape to start tracking gradients/backward ops you have to call .trace() or .traced(). These copy the tensor and insert an OwnedTape in:

let t: Tensor1D<5, NoneTape> = Tensor1D::new([1.0, 2.0, 3.0]);
let t: Tensor1D<5, OwnedTape> = t.trace(); // start tracking gradients!

Therefore, to track gradients, you must call .trace() before passing the tensor into the module. Specifically:

let module: Linear<3, 10> = Default::default();
let x = Tensor1D::new([1.0, 2.0, 3.0]);

// GRADIENTS ARE NOT TRACKED here because `x` has a `NoneTape`!
// we can't actually call .backward() since there won't be an `OwnedTape`
let y = module.forward(x.clone());

// now gradients are being tracked because `x` has a `OwnedTape`!
let y = module.forward(x.trace());

When calling Module::forward, the module implementation determines what to do when there is or isn't a tape. However all the modules keep at the time of writing this keep whatever tape is in the input in the output.

Tape in Module::forward implementations

For most of the modules in nn, the input tensor has a generic Tape. This means that forward call is called for both NoneTapes and OwnedTapes.

If you wanted to specify Module::forward differently based on what the tape is - you can!

For example the following is an implementation of a module where no gradient tracking is enabled - this is the version of forward that gets called when the input has no tape!

impl Module<Tensor1D<5, NoneTape>> for MyModule {
    type Output = Tensor1D<5, NoneTape>;
    fn forward(&self, x: Tensor1D<5, NoneTape>) -> Self::Output {
         ...
    }
}

And here's the same thing but when gradients are being tracked:

impl Module<Tensor1D<5, OwnedTape>> for MyModule {
    type Output = Tensor1D<5, OwnedTape>;
    fn forward(&self, x: Tensor1D<5, OwnedTape>) -> Self::Output {
         ...
    }
}

For completeness here is the generic version:

impl<TAPE: Tape> Module<Tensor1D<5, TAPE>> for MyModule {
    type Output = Tensor1D<5, TAPE>;
    fn forward(&self, x: Tensor1D<5, TAPE>) -> Self::Output {
         ...
    }
}

Tape in tensor_ops

A majority of the functions in tensor_ops accept a generic tensor (T: Tensor<Dtype = f32>). This means they don't care what the tape actually is. Instead they use the trait function from Tape to add a backward op.

However you can write a tensor op that DOES care about what the tape is. For example the dropout() function doesn't do anything if there's no tape. Here's how that function looks:

pub fn dropout<T: Tensor<Dtype = f32>, R: Rng>(t: T, p: f32, rng: &mut R) -> T {
    if !T::Tape::OWNS_TAPE {
        // This is the branch where `t` doesn't own the tape, so we don't have to drop out anything.
        t
    } else {
       // `t` owns the tape in this branch, so apply dropout randomly.
       ...
    }