Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numpy version has error checking jax #244

Open
esheldon opened this issue Oct 8, 2024 · 3 comments
Open

numpy version has error checking jax #244

esheldon opened this issue Oct 8, 2024 · 3 comments

Comments

@esheldon
Copy link

esheldon commented Oct 8, 2024

pip install --upgrade muygpys[hnswlib]
[ins] In [1]: import MuyGPyS
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 import MuyGPyS

File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/__init__.py:12
      8 import importlib.metadata
     10 __version__ = importlib.metadata.version(__package__)
---> 12 from MuyGPyS._src.config import (
     13     config as config,
     14     jax_config as jax_config,
     15     MPI as MPI,
     16 )

File ~/miniforge3/lib/python3.10/site-packages/MuyGPyS/_src/config.py:82
     77     config.state.jax_enabled = val
     80 # JAX and GPU states
---> 82 enable_jax = config.define_bool_state(
     83     name="muygpys_jax_enabled",
     84     default=False,
     85     help="Enable use of jax implementations of math functions.",
     86     update_global_hook=_update_jax_global,
     87     update_thread_local_hook=_update_jax_thread_local,
     88 )
     91 def _update_gpu_global(val):
     92     config.state.gpu_enabled = val

AttributeError: 'MuyGPySConfig' object has no attribute 'define_bool_state'
@esheldon
Copy link
Author

esheldon commented Oct 8, 2024

This appears to be because this line:

from jax._src.config import Config as JaxConfig

I do have jax installed and the JaxConfig does not have define_bool_state

Is this a version compatibility issue?

@esheldon
Copy link
Author

esheldon commented Oct 8, 2024

Downgrading to jax 0.4.24 fixed this

@bwpriest
Copy link
Member

bwpriest commented Oct 8, 2024

Thanks @esheldon for investigating. There is a known incompatibility with recent versions of JAX in Python >= 3.9 arising from their config objects. We can fix this in a future release, but in the meantime thank you for identifying a compatible version of JAX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants