You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With #10, I get the following timings with NumPy on my Apple M1 Max:
$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40generating: 100%|███████████████████████████████| 40/40 [00:18<00:00, 2.13it/s] the most powerful machines on the planet.The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.python gpt2.py "Alan Turing theorized that computers would one day become" -n 115.74s user 1.71s system 559% cpu 20.993 total
And Jax:
$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40generating: 100%|███████████████████████████████| 40/40 [00:21<00:00, 1.85it/s] the most powerful machines on the planet.The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.python gpt2.py "Alan Turing theorized that computers would one day become" -n 28.86s user 1.91s system 127% cpu 24.115 total
So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?
Curious, I would've expected jax to be faster given that it executes asynchronously (which should effectively make this line out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] parallel, while numpy would execute sequentially since each call is eager and blocking).
Not sure how jax handles multiple CPUs, I know you can manually set multiple CPUs with the environment var export XLA_FLAGS="--xla_force_host_platform_device_count=8", but that didn't yield a speedup for me.
With #10, I get the following timings with NumPy on my Apple M1 Max:
And Jax:
So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?
Here is my Conda environment:
The text was updated successfully, but these errors were encountered: