From da96d3d8339d75f0510d9ca736bdf732fa4e5767 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 11 Oct 2024 12:44:07 +0200 Subject: [PATCH] Explicitly mark which blox are supported by GraphDynamics (#454) * explicitly mark which blox are supported by GraphDynamics * remove guesses from algebraic states in discrete blox --------- Co-authored-by: harisorgn --- Project.toml | 2 +- docs/src/tutorials/ping_network.jl | 2 +- .../GraphDynamicsInterop.jl | 12 +++++ .../connection_interop.jl | 45 ++++++------------- src/GraphDynamicsInterop/neuron_interop.jl | 7 +++ src/blox/discrete.jl | 10 ++--- test/GraphDynamicsTests/test_suite.jl | 17 +++---- 7 files changed, 46 insertions(+), 49 deletions(-) diff --git a/Project.toml b/Project.toml index 32b12a12..10617220 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Distributions = "0.25.102" ExponentialUtilities = "1" Flux = "0.14" ForwardDiff = "0.10" -GraphDynamics = "0.1.2" +GraphDynamics = "0.1.4" Graphs = "1" Interpolations = "0.14, 0.15" MetaGraphs = "0.7" diff --git a/docs/src/tutorials/ping_network.jl b/docs/src/tutorials/ping_network.jl index bac107f1..fdd3df33 100644 --- a/docs/src/tutorials/ping_network.jl +++ b/docs/src/tutorials/ping_network.jl @@ -93,7 +93,7 @@ for ni1 ∈ inhib end # ## Simulate the network -# Now that we have the neurons and the graph, we can simulate the network. We use the `system_from_graph` function to create a system of ODEs from the graph and then solve it using the DifferentialEquations.jl package. +# Now that we have the neurons and the graph, we can simulate the network. We use the `system_from_graph` function to create a system of ODEs from the graph and then solve it using the DifferentialEquations.jl package, but for performance scaling reasons we will use the experimental option `graphdynamics=true` which uses a separate compilation backend called [GraphDynamics.jl](https://github.com/Neuroblox/GraphDynamics.jl). The GraphDynamics.jl backend is still experimental, and may not yet support all of the standard Neuroblox features, such as those seen in the Spectral DCM tutorial. tspan = (0.0, 300.0) ## Time span for the simulation - run for 300ms to match the Börgers et al. [1] Figure 1. @named sys = system_from_graph(g, graphdynamics=true) ## Use GraphDynamics.jl otherwise this can be a very slow simulation diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index 36725f64..473c518a 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -314,6 +314,17 @@ Could not create a ordered subsystem layout in $(N_tries) attempts, this is like """) end +function check_all_supported_blox(g::MetaDiGraph) + unsupported_blox = filter(vertices(g)) do i + blox = get_blox(g, i) + !issupported(blox) + end + if !isempty(unsupported_blox) + v = unique(typeof.(unsupported_blox)) + error("Got unsupported Blox. The GraphDynamics backend is not compatible with blox of type $(join(v, ", "))") + end +end + """ graphsystem_from_graph(g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0) @@ -327,6 +338,7 @@ of connections, but only if the matrix is also longer than `sparse_length_cutoff situations where tiny matrices like (e.g. 5x5) get stored as sparse arrays rather than dense arrays. """ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0) + check_all_supported_blox(_g) g = flat_graph(_g) total_eltype = mapreduce(promote_type, vertices(g)) do i diff --git a/src/GraphDynamicsInterop/connection_interop.jl b/src/GraphDynamicsInterop/connection_interop.jl index f5c485ee..7261c138 100644 --- a/src/GraphDynamicsInterop/connection_interop.jl +++ b/src/GraphDynamicsInterop/connection_interop.jl @@ -191,10 +191,7 @@ function ((;w, w_gap, w_gap_rev)::HHConnection_GAP)(HH_src::Subsystem{HHNeuronIn end ##---------------------------------------------- -# struct NGEI_HHConnection -# w::Float64 -# end -# Base.zero(::Type{<:NGEI_HHConnection}) = NGEI_HHConnection(0.0) +# Next Generation EI function get_connection(asc_src::NextGenerationEIBlox, HH_dst::Union{HHNeuronExciBlox, HHNeuronInhibBlox}, kwargs) @@ -214,12 +211,7 @@ end #---------------------------------------------- -# Kuramoto Kuramoto -# struct KK_Conn{T} -# w::T -# end -# Base.zero(::Type{KK_Conn{T}}) where {T} = KK_Conn(zero(T)) - +# Kuramoto function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs) (;w_val, name) = generate_weight_param(src, dst, kwargs) (;conn=BasicConnection(w_val), names=[name]) @@ -359,22 +351,6 @@ function GraphDynamics.apply_discrete_event!(integrator, !isnothing(i_dst_exci) && for j_dst ∈ j_dsts_exci states[i_dst_exci][:S_GABA, j_dst] += 1 end - - - - # !isnothing(i_dst_inh) && GraphDynamics.tforeach(j_dsts_inh) do j_dst # for j_dst ∈ j_dsts_inh - # states[i_dst_inh][:S_AMPA, j_dst] += 1 - # end - # !isnothing(i_dst_exci) && GraphDynamics.tforeach(j_dsts_exci) do j_dst # for j_dst ∈ j_dsts_exci - # states[i_dst_exci][:S_AMPA, j_dst] += 1 - # end - # elseif states_src isa SubsystemStates{LIFInhNeuron} - # !isnothing(i_dst_inh) && GraphDynamics.tforeach(j_dsts_inh) do j_dst #for j_dst ∈ j_dsts_inh - # states[i_dst_inh][:S_GABA, j_dst] += 1 - # end - # !isnothing(i_dst_exci) && GraphDynamics.tforeach(j_dsts_exci) do j_dst #for j_dst ∈ j_dsts_exci - # states[i_dst_exci][:S_GABA, j_dst] += 1 - # end else error("this should be unreachable") end @@ -419,6 +395,7 @@ end components(blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = blox.parts +issupported(::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = true function blox_wiring_rule!(g, blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, v, kwargs) neurons = components(blox) for i ∈ eachindex(neurons) @@ -486,6 +463,7 @@ function blox_wiring_rule!(h, wta_src::WinnerTakeAllBlox, wta_dst::WinnerTakeAll end end +issupported(::WinnerTakeAllBlox) = true function blox_wiring_rule!(h, wta::WinnerTakeAllBlox, v, kwargs) i_inh = v[1] inh = wta.parts[1] @@ -507,6 +485,7 @@ end ##---------------------------------------------- # CorticalBlox +issupported(::CorticalBlox) = true components(c::CorticalBlox) = c.parts outer_nameof(c::CorticalBlox) = split(String(namespaced_nameof(c)), '₊') function blox_wiring_rule!(h, c::CorticalBlox, v, kwargs) @@ -560,6 +539,7 @@ end #---------------------------------------------- # Striatum_MSN_Adam +issupported(::Striatum_MSN_Adam) = true components(s::Striatum_MSN_Adam) = s.parts function blox_wiring_rule!(h, s::Striatum_MSN_Adam, v, kwargs) n_inh = s.parts @@ -577,6 +557,7 @@ end #---------------------------------------------- # Striatum_FSI_Adam +issupported(::Striatum_FSI_Adam) = true components(s::Striatum_FSI_Adam) = s.parts function blox_wiring_rule!(h, s::Striatum_FSI_Adam, v, kwargs) n_inh = s.parts @@ -626,6 +607,7 @@ end #---------------------------------------------- # GPe_Adam +issupported(::GPe_Adam) = true components(gpe::GPe_Adam) = gpe.parts function blox_wiring_rule!(h, gpe::GPe_Adam, v, kwargs) n_inh = gpe.parts @@ -643,6 +625,7 @@ end #---------------------------------------------- # STN_Adam +issupported(::STN_Adam) = true components(stn::STN_Adam) = stn.parts function blox_wiring_rule!(h, stn::STN_Adam, v, kwargs) n_inh = stn.parts @@ -894,7 +877,7 @@ GraphDynamics.must_run_before(::Type{Striosome}, ::Type{<:Union{TAN, SNc}}) = tr #---------------------------------------------- # Striatum - Striatum - +issupported(::Striatum) = true components(sta::Striatum) = sta.parts function blox_wiring_rule!(h, str::Striatum, v_src, kwargs) # no internal wiring @@ -1041,15 +1024,15 @@ end function (c::PINGConnection)(blox_src::Subsystem{PINGNeuronExci}, blox_dst::Subsystem{PINGNeuronInhib}) (; w, V_E) = c - (;s) = blox_src - (;V) = blox_dst + (;s) = blox_src + (;V) = blox_dst (; jcn = w * s * (V_E - V)) end function (c::PINGConnection)(blox_src::Subsystem{PINGNeuronInhib}, blox_dst::Subsystem{<:AbstractPINGNeuron}) (; w, V_I) = c - (;s) = blox_src - (;V) = blox_dst + (;s) = blox_src + (;V) = blox_dst (; jcn = w * s * (V_I - V)) end diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 1b0c9fef..3126cc32 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -61,6 +61,7 @@ function define_neurons() input_init = NamedTuple{(input_syms...,)}(ntuple(i -> 0.0, length(inputs))) @eval begin + issupported(::$T) = true GraphDynamics.initialize_input(s::Subsystem{$T}) = $input_init function GraphDynamics.subsystem_differential((; $(p_and_s_syms...),)::Subsystem{$T}, ($(input_syms...),), t) Dneuron = SubsystemStates{$T}( @@ -150,6 +151,7 @@ for T ∈ [:LIFExciNeuron, :LIFInhNeuron] end end +issupported(::PoissonSpikeTrain) = true components(p::PoissonSpikeTrain) = (p,) function to_subsystem(s::PoissonSpikeTrain) states = SubsystemStates{PoissonSpikeTrain, Float64, @NamedTuple{}}((;)) @@ -163,6 +165,7 @@ GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain}) #------------------------- # Kuramoto +issupported(::KuramotoOscillator) = true function to_subsystem(o::KuramotoOscillator) states = SubsystemStates{KuramotoOscillator}((;θ=0.0,)) params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.odesystem.ω), ζ=getdefault(o.odesystem.ζ))) @@ -182,6 +185,7 @@ end #------------------------- # Matrisome +issupported(::Matrisome) = true components(m::Matrisome) = (m,) GraphDynamics.initialize_input(s::Subsystem{Matrisome}) = 0.0 function GraphDynamics.apply_subsystem_differential!(_, m::Subsystem{Matrisome}, jcn, t) @@ -208,6 +212,7 @@ end #------------------------- # Striosome +issupported(::Striosome) = true components(s::Striosome) = (s,) GraphDynamics.initialize_input(s::Subsystem{Striosome}) = 0.0 GraphDynamics.subsystem_differential(s::Subsystem{Striosome}, _, _) = SubsystemStates{Striosome}((;)) @@ -225,6 +230,7 @@ end #------------------------- # TAN +issupported(::TAN) = true components(t::TAN) = (t,) GraphDynamics.initialize_input(s::Subsystem{TAN}) = 0.0 function GraphDynamics.apply_subsystem_differential!(_, s::Subsystem{TAN}, jcn, t) @@ -242,6 +248,7 @@ end #------------------------- # SNc +issupported(::SNc) = true components(s::SNc) = (s,) GraphDynamics.initialize_input(s::Subsystem{SNc}) = (;jcn,) GraphDynamics.subsystem_differential(s::Subsystem{SNc}, _, _) = SubsystemStates{SNc}((;)) diff --git a/src/blox/discrete.jl b/src/blox/discrete.jl index 257c197c..aca753f4 100644 --- a/src/blox/discrete.jl +++ b/src/blox/discrete.jl @@ -58,7 +58,7 @@ struct TAN <: AbstractDiscrete namespace function TAN(; name, namespace=nothing, κ=100, λ=1) - sts = @variables R(t)=κ + sts = @variables R(t) ps = @parameters κ=κ λ=λ jcn=0 [input=true] eqs = [ R ~ min(κ, κ/(λ*jcn + sqrt(eps()))) @@ -78,11 +78,11 @@ struct SNc <: AbstractModulator t_event function SNc(; name, namespace=nothing, κ_DA=1, N_time_blocks=5, DA_reward=10, λ_DA=0.33, t_event=90.0) - sts = @variables R(t)=κ_DA R_(t)=κ_DA - ps = @parameters κ=κ_DA λ_DA=λ_DA jcn=0 [input=true] jcn_=0.0 #HACK: jcn_ stores the value of jcn at time t_event that can be accessed after the simulation + sts = @variables R(t) R_(t) + ps = @parameters κ=κ_DA λ=λ_DA jcn=0 [input=true] jcn_=0.0 #HACK: jcn_ stores the value of jcn at time t_event that can be accessed after the simulation eqs = [ - R ~ min(κ_DA, κ_DA/(λ_DA*jcn + sqrt(eps()))), - R_ ~ min(κ_DA, κ_DA/(λ_DA*jcn_ + sqrt(eps()))) + R ~ min(κ, κ/(λ*jcn + sqrt(eps()))), + R_ ~ min(κ, κ/(λ*jcn_ + sqrt(eps()))) ] R_cb = [[t_event + sqrt(eps(t_event))] => [jcn_ ~ jcn]] diff --git a/test/GraphDynamicsTests/test_suite.jl b/test/GraphDynamicsTests/test_suite.jl index 77f571ec..7179ece3 100644 --- a/test/GraphDynamicsTests/test_suite.jl +++ b/test/GraphDynamicsTests/test_suite.jl @@ -237,8 +237,7 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto dnoise = zero(u0) g(dnoise, u0, p, 1.1) - # This is broken because ImplicitEM uses ForwardDiff which doesnt' work currently - @test_broken solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success + @test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success ens_prob = EnsembleProblem(prob) sols = solve(ens_prob, alg, EnsembleThreads(); trajectories) @@ -309,8 +308,7 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr dnoise = zero(u0) g(dnoise, u0, p, 1.1) - # This is broken because ImplicitEM uses ForwardDiff which doesnt' work currently - @test_broken solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success + @test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success sol = solve(prob, alg, saveat = 0.01) @test sol.retcode == ReturnCode.Success @@ -333,7 +331,6 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr sol_reordered = map(state_names) do name sol[name][end] end - println() sol_reordered, collect(du), collect(dnoise) end @debug "" norm(sol_grp .- sol_mtk) / norm(sol_grp) @@ -509,7 +506,7 @@ function wta_tests() weight = 1.0 density = 0.25 - @testset "WinnerTakeAll network" begin + @testset "WinnerTakeAll network 1" begin g1 = let g = MetaDiGraph() @named wta1 = WinnerTakeAllBlox(;I_bg=I_bg_1, N_exci=N_exci_1, namespace) @named wta2 = WinnerTakeAllBlox(;I_bg=I_bg_2, N_exci=N_exci_2, namespace) @@ -538,7 +535,7 @@ function wta_tests() end test_compare_du_and_sols(ODEProblem, (g1, g2), tspan; rtol=1e-9, alg=Tsit5()) end - @testset "WinnerTakeAll network" begin + @testset "WinnerTakeAll network 2" begin density_1_2 = 0.5 connection_matrix_1_2 = rand(Bernoulli(density_1_2), N_exci_1, N_exci_2) g1 = let g = MetaDiGraph() @@ -629,16 +626,14 @@ function dbs_circuit() end function discrete() - #@testset "Discrete blox" begin - let - g = MetaDiGraph() + @testset "Discrete blox" begin + g = MetaDiGraph() @named n = HHNeuronExciBlox() @named m = Matrisome(t_event=8.0) @named t = TAN() add_blox!.((g,), (n, m, t)) add_edge!(g, 1, 2, :weight, 1.0) add_edge!(g, 3, 2, Dict(:weight => 0.1, :t_event=>5.0)) - test_compare_du_and_sols(ODEProblem, g, (0.0, 20.0), rtol=1e-5, alg=Tsit5()) end end