-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ArgumentError in Lux.jl with GPU-Accelerated Neural Network Using LuxCUDA and Zygote #1094
Comments
Seems to come from ChainRules.OneElement which is very not GPU compatible. Why do you need the allowscalar? |
I did not want to use
But I got this error message:
I could not figure out how to compute the gradient and Laplacian of a neural network in GPU.
|
If you just want hessian the correct way to do so would be https://lux.csail.mit.edu/stable/manual/nested_autodiff (these are GPU compatible). But if you want to use the Laplacian inside a loss function (i.e. you want to compute 3rd-order derivatives), currently we don;t support that (at least in a way I can recommend you to use it). The best way forward (in the near future) would compiling the model with Reactant (https://lux.csail.mit.edu/stable/manual/compiling_lux_models) |
That is right. I need to compute the neural network functions, its gradient, and Laplacian to compute the loss function to optimize the parameters. |
When running a GPU-accelerated neural network model using Lux.jl, LuxCUDA.jl, and Zygote.jl, the program encounters an
ArgumentError
when attempting to compute gradients on GPU inputs usingZygote.gradient
. The error message indicates device mismatch, specifically an incompatibility betweenCPUDevice
andCUDADevice
.The error arises when using a
Chain
neural network with GPU parameters and attempting to compute gradients using Zygote.Reproducible Example:
Error Output:
ExceptionStack output:
ExceptionStack.txt
Environment:
Lux v1.2.3
LuxCUDA v0.3.3
CUDA v5.5.2
Zygote v0.6.73
ComponentArrays v0.15.17
The text was updated successfully, but these errors were encountered: