-
-
Notifications
You must be signed in to change notification settings - Fork 104
Tapes and tracking gradients
These two structs are concrete implementations of the Tape
trait. NoneTape
just ignores all backward ops & gradients during the forward pass, while OwnedTape
stores backward ops inside a Gradient Tape. By default when you create a tensor it has a NoneTape
in it.
They indicate whether the tensor CURRENTLY has a tape.
To actually insert a tape to start tracking gradients/backward ops you have to call .trace()
or .traced()
. These copy the tensor and insert an OwnedTape
in:
let t: Tensor1D<5, NoneTape> = Tensor1D::new([1.0, 2.0, 3.0]);
let t: Tensor1D<5, OwnedTape> = t.trace(); // start tracking gradients!
Therefore, to track gradients, you must call .trace()
before passing the tensor into the module. Specifically:
let module: Linear<3, 10> = Default::default();
let x = Tensor1D::new([1.0, 2.0, 3.0]);
// GRADIENTS ARE NOT TRACKED here because `x` has a `NoneTape`!
// we can't actually call .backward() since there won't be an `OwnedTape`
let y = module.forward(x.clone());
// now gradients are being tracked because `x` has a `OwnedTape`!
let y = module.forward(x.trace());
When calling Module::forward
, the module implementation determines what to do when there is or isn't a tape. However all the modules keep at the time of writing this keep whatever tape is in the input in the output.
For most of the modules in nn, the input tensor has a generic Tape
. This means that forward call is called for both NoneTape
s and OwnedTape
s.
If you wanted to specify Module::forward
differently based on what the tape is - you can!
For example the following is an implementation of a module where no gradient tracking is enabled - this is the version of forward that gets called when the input has no tape!
impl Module<Tensor1D<5, NoneTape>> for MyModule {
type Output = Tensor1D<5, NoneTape>;
fn forward(&self, x: Tensor1D<5, NoneTape>) -> Self::Output {
...
}
}
And here's the same thing but when gradients are being tracked:
impl Module<Tensor1D<5, OwnedTape>> for MyModule {
type Output = Tensor1D<5, OwnedTape>;
fn forward(&self, x: Tensor1D<5, OwnedTape>) -> Self::Output {
...
}
}
For completeness here is the generic version:
impl<TAPE: Tape> Module<Tensor1D<5, TAPE>> for MyModule {
type Output = Tensor1D<5, TAPE>;
fn forward(&self, x: Tensor1D<5, TAPE>) -> Self::Output {
...
}
}
A majority of the functions in tensor_ops accept a generic tensor (T: Tensor<Dtype = f32>
). This means they don't care what the tape actually is. Instead they use the trait function from Tape
to add a backward op.
However you can write a tensor op that DOES care about what the tape is. For example the dropout() function doesn't do anything if there's no tape. Here's how that function looks:
pub fn dropout<T: Tensor<Dtype = f32>, R: Rng>(t: T, p: f32, rng: &mut R) -> T {
if !T::Tape::OWNS_TAPE {
// This is the branch where `t` doesn't own the tape, so we don't have to drop out anything.
t
} else {
// `t` owns the tape in this branch, so apply dropout randomly.
...
}