Skip to content

Commit

Permalink
fix suggestions from qinguo
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Nov 27, 2023
1 parent 9ab161b commit 19c3ca0
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions docs/walk_through.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ And the following computation graph shows the details of a transformed implement
:width: 1200
:alt: AOT mode autograd

We can only save one value, and recompute the first ``cos`` function to get another value for backward.
We can only save one value, and recompute the first ``cos`` function to get another value for backward. Note that additional computation does not imply more computation time: modern devices like GPU are usually memory-bound, i.e. memory access time dominates the computation time, and it does not matter if we have slightly more computation.

AOTAutograd does the above transformation automatically. In essense, it dynamically generates a function like the following:

Expand Down Expand Up @@ -354,12 +354,7 @@ AOTAutograd does the above transformation automatically. In essense, it dynamica
This way, the saved tensors are made explicit, and the ``AOT_transformed_function`` accepts exactly the same inputs as the original function, while the producing exactly the same output as the original function and having exactly the same backward behavior as the original function.

By varying the amount of ``saved_tensors``, we can:

- Save more tensors for backward, so that backward computation is less heavy.
- Save less tensors for backward, so that the memory footprint of forward is less heavy.

Usually people goes the second way, i.e., saving memory by having more computation in the backward pass. And AOTAutograd will automatically select the optimal way to save memory. To be specific, it uses a `max flow mini cut <https://en.wikipedia.org/wiki/Minimum_cut>`_ algorithm to cut the joint graph into a forward graph and a backward graph. More discussions can be found `at this thread <https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467>`_.
By varying the amount of ``saved_tensors``, we can save less tensors for backward, so that the memory footprint of forward is less heavy. And AOTAutograd will automatically select the optimal way to save memory. To be specific, it uses a `max flow mini cut <https://en.wikipedia.org/wiki/Minimum_cut>`_ algorithm to cut the joint graph into a forward graph and a backward graph. More discussions can be found `at this thread <https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467>`_.

That is basically how AOT Autograd works!

Expand All @@ -368,7 +363,7 @@ Backend: compile and optimize computation graph

Finally, after ``Dynamo`` separates PyTorch code from Python code, and after ``AOTAutograd`` generates the backward computation graph from the forward computation graph, we entered the world of pure computation graphs.

This is how the ``backend`` argument in ``torch.compile`` comes into play. It takes the above computation graphs as input, and generates optimized code that can execute the above computation graphs.
This is how the ``backend`` argument in ``torch.compile`` comes into play. It takes the above computation graphs as input, and generates optimized code that can execute the above computation graphs on different devices.

In general, a backend will try every optimize techniques it knows for the computation graphs. Each optimization technique is called one ``pass``. Some optimization passes from the PyTorch builtin backend, namely the ``Inductor`` backend, can be found `here <https://github.com/pytorch/pytorch/tree/main/torch/_inductor/fx_passes>`_.

Expand Down

0 comments on commit 19c3ca0

Please sign in to comment.