Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Instability in Lux Neural Network with CUDA and ComponentArrays #1109

Open
alexiscltrn opened this issue Nov 28, 2024 · 2 comments
Open

Comments

@alexiscltrn
Copy link

I am trying to implement a neural network architecture inspired by the design outlined in this paper. Here's the code I am working with:

using Lux, LuxCUDA, Random, ComponentArrays

function make_chain(input_dim, output_dim, hidden_dim, activation)

    return Chain(
        Parallel(.*, Dense(input_dim => hidden_dim, activation), Dense(input_dim => hidden_dim, activation)),
        SkipConnection(Chain(Parallel(.*, Dense(hidden_dim => hidden_dim, activation), Dense(hidden_dim => hidden_dim, activation)), Parallel(.*, Dense(hidden_dim => hidden_dim, activation), Dense(hidden_dim => hidden_dim, activation))), .+),
        SkipConnection(Chain(Parallel(.*, Dense(hidden_dim => hidden_dim, activation), Dense(hidden_dim => hidden_dim, activation)), Parallel(.*, Dense(hidden_dim => hidden_dim, activation), Dense(hidden_dim => hidden_dim, activation))), .+),
        Dense(hidden_dim => output_dim),
        )

end

chain = make_chain(4, 3, 128, tanh)

ps, st = Lux.setup(Random.default_rng(), chain)

gpu = Lux.gpu_device()
ps = ps |> ComponentArray |> gpu
st = st |> gpu

x = rand(Float32, 4, 100) |> gpu 

@code_warntype chain(x, ps, st)

Problem Description

When running the @code_warntype analysis on the chain function call, I observe that the output type is non-concrete. Here is a snippet of the warning:

Body::Tuple{Any, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), <:Tuple{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, NamedTuple{(:layer_1, :layer_2), <:Tuple{@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Any}}, NamedTuple{(:layer_1, :layer_2), <:Tuple{Any, Any}}, @NamedTuple{}}}}

The issue occurs specifically when ps is stored as a ComponentArray. Note that this works perfectly well on the CPU, or even on the GPU if ps is kept as a tuple instead of a ComponentArray. However, I need this setup to work with Optimization.jl, which requires using ComponentArray for parameter handling.

Here is a snippet of my environment:

[b0b7db55] ComponentArrays v0.15.19
[b2108857] Lux v1.4.0
[d0bbae9a] LuxCUDA v0.3.3
[f1d291b0] MLUtils v0.4.4
[7f7a1694] Optimization v4.0.5
[9a3f8284] Random v1.11.0

Request

Could you help identify the source of this type instability? Is there a workaround to make the output type concrete while maintaining the architecture as defined? Any insights or debugging tips would be greatly appreciated.

@avik-pal
Copy link
Member

avik-pal commented Dec 3, 2024

I would recommend using https://github.com/JuliaDebug/Cthulhu.jl to debug this.

@avik-pal
Copy link
Member

avik-pal commented Dec 3, 2024

It is possible that this type-instability originates from recursion, in which case the best option is to use https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.@compact instead of the built in-layers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants