We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The doc says mu is inferred from grads and updates if mu_dtype=None.
mu
grads
updates
mu_dtype=None
But this line actually turns jnp.bfloat16 and jnp.float16 to jnp.float32 when mu_dtype=None.
jnp.bfloat16
jnp.float16
jnp.float32
Example on GPUs:
>>> jax.__version__ '0.4.4' >>> x.astype(jnp.float16).dtype dtype('float16') >>> x.astype(jnp.float16).astype(None).dtype dtype('float32')
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The doc says
mu
is inferred fromgrads
andupdates
ifmu_dtype=None
.But this line actually turns
jnp.bfloat16
andjnp.float16
tojnp.float32
whenmu_dtype=None
.Example on GPUs:
The text was updated successfully, but these errors were encountered: