diff --git a/docs/walk_through.rst b/docs/walk_through.rst index 14cb9806..e985775c 100644 --- a/docs/walk_through.rst +++ b/docs/walk_through.rst @@ -203,23 +203,23 @@ To explain what happens in eager mode during backward, we have the following imp @staticmethod def forward(x0): x1 = torch.cos(x0) - return x1 + return x1, x0 @staticmethod def setup_context(ctx, inputs, output): - x, = inputs - print(f"saving tensor of size {x.shape}") - ctx.save_for_backward(x) + x1, x0 = output + print(f"saving tensor of size {x0.shape}") + ctx.save_for_backward(x0) @staticmethod def backward(ctx, grad_output): - x, = ctx.saved_tensors - result = (-torch.sin(x)) * grad_output + x0, = ctx.saved_tensors + result = (-torch.sin(x0)) * grad_output return result # Wrap Cosine in a function so that it is clearer what the output is def cosine(x): - y = Cosine.apply(x) + y, x= Cosine.apply(x) return y def naive_two_cosine(x0): @@ -250,13 +250,13 @@ If we have the computation graph ahead-of-time, we can optimize the computation def forward(x0): x1 = torch.cos(x0) x2 = torch.cos(x1) - return x2 + return x2, x0 @staticmethod def setup_context(ctx, inputs, output): - x, = inputs - print(f"saving tensor of size {x.shape}") - ctx.save_for_backward(x) + x2, x0 = output + print(f"saving tensor of size {x0.shape}") + ctx.save_for_backward(x0) @staticmethod def backward(ctx, grad_x2): @@ -268,8 +268,8 @@ If we have the computation graph ahead-of-time, we can optimize the computation return grad_x0 def optimized_two_cosine(x): - y = OptimizedTwoCosine.apply(x) - return y + x2, x0 = OptimizedTwoCosine.apply(x) + return x2 Running the above function with an input that requires grad, we can see that only one tensor is saved: