Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArgumentError in Lux.jl with GPU-Accelerated Neural Network Using LuxCUDA and Zygote #1094

Open
aligurbu opened this issue Nov 19, 2024 · 4 comments · May be fixed by #1097
Open

ArgumentError in Lux.jl with GPU-Accelerated Neural Network Using LuxCUDA and Zygote #1094

aligurbu opened this issue Nov 19, 2024 · 4 comments · May be fixed by #1097

Comments

@aligurbu
Copy link

When running a GPU-accelerated neural network model using Lux.jl, LuxCUDA.jl, and Zygote.jl, the program encounters an ArgumentError when attempting to compute gradients on GPU inputs using Zygote.gradient. The error message indicates device mismatch, specifically an incompatibility between CPUDevice and CUDADevice.

The error arises when using a Chain neural network with GPU parameters and attempting to compute gradients using Zygote.

Reproducible Example:

using Lux
using LuxCUDA
using CUDA
using ComponentArrays
using Random
using Zygote

# Setup
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

# Neural network definition
nn = Chain(
    Dense(3, 20, σ),
    Dense(20, 10, σ),
    Dense(10, 1, tanh)
)

# Initialize parameters
parameters, layer_states = Lux.setup(rng, nn)
gpu_parameters = parameters |> ComponentArray |> gpud

# GPU function
gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]

# Data points
points = rand(rng, Float32, 3, 10)
gpu_points = CuArray(points)

# Gradient computation
CUDA.allowscalar() do
    for kk in axes(gpu_points, 2)
        r = gpu_points[:, kk]
        φ = gpu_NN(r)[1]
        ∇φ = Zygote.gradient(s -> gpu_NN(s)[1], r)[1] # Fails here
    end
end

Error Output:

ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
  [1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:127
  [2] macro expansion
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:205 [inlined]
  [3] unrolled_mapreduce
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:192 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:183
  [5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:155
  [6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\public.jl:388
  [7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ LuxLib C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\traits.jl:210
  [8] ∇activation(Δ::Base.ReshapedArray{…}, out::CuArray{…}, act::typeof(tanh_fast),
 x::LuxLib.Utils.NotaNumber)
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\activation.jl:107
  [9] (::LuxLib.Impl.var"#78#81"{…})(Δ::Base.ReshapedArray{…})
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:51  
 [10] ZBack
    @ C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:212 [inlined]
 [11] fused_dense
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:11 [inlined]    
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [13] fused_dense_bias_activation
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\api\dense.jl:35 [inlined]     
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [15] Dense
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\basic.jl:343 [inlined]    
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [17] apply
    @ C:\Users\aligu\.julia\packages\LuxCore\SN4dl\src\LuxCore.jl:155 [inlined]     
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [19] applychain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:0 [inlined] 
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [21] Chain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:480 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [23] gpu_NN
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:162 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRules.OneElement{Float32, 1, Tuple{…}, Tuple{…}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [25] #16
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173 [inlined]
 [26] (::Zygote.Pullback{Tuple{var"#16#18", CuArray{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:148
 [29] (::var"#15#17")()
    @ Main d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:173
 [30] task_local_storage(body::var"#15#17", key::Symbol, val::GPUArraysCore.ScalarIndexing)
    @ Base .\task.jl:297
 [31] allowscalar(f::Function)
    @ GPUArraysCore C:\Users\aligu\.julia\packages\GPUArraysCore\GMsgk\src\GPUArraysCore.jl:183
 [32] top-level scope
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:169    
Some type information was truncated. Use `show(err)` to see complete types.

ExceptionStack output:
ExceptionStack.txt

Environment:

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 48 × AMD Ryzen Threadripper PRO 5965WX 24-Cores
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 48 default, 0 interactive, 24 GC (on 48 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 48

Lux v1.2.3
LuxCUDA v0.3.3
CUDA v5.5.2
Zygote v0.6.73
ComponentArrays v0.15.17

julia> CUDA.versioninfo()
CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 556.18.0

CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+556.18

Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

2 devices:
  0: NVIDIA RTX A4000 (sm_86, 11.093 GiB / 15.992 GiB available)
  1: NVIDIA RTX A4000 (sm_86, 11.094 GiB / 15.992 GiB available)
@avik-pal
Copy link
Member

Seems to come from ChainRules.OneElement which is very not GPU compatible. Why do you need the allowscalar?

@aligurbu
Copy link
Author

I did not want to use allowscalar in the first place. I tried

∇φ = Zygote.gradient(s -> gpu_NN(s)[1], gpu_points)[1]

But I got this error message:

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.        
Such implementations *do not* execute on the GPU, but very slowly on the CPU,       
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.

I could not figure out how to compute the gradient and Laplacian of a neural network in GPU.
I tried the following lines of code to compute Laplacian as well.

∇²φ = ForwardDiff.hessian(s -> gpu_NN(s)[1], gpu_points) |> diag |> sum
∇²φ = Zygote.hessian(s -> gpu_NN(s)[1], gpu_points) |> diag |> sum

@avik-pal
Copy link
Member

If you just want hessian the correct way to do so would be https://lux.csail.mit.edu/stable/manual/nested_autodiff (these are GPU compatible).

But if you want to use the Laplacian inside a loss function (i.e. you want to compute 3rd-order derivatives), currently we don;t support that (at least in a way I can recommend you to use it). The best way forward (in the near future) would compiling the model with Reactant (https://lux.csail.mit.edu/stable/manual/compiling_lux_models)

@aligurbu
Copy link
Author

That is right. I need to compute the neural network functions, its gradient, and Laplacian to compute the loss function to optimize the parameters.
A couple of days ago, I thought I had managed to train this neural network on the CPU, but it was very slow; then I started working on GPU implementation. But now I cannot make the CPU version work as well. 😭
Anyway, thanks a lot for your help and time.

@avik-pal avik-pal linked a pull request Nov 25, 2024 that will close this issue
@avik-pal avik-pal linked a pull request Nov 25, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants