Skip to content

Commit

Permalink
unify function
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Nov 27, 2023
1 parent 8f1b2e2 commit e24e9c9
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions docs/walk_through.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,23 @@ To explain what happens in eager mode during backward, we have the following imp
@staticmethod
def forward(x0):
x1 = torch.cos(x0)
return x1
return x1, x0
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
print(f"saving tensor of size {x.shape}")
ctx.save_for_backward(x)
x1, x0 = output
print(f"saving tensor of size {x0.shape}")
ctx.save_for_backward(x0)
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
result = (-torch.sin(x)) * grad_output
x0, = ctx.saved_tensors
result = (-torch.sin(x0)) * grad_output
return result
# Wrap Cosine in a function so that it is clearer what the output is
def cosine(x):
y = Cosine.apply(x)
y, x= Cosine.apply(x)
return y
def naive_two_cosine(x0):
Expand Down Expand Up @@ -250,13 +250,13 @@ If we have the computation graph ahead-of-time, we can optimize the computation
def forward(x0):
x1 = torch.cos(x0)
x2 = torch.cos(x1)
return x2
return x2, x0
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
print(f"saving tensor of size {x.shape}")
ctx.save_for_backward(x)
x2, x0 = output
print(f"saving tensor of size {x0.shape}")
ctx.save_for_backward(x0)
@staticmethod
def backward(ctx, grad_x2):
Expand All @@ -268,8 +268,8 @@ If we have the computation graph ahead-of-time, we can optimize the computation
return grad_x0
def optimized_two_cosine(x):
y = OptimizedTwoCosine.apply(x)
return y
x2, x0 = OptimizedTwoCosine.apply(x)
return x2
Running the above function with an input that requires grad, we can see that only one tensor is saved:

Expand Down

0 comments on commit e24e9c9

Please sign in to comment.