Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural mass housekeeping #530

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
5 changes: 4 additions & 1 deletion src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ using ..Neuroblox:
Matrisome,
TAN,
SNc,
Noisy,
NonNoisy,
KuramotoOscillator,
CorticalBlox,
STN,
Expand All @@ -47,7 +49,8 @@ using ..Neuroblox:
LIFInhCircuitBlox,
PINGNeuronExci,
PINGNeuronInhib,
AbstractPINGNeuron
AbstractPINGNeuron,
VanDerPol

using GraphDynamics:
GraphDynamics,
Expand Down
5 changes: 3 additions & 2 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs
(;conn=BasicConnection(w_val), names=[name])
end

function (c::BasicConnection)(src::Subsystem{KuramotoOscillator}, dst::Subsystem{KuramotoOscillator})
function (c::BasicConnection)(src::Subsystem{<:KuramotoOscillator},
dst::Subsystem{<:KuramotoOscillator})
w = c.weight
x₀ = src.θ
xᵢ = dst.θ
w * sin(x₀ - xᵢ)
(;jcn = w * sin(x₀ - xᵢ))
end

#----------------------------------------------
Expand Down
82 changes: 24 additions & 58 deletions src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,37 @@ function recursive_getdefault(x::Union{MTK.Num, MTK.BasicSymbolic})
substitute(def_x, defs)
end


function define_neurons()
for (name, T) ∈ [(:hhne, :HHNeuronExciBlox)
(:hhni, :HHNeuronInhibBlox)
(:hhni_msn_adam, :HHNeuronInhib_MSN_Adam_Blox)
(:hhni_fsi_adam, :HHNeuronInhib_FSI_Adam_Blox)
(:hhne_stn_adam, :HHNeuronExci_STN_Adam_Blox)
(:hhni_GPe_adam, :HHNeuronInhib_GPe_Adam_Blox)
(:ngei, :NextGenerationEIBlox)
(:wc, :WilsonCowan)
(:ho, :HarmonicOscillator)
(:jr, :JansenRit) # Note! Regular JansenRit can support delays, and I have not yet implemented this!
(:if, :IFNeuron)
(:lif, :LIFNeuron)
(:qif, :QIFNeuron)
(:izh, :IzhikevichNeuron)
(:lif_exci, :LIFExciNeuron)
(:lif_inh, :LIFInhNeuron)
(:pexci, :PINGNeuronExci)
(:pinhib, :PINGNeuronInhib)
]
sys = getproperty(Neuroblox, T)(;name)
for (name, T) ∈ [(:hhne, HHNeuronExciBlox)
(:hhni, HHNeuronInhibBlox)
(:hhni_msn_adam, HHNeuronInhib_MSN_Adam_Blox)
(:hhni_fsi_adam, HHNeuronInhib_FSI_Adam_Blox)
(:hhne_stn_adam, HHNeuronExci_STN_Adam_Blox)
(:hhni_GPe_adam, HHNeuronInhib_GPe_Adam_Blox)
(:ngei, NextGenerationEIBlox)
(:wc, WilsonCowan)
(:ho, HarmonicOscillator)
(:jr, JansenRit) # Note! Regular JansenRit can support delays, and I have not yet implemented this!
(:if, IFNeuron)
(:lif, LIFNeuron)
(:qif, QIFNeuron)
(:izh, IzhikevichNeuron)
(:lif_exci, LIFExciNeuron)
(:lif_inh, LIFInhNeuron)
(:pexci, PINGNeuronExci)
(:pinhib, PINGNeuronInhib)
(:VdP, VanDerPol{NonNoisy})
(:VdPN, VanDerPol{Noisy})
(:ko, KuramotoOscillator{NonNoisy})
(:kon, KuramotoOscillator{Noisy})]
sys = T(;name)
system = structural_simplify(sys.system; fully_determined=false)
params = get_ps(system)
t = Symbol(get_iv(system))

states = [s for s ∈ unknowns(system) if !MTK.isinput(s)]
inputs = [s for s ∈ unknowns(system) if MTK.isinput(s)]

