Skip to content

Commit

Permalink
Merge pull request #311 from Neuroblox/ho/jcn_param
Browse files Browse the repository at this point in the history
Make `jcn` a paremeter in `AbstractDiscrete` blox
  • Loading branch information
harisorgn authored Nov 28, 2023
2 parents 4b7c6a9 + ddfbb54 commit dcadc3e
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 48 deletions.
6 changes: 3 additions & 3 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ using DelayDiffEq
using StatsBase: sample
using Distributions

using ModelingToolkit: get_namespace, get_systems, renamespace,
namespace_equation, namespace_variables, namespace_parameters, namespace_expr,
using ModelingToolkit: get_namespace, get_systems, isparameter,
renamespace, namespace_equation, namespace_parameters, namespace_expr,
AbstractODESystem
import ModelingToolkit: inputs, nameof

using Symbolics: @register_symbolic
using Symbolics: @register_symbolic, getdefaultval
using IfElse

using DelimitedFiles: readdlm
Expand Down
25 changes: 17 additions & 8 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,27 +180,36 @@ function graph_delays(g::MetaDiGraph)
return bc.delays
end

function system_from_graph(g::MetaDiGraph; name)
function system_from_graph(g::MetaDiGraph; name, t_affect=missing)
bc = connector_from_graph(g)
return system_from_graph(g, bc; name)
return system_from_graph(g, bc; name, t_affect)
end

# Additional dispatch if extra parameters are passed for edge definitions
function system_from_graph(g::MetaDiGraph, p::Vector{Num}; name)
function system_from_graph(g::MetaDiGraph, p::Vector{Num}; name, t_affect=missing)
bc = connector_from_graph(g)
return system_from_graph(g, bc, p; name)
return system_from_graph(g, bc, p; name, t_affect)
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name)
function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_affect=missing)
@variables t
blox_syss = get_sys(g)
return compose(ODESystem(bc.eqs, t, [], params(bc); name, discrete_events = bc.events), blox_syss)

connection_eqs = get_equations_with_state_lhs(bc)
cbs = get_callbacks(bc, t_affect)

return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss)
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name)
function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; name, t_affect=missing)
@variables t
blox_syss = get_sys(g)
return compose(ODESystem(bc.eqs, t, [], vcat(params(bc), p); name, discrete_events = bc.events), blox_syss)

connection_eqs = get_equations_with_state_lhs(bc)

cbs = get_callbacks(bc, t_affect)

return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss)
end

function system_from_parts(parts::AbstractVector; name)
Expand Down
35 changes: 26 additions & 9 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ function accumulate_equation!(bc::BloxConnector, eq)
bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs
end

get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs)

get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs)

function get_callbacks(bc, t_affect=missing)
if !ismissing(t_affect)
cbs_params = t_affect => get_equations_with_parameter_lhs(bc)

return vcat(cbs_params, bc.events)
else
return bc.events
end
end

function generate_callbacks_for_parameter_lhs(bc)
eqs = get_equations_with_parameter_lhs(bc)
cbs = [bc.param_update_times[eq.lhs] => eq for eq in eqs]

return cbs
end

function generate_weight_param(blox_out, blox_in; kwargs...)
name_out = namespaced_nameof(blox_out)
name_in = namespaced_nameof(blox_in)
Expand Down Expand Up @@ -119,7 +140,6 @@ function (bc::BloxConnector)(

end


function (bc::BloxConnector)(
bloxout::NeuralMassBlox,
bloxin::NeuralMassBlox;
Expand All @@ -133,8 +153,6 @@ function (bc::BloxConnector)(

if haskey(kwargs, :learning_rule)
lr = kwargs[:learning_rule]
# maybe_set_state_pre!(lr,
# maybe_set_state_post!(lr,
bc.learning_rules[w] = lr
end

Expand Down Expand Up @@ -293,8 +311,7 @@ function (bc::BloxConnector)(
end

eq = sys_in.jcn ~ w*sys_out.spikes_window

accumulate_equation!(bc, eq)
accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
Expand All @@ -308,16 +325,16 @@ function (bc::BloxConnector)(
neurons_in = get_inh_neurons(str_in)

t_event = get_event_time(kwargs, nameof(str_out), nameof(str_in))
cb_matr = [t_event] => [sys_matr_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)]
cb_strios = [t_event] => [sys_strios_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)]
cb_matr = t_event => [sys_matr_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)]
cb_strios = t_event => [sys_strios_in.H ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, 0, 1)]
push!(bc.events, cb_matr)
push!(bc.events, cb_strios)

for neuron in neurons_in
sys_neuron = get_namespaced_sys(neuron)
# Large negative current added to shut down the Striatum spiking neurons.
# Value is hardcoded for now, as it's more of a hack, not user option.
cb_neuron = [t_event] => [sys_neuron.I_bg ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, -2, 0)]
cb_neuron = t_event => [sys_neuron.I_bg ~ IfElse.ifelse(sys_matr_out.ρ > sys_matr_in.ρ, -2, 0)]
push!(bc.events, cb_neuron)
end
end
Expand Down Expand Up @@ -404,7 +421,7 @@ function (bc::BloxConnector)(
sys_in = get_namespaced_sys(discr_in)

t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in))
cb = [t_event] => [sys_in.H ~ IfElse.ifelse(sys_out.ρ > sys_in.ρ, 0, 1)]
cb = t_event => [sys_in.H ~ IfElse.ifelse(sys_out.ρ > sys_in.ρ, 0, 1)]
push!(bc.events, cb)
end

Expand Down
28 changes: 14 additions & 14 deletions src/blox/discrete.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
abstract type AbstractAlgebraic <: AbstractBlox end
abstract type AbstractDiscrete <: AbstractBlox end

abstract type AbstractModulator <: AbstractAlgebraic end
abstract type AbstractModulator <: AbstractDiscrete end

struct Matrisome <: AbstractAlgebraic
struct Matrisome <: AbstractDiscrete
odesystem
namespace

function Matrisome(; name, namespace=nothing)
@variables t
sts = @variables ρ(t)=0.0 [irreducible=true] jcn(t)=0.0 [input=true]
ps = @parameters H=1
sts = @variables ρ(t)=0.0 [irreducible=true]
ps = @parameters H=1 jcn=0.0 [input=true]
eqs = [
ρ ~ H*jcn
]
Expand All @@ -19,31 +19,31 @@ struct Matrisome <: AbstractAlgebraic
end
end

struct Striosome <: AbstractAlgebraic
struct Striosome <: AbstractDiscrete
odesystem
namespace

function Striosome(; name, namespace=nothing)
@variables t
sts = @variables ρ(t)=0.0 [irreducible=true] jcn(t)=0.0 [input=true]
ps = @parameters H=1
sts = @variables ρ(t)=0.0 [irreducible=true]
ps = @parameters H=1 jcn=0.0 [input=true]
eqs = [
ρ ~ H*jcn + 0.1
ρ ~ H*jcn
]
sys = ODESystem(eqs, t, sts, ps; name)

new(sys, namespace)
end
end

struct TAN <: AbstractAlgebraic
struct TAN <: AbstractDiscrete
odesystem
namespace

function TAN(; name, namespace=nothing, κ=0.2)
@variables t
sts = @variables R(t)=κ [irreducible=true] jcn(t)=0.0 [input=true]
ps = @parameters κ=κ spikes_window=0.0
sts = @variables R(t)=κ [irreducible=true]
ps = @parameters κ=κ spikes_window=0.0 jcn=0.0 [input=true]
eqs = [
R ~ IfElse.ifelse(iszero(jcn), κ, κ/jcn)
]
Expand All @@ -62,8 +62,8 @@ struct SNc <: AbstractModulator

function SNc(; name, namespace=nothing, κ_DA=0.2, N_time_blocks=5, DA_reward=10)
@variables t
sts = @variables R(t)=κ_DA jcn(t)=0.0 [input=true]
ps = @parameters κ=κ_DA
sts = @variables R(t)=κ_DA
ps = @parameters κ=κ_DA jcn=0.0 [input=true]
eqs = [
R ~ IfElse.ifelse(iszero(jcn), κ, κ/jcn)
]
Expand Down
3 changes: 2 additions & 1 deletion src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ mutable struct Agent
function Agent(g::MetaDiGraph; name, kwargs...)
bc = connector_from_graph(g)

sys = system_from_graph(g, bc; name)
t_affect = haskey(kwargs, :t_block) ? kwargs[:t_block] : missing
sys = system_from_graph(g, bc; name, t_affect)
ss = structural_simplify(sys; allow_parameter=false)

u0 = haskey(kwargs, :u0) ? kwargs[:u0] : []
Expand Down
22 changes: 11 additions & 11 deletions test/jansen_rit_component_tests_new_timing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ using Neuroblox, DifferentialEquations, DataFrames, Test, Distributions, Statist

# Create Regions
@named Str = jansen_ritC=0.0022*τ_factor, H=20, λ=300, r=0.3)
@named GPe = jansen_ritC=0.04*τ_factor, H=20, λ=400, r=0.1)
@named STN = jansen_ritC=0.01*τ_factor, H=20, λ=500, r=0.1)
@named GPi = jansen_ritSC=0.014*τ_factor, H=20, λ=400, r=0.1)
@named gpe = jansen_ritC=0.04*τ_factor, H=20, λ=400, r=0.1)
@named stn = jansen_ritC=0.01*τ_factor, H=20, λ=500, r=0.1)
@named gpi = jansen_ritSC=0.014*τ_factor, H=20, λ=400, r=0.1)
@named Th = jansen_ritSC=0.002*τ_factor, H=10, λ=20, r=5)
@named EI = jansen_ritSC=0.01*τ_factor, H=20, λ=5, r=5)
@named PY = jansen_ritSC=0.001*τ_factor, H=20, λ=5, r=0.15)
@named II = jansen_ritSC=2.0*τ_factor, H=60, λ=5, r=5)

