Skip to content

Commit

Permalink
Explicitly mark which blox are supported by GraphDynamics (#454)
Browse files Browse the repository at this point in the history
* explicitly mark which blox are supported by GraphDynamics

* remove guesses from algebraic states in discrete blox

---------

Co-authored-by: harisorgn <organtzh@gmail.com>
  • Loading branch information
MasonProtter and harisorgn authored Oct 11, 2024
1 parent 64396d5 commit da96d3d
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Distributions = "0.25.102"
ExponentialUtilities = "1"
Flux = "0.14"
ForwardDiff = "0.10"
GraphDynamics = "0.1.2"
GraphDynamics = "0.1.4"
Graphs = "1"
Interpolations = "0.14, 0.15"
MetaGraphs = "0.7"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/ping_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ for ni1 ∈ inhib
end

# ## Simulate the network
# Now that we have the neurons and the graph, we can simulate the network. We use the `system_from_graph` function to create a system of ODEs from the graph and then solve it using the DifferentialEquations.jl package.
# Now that we have the neurons and the graph, we can simulate the network. We use the `system_from_graph` function to create a system of ODEs from the graph and then solve it using the DifferentialEquations.jl package, but for performance scaling reasons we will use the experimental option `graphdynamics=true` which uses a separate compilation backend called [GraphDynamics.jl](https://github.com/Neuroblox/GraphDynamics.jl). The GraphDynamics.jl backend is still experimental, and may not yet support all of the standard Neuroblox features, such as those seen in the Spectral DCM tutorial.

tspan = (0.0, 300.0) ## Time span for the simulation - run for 300ms to match the Börgers et al. [1] Figure 1.
@named sys = system_from_graph(g, graphdynamics=true) ## Use GraphDynamics.jl otherwise this can be a very slow simulation
Expand Down
12 changes: 12 additions & 0 deletions src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,17 @@ Could not create a ordered subsystem layout in $(N_tries) attempts, this is like
""")
end

function check_all_supported_blox(g::MetaDiGraph)
unsupported_blox = filter(vertices(g)) do i
blox = get_blox(g, i)
!issupported(blox)
end
if !isempty(unsupported_blox)
v = unique(typeof.(unsupported_blox))
error("Got unsupported Blox. The GraphDynamics backend is not compatible with blox of type $(join(v, ", "))")
end
end


"""
graphsystem_from_graph(g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0)
Expand All @@ -327,6 +338,7 @@ of connections, but only if the matrix is also longer than `sparse_length_cutoff
situations where tiny matrices like (e.g. 5x5) get stored as sparse arrays rather than dense arrays.
"""
function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0)
check_all_supported_blox(_g)
g = flat_graph(_g)

total_eltype = mapreduce(promote_type, vertices(g)) do i
Expand Down
45 changes: 14 additions & 31 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,7 @@ function ((;w, w_gap, w_gap_rev)::HHConnection_GAP)(HH_src::Subsystem{HHNeuronIn
end

##----------------------------------------------
# struct NGEI_HHConnection
# w::Float64
# end
# Base.zero(::Type{<:NGEI_HHConnection}) = NGEI_HHConnection(0.0)
# Next Generation EI
function get_connection(asc_src::NextGenerationEIBlox,
HH_dst::Union{HHNeuronExciBlox, HHNeuronInhibBlox},
kwargs)
Expand All @@ -214,12 +211,7 @@ end


#----------------------------------------------
# Kuramoto Kuramoto
# struct KK_Conn{T}
# w::T
# end
# Base.zero(::Type{KK_Conn{T}}) where {T} = KK_Conn(zero(T))

# Kuramoto
function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs)
(;w_val, name) = generate_weight_param(src, dst, kwargs)
(;conn=BasicConnection(w_val), names=[name])
Expand Down Expand Up @@ -359,22 +351,6 @@ function GraphDynamics.apply_discrete_event!(integrator,
!isnothing(i_dst_exci) && for j_dst j_dsts_exci
states[i_dst_exci][:S_GABA, j_dst] += 1
end



# !isnothing(i_dst_inh) && GraphDynamics.tforeach(j_dsts_inh) do j_dst # for j_dst ∈ j_dsts_inh
# states[i_dst_inh][:S_AMPA, j_dst] += 1
# end
# !isnothing(i_dst_exci) && GraphDynamics.tforeach(j_dsts_exci) do j_dst # for j_dst ∈ j_dsts_exci
# states[i_dst_exci][:S_AMPA, j_dst] += 1
# end
# elseif states_src isa SubsystemStates{LIFInhNeuron}
# !isnothing(i_dst_inh) && GraphDynamics.tforeach(j_dsts_inh) do j_dst #for j_dst ∈ j_dsts_inh
# states[i_dst_inh][:S_GABA, j_dst] += 1
# end
# !isnothing(i_dst_exci) && GraphDynamics.tforeach(j_dsts_exci) do j_dst #for j_dst ∈ j_dsts_exci
# states[i_dst_exci][:S_GABA, j_dst] += 1
# end
else
error("this should be unreachable")
end
Expand Down Expand Up @@ -419,6 +395,7 @@ end

components(blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = blox.parts

issupported(::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = true
function blox_wiring_rule!(g, blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, v, kwargs)
neurons = components(blox)
for i eachindex(neurons)
Expand Down Expand Up @@ -486,6 +463,7 @@ function blox_wiring_rule!(h, wta_src::WinnerTakeAllBlox, wta_dst::WinnerTakeAll
end
end

issupported(::WinnerTakeAllBlox) = true
function blox_wiring_rule!(h, wta::WinnerTakeAllBlox, v, kwargs)
i_inh = v[1]
inh = wta.parts[1]
Expand All @@ -507,6 +485,7 @@ end

##----------------------------------------------
# CorticalBlox
issupported(::CorticalBlox) = true
components(c::CorticalBlox) = c.parts
outer_nameof(c::CorticalBlox) = split(String(namespaced_nameof(c)), '')
function blox_wiring_rule!(h, c::CorticalBlox, v, kwargs)
Expand Down Expand Up @@ -560,6 +539,7 @@ end

#----------------------------------------------
# Striatum_MSN_Adam
issupported(::Striatum_MSN_Adam) = true
components(s::Striatum_MSN_Adam) = s.parts
function blox_wiring_rule!(h, s::Striatum_MSN_Adam, v, kwargs)
n_inh = s.parts
Expand All @@ -577,6 +557,7 @@ end

#----------------------------------------------
# Striatum_FSI_Adam
issupported(::Striatum_FSI_Adam) = true
components(s::Striatum_FSI_Adam) = s.parts
function blox_wiring_rule!(h, s::Striatum_FSI_Adam, v, kwargs)
n_inh = s.parts
Expand Down Expand Up @@ -626,6 +607,7 @@ end

#----------------------------------------------
# GPe_Adam
issupported(::GPe_Adam) = true
components(gpe::GPe_Adam) = gpe.parts
function blox_wiring_rule!(h, gpe::GPe_Adam, v, kwargs)
n_inh = gpe.parts
Expand All @@ -643,6 +625,7 @@ end

#----------------------------------------------
# STN_Adam
issupported(::STN_Adam) = true
components(stn::STN_Adam) = stn.parts
function blox_wiring_rule!(h, stn::STN_Adam, v, kwargs)
n_inh = stn.parts
Expand Down Expand Up @@ -894,7 +877,7 @@ GraphDynamics.must_run_before(::Type{Striosome}, ::Type{<:Union{TAN, SNc}}) = tr

#----------------------------------------------
# Striatum - Striatum

issupported(::Striatum) = true
components(sta::Striatum) = sta.parts
function blox_wiring_rule!(h, str::Striatum, v_src, kwargs)
# no internal wiring
Expand Down Expand Up @@ -1041,15 +1024,15 @@ end

function (c::PINGConnection)(blox_src::Subsystem{PINGNeuronExci}, blox_dst::Subsystem{PINGNeuronInhib})
(; w, V_E) = c
(;s) = blox_src
(;V) = blox_dst
(;s) = blox_src
(;V) = blox_dst
(; jcn = w * s * (V_E - V))
end

function (c::PINGConnection)(blox_src::Subsystem{PINGNeuronInhib}, blox_dst::Subsystem{<:AbstractPINGNeuron})
(; w, V_I) = c
(;s) = blox_src
(;V) = blox_dst
(;s) = blox_src
(;V) = blox_dst
(; jcn = w * s * (V_I - V))
end

7 changes: 7 additions & 0 deletions src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function define_neurons()
input_init = NamedTuple{(input_syms...,)}(ntuple(i -> 0.0, length(inputs)))

@eval begin
issupported(::$T) = true
GraphDynamics.initialize_input(s::Subsystem{$T}) = $input_init
function GraphDynamics.subsystem_differential((; $(p_and_s_syms...),)::Subsystem{$T}, ($(input_syms...),), t)
Dneuron = SubsystemStates{$T}(
Expand Down Expand Up @@ -150,6 +151,7 @@ for T ∈ [:LIFExciNeuron, :LIFInhNeuron]
end
end

issupported(::PoissonSpikeTrain) = true
components(p::PoissonSpikeTrain) = (p,)
function to_subsystem(s::PoissonSpikeTrain)
states = SubsystemStates{PoissonSpikeTrain, Float64, @NamedTuple{}}((;))
Expand All @@ -163,6 +165,7 @@ GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain})

#-------------------------
# Kuramoto
issupported(::KuramotoOscillator) = true
function to_subsystem(o::KuramotoOscillator)
states = SubsystemStates{KuramotoOscillator}((;θ=0.0,))
params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.odesystem.ω), ζ=getdefault(o.odesystem.ζ)))
Expand All @@ -182,6 +185,7 @@ end

#-------------------------
# Matrisome
issupported(::Matrisome) = true
components(m::Matrisome) = (m,)
GraphDynamics.initialize_input(s::Subsystem{Matrisome}) = 0.0
function GraphDynamics.apply_subsystem_differential!(_, m::Subsystem{Matrisome}, jcn, t)
Expand All @@ -208,6 +212,7 @@ end

#-------------------------
# Striosome
issupported(::Striosome) = true
components(s::Striosome) = (s,)
GraphDynamics.initialize_input(s::Subsystem{Striosome}) = 0.0
GraphDynamics.subsystem_differential(s::Subsystem{Striosome}, _, _) = SubsystemStates{Striosome}((;))
Expand All @@ -225,6 +230,7 @@ end

#-------------------------
# TAN
issupported(::TAN) = true
components(t::TAN) = (t,)
GraphDynamics.initialize_input(s::Subsystem{TAN}) = 0.0
function GraphDynamics.apply_subsystem_differential!(_, s::Subsystem{TAN}, jcn, t)
Expand All @@ -242,6 +248,7 @@ end

#-------------------------
# SNc
issupported(::SNc) = true
components(s::SNc) = (s,)
GraphDynamics.initialize_input(s::Subsystem{SNc}) = (;jcn,)
GraphDynamics.subsystem_differential(s::Subsystem{SNc}, _, _) = SubsystemStates{SNc}((;))
Expand Down
10 changes: 5 additions & 5 deletions src/blox/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct TAN <: AbstractDiscrete
namespace

function TAN(; name, namespace=nothing, κ=100, λ=1)
sts = @variables R(t)=κ
sts = @variables R(t)
ps = @parameters κ=κ λ=λ jcn=0 [input=true]
eqs = [
R ~ min(κ, κ/*jcn + sqrt(eps())))
Expand All @@ -78,11 +78,11 @@ struct SNc <: AbstractModulator
t_event

function SNc(; name, namespace=nothing, κ_DA=1, N_time_blocks=5, DA_reward=10, λ_DA=0.33, t_event=90.0)
sts = @variables R(t)=κ_DA R_(t)=κ_DA
ps = @parameters κ=κ_DA λ_DA=λ_DA jcn=0 [input=true] jcn_=0.0 #HACK: jcn_ stores the value of jcn at time t_event that can be accessed after the simulation
sts = @variables R(t) R_(t)
ps = @parameters κ=κ_DA λ=λ_DA jcn=0 [input=true] jcn_=0.0 #HACK: jcn_ stores the value of jcn at time t_event that can be accessed after the simulation
eqs = [
R ~ min(κ_DA, κ_DA/(λ_DA*jcn + sqrt(eps()))),
R_ ~ min(κ_DA, κ_DA/(λ_DA*jcn_ + sqrt(eps())))
R ~ min(κ, κ/*jcn + sqrt(eps()))),
R_ ~ min(κ, κ/*jcn_ + sqrt(eps())))
]

R_cb = [[t_event + sqrt(eps(t_event))] => [jcn_ ~ jcn]]
Expand Down
17 changes: 6 additions & 11 deletions test/GraphDynamicsTests/test_suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto
dnoise = zero(u0)
g(dnoise, u0, p, 1.1)

# This is broken because ImplicitEM uses ForwardDiff which doesnt' work currently
@test_broken solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success
@test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success

ens_prob = EnsembleProblem(prob)
sols = solve(ens_prob, alg, EnsembleThreads(); trajectories)
Expand Down Expand Up @@ -309,8 +308,7 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr
dnoise = zero(u0)
g(dnoise, u0, p, 1.1)

# This is broken because ImplicitEM uses ForwardDiff which doesnt' work currently
@test_broken solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success
@test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success

sol = solve(prob, alg, saveat = 0.01)
@test sol.retcode == ReturnCode.Success
Expand All @@ -333,7 +331,6 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr
sol_reordered = map(state_names) do name
sol[name][end]
end
println()
sol_reordered, collect(du), collect(dnoise)
end
@debug "" norm(sol_grp .- sol_mtk) / norm(sol_grp)
Expand Down Expand Up @@ -509,7 +506,7 @@ function wta_tests()
weight = 1.0
density = 0.25

@testset "WinnerTakeAll network" begin
@testset "WinnerTakeAll network 1" begin
g1 = let g = MetaDiGraph()
@named wta1 = WinnerTakeAllBlox(;I_bg=I_bg_1, N_exci=N_exci_1, namespace)
@named wta2 = WinnerTakeAllBlox(;I_bg=I_bg_2, N_exci=N_exci_2, namespace)
Expand Down Expand Up @@ -538,7 +535,7 @@ function wta_tests()
end
test_compare_du_and_sols(ODEProblem, (g1, g2), tspan; rtol=1e-9, alg=Tsit5())
end
@testset "WinnerTakeAll network" begin
@testset "WinnerTakeAll network 2" begin
density_1_2 = 0.5
connection_matrix_1_2 = rand(Bernoulli(density_1_2), N_exci_1, N_exci_2)
g1 = let g = MetaDiGraph()
Expand Down Expand Up @@ -629,16 +626,14 @@ function dbs_circuit()
end

function discrete()
#@testset "Discrete blox" begin
let
g = MetaDiGraph()
@testset "Discrete blox" begin
g = MetaDiGraph()
@named n = HHNeuronExciBlox()
@named m = Matrisome(t_event=8.0)
@named t = TAN()
add_blox!.((g,), (n, m, t))
add_edge!(g, 1, 2, :weight, 1.0)
add_edge!(g, 3, 2, Dict(:weight => 0.1, :t_event=>5.0))

test_compare_du_and_sols(ODEProblem, g, (0.0, 20.0), rtol=1e-5, alg=Tsit5())
end
end
Expand Down

0 comments on commit da96d3d

Please sign in to comment.