# states_unwrapped = map(x -> x.f, states)
# inputs_unwrapped = map(x -> x.f, inputs)

p_syms = map(Symbol, params)
s_syms = map(x -> tosymbol(x; escape=false), states)
input_syms = map(x -> tosymbol(x; escape=false), inputs)
Expand Down Expand Up @@ -137,7 +136,7 @@ function define_neurons()
end
end
end
if !isempty(get_discrete_events(system)) && T ∉ (:LIFExciNeuron, :LIFInhNeuron)
if !isempty(get_discrete_events(system)) && T ∉ (LIFExciNeuron, LIFInhNeuron)
cb = only(collect(get_discrete_events(system))) # currently only support single events
cb_eq = r(cb.condition)
if cb_eq.f ∉ (<, >, <=, >=)
Expand Down Expand Up @@ -166,18 +165,6 @@ function define_neurons()
end
define_neurons() # it's useful when developing this module to have these in a function

#Maybe should just encorporate this into define_neurons()
# for T ∈ [:LIFExciNeuron, :LIFInhNeuron]
# @eval begin
# GraphDynamics.has_discrete_events(::Type{$T}) = true
# GraphDynamics.discrete_event_condition((; t_refract_end)::Subsystem{$T}, t) = t_refract_end == t
# function GraphDynamics.apply_discrete_event!(integrator, _, pview, neuron::Subsystem{$T}, _)
# params = get_params(neuron)
# pview[] = @set params.is_refractory = 0
# end
# end
# end

issupported(::PoissonSpikeTrain) = true
components(p::PoissonSpikeTrain) = (p,)
function to_subsystem(s::PoissonSpikeTrain)
Expand All @@ -189,27 +176,6 @@ GraphDynamics.initialize_input(s::Subsystem{PoissonSpikeTrain}) = (;)
GraphDynamics.apply_subsystem_differential!(_, ::Subsystem{PoissonSpikeTrain}, _, _) = nothing
GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain}) = false


#-------------------------
# Kuramoto
issupported(::KuramotoOscillator) = true
function to_subsystem(o::KuramotoOscillator)
states = SubsystemStates{KuramotoOscillator}((;θ=0.0,))
params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ)))
Subsystem(states, params)
end

GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = 0.0
function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, jcn, t)
SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn))
end
GraphDynamics.isstochastic(::Type{KuramotoOscillator}) = true
function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillator}, t)
vstates[1] = ζ
nothing
end


#-------------------------
# Matrisome
issupported(::Matrisome) = true
Expand Down
3 changes: 1 addition & 2 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ include("blox/cortical.jl")
include("blox/canonicalmicrocircuit.jl")
include("blox/neuron_models.jl")
include("blox/DBS_Model_Blox_Adam_Brown.jl")
include("blox/van_der_pol.jl")
include("blox/ts_outputs.jl")
include("blox/sources.jl")
include("blox/DBS_sources.jl")
Expand Down Expand Up @@ -225,7 +224,7 @@ end

export Neuron
export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory,
hh_neuron_inhibitory, van_der_pol, Generic2dOscillator
hh_neuron_inhibitory, VanDerPol, Generic2dOscillator, kuramoto_oscillator
export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron,
CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox,
HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox
Expand Down
90 changes: 68 additions & 22 deletions src/blox/neural_mass.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
struct Noisy end
struct NonNoisy end

mutable struct NextGenerationBlox <: NeuralMassBlox
C::Num
Δ::Num
Expand Down Expand Up @@ -465,9 +468,10 @@ end

where \$W_i\$ is a Wiener process and \$\\zeta_i\$ is the noise strength.

Arguments:
Keyword arguments:
- name: Name given to ODESystem object within the blox.
- namespace: Additional namespace above name if needed for inheritance.
- `include_noise` (default `false`) determines if brownian noise is included in the dynamics of the blox.
- Other parameters: See reference for full list. Note that parameters are scaled so that units of time are in milliseconds.
Default parameter values are taken from [2].

Expand All @@ -481,34 +485,39 @@ Citations:
2024 Jun 14;199:106565. doi: 10.1016/j.nbd.2024.106565. Epub ahead of print. PMID: 38880431.

"""
struct KuramotoOscillator <: NeuralMassBlox
struct KuramotoOscillator{IsNoisy} <: NeuralMassBlox
params
system
namespace

function KuramotoOscillator(;
name,
namespace=nothing,
ω=249.0,
ζ=5.92,
include_noise=false
)
p = paramscoping(ω=ω, ζ=ζ)
ω, ζ = p

function KuramotoOscillator(; name,
namespace=nothing,
ω=249.0,
ζ=5.92,
include_noise=false)
if include_noise
sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true]
@brownian w
eqs = [D(θ) ~ ω + ζ * w + jcn]
sys = System(eqs, t, sts, p; name=name)
new(p, sys, namespace)
KuramotoOscillator{Noisy}(;name, namespace, ω, ζ)
else
sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true]
eqs = [D(θ) ~ ω + jcn]
sys = System(eqs, t, sts, p; name=name)
new(p, sys, namespace)
KuramotoOscillator{NonNoisy}(;name, namespace, ω)
end
end
function KuramotoOscillator{Noisy}(;name, namespace=nothing, ω=249.0, ζ=5.92)
p = paramscoping(ω=ω, ζ=ζ)
ω, ζ = p
sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true]
@brownian w
eqs = [D(θ) ~ ω + jcn + ζ*w]
sys = System(eqs, t, sts, p; name=name)
new{Noisy}(p, sys, namespace)
end
function KuramotoOscillator{NonNoisy}(;name, namespace=nothing, ω=249.0)
p = paramscoping(ω=ω)
ω = p[1]
sts = @variables θ(t)=0.0 [output = true] jcn(t) [input=true]
eqs = [D(θ) ~ ω + jcn]
sys = System(eqs, t, sts, p; name=name)
new{NonNoisy}(p, sys, namespace)
end
end

struct PYR_Izh <: NeuralMassBlox
Expand Down Expand Up @@ -570,4 +579,41 @@ struct QIF_PING_NGNMM <: NeuralMassBlox

new(p, sys, namespace)
end
end
end

struct VanDerPol{IsNoisy} <: NeuralMassBlox
params
system
namespace

function VanDerPol(; name, namespace=nothing, θ=1.0, ϕ=0.1, include_noise=false)
if include_noise
VanDerPol{Noisy}(;name, namespace, θ, ϕ)
else
VanDerPol{NonNoisy}(;name, namespace, θ)
end
end
function VanDerPol{Noisy}(; name, namespace=nothing, θ=1.0, ϕ=0.1)
p = paramscoping(θ=θ, ϕ=ϕ)
θ, ϕ = p
sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true]
@brownian ξ

eqs = [D(x) ~ y,
D(y) ~ θ*(1-x^2)*y - x + ϕ*ξ + jcn]

sys = System(eqs, t, sts, p; name=name)
new{Noisy}(p, sys, namespace)
end
function VanDerPol{NonNoisy}(; name, namespace=nothing, θ=1.0)
p = paramscoping(θ=θ)
θ = p[1]
sts = @variables x(t)=0.0 [output=true] y(t)=0.0 jcn(t) [input=true]

eqs = [D(x) ~ y,
D(y) ~ θ*(1-x^2)*y - x + jcn]

sys = System(eqs, t, sts, p; name=name)
new{NonNoisy}(p, sys, namespace)
end
end
11 changes: 0 additions & 11 deletions src/blox/van_der_pol.jl

This file was deleted.

3 changes: 2 additions & 1 deletion test/GraphDynamicsTests/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ if GROUP == "All" || GROUP == "GraphDynamics1"
end

if GROUP == "All" || GROUP == "GraphDynamics2"
kuramato_test()
vdp_test()
kuramoto_test()
wta_tests()
dbs_circuit_components()
dbs_circuit()
Expand Down
Loading
Loading