-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make fp8 compatible with tensor parallelism #65
base: main
Are you sure you want to change the base?
Conversation
Stack from ghstack (oldest at bottom): |
ghstack-source-id: db07e928f48cb886a86e017755ec4372c0f7ec3e ghstack-comment-id: 2566319697 Pull Request resolved: #65
return 1 | ||
|
||
|
||
def mul_tiled(a, *bs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that you need to apply such a function in the test of pytorch/pytorch#143760 to "manually" do tiled multiplication/division to compute scaled results.
Here "if b is m x n" only appears when it's DTensor sub-row-wise scaling, in which case the local tensor of b would always have m x 1
shape. So is it correct that:
- on L38 with
local_map
we can always assume no tiled multiplication is needed; and - on L46 if you're willing to also use
local_map
, tile multiplication can be avoided too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look!
I am learning to use DTensors and I thought it was more idiomatic to express the calculation on the "global" distributed tensor, rather than on the local shard. In order to do so, however, we need to know how many shards there are and reshape accordingly, which arguably isn't that pretty either.
I believe the "fundamental" reason for it is that we're stacking the different components in the wrong order. Here we first replace the matmuls with our custom function, and then we propagate DTensors through it (which means our function needs to know how to handle DTensors). However, I believe the ideal solution would be to first propagate DTensors through some regular matmuls, then take the resulting graph and swap the local matmuls with our function. The issue is that I don't really know how to achieve that, and our code was already written this way before we started supporting DTensors.
(There's also another open question which is how to integrate this with async-TP)
As for local_map
, this is currently an unfortunate implementation detail. Ideally the scaling is supposed to be done by the _scaled_mm
operator internally, which is what it does! However, because the row-wise scaled-mm is slow (when using slow accum), we use the tensor-wise (un)scaled-mm and do the scaling ourselves. If we were able to make the row-wise scaled-mm faster we could avoid local_map
altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis! Makes a lot of sense to me!
No description provided.