From 1bcc7334d2be3a67085bc4850e743a06746c40a7 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:38:08 +0200 Subject: [PATCH 1/8] update imports --- src/Neuroblox.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 26be7d77..c7b09df4 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -24,12 +24,12 @@ using DelayDiffEq using StatsBase: sample using Distributions -using ModelingToolkit: get_namespace, get_systems, renamespace, - namespace_equation, namespace_variables, namespace_parameters, namespace_expr, +using ModelingToolkit: get_namespace, get_systems, isparameter, + renamespace, namespace_equation, namespace_parameters, namespace_expr, AbstractODESystem import ModelingToolkit: inputs, nameof -using Symbolics: @register_symbolic +using Symbolics: @register_symbolic, getdefaultval using IfElse using DelimitedFiles: readdlm From 6346b8078ca777625de09986bf5820fa9fc05734 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:41:50 +0200 Subject: [PATCH 2/8] add `t_affect` kwarg for parameter callbacks --- src/Neurographs.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/Neurographs.jl b/src/Neurographs.jl index fff1678e..a7b2d077 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -180,27 +180,36 @@ function graph_delays(g::MetaDiGraph) return bc.delays end -function system_from_graph(g::MetaDiGraph; name) +function system_from_graph(g::MetaDiGraph; name, t_affect=missing) bc = connector_from_graph(g) - return system_from_graph(g, bc; name) + return system_from_graph(g, bc; name, t_affect) end # Additional dispatch if extra parameters are passed for edge definitions -function system_from_graph(g::MetaDiGraph, p::Vector{Num}; name) +function system_from_graph(g::MetaDiGraph, p::Vector{Num}; name, t_affect=missing) bc = connector_from_graph(g) - return system_from_graph(g, bc, p; name) + return system_from_graph(g, bc, p; name, t_affect) end -function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name) +function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_affect=missing) @variables t blox_syss = get_sys(g) - return compose(ODESystem(bc.eqs, t, [], params(bc); name, discrete_events = bc.events), blox_syss) + + connection_eqs = get_equations_with_state_lhs(bc) + cbs = get_callbacks(bc, t_affect) + + return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss) end -function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name) +function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_affect=missing) @variables t blox_syss = get_sys(g) - return compose(ODESystem(bc.eqs, t, [], vcat(params(bc), p); name, discrete_events = bc.events), blox_syss) + + connection_eqs = get_equations_with_state_lhs(bc) + + cbs = get_callbacks(bc, t_affect) + + return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss) end function system_from_parts(parts::AbstractVector; name) From 803c0a2e99cb2f75e72d99aa59c95e457d00332e Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:42:25 +0200 Subject: [PATCH 3/8] remove brackets from callback time --- src/blox/connections.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index d6836cbc..c2bb4e4a 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -119,7 +119,6 @@ function (bc::BloxConnector)( end - function (bc::BloxConnector)( bloxout::NeuralMassBlox, bloxin::NeuralMassBlox; @@ -133,8 +132,6 @@ function (bc::BloxConnector)( if haskey(kwargs, :learning_rule) lr = kwargs[:learning_rule] - # maybe_set_state_pre!(lr, - # maybe_set_state_post!(lr, bc.learning_rules[w] = lr end @@ -293,8 +290,7 @@ function (bc::BloxConnector)( end eq = sys_in.jcn ~ w*sys_out.spikes_window - - accumulate_equation!(bc, eq) + accumulate_equation!(bc, eq) end function (bc::BloxConnector)( @@ -308,8 +304,8 @@ function (bc::BloxConnector)( neurons_in = get_inh_neurons(str_in) t_event = get_event_time(kwargs, nameof(str_out), nameof(str_in)) - cb_matr = [t_event] => [sys_matr_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)] - cb_strios = [t_event] => [sys_strios_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)] + cb_matr = t_event => [sys_matr_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)] + cb_strios = t_event => [sys_strios_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)] push!(bc.events, cb_matr) push!(bc.events, cb_strios) @@ -317,7 +313,7 @@ function (bc::BloxConnector)( sys_neuron = get_namespaced_sys(neuron) # Large negative current added to shut down the Striatum spiking neurons. # Value is hardcoded for now, as it's more of a hack, not user option. - cb_neuron = [t_event] => [sys_neuron.I_bg ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, -2, 0)] + cb_neuron = t_event => [sys_neuron.I_bg ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, -2, 0)] push!(bc.events, cb_neuron) end end @@ -404,7 +400,7 @@ function (bc::BloxConnector)( sys_in = get_namespaced_sys(discr_in) t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in)) - cb = [t_event] => [sys_in.H ~ IfElse.ifelse(sys_out.ρ > sys_in.ρ, 0, 1)] + cb = t_event => [sys_in.H ~ IfElse.ifelse(sys_out.ρ > sys_in.ρ, 0, 1)] push!(bc.events, cb) end From cbbdc2d64d609bc21f5e78698f37259cbf38834d Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:43:13 +0200 Subject: [PATCH 4/8] add utils for parameter callbacks from `BloxConnector` --- src/blox/connections.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index c2bb4e4a..1c442923 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -24,6 +24,27 @@ function accumulate_equation!(bc::BloxConnector, eq) bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs end +get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs) + +get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs) + +function get_callbacks(bc, t_affect=missing) + if !ismissing(t_affect) + cbs_params = t_affect => get_equations_with_parameter_lhs(bc) + + return vcat(cbs_params, bc.events) + else + return bc.events + end +end + +function generate_callbacks_for_parameter_lhs(bc) + eqs = get_equations_with_parameter_lhs(bc) + cbs = [bc.param_update_times[eq.lhs] => eq for eq in eqs] + + return cbs +end + function generate_weight_param(blox_out, blox_in; kwargs...) name_out = namespaced_nameof(blox_out) name_in = namespaced_nameof(blox_in) From 5e8b621b0e117835b3d281c19924aa8dbc4417d3 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:43:52 +0200 Subject: [PATCH 5/8] make `jcn` a parameter in `AbstractDiscrete` blox --- src/blox/discrete.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/blox/discrete.jl b/src/blox/discrete.jl index 54b3c5ce..1eb05fcd 100644 --- a/src/blox/discrete.jl +++ b/src/blox/discrete.jl @@ -1,15 +1,15 @@ -abstract type AbstractAlgebraic <: AbstractBlox end +abstract type AbstractDiscrete <: AbstractBlox end -abstract type AbstractModulator <: AbstractAlgebraic end +abstract type AbstractModulator <: AbstractDiscrete end -struct Matrisome <: AbstractAlgebraic +struct Matrisome <: AbstractDiscrete odesystem namespace function Matrisome(; name, namespace=nothing) @variables t - sts = @variables ρ(t)=0.0 [irreducible=true] jcn(t)=0.0 [input=true] - ps = @parameters H=1 + sts = @variables ρ(t)=0.0 [irreducible=true] + ps = @parameters H=1 jcn=0.0 [input=true] eqs = [ ρ ~ H*jcn ] @@ -19,16 +19,16 @@ struct Matrisome <: AbstractAlgebraic end end -struct Striosome <: AbstractAlgebraic +struct Striosome <: AbstractDiscrete odesystem namespace function Striosome(; name, namespace=nothing) @variables t - sts = @variables ρ(t)=0.0 [irreducible=true] jcn(t)=0.0 [input=true] - ps = @parameters H=1 + sts = @variables ρ(t)=0.0 [irreducible=true] + ps = @parameters H=1 jcn=0.0 [input=true] eqs = [ - ρ ~ H*jcn + 0.1 + ρ ~ H*jcn ] sys = ODESystem(eqs, t, sts, ps; name) @@ -36,14 +36,14 @@ struct Striosome <: AbstractAlgebraic end end -struct TAN <: AbstractAlgebraic +struct TAN <: AbstractDiscrete odesystem namespace function TAN(; name, namespace=nothing, κ=0.2) @variables t - sts = @variables R(t)=κ [irreducible=true] jcn(t)=0.0 [input=true] - ps = @parameters κ=κ spikes_window=0.0 + sts = @variables R(t)=κ [irreducible=true] + ps = @parameters κ=κ spikes_window=0.0 jcn=0.0 [input=true] eqs = [ R ~ IfElse.ifelse(iszero(jcn), κ, κ/jcn) ] @@ -62,8 +62,8 @@ struct SNc <: AbstractModulator function SNc(; name, namespace=nothing, κ_DA=0.2, N_time_blocks=5, DA_reward=10) @variables t - sts = @variables R(t)=κ_DA jcn(t)=0.0 [input=true] - ps = @parameters κ=κ_DA + sts = @variables R(t)=κ_DA + ps = @parameters κ=κ_DA jcn=0.0 [input=true] eqs = [ R ~ IfElse.ifelse(iszero(jcn), κ, κ/jcn) ] From d19d055d1dd0f4806c00ddb9bb58489848ca10e9 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:44:12 +0200 Subject: [PATCH 6/8] add `t_affect` in `Agent` for param callbacks --- src/blox/reinforcement_learning.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index c3ad472d..5d53d514 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -154,7 +154,8 @@ mutable struct Agent function Agent(g::MetaDiGraph; name, kwargs...) bc = connector_from_graph(g) - sys = system_from_graph(g, bc; name) + t_affect = haskey(kwargs, :t_block) ? kwargs[:t_block] : missing + sys = system_from_graph(g, bc; name, t_affect) ss = structural_simplify(sys; allow_parameter=false) u0 = haskey(kwargs, :u0) ? kwargs[:u0] : [] From 2abd506208a35ed5813973d1ee7a09b3fabfaaba Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:44:25 +0200 Subject: [PATCH 7/8] fix names to not overlap with types --- test/jansen_rit_component_tests_new_timing.jl | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/jansen_rit_component_tests_new_timing.jl b/test/jansen_rit_component_tests_new_timing.jl index 6b8d58ba..a809737a 100644 --- a/test/jansen_rit_component_tests_new_timing.jl +++ b/test/jansen_rit_component_tests_new_timing.jl @@ -4,16 +4,16 @@ using Neuroblox, DifferentialEquations, DataFrames, Test, Distributions, Statist # Create Regions @named Str = jansen_ritC(τ=0.0022*τ_factor, H=20, λ=300, r=0.3) -@named GPe = jansen_ritC(τ=0.04*τ_factor, H=20, λ=400, r=0.1) -@named STN = jansen_ritC(τ=0.01*τ_factor, H=20, λ=500, r=0.1) -@named GPi = jansen_ritSC(τ=0.014*τ_factor, H=20, λ=400, r=0.1) +@named gpe = jansen_ritC(τ=0.04*τ_factor, H=20, λ=400, r=0.1) +@named stn = jansen_ritC(τ=0.01*τ_factor, H=20, λ=500, r=0.1) +@named gpi = jansen_ritSC(τ=0.014*τ_factor, H=20, λ=400, r=0.1) @named Th = jansen_ritSC(τ=0.002*τ_factor, H=10, λ=20, r=5) @named EI = jansen_ritSC(τ=0.01*τ_factor, H=20, λ=5, r=5) @named PY = jansen_ritSC(τ=0.001*τ_factor, H=20, λ=5, r=0.15) @named II = jansen_ritSC(τ=2.0*τ_factor, H=60, λ=5, r=5) # Connect Regions through Adjacency Matrix -blox = [Str, GPe, STN, GPi, Th, EI, PY, II] +blox = [Str, gpe, stn, gpi, Th, EI, PY, II] sys = [s.odesystem for s in blox] connect = [s.connector for s in blox] @@ -33,7 +33,7 @@ adj_matrix_lin = [0 0 0 0 0 0 0 0; sim_dur = 2000.0 # Simulate for 2 seconds mysys = structural_simplify(CBGTC_Circuit_lin) sol = simulate(mysys, [], (0.0, sim_dur), [], Vern7(); saveat=1) -@test sol[!, "GPi₊x(t)"][4] ≈ -2219.2560209502685 #updated to new value in ms +@test sol[!, "gpi₊x(t)"][4] ≈ -2219.2560209502685 #updated to new value in ms """ Testing new Jansen-Rit blox @@ -47,14 +47,14 @@ same thing as the old simulate call with AutoVern7(Rodas4() since there are no d # test new Jansen-Rit blox @named Str = JansenRit(τ=0.0022*τ_factor, H=20, λ=300, r=0.3) -@named GPe = JansenRit(τ=0.04*τ_factor, cortical=false) # all default subcortical except τ -@named STN = JansenRit(τ=0.01*τ_factor, H=20, λ=500, r=0.1) -@named GPi = JansenRit(cortical=false) # default parameters subcortical Jansen Rit blox +@named gpe = JansenRit(τ=0.04*τ_factor, cortical=false) # all default subcortical except τ +@named stn = JansenRit(τ=0.01*τ_factor, H=20, λ=500, r=0.1) +@named gpi = JansenRit(cortical=false) # default parameters subcortical Jansen Rit blox @named Th = JansenRit(τ=0.002*τ_factor, H=10, λ=20, r=5) @named EI = JansenRit(τ=0.01*τ_factor, H=20, λ=5, r=5) @named PY = JansenRit(cortical=true) # default parameters cortical Jansen Rit blox @named II = JansenRit(τ=2.0*τ_factor, H=60, λ=5, r=5) -blox = [Str, GPe, STN, GPi, Th, EI, PY, II] +blox = [Str, gpe, stn, gpi, Th, EI, PY, II] # test graphs g = MetaDiGraph() @@ -104,7 +104,7 @@ prob = DDEProblem(final_system_sys, alg = MethodOfSteps(Vern7()) sol_dde_no_delays = solve(prob, alg, saveat=1) sol2 = DataFrame(sol_dde_no_delays) -@test isapprox(sol2[!, "GPi₊x(t)"][500:1000], sol[!, "GPi₊x(t)"][500:1000], rtol=1e-8) +@test isapprox(sol2[!, "gpi₊x(t)"][500:1000], sol[!, "gpi₊x(t)"][500:1000], rtol=1e-8) # Alternative version using adjacency matrix @@ -123,4 +123,4 @@ prob = DDEProblem(final_system_sys, alg = MethodOfSteps(Vern7()) sol_dde_no_delays = solve(prob, alg, saveat=1) sol3 = DataFrame(sol_dde_no_delays) -@test isapprox(sol3[!, "GPi₊x(t)"][500:1000], sol[!, "GPi₊x(t)"][500:1000], rtol=1e-8) \ No newline at end of file +@test isapprox(sol3[!, "gpi₊x(t)"][500:1000], sol[!, "gpi₊x(t)"][500:1000], rtol=1e-8) \ No newline at end of file From ddfbb5410ee397552604b315d17522ea635810b0 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Tue, 28 Nov 2023 09:44:40 +0200 Subject: [PATCH 8/8] update RL test --- test/reinforcement_learning.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index ff93f743..147e45de 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -43,13 +43,14 @@ add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1)) add_edge!(g, d[STR_L], d[AS]) add_edge!(g, d[STR_R], d[AS]) -agent = Agent(g; name=:ag) +agent = Agent(g; name=:ag, t_block = t_trial/5); ps = parameters(agent.odesystem) init_params = agent.problem.p map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps) idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps) -idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim)) +idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps) +idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn)) env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) run_experiment!(agent, env; alg=QNDF(), reltol=1e-9,abstol=1e-9)