From 0d1f1c980485d0a68d4d172f5a6588c875156027 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Wed, 15 Jan 2025 16:59:32 -0500 Subject: [PATCH 01/15] Initial VdP moving --- src/blox/neural_mass.jl | 25 ++++++++++++++++++++++++- src/blox/van_der_pol.jl | 11 ----------- 2 files changed, 24 insertions(+), 12 deletions(-) delete mode 100644 src/blox/van_der_pol.jl diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index 0b6166e0..a3830fed 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -570,4 +570,27 @@ struct QIF_PING_NGNMM <: NeuralMassBlox new(p, sys, namespace) end -end \ No newline at end of file +end + +struct VanderPol <: NeuralMassBlox + params + system + namespace + + function VanderPol(; + name, + namespace=nothing, + θ=1.0, + ϕ=0.1) + p = paramscoping(θ=θ, ϕ=ϕ) + θ, ϕ = p + sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] + @brownian ξ + + eqs = [D(x) ~ y, + D(y) ~ θ*(1-x^2)*y - x + ϕ*ξ + jcn] + + sys = System(eqs, t, sts, p; name=name) + new(p, sys, namespace) + end +end diff --git a/src/blox/van_der_pol.jl b/src/blox/van_der_pol.jl deleted file mode 100644 index 9d99d268..00000000 --- a/src/blox/van_der_pol.jl +++ /dev/null @@ -1,11 +0,0 @@ -function van_der_pol(;name, θ=1.0,ϕ=0.1) - params = @parameters θ=θ ϕ=ϕ - sts = @variables x(t) y(t) - - eqs = [D(x) ~ y, - D(y) ~ θ*(1-x^2)*y - x] - - noiseeqs = [ϕ,ϕ] - - return SDESystem(eqs,noiseeqs,t,sts,params; name=name) -end From b5ff31d9deff4b162bfa3cb02a645eecb7dc7ee2 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Wed, 15 Jan 2025 17:16:27 -0500 Subject: [PATCH 02/15] Split noise vs no noise VdP and add constructor --- src/blox/neural_mass.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index a3830fed..016cb644 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -572,12 +572,46 @@ struct QIF_PING_NGNMM <: NeuralMassBlox end end +function van_der_pol(;name, + namespace=nothing, + θ=1.0, + ϕ=0.1, + noise=false) + if noise + return VanderPolNoise(name=name, namespace=namespace, θ=θ, ϕ=ϕ) + else + return VanderPol(name=name, namespace=namespace, θ=θ) + end +end + + struct VanderPol <: NeuralMassBlox params system namespace function VanderPol(; + name, + namespace=nothing, + θ=1.0) + p = paramscoping(θ=θ, ϕ=ϕ) + θ, ϕ = p + sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] + + eqs = [D(x) ~ y, + D(y) ~ θ*(1-x^2)*y - x + jcn] + + sys = System(eqs, t, sts, p; name=name) + new(p, sys, namespace) + end +end + +struct VanderPolNoise <: NeuralMassBlox + params + system + namespace + + function VanderPolNoise(; name, namespace=nothing, θ=1.0, From b051eac8bca6af4988c79be995d18f1f525199c2 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Wed, 15 Jan 2025 17:41:48 -0500 Subject: [PATCH 03/15] Cleanup and add test --- Project.toml | 1 + src/Neuroblox.jl | 1 - src/blox/neural_mass.jl | 4 ++-- test/components.jl | 19 +++++++++++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 0ff65922..d68a1b09 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.5.7" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 97f40932..10363fe2 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -104,7 +104,6 @@ include("blox/cortical.jl") include("blox/canonicalmicrocircuit.jl") include("blox/neuron_models.jl") include("blox/DBS_Model_Blox_Adam_Brown.jl") -include("blox/van_der_pol.jl") include("blox/ts_outputs.jl") include("blox/sources.jl") include("blox/DBS_sources.jl") diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index 016cb644..71608146 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -594,8 +594,8 @@ struct VanderPol <: NeuralMassBlox name, namespace=nothing, θ=1.0) - p = paramscoping(θ=θ, ϕ=ϕ) - θ, ϕ = p + p = paramscoping(θ=θ) + θ = p[1] sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] eqs = [D(x) ~ y, diff --git a/test/components.jl b/test/components.jl index b2970893..4e3934db 100644 --- a/test/components.jl +++ b/test/components.jl @@ -786,6 +786,25 @@ end @test sol.retcode == ReturnCode.Success end +@testset "VdP" begin + Random.seed!(1234) + @named vdp = van_der_pol() + g = MetaDiGraph() + add_blox!(g, vdp) + @named sys = system_from_graph(g) + prob = ODEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) + sol = solve(prob,Tsit5()) + @test sol.retcode == ReturnCode.Success + + @named vdp = van_der_pol(noise=true) + g = MetaDiGraph() + add_blox!(g, vdp) + @named sys = system_from_graph(g) + prob = SDEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) + sol = solve(prob, RKMil()) + @test sol.retcode == ReturnCode.Success +end + @testset "DBS circuit firing rates" begin @testset "Striatum_MSN_Adam" begin Random.seed!(1234) From 9a351b7ff33ccfc3eb181132cf99c1c3a2755b2c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 16 Jan 2025 00:38:54 +0100 Subject: [PATCH 04/15] remove old VdP test --- test/components.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/components.jl b/test/components.jl index 4e3934db..bb948d0f 100644 --- a/test/components.jl +++ b/test/components.jl @@ -275,14 +275,6 @@ end @test norm.(R[length(R)]) < 0.1 end -@testset "Van der Pol" begin - @named VdP = van_der_pol() - - prob_vdp = SDEProblem(complete(VdP),[0.1,0.1],[0.0, 20.0],[]) - sol = solve(prob_vdp,EM(),dt=0.1) - @test sol.retcode == SciMLBase.ReturnCode.Success -end - """ stochastic.jl test From 12967a55d7e5a68d4e237a88d325a0f30ceca707 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:38:46 -0500 Subject: [PATCH 05/15] Remove CairoMakie from deps also some renaming for aesthetics --- Project.toml | 1 - src/blox/neural_mass.jl | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index d68a1b09..0ff65922 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.5.7" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index 71608146..c7a0101e 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -578,14 +578,14 @@ function van_der_pol(;name, ϕ=0.1, noise=false) if noise - return VanderPolNoise(name=name, namespace=namespace, θ=θ, ϕ=ϕ) + return VanDerPolNoise(name=name, namespace=namespace, θ=θ, ϕ=ϕ) else - return VanderPol(name=name, namespace=namespace, θ=θ) + return VanDerPol(name=name, namespace=namespace, θ=θ) end end -struct VanderPol <: NeuralMassBlox +struct VanDerPol <: NeuralMassBlox params system namespace @@ -606,12 +606,12 @@ struct VanderPol <: NeuralMassBlox end end -struct VanderPolNoise <: NeuralMassBlox +struct VanDerPolNoise <: NeuralMassBlox params system namespace - function VanderPolNoise(; + function VanDerPolNoise(; name, namespace=nothing, θ=1.0, From f3a0db878ec136b8440e610f84c9b625c13c4a55 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Wed, 15 Jan 2025 21:06:09 -0500 Subject: [PATCH 06/15] ah yes well I'm dumb --- src/blox/neural_mass.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index c7a0101e..de344930 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -590,7 +590,7 @@ struct VanDerPol <: NeuralMassBlox system namespace - function VanderPol(; + function VanDerPol(; name, namespace=nothing, θ=1.0) From 3b37df998dfb6a90bab7b12d51c2959fe8877e1a Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 16 Jan 2025 13:53:55 +0100 Subject: [PATCH 07/15] integrate Van der Pol with GraphDynamics and test it --- .../GraphDynamicsInterop.jl | 5 +- src/GraphDynamicsInterop/neuron_interop.jl | 5 +- test/GraphDynamicsTests/runtests.jl | 1 + test/GraphDynamicsTests/test_suite.jl | 56 ++++++++++++++----- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index a087494c..4f4aefe4 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -47,7 +47,10 @@ using ..Neuroblox: LIFInhCircuitBlox, PINGNeuronExci, PINGNeuronInhib, - AbstractPINGNeuron + AbstractPINGNeuron, + VanDerPol, + VanDerPolNoise, + van_der_pol using GraphDynamics: GraphDynamics, diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 2879e4dc..77103167 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -33,6 +33,8 @@ function define_neurons() (:lif_inh, :LIFInhNeuron) (:pexci, :PINGNeuronExci) (:pinhib, :PINGNeuronInhib) + (:VdP, :VanDerPol) + (:VdPN, :VanDerPolNoise) ] sys = getproperty(Neuroblox, T)(;name) system = structural_simplify(sys.system; fully_determined=false) @@ -42,9 +44,6 @@ function define_neurons() states = [s for s ∈ unknowns(system) if !MTK.isinput(s)] inputs = [s for s ∈ unknowns(system) if MTK.isinput(s)] - # states_unwrapped = map(x -> x.f, states) - # inputs_unwrapped = map(x -> x.f, inputs) - p_syms = map(Symbol, params) s_syms = map(x -> tosymbol(x; escape=false), states) input_syms = map(x -> tosymbol(x; escape=false), inputs) diff --git a/test/GraphDynamicsTests/runtests.jl b/test/GraphDynamicsTests/runtests.jl index 0d957460..85d578d9 100644 --- a/test/GraphDynamicsTests/runtests.jl +++ b/test/GraphDynamicsTests/runtests.jl @@ -12,6 +12,7 @@ if GROUP == "All" || GROUP == "GraphDynamics1" end if GROUP == "All" || GROUP == "GraphDynamics2" + vdp_test() kuramato_test() wta_tests() dbs_circuit_components() diff --git a/test/GraphDynamicsTests/test_suite.jl b/test/GraphDynamicsTests/test_suite.jl index b71201bb..9ff59d9b 100644 --- a/test/GraphDynamicsTests/test_suite.jl +++ b/test/GraphDynamicsTests/test_suite.jl @@ -23,6 +23,7 @@ using Base.Iterators: map as imap using GraphDynamics.SymbolicIndexingInterface function test_compare_du_and_sols(::Type{ODEProblem}, g, tspan; + u0map=[], param_map=[], rtol, parallel=true, mtk=true, alg=nothing) if g isa Tuple @@ -34,7 +35,7 @@ function test_compare_du_and_sols(::Type{ODEProblem}, g, tspan; @named gsys = system_from_graph(gl; graphdynamics=true) state_names = variable_symbols(gsys) sol_grp, du_grp = let sys = gsys - prob = ODEProblem(sys, [], tspan) + prob = ODEProblem(sys, u0map, tspan, param_map) (; f, u0, p) = prob du = similar(u0) f(du, u0, p, 1.0) @@ -52,7 +53,7 @@ function test_compare_du_and_sols(::Type{ODEProblem}, g, tspan; if mtk sol_mtk, du_mtk = let @named sys = system_from_graph(gr) - prob = ODEProblem(sys, [], tspan) + prob = ODEProblem(sys, u0map, tspan, param_map) (; f, u0, p) = prob du = similar(u0) f(du, u0, p, 1.0) @@ -80,7 +81,7 @@ function test_compare_du_and_sols(::Type{ODEProblem}, g, tspan; end if parallel sol_grp_p, du_grp_p = let sys = gsys - prob = ODEProblem(sys, [], tspan, scheduler=StaticScheduler()) + prob = ODEProblem(sys, u0map, tspan, param_map, scheduler=StaticScheduler()) (; f, u0, p) = prob du = similar(u0) f(du, u0, p, 1.0) @@ -169,7 +170,7 @@ function neuron_and_neural_mass_comparison_tests() HarmonicOscillator(name=:ho1) HarmonicOscillator(name=:ho2) JansenRit(name=:jr1) - JansenRit(name=:jr2)], + JansenRit(name=:jr2)] ) if length(unknowns(LIFNeuron(;name=:_).system)) > 3 @warn "excluding LIFNeurons from test" @@ -183,15 +184,17 @@ function neuron_and_neural_mass_comparison_tests() add_blox!.((g,), neurons) for i ∈ eachindex(neurons) for j ∈ eachindex(neurons) - if (neurons[i] isa NeuralMassBlox && neurons[j] isa AbstractNeuronBlox) - nothing # Neuroblox doesn't support this currently - elseif neurons[i] isa QIFNeuron && neurons[j] isa QIFNeuron - add_edge!(g, i, j, Dict(:weight => 2*randn(), :connection_rule => "psp")) - elseif neurons[i] isa IFNeuron || neurons[j] isa IFNeuron - add_edge!(g, i, j, Dict(:weight => -rand(), :connection_rule => "basic")) - else - add_edge!(g, i, j, Dict(:weight => 2*randn(), :connection_rule => "basic")) - end + if i != j + if (neurons[i] isa NeuralMassBlox && neurons[j] isa AbstractNeuronBlox) + nothing # Neuroblox doesn't support this currently + elseif neurons[i] isa QIFNeuron && neurons[j] isa QIFNeuron + add_edge!(g, i, j, Dict(:weight => 2*randn(), :connection_rule => "psp")) + elseif neurons[i] isa IFNeuron || neurons[j] isa IFNeuron + add_edge!(g, i, j, Dict(:weight => -rand(), :connection_rule => "basic")) + else + add_edge!(g, i, j, Dict(:weight => 2*randn(), :connection_rule => "basic")) + end + end end end @@ -222,6 +225,28 @@ function basic_hh_network_tests() end end +function vdp_test() + @testset "VdP" begin + Random.seed!(1234) + @named vdp = van_der_pol() + g = MetaDiGraph() + add_blox!(g, vdp) + test_compare_du_and_sols(ODEProblem, g, (0.0, 1.0); u0map=[vdp.x => 0.0, vdp.y=>0.1], rtol=1e-10, alg=Vern7()) + + @named vdpn = van_der_pol(noise=true) + @named vdpn2 = van_der_pol(noise=true) + g = MetaDiGraph() + add_blox!(g, vdpn) + add_blox!(g, vdpn2) + add_edge!(g, 1, 2, :weight, 1.0) + + prob = test_compare_du_and_sols(SDEProblem, g, (0.0, 1.0); + u0map=[vdpn.x => 0.0, vdpn.y=>1.1], rtol=1e-10, alg=RKMil(), seed=123) + end +end + + + function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, trajectories=50_000) Random.seed!(1234) if graph isa Tuple @@ -295,6 +320,7 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto end function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, seed=1234, + u0map=[], param_map=[], sol_comparison_broken=false, f_comparison_broken=false, g_comparison_broken=false) Random.seed!(seed) if graph isa Tuple @@ -306,7 +332,7 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr @named gsys = system_from_graph(graph_l; graphdynamics=true) state_names = variable_symbols(gsys) sol_grp, du_grp, dnoise_grp = let sys = gsys - prob = SDEProblem(sys, [], tspan, [], seed=seed) + prob = SDEProblem(sys, u0map, tspan, param_map, seed=seed) (; f, g, u0, p) = prob du = similar(u0) f(du, u0, p, 1.1) @@ -324,7 +350,7 @@ function test_compare_du_and_sols(::Type{SDEProblem}, graph, tspan; rtol, mtk=tr end if mtk sol_mtk, du_mtk, dnoise_mtk = let neuron_net = system_from_graph(graph_r; name=:neuron_net) - prob = SDEProblem(neuron_net, [], tspan, [], seed=seed) + prob = SDEProblem(neuron_net, u0map, tspan, param_map, seed=seed) (; f, g, u0, p) = prob du = similar(u0) f(du, u0, p, 1.1) From 423d65ce4609f72236692dd62c50006dc72e75db Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:49:07 -0500 Subject: [PATCH 08/15] Update Kuramoto to new system --- src/blox/neural_mass.jl | 52 +++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index de344930..2b3442e6 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -441,7 +441,7 @@ struct Generic2dOscillator <: NeuralMassBlox end """ - KuramotoOscillator(name, namespace, ...) + kuramoto_oscillator(name, namespace, ...) Simple implementation of the Kuramoto oscillator as described in the original paper [1]. Useful for general models of synchronization and oscillatory behavior. @@ -481,33 +481,55 @@ Citations: 2024 Jun 14;199:106565. doi: 10.1016/j.nbd.2024.106565. Epub ahead of print. PMID: 38880431. """ +function kuramoto_oscillator(; name, + namespace=nothing, + ω=249.0, + ζ=5.92, + noise=false) + noise ? return KuramotoOscillatorNoise(name=name, namespace=namespace, ω=ω, ζ=ζ) : return KuramotoOscillator(name=name, namespace=namespace, ω=ω) +end + struct KuramotoOscillator <: NeuralMassBlox params system namespace function KuramotoOscillator(; + name, + namespace=nothing, + ω=249.0 + ) + p = paramscoping(ω=ω) + ω = p[1] + + sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] + eqs = [D(θ) ~ ω + jcn] + sys = System(eqs, t, sts, p; name=name) + new(p, sys, namespace) + end +end + +struct KuramotoOscillatorNoise <: NeuralMassBlox + params + system + namespace + + function KuramotoOscillatorNoise(; name, namespace=nothing, ω=249.0, - ζ=5.92, - include_noise=false + ζ=5.92 ) + p = paramscoping(ω=ω, ζ=ζ) ω, ζ = p - if include_noise - sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] - @brownian w - eqs = [D(θ) ~ ω + ζ * w + jcn] - sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) - else - sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] - eqs = [D(θ) ~ ω + jcn] - sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) - end + sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] + @brownian w + eqs = [D(θ) ~ ω + ζ * w + jcn] + sys = System(eqs, t, sts, p; name=name) + new(p, sys, namespace) + end end From d7c61d62f40a1202e749c4dd39686e33c48a9444 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:50:26 -0500 Subject: [PATCH 09/15] New Kuramoto test --- test/components.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/test/components.jl b/test/components.jl index bb948d0f..3d0130cf 100644 --- a/test/components.jl +++ b/test/components.jl @@ -195,8 +195,8 @@ end end @testset "Kuramoto Oscillator" begin - @named K01 = KuramotoOscillator(ω=2.0) - @named K02 = KuramotoOscillator(ω=5.0) + @named K01 = kuramoto_oscillator(ω=2.0) + @named K02 = kuramoto_oscillator(ω=5.0) adj = [0 1; 1 0] g = MetaDiGraph() @@ -209,6 +209,21 @@ end prob = ODEProblem(sys, [], (0.0, sim_dur), []) sol = solve(prob, AutoVern7(Rodas4()), saveat=0.1) @test sol.retcode == ReturnCode.Success + + @named K01 = kuramoto_oscillator(ω=2.0, noise=true) + @named K02 = kuramoto_oscillator(ω=5.0, noise=true) + + adj = [0 1; 1 0] + g = MetaDiGraph() + add_blox!.(Ref(g), [K01, K02]) + create_adjacency_edges!(g, adj) + + @named sys = system_from_graph(g) + + sim_dur = 1e2 + prob = SDEProblem(sys, [], (0.0, sim_dur), []) + sol = solve(prob, RKMil(), saveat=0.1) + @test sol.retcode == ReturnCode.Success end @testset "Noisy Kuramoto Oscillator" begin From f3d53397013195dd04301d12e238fb9ab12bb830 Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:01:37 -0500 Subject: [PATCH 10/15] `return` not parsed in ternary operator Plus function export and shortened test --- src/Neuroblox.jl | 2 +- src/blox/neural_mass.jl | 5 +++-- test/components.jl | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 10363fe2..9b2b8b63 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -224,7 +224,7 @@ end export Neuron export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory, - hh_neuron_inhibitory, van_der_pol, Generic2dOscillator + hh_neuron_inhibitory, van_der_pol, Generic2dOscillator, kuramoto_oscillator export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron, CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index 2b3442e6..eebd9af3 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -486,7 +486,8 @@ function kuramoto_oscillator(; name, ω=249.0, ζ=5.92, noise=false) - noise ? return KuramotoOscillatorNoise(name=name, namespace=namespace, ω=ω, ζ=ζ) : return KuramotoOscillator(name=name, namespace=namespace, ω=ω) + + noise ? KuramotoOscillatorNoise(name=name, namespace=namespace, ω=ω, ζ=ζ) : KuramotoOscillator(name=name, namespace=namespace, ω=ω) end struct KuramotoOscillator <: NeuralMassBlox @@ -520,7 +521,7 @@ struct KuramotoOscillatorNoise <: NeuralMassBlox ω=249.0, ζ=5.92 ) - + p = paramscoping(ω=ω, ζ=ζ) ω, ζ = p diff --git a/test/components.jl b/test/components.jl index 3d0130cf..cf52cf04 100644 --- a/test/components.jl +++ b/test/components.jl @@ -205,7 +205,7 @@ end @named sys = system_from_graph(g) - sim_dur = 1e2 + sim_dur = 2e1 prob = ODEProblem(sys, [], (0.0, sim_dur), []) sol = solve(prob, AutoVern7(Rodas4()), saveat=0.1) @test sol.retcode == ReturnCode.Success @@ -220,7 +220,7 @@ end @named sys = system_from_graph(g) - sim_dur = 1e2 + sim_dur = 2e1 prob = SDEProblem(sys, [], (0.0, sim_dur), []) sol = solve(prob, RKMil(), saveat=0.1) @test sol.retcode == ReturnCode.Success From 56a18fe24adb0820d1e590527857eb2f3f9e08fe Mon Sep 17 00:00:00 2001 From: agchesebro <76024790+agchesebro@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:19:56 -0500 Subject: [PATCH 11/15] Remove old test --- test/components.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/test/components.jl b/test/components.jl index cf52cf04..00f885c1 100644 --- a/test/components.jl +++ b/test/components.jl @@ -226,23 +226,6 @@ end @test sol.retcode == ReturnCode.Success end -@testset "Noisy Kuramoto Oscillator" begin - @named K01 = KuramotoOscillator(ω=2.0, ζ=5.0, include_noise=true) - @named K02 = KuramotoOscillator(ω=5.0, ζ=2.0, include_noise=true) - - adj = [0 1; 1 0] - g = MetaDiGraph() - add_blox!.(Ref(g), [K01, K02]) - create_adjacency_edges!(g, adj) - - @named sys = system_from_graph(g) - - sim_dur = 1e2 - prob = SDEProblem(sys, [0.1, 0.2], (0.0, sim_dur), []) - sol = solve(prob, RKMil(), saveat=0.1) - @test sol.retcode == ReturnCode.Success -end - @testset "Canonical Micro Circuit network" begin # connect multiple canonical micro circuits according to Figure 4 in Bastos et al. 2015 global_ns = :g # global namespace From ce333e6883b4554139dfd85d60a8b122ed0e5bd5 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 16 Jan 2025 22:00:11 +0100 Subject: [PATCH 12/15] fix bug in noisy Kuramoto connection --- src/blox/connections.jl | 4 ++-- src/blox/neural_mass.jl | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 143e342c..64ce0299 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -486,8 +486,8 @@ function Connector( end function Connector( - blox_src::KuramotoOscillator, - blox_dest::KuramotoOscillator; + blox_src::AbstractKuramotoOscillator, + blox_dest::AbstractKuramotoOscillator; kwargs... ) sys_src = get_namespaced_sys(blox_src) diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index eebd9af3..47659396 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -490,7 +490,9 @@ function kuramoto_oscillator(; name, noise ? KuramotoOscillatorNoise(name=name, namespace=namespace, ω=ω, ζ=ζ) : KuramotoOscillator(name=name, namespace=namespace, ω=ω) end -struct KuramotoOscillator <: NeuralMassBlox +abstract type AbstractKuramotoOscillator <: NeuralMassBlox end + +struct KuramotoOscillator <: AbstractKuramotoOscillator params system namespace @@ -510,7 +512,7 @@ struct KuramotoOscillator <: NeuralMassBlox end end -struct KuramotoOscillatorNoise <: NeuralMassBlox +struct KuramotoOscillatorNoise <: AbstractKuramotoOscillator params system namespace From b7da64c55b8bcddd99f5f93b49ef2a4eaa2d8a16 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 16 Jan 2025 22:10:09 +0100 Subject: [PATCH 13/15] adapt GraphDynamics code to Kuramoto cleanup --- .../GraphDynamicsInterop.jl | 2 + .../connection_interop.jl | 7 ++- src/GraphDynamicsInterop/neuron_interop.jl | 61 ++++++++++--------- test/GraphDynamicsTests/runtests.jl | 2 +- test/GraphDynamicsTests/test_suite.jl | 52 ++++++++-------- 5 files changed, 66 insertions(+), 58 deletions(-) diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index 4f4aefe4..e6e8ba7d 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -30,7 +30,9 @@ using ..Neuroblox: Matrisome, TAN, SNc, + AbstractKuramotoOscillator, KuramotoOscillator, + KuramotoOscillatorNoise, CorticalBlox, STN, Thalamus, diff --git a/src/GraphDynamicsInterop/connection_interop.jl b/src/GraphDynamicsInterop/connection_interop.jl index 914525d4..07a75bdf 100644 --- a/src/GraphDynamicsInterop/connection_interop.jl +++ b/src/GraphDynamicsInterop/connection_interop.jl @@ -213,16 +213,17 @@ end #---------------------------------------------- # Kuramoto -function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs) +function get_connection(src::AbstractKuramotoOscillator, dst::AbstractKuramotoOscillator, kwargs) (;w_val, name) = generate_weight_param(src, dst, kwargs) (;conn=BasicConnection(w_val), names=[name]) end -function (c::BasicConnection)(src::Subsystem{KuramotoOscillator}, dst::Subsystem{KuramotoOscillator}) +function (c::BasicConnection)(src::Subsystem{<:AbstractKuramotoOscillator}, + dst::Subsystem{<:AbstractKuramotoOscillator}) w = c.weight x₀ = src.θ xᵢ = dst.θ - w * sin(x₀ - xᵢ) + (;jcn = w * sin(x₀ - xᵢ)) end #---------------------------------------------- diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 77103167..2f64e5cf 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -35,6 +35,8 @@ function define_neurons() (:pinhib, :PINGNeuronInhib) (:VdP, :VanDerPol) (:VdPN, :VanDerPolNoise) + (:ko, :KuramotoOscillator) + (:kon, :KuramotoOscillatorNoise) ] sys = getproperty(Neuroblox, T)(;name) system = structural_simplify(sys.system; fully_determined=false) @@ -165,18 +167,6 @@ function define_neurons() end define_neurons() # it's useful when developing this module to have these in a function -#Maybe should just encorporate this into define_neurons() -# for T ∈ [:LIFExciNeuron, :LIFInhNeuron] -# @eval begin -# GraphDynamics.has_discrete_events(::Type{$T}) = true -# GraphDynamics.discrete_event_condition((; t_refract_end)::Subsystem{$T}, t) = t_refract_end == t -# function GraphDynamics.apply_discrete_event!(integrator, _, pview, neuron::Subsystem{$T}, _) -# params = get_params(neuron) -# pview[] = @set params.is_refractory = 0 -# end -# end -# end - issupported(::PoissonSpikeTrain) = true components(p::PoissonSpikeTrain) = (p,) function to_subsystem(s::PoissonSpikeTrain) @@ -189,24 +179,37 @@ GraphDynamics.apply_subsystem_differential!(_, ::Subsystem{PoissonSpikeTrain}, _ GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain}) = false -#------------------------- -# Kuramoto -issupported(::KuramotoOscillator) = true -function to_subsystem(o::KuramotoOscillator) - states = SubsystemStates{KuramotoOscillator}((;θ=0.0,)) - params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ))) - Subsystem(states, params) -end +# #------------------------- +# # Kuramoto +# issupported(::KuramotoOscillator) = true +# function to_subsystem(o::KuramotoOscillator) +# states = SubsystemStates{KuramotoOscillator}((;θ=0.0,)) +# params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω))) +# Subsystem(states, params) +# end -GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = 0.0 -function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, jcn, t) - SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn)) -end -GraphDynamics.isstochastic(::Type{KuramotoOscillator}) = true -function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillator}, t) - vstates[1] = ζ - nothing -end +# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = (;jcn = 0.0) +# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, (; jcn), t) +# SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn)) +# end + + +# issupported(::KuramotoOscillatorNoise) = true +# function to_subsystem(o::KuramotoOscillatorNoise) +# states = SubsystemStates{KuramotoOscillatorNoise}((;θ=0.0,)) +# params = SubsystemParams{KuramotoOscillatorNoise}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ))) +# Subsystem(states, params) +# end + +# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillatorNoise}) = (;jcn=0.0) +# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillatorNoise}, (;jcn), t) +# SubsystemStates{KuramotoOscillatorNoise}((; #=D=#θ= s.ω + jcn)) +# end +# GraphDynamics.isstochastic(::Type{KuramotoOscillatorNoise}) = true +# function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillatorNoise}, t) +# vstates[1] = ζ +# nothing +# end #------------------------- diff --git a/test/GraphDynamicsTests/runtests.jl b/test/GraphDynamicsTests/runtests.jl index 85d578d9..7569e328 100644 --- a/test/GraphDynamicsTests/runtests.jl +++ b/test/GraphDynamicsTests/runtests.jl @@ -13,7 +13,7 @@ end if GROUP == "All" || GROUP == "GraphDynamics2" vdp_test() - kuramato_test() + kuramoto_test() wta_tests() dbs_circuit_components() dbs_circuit() diff --git a/test/GraphDynamicsTests/test_suite.jl b/test/GraphDynamicsTests/test_suite.jl index 9ff59d9b..b9fceeae 100644 --- a/test/GraphDynamicsTests/test_suite.jl +++ b/test/GraphDynamicsTests/test_suite.jl @@ -247,13 +247,13 @@ end -function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, trajectories=50_000) - Random.seed!(1234) +function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, trajectories=100_000) + # Random.seed!(1234) if graph isa Tuple (graph_l, graph_r) = graph else - graph_l = g - graph_r = g + graph_l = graph + graph_r = graph end @named gsys = system_from_graph(graph_l; graphdynamics=true) @@ -267,8 +267,6 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto dnoise = zero(u0) g(dnoise, u0, p, 1.1) - @test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success - ens_prob = EnsembleProblem(prob) sols = solve(ens_prob, alg, EnsembleThreads(); trajectories) @@ -313,6 +311,7 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto end @test sort(du_grp) ≈ sort(du_mtk) #due to the MTK getu bug, we'll compare the sorted versions @test sort(dnoise_grp) ≈ sort(dnoise_mtk) #due to the MTK getu bug, we'll compare the sorted versions + @debug "" norm(mean(sol_grp_ens) .- mean(sol_mtk_ens)) / norm(mean(sol_grp_ens)) @test mean(sol_grp_ens) ≈ mean(sol_mtk_ens) rtol=rtol @test std(sol_grp_ens) ≈ std(sol_mtk_ens) rtol=rtol end @@ -438,27 +437,30 @@ function ngei_test() end end -function kuramato_test() - @testset "Kuramoto" begin - N = 2 - # Define the natural distribution of oscillator frequencies - Ω = 249 - σ = 26.317 - ks_blocks = [KuramotoOscillator(name=Symbol("KO$i"), - ω=rand(Normal(Ω, σ)), - ζ=5.920, - include_noise=true) for i in 1:N] - # Create a graph and add all the oscillators to it - g = MetaDiGraph() - add_blox!.(Ref(g), ks_blocks) +function kuramoto_test() + @testset "Kuramoto Oscillator" begin + @testset "Non-noisy" begin + @named K01 = kuramoto_oscillator(ω=2.0) + @named K02 = kuramoto_oscillator(ω=5.0) - # Connect all oscillators to each other - for i in 1:N - for j in 1:N - add_edge!(g, i, j, Dict(:weight => 1.0)) - end + adj = [0 1; 1 0] + g = MetaDiGraph() + add_blox!.(Ref(g), [K01, K02]) + create_adjacency_edges!(g, adj) + + test_compare_du_and_sols(ODEProblem, g, (0.0, 2.0); rtol=1e-10, alg=AutoVern7(Rodas4())) + end + @testset "Noisy" begin + @named K01 = kuramoto_oscillator(ω=2.0, noise=true) + @named K02 = kuramoto_oscillator(ω=5.0, noise=true) + + adj = [0 1; 1 0] + g = MetaDiGraph() + add_blox!.(Ref(g), [K01, K02]) + create_adjacency_edges!(g, adj) + + test_compare_du_and_sols(SDEProblem, g, (0.0, 2.0); rtol=1e-10, alg=RKMil()) end - test_compare_du_and_sols(SDEProblem, g, (0.0, 2.0), rtol=0.05, alg=RKMil()) end end From 178674106a23187795def4f3ee94a4e274a28ad6 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 17 Jan 2025 17:16:05 +0100 Subject: [PATCH 14/15] make neural-mass housekeeping PR non-breaking (#532) --- .../GraphDynamicsInterop.jl | 8 +- .../connection_interop.jl | 6 +- src/GraphDynamicsInterop/neuron_interop.jl | 49 ++++--- src/Neuroblox.jl | 2 +- src/blox/connections.jl | 4 +- src/blox/neural_mass.jl | 132 +++++++----------- test/GraphDynamicsTests/test_suite.jl | 14 +- test/components.jl | 78 ++++++----- 8 files changed, 130 insertions(+), 163 deletions(-) diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index e6e8ba7d..4a6e8021 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -30,9 +30,9 @@ using ..Neuroblox: Matrisome, TAN, SNc, - AbstractKuramotoOscillator, + Noisy, + NonNoisy, KuramotoOscillator, - KuramotoOscillatorNoise, CorticalBlox, STN, Thalamus, @@ -50,9 +50,7 @@ using ..Neuroblox: PINGNeuronExci, PINGNeuronInhib, AbstractPINGNeuron, - VanDerPol, - VanDerPolNoise, - van_der_pol + VanDerPol using GraphDynamics: GraphDynamics, diff --git a/src/GraphDynamicsInterop/connection_interop.jl b/src/GraphDynamicsInterop/connection_interop.jl index 07a75bdf..2344a193 100644 --- a/src/GraphDynamicsInterop/connection_interop.jl +++ b/src/GraphDynamicsInterop/connection_interop.jl @@ -213,13 +213,13 @@ end #---------------------------------------------- # Kuramoto -function get_connection(src::AbstractKuramotoOscillator, dst::AbstractKuramotoOscillator, kwargs) +function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs) (;w_val, name) = generate_weight_param(src, dst, kwargs) (;conn=BasicConnection(w_val), names=[name]) end -function (c::BasicConnection)(src::Subsystem{<:AbstractKuramotoOscillator}, - dst::Subsystem{<:AbstractKuramotoOscillator}) +function (c::BasicConnection)(src::Subsystem{<:KuramotoOscillator}, + dst::Subsystem{<:KuramotoOscillator}) w = c.weight x₀ = src.θ xᵢ = dst.θ diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 2f64e5cf..65ab85e4 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -15,30 +15,29 @@ end function define_neurons() - for (name, T) ∈ [(:hhne, :HHNeuronExciBlox) - (:hhni, :HHNeuronInhibBlox) - (:hhni_msn_adam, :HHNeuronInhib_MSN_Adam_Blox) - (:hhni_fsi_adam, :HHNeuronInhib_FSI_Adam_Blox) - (:hhne_stn_adam, :HHNeuronExci_STN_Adam_Blox) - (:hhni_GPe_adam, :HHNeuronInhib_GPe_Adam_Blox) - (:ngei, :NextGenerationEIBlox) - (:wc, :WilsonCowan) - (:ho, :HarmonicOscillator) - (:jr, :JansenRit) # Note! Regular JansenRit can support delays, and I have not yet implemented this! - (:if, :IFNeuron) - (:lif, :LIFNeuron) - (:qif, :QIFNeuron) - (:izh, :IzhikevichNeuron) - (:lif_exci, :LIFExciNeuron) - (:lif_inh, :LIFInhNeuron) - (:pexci, :PINGNeuronExci) - (:pinhib, :PINGNeuronInhib) - (:VdP, :VanDerPol) - (:VdPN, :VanDerPolNoise) - (:ko, :KuramotoOscillator) - (:kon, :KuramotoOscillatorNoise) - ] - sys = getproperty(Neuroblox, T)(;name) + for (name, T) ∈ [(:hhne, HHNeuronExciBlox) + (:hhni, HHNeuronInhibBlox) + (:hhni_msn_adam, HHNeuronInhib_MSN_Adam_Blox) + (:hhni_fsi_adam, HHNeuronInhib_FSI_Adam_Blox) + (:hhne_stn_adam, HHNeuronExci_STN_Adam_Blox) + (:hhni_GPe_adam, HHNeuronInhib_GPe_Adam_Blox) + (:ngei, NextGenerationEIBlox) + (:wc, WilsonCowan) + (:ho, HarmonicOscillator) + (:jr, JansenRit) # Note! Regular JansenRit can support delays, and I have not yet implemented this! + (:if, IFNeuron) + (:lif, LIFNeuron) + (:qif, QIFNeuron) + (:izh, IzhikevichNeuron) + (:lif_exci, LIFExciNeuron) + (:lif_inh, LIFInhNeuron) + (:pexci, PINGNeuronExci) + (:pinhib, PINGNeuronInhib) + (:VdP, VanDerPol{NonNoisy}) + (:VdPN, VanDerPol{Noisy}) + (:ko, KuramotoOscillator{NonNoisy}) + (:kon, KuramotoOscillator{Noisy})] + sys = T(;name) system = structural_simplify(sys.system; fully_determined=false) params = get_ps(system) t = Symbol(get_iv(system)) @@ -138,7 +137,7 @@ function define_neurons() end end end - if !isempty(get_discrete_events(system)) && T ∉ (:LIFExciNeuron, :LIFInhNeuron) + if !isempty(get_discrete_events(system)) && T ∉ (LIFExciNeuron, LIFInhNeuron) cb = only(collect(get_discrete_events(system))) # currently only support single events cb_eq = r(cb.condition) if cb_eq.f ∉ (<, >, <=, >=) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 9b2b8b63..b469e209 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -224,7 +224,7 @@ end export Neuron export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory, - hh_neuron_inhibitory, van_der_pol, Generic2dOscillator, kuramoto_oscillator + hh_neuron_inhibitory, VanDerPol, Generic2dOscillator, kuramoto_oscillator export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron, CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 64ce0299..143e342c 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -486,8 +486,8 @@ function Connector( end function Connector( - blox_src::AbstractKuramotoOscillator, - blox_dest::AbstractKuramotoOscillator; + blox_src::KuramotoOscillator, + blox_dest::KuramotoOscillator; kwargs... ) sys_src = get_namespaced_sys(blox_src) diff --git a/src/blox/neural_mass.jl b/src/blox/neural_mass.jl index 47659396..e102986d 100644 --- a/src/blox/neural_mass.jl +++ b/src/blox/neural_mass.jl @@ -1,3 +1,6 @@ +struct Noisy end +struct NonNoisy end + mutable struct NextGenerationBlox <: NeuralMassBlox C::Num Δ::Num @@ -441,7 +444,7 @@ struct Generic2dOscillator <: NeuralMassBlox end """ - kuramoto_oscillator(name, namespace, ...) + KuramotoOscillator(name, namespace, ...) Simple implementation of the Kuramoto oscillator as described in the original paper [1]. Useful for general models of synchronization and oscillatory behavior. @@ -465,9 +468,10 @@ end where \$W_i\$ is a Wiener process and \$\\zeta_i\$ is the noise strength. -Arguments: +Keyword arguments: - name: Name given to ODESystem object within the blox. - namespace: Additional namespace above name if needed for inheritance. +- `include_noise` (default `false`) determines if brownian noise is included in the dynamics of the blox. - Other parameters: See reference for full list. Note that parameters are scaled so that units of time are in milliseconds. Default parameter values are taken from [2]. @@ -481,58 +485,38 @@ Citations: 2024 Jun 14;199:106565. doi: 10.1016/j.nbd.2024.106565. Epub ahead of print. PMID: 38880431. """ -function kuramoto_oscillator(; name, - namespace=nothing, - ω=249.0, - ζ=5.92, - noise=false) - - noise ? KuramotoOscillatorNoise(name=name, namespace=namespace, ω=ω, ζ=ζ) : KuramotoOscillator(name=name, namespace=namespace, ω=ω) -end - -abstract type AbstractKuramotoOscillator <: NeuralMassBlox end - -struct KuramotoOscillator <: AbstractKuramotoOscillator +struct KuramotoOscillator{IsNoisy} <: NeuralMassBlox params system namespace - function KuramotoOscillator(; - name, - namespace=nothing, - ω=249.0 - ) - p = paramscoping(ω=ω) - ω = p[1] - - sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] - eqs = [D(θ) ~ ω + jcn] - sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) + function KuramotoOscillator(; name, + namespace=nothing, + ω=249.0, + ζ=5.92, + include_noise=false) + if include_noise + KuramotoOscillator{Noisy}(;name, namespace, ω, ζ) + else + KuramotoOscillator{NonNoisy}(;name, namespace, ω) + end end -end - -struct KuramotoOscillatorNoise <: AbstractKuramotoOscillator - params - system - namespace - - function KuramotoOscillatorNoise(; - name, - namespace=nothing, - ω=249.0, - ζ=5.92 - ) - + function KuramotoOscillator{Noisy}(;name, namespace=nothing, ω=249.0, ζ=5.92) p = paramscoping(ω=ω, ζ=ζ) ω, ζ = p - sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] @brownian w - eqs = [D(θ) ~ ω + ζ * w + jcn] + eqs = [D(θ) ~ ω + jcn + ζ*w] sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) - + new{Noisy}(p, sys, namespace) + end + function KuramotoOscillator{NonNoisy}(;name, namespace=nothing, ω=249.0) + p = paramscoping(ω=ω) + ω = p[1] + sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true] + eqs = [D(θ) ~ ω + jcn] + sys = System(eqs, t, sts, p; name=name) + new{NonNoisy}(p, sys, namespace) end end @@ -597,50 +581,19 @@ struct QIF_PING_NGNMM <: NeuralMassBlox end end -function van_der_pol(;name, - namespace=nothing, - θ=1.0, - ϕ=0.1, - noise=false) - if noise - return VanDerPolNoise(name=name, namespace=namespace, θ=θ, ϕ=ϕ) - else - return VanDerPol(name=name, namespace=namespace, θ=θ) - end -end - - -struct VanDerPol <: NeuralMassBlox +struct VanDerPol{IsNoisy} <: NeuralMassBlox params system namespace - function VanDerPol(; - name, - namespace=nothing, - θ=1.0) - p = paramscoping(θ=θ) - θ = p[1] - sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] - - eqs = [D(x) ~ y, - D(y) ~ θ*(1-x^2)*y - x + jcn] - - sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) + function VanDerPol(; name, namespace=nothing, θ=1.0, ϕ=0.1, include_noise=false) + if include_noise + VanDerPol{Noisy}(;name, namespace, θ, ϕ) + else + VanDerPol{NonNoisy}(;name, namespace, θ) + end end -end - -struct VanDerPolNoise <: NeuralMassBlox - params - system - namespace - - function VanDerPolNoise(; - name, - namespace=nothing, - θ=1.0, - ϕ=0.1) + function VanDerPol{Noisy}(; name, namespace=nothing, θ=1.0, ϕ=0.1) p = paramscoping(θ=θ, ϕ=ϕ) θ, ϕ = p sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] @@ -650,6 +603,17 @@ struct VanDerPolNoise <: NeuralMassBlox D(y) ~ θ*(1-x^2)*y - x + ϕ*ξ + jcn] sys = System(eqs, t, sts, p; name=name) - new(p, sys, namespace) + new{Noisy}(p, sys, namespace) + end + function VanDerPol{NonNoisy}(; name, namespace=nothing, θ=1.0) + p = paramscoping(θ=θ) + θ = p[1] + sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true] + + eqs = [D(x) ~ y, + D(y) ~ θ*(1-x^2)*y - x + jcn] + + sys = System(eqs, t, sts, p; name=name) + new{NonNoisy}(p, sys, namespace) end end diff --git a/test/GraphDynamicsTests/test_suite.jl b/test/GraphDynamicsTests/test_suite.jl index b9fceeae..a7adf705 100644 --- a/test/GraphDynamicsTests/test_suite.jl +++ b/test/GraphDynamicsTests/test_suite.jl @@ -228,13 +228,13 @@ end function vdp_test() @testset "VdP" begin Random.seed!(1234) - @named vdp = van_der_pol() + @named vdp = VanDerPol() g = MetaDiGraph() add_blox!(g, vdp) test_compare_du_and_sols(ODEProblem, g, (0.0, 1.0); u0map=[vdp.x => 0.0, vdp.y=>0.1], rtol=1e-10, alg=Vern7()) - @named vdpn = van_der_pol(noise=true) - @named vdpn2 = van_der_pol(noise=true) + @named vdpn = VanDerPol(include_noise=true) + @named vdpn2 = VanDerPol(include_noise=true) g = MetaDiGraph() add_blox!(g, vdpn) add_blox!(g, vdpn2) @@ -440,8 +440,8 @@ end function kuramoto_test() @testset "Kuramoto Oscillator" begin @testset "Non-noisy" begin - @named K01 = kuramoto_oscillator(ω=2.0) - @named K02 = kuramoto_oscillator(ω=5.0) + @named K01 = KuramotoOscillator(ω=2.0) + @named K02 = KuramotoOscillator(ω=5.0) adj = [0 1; 1 0] g = MetaDiGraph() @@ -451,8 +451,8 @@ function kuramoto_test() test_compare_du_and_sols(ODEProblem, g, (0.0, 2.0); rtol=1e-10, alg=AutoVern7(Rodas4())) end @testset "Noisy" begin - @named K01 = kuramoto_oscillator(ω=2.0, noise=true) - @named K02 = kuramoto_oscillator(ω=5.0, noise=true) + @named K01 = KuramotoOscillator(ω=2.0, include_noise=true) + @named K02 = KuramotoOscillator(ω=5.0, include_noise=true) adj = [0 1; 1 0] g = MetaDiGraph() diff --git a/test/components.jl b/test/components.jl index 00f885c1..cbdc5d82 100644 --- a/test/components.jl +++ b/test/components.jl @@ -195,35 +195,37 @@ end end @testset "Kuramoto Oscillator" begin - @named K01 = kuramoto_oscillator(ω=2.0) - @named K02 = kuramoto_oscillator(ω=5.0) - adj = [0 1; 1 0] - g = MetaDiGraph() - add_blox!.(Ref(g), [K01, K02]) - create_adjacency_edges!(g, adj) + sim_dur = 2e1 + @testset "Non-noisy" begin + @named K01 = KuramotoOscillator(ω=2.0) + @named K02 = KuramotoOscillator(ω=5.0) - @named sys = system_from_graph(g) + g = MetaDiGraph() + add_blox!.(Ref(g), [K01, K02]) + create_adjacency_edges!(g, adj) - sim_dur = 2e1 - prob = ODEProblem(sys, [], (0.0, sim_dur), []) - sol = solve(prob, AutoVern7(Rodas4()), saveat=0.1) - @test sol.retcode == ReturnCode.Success + @named sys = system_from_graph(g) - @named K01 = kuramoto_oscillator(ω=2.0, noise=true) - @named K02 = kuramoto_oscillator(ω=5.0, noise=true) + prob = ODEProblem(sys, [], (0.0, sim_dur), []) + sol = solve(prob, AutoVern7(Rodas4()), saveat=0.1) + @test sol.retcode == ReturnCode.Success + end - adj = [0 1; 1 0] - g = MetaDiGraph() - add_blox!.(Ref(g), [K01, K02]) - create_adjacency_edges!(g, adj) + @testset "Noisy" begin + @named K01 = KuramotoOscillator(ω=2.0, include_noise=true) + @named K02 = KuramotoOscillator(ω=5.0, include_noise=true) - @named sys = system_from_graph(g) + g = MetaDiGraph() + add_blox!.(Ref(g), [K01, K02]) + create_adjacency_edges!(g, adj) - sim_dur = 2e1 - prob = SDEProblem(sys, [], (0.0, sim_dur), []) - sol = solve(prob, RKMil(), saveat=0.1) - @test sol.retcode == ReturnCode.Success + @named sys = system_from_graph(g) + + prob = SDEProblem(sys, [], (0.0, sim_dur), []) + sol = solve(prob, RKMil(), saveat=0.1) + @test sol.retcode == ReturnCode.Success + end end @testset "Canonical Micro Circuit network" begin @@ -778,21 +780,25 @@ end @testset "VdP" begin Random.seed!(1234) - @named vdp = van_der_pol() - g = MetaDiGraph() - add_blox!(g, vdp) - @named sys = system_from_graph(g) - prob = ODEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) - sol = solve(prob,Tsit5()) - @test sol.retcode == ReturnCode.Success + @testset "Non-noisy" begin + @named vdp = VanDerPol() + g = MetaDiGraph() + add_blox!(g, vdp) + @named sys = system_from_graph(g) + prob = ODEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) + sol = solve(prob,Tsit5()) + @test sol.retcode == ReturnCode.Success + end - @named vdp = van_der_pol(noise=true) - g = MetaDiGraph() - add_blox!(g, vdp) - @named sys = system_from_graph(g) - prob = SDEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) - sol = solve(prob, RKMil()) - @test sol.retcode == ReturnCode.Success + @testset "Noisy" begin + @named vdp = VanDerPol(include_noise=true) + g = MetaDiGraph() + add_blox!(g, vdp) + @named sys = system_from_graph(g) + prob = SDEProblem(sys, [0.0, 0.1], (0.0, 20.0), []) + sol = solve(prob, RKMil()) + @test sol.retcode == ReturnCode.Success + end end @testset "DBS circuit firing rates" begin From a47122751bebbdec9b81999dcadf4b4c4916e66e Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 17 Jan 2025 17:19:00 +0100 Subject: [PATCH 15/15] remove unneeded comments --- src/GraphDynamicsInterop/neuron_interop.jl | 35 ---------------------- 1 file changed, 35 deletions(-) diff --git a/src/GraphDynamicsInterop/neuron_interop.jl b/src/GraphDynamicsInterop/neuron_interop.jl index 65ab85e4..17dc45a1 100644 --- a/src/GraphDynamicsInterop/neuron_interop.jl +++ b/src/GraphDynamicsInterop/neuron_interop.jl @@ -13,7 +13,6 @@ function recursive_getdefault(x::Union{MTK.Num, MTK.BasicSymbolic}) substitute(def_x, defs) end - function define_neurons() for (name, T) ∈ [(:hhne, HHNeuronExciBlox) (:hhni, HHNeuronInhibBlox) @@ -177,40 +176,6 @@ GraphDynamics.initialize_input(s::Subsystem{PoissonSpikeTrain}) = (;) GraphDynamics.apply_subsystem_differential!(_, ::Subsystem{PoissonSpikeTrain}, _, _) = nothing GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain}) = false - -# #------------------------- -# # Kuramoto -# issupported(::KuramotoOscillator) = true -# function to_subsystem(o::KuramotoOscillator) -# states = SubsystemStates{KuramotoOscillator}((;θ=0.0,)) -# params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω))) -# Subsystem(states, params) -# end - -# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = (;jcn = 0.0) -# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, (; jcn), t) -# SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn)) -# end - - -# issupported(::KuramotoOscillatorNoise) = true -# function to_subsystem(o::KuramotoOscillatorNoise) -# states = SubsystemStates{KuramotoOscillatorNoise}((;θ=0.0,)) -# params = SubsystemParams{KuramotoOscillatorNoise}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ))) -# Subsystem(states, params) -# end - -# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillatorNoise}) = (;jcn=0.0) -# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillatorNoise}, (;jcn), t) -# SubsystemStates{KuramotoOscillatorNoise}((; #=D=#θ= s.ω + jcn)) -# end -# GraphDynamics.isstochastic(::Type{KuramotoOscillatorNoise}) = true -# function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillatorNoise}, t) -# vstates[1] = ζ -# nothing -# end - - #------------------------- # Matrisome issupported(::Matrisome) = true