# Connect Regions through Adjacency Matrix
blox = [Str, GPe, STN, GPi, Th, EI, PY, II]
blox = [Str, gpe, stn, gpi, Th, EI, PY, II]
sys = [s.odesystem for s in blox]
connect = [s.connector for s in blox]

Expand All @@ -33,7 +33,7 @@ adj_matrix_lin = [0 0 0 0 0 0 0 0;
sim_dur = 2000.0 # Simulate for 2 seconds
mysys = structural_simplify(CBGTC_Circuit_lin)
sol = simulate(mysys, [], (0.0, sim_dur), [], Vern7(); saveat=1)
@test sol[!, "GPi₊x(t)"][4] -2219.2560209502685 #updated to new value in ms
@test sol[!, "gpi₊x(t)"][4] -2219.2560209502685 #updated to new value in ms

"""
Testing new Jansen-Rit blox
Expand All @@ -47,14 +47,14 @@ same thing as the old simulate call with AutoVern7(Rodas4() since there are no d

# test new Jansen-Rit blox
@named Str = JansenRit=0.0022*τ_factor, H=20, λ=300, r=0.3)
@named GPe = JansenRit=0.04*τ_factor, cortical=false) # all default subcortical except τ
@named STN = JansenRit=0.01*τ_factor, H=20, λ=500, r=0.1)
@named GPi = JansenRit(cortical=false) # default parameters subcortical Jansen Rit blox
@named gpe = JansenRit=0.04*τ_factor, cortical=false) # all default subcortical except τ
@named stn = JansenRit=0.01*τ_factor, H=20, λ=500, r=0.1)
@named gpi = JansenRit(cortical=false) # default parameters subcortical Jansen Rit blox
@named Th = JansenRit=0.002*τ_factor, H=10, λ=20, r=5)
@named EI = JansenRit=0.01*τ_factor, H=20, λ=5, r=5)
@named PY = JansenRit(cortical=true) # default parameters cortical Jansen Rit blox
@named II = JansenRit=2.0*τ_factor, H=60, λ=5, r=5)
blox = [Str, GPe, STN, GPi, Th, EI, PY, II]
blox = [Str, gpe, stn, gpi, Th, EI, PY, II]

# test graphs
g = MetaDiGraph()
Expand Down Expand Up @@ -104,7 +104,7 @@ prob = DDEProblem(final_system_sys,
alg = MethodOfSteps(Vern7())
sol_dde_no_delays = solve(prob, alg, saveat=1)
sol2 = DataFrame(sol_dde_no_delays)
@test isapprox(sol2[!, "GPi₊x(t)"][500:1000], sol[!, "GPi₊x(t)"][500:1000], rtol=1e-8)
@test isapprox(sol2[!, "gpi₊x(t)"][500:1000], sol[!, "gpi₊x(t)"][500:1000], rtol=1e-8)


# Alternative version using adjacency matrix
Expand All @@ -123,4 +123,4 @@ prob = DDEProblem(final_system_sys,
alg = MethodOfSteps(Vern7())
sol_dde_no_delays = solve(prob, alg, saveat=1)
sol3 = DataFrame(sol_dde_no_delays)
@test isapprox(sol3[!, "GPi₊x(t)"][500:1000], sol[!, "GPi₊x(t)"][500:1000], rtol=1e-8)
@test isapprox(sol3[!, "gpi₊x(t)"][500:1000], sol[!, "gpi₊x(t)"][500:1000], rtol=1e-8)
5 changes: 3 additions & 2 deletions test/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ add_edge!(g, d[STR_R], d[SNcb], Dict(:weight => 1))
add_edge!(g, d[STR_L], d[AS])
add_edge!(g, d[STR_R], d[AS])

agent = Agent(g; name=:ag)
agent = Agent(g; name=:ag, t_block = t_trial/5);
ps = parameters(agent.odesystem)
init_params = agent.problem.p
map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps))
idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps)
idx_stim = findall(x -> occursin("stim₊", String(Symbol(x))), ps)
idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim))
idx_jcn = findall(x -> occursin("jcn", String(Symbol(x))), ps)
idxs_other_params = setdiff(eachindex(ps), vcat(idxs_weight, idx_stim, idx_jcn))

env = ClassificationEnvironment(stim; name=:env, namespace=global_ns)
run_experiment!(agent, env; alg=QNDF(), reltol=1e-9,abstol=1e-9)
Expand Down

0 comments on commit dcadc3e

Please sign in to comment.