-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] XLA fails to fuse embedding lookup / array indexing #20899
Comments
Can you paste the HLO snippet around the embedding operation with and without your change? You can get HLO IR using this
|
@patrick-toulme The JAXPRs are linked in this equinox issue. I have attached the HLO dump for both. However, note that there's a huge amount of duplication between the files - but I've zipped up everything for completeness sake. Let me know if you want any other details |
Here is Jnp.take after all passses
This is naive indexing after all passes.
The main difference seems to be in the jnp.take indices are assumed to be in bounds
in naive indexing it is different
|
The take is broken into four kernels, while naive indexing is broken into two. |
https://github.com/patrick-kidger/equinox/blob/7ee4ca944d75c33d1403122f7ccf141bc390a55e/equinox/nn/_embedding.py#L100
I'm using
equinox
, and Internallyeqx.nn.Embedding
is just naively indexing (as shown in above link). However, this is subpar as XLA is unable to fusevmap(embed_layer)
calls, instead doing hundreds of thousands of dynamic slice updates over theweight
array:Zooming in, we see this repetitive block pattern repeated thousands of times:
Instead, we can force
XLA
to fuse by:Which fixes the issue and yields a ~25% improvement in throughput.
Here's a simple colab repro that records 2 tensorboard traces; Note that the blocks for naive lookup are too small so one may have to zoom in into the trace.
Why does
XLA
fail to fuse/parallelize naive indexing compared tojnp.take
?Why is the jaxpr generated by
jnp.take
containingPjit
but the naive indexing does not?If those ops are equivalent, surely
XLA
would be able to optimize them? 🤔Tasks
The text was updated successfully, but these errors were encountered: