Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I install optax, I am no longer able to use the GPU #1144

Open
Alessandro-Castelli opened this issue Nov 27, 2024 · 10 comments
Open

When I install optax, I am no longer able to use the GPU #1144

Alessandro-Castelli opened this issue Nov 27, 2024 · 10 comments

Comments

@Alessandro-Castelli
Copy link

Alessandro-Castelli commented Nov 27, 2024

"I have jax 0.4.23. What happens is that when I install optax with the command pip install optax, I get an error message saying 'An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. [CpuDevice(id=0)]'.

This error only occurs after I install optax. What version of optax is compatible with my version of jax and tensorflow 2.9.0?"

Name: jax
Version: 0.4.23
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:

Name: jaxlib
Version: 0.4.23+cuda11.cudnn86
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:
(af) acastelli@leonardo:/media/HDD/acastelli/test2$ `

@vroulet
Copy link
Collaborator

vroulet commented Nov 27, 2024

Hello @Alessandro-Castelli,

pip install jax==0.4.23 jaxlib==0.4.23 optax==0.2.2 will not modify jax or jaxlib versions.
If you used pip install optax, the jax and jaxlib versions are automatically bumped (because of the requirements put in optax 2.3). The real culprit is probably not optax but tensorflow that has some maximal versioning requirements that often mess with other packages. In particular tensorflow 2.9.0 does not seem available from pip (see results from running pip index versions tensorflow on python 3.9.)

@Alessandro-Castelli
Copy link
Author

Alessandro-Castelli commented Nov 28, 2024

Thank you, @vroulet .
Yes, maybe the problem lies in the TensorFlow version. How can I fix it?

Available versions: 2.18.0, 2.17.1, 2.17.0, 2.16.2, 2.16.1, 2.15.1, 2.15.0.post1, 2.15.0, 2.14.1, 2.14.0, 2.13.1, 2.13.0, 2.12.1, 2.12.0, 2.11.1, 2.11.0, 2.10.1, 2.10.0, 2.9.3, 2.9.2, 2.9.1, 2.9.0, 2.8.4, 2.8.3, 2.8.2, 2.8.1, 2.8.0, 2.7.4, 2.7.3, 2.7.2, 2.7.1, 2.7.0, 2.6.5, 2.6.4, 2.6.3, 2.6.2, 2.6.1, 2.6.0, 2.5.3, 2.5.2, 2.5.1, 2.5.0

@Alessandro-Castelli
Copy link
Author

I tried creating two separate conda environments: one where I use TensorFlow to download the dataset and another where I install PennyLane, JAX, JAXlib, and Optax to train the model, but the error still occurs.

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu

@Alessandro-Castelli
Copy link
Author

Alessandro-Castelli commented Nov 28, 2024

At this point, I think that the problem is optax.

Name: jax
Version: 0.4.23
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, optax

Name: jaxlib
Version: 0.4.23+cuda11.cudnn86
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, optax

Name: optax
Version: 0.1.5
Summary: A gradient processing and optimisation library in JAX.
Home-page: https://github.com/deepmind/optax
Author: DeepMind
Author-email: optax-dev@google.com
License: Apache 2.0
Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages
Requires: absl-py, chex, jax, jaxlib, numpy
Required-by:

@vroulet
Copy link
Collaborator

vroulet commented Nov 28, 2024

At this point, I think that the problem is optax.

I really don't think so. Just look at the code in optax. It's quite a lightweight library not related to any cuda gpu functionality. There could have been bumped imports but the above version of optax and jax jaxlib seem good.

I cannot reproduce the error you're mentioning as tensorflow 0.9 is not seem available to me locally, and anyway I don't have a gpu. The error clearly points out to jaxlib not optax.

@Alessandro-Castelli
Copy link
Author

Alessandro-Castelli commented Nov 28, 2024

@vroulet I’ll explain why I think it’s optax. Basically, in my initial code, I was using jax and jaxlib 0.4.23, pennylane, and tensorflow 2.9.0, and I didn’t have any issues installing those versions. At some point, I needed more powerful optimizers like Adam to do some tasks, and that’s when I started using optax. Only from that moment, I encountered the issue:

2024-11-28 17:39:56.095832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu

Maybe it’s optax that doesn’t get along with tensorflow.

Another thing is that when I install optax, it updates jax and jaxlib to version 0.4.30, so maybe that’s the problem.
I really don’t know, I’ve tried many combinations of versions but I just can’t get it to work.
Do you have a suggestions?

@fabianp
Copy link
Member

fabianp commented Nov 29, 2024

I feel the pain, I've been there - versioning between cuda/jax/tensorflow is a mess.

I would suggest having a different virtualenv for jax-based and TF-based projects if you can ....

@Alessandro-Castelli
Copy link
Author

Hello @fabianp, I tried to do it, but I think that the real problem is the versioning between JAX and Optax. I tried many different Optax versions, but I didn't resolve my problem.

@fabianp
Copy link
Member

fabianp commented Nov 29, 2024

have you tried installing optax with --no-deps so it doesn't try to modify the other packages?

@Alessandro-Castelli
Copy link
Author

Yes, but Optax has additional dependencies, and following this approach doesn't seem to work for Optax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants