Skip to content

Commit

Permalink
More LIF fixes for the decision making tutorial (#472)
Browse files Browse the repository at this point in the history
* more terms from connections to neurons

* accumulate both states and parameters (values) for spike affects

* match state with the correct parameter value in spike affect

* allow for duplicate parameters to be passed in functional affect using Pairs

* update comment & fix typo

* rename variable for clarity

* synchronize GraphDynamicsInterop with changes to the LIFExci / LIFInh neurons

---------

Co-authored-by: Mason Protter <mason.protter@icloud.com>
  • Loading branch information
harisorgn and MasonProtter authored Oct 30, 2024
1 parent 1c4d995 commit 1eba98e
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 55 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DataFrames = "1.3"
Distributions = "0.25.102"
ExponentialUtilities = "1"
ForwardDiff = "0.10"
GraphDynamics = "0.1.4"
GraphDynamics = "0.1.5"
Graphs = "1"
Interpolations = "0.14, 0.15"
MetaGraphs = "0.7"
Expand Down
3 changes: 2 additions & 1 deletion src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ using GraphDynamics:
StateIndex,
ParamIndex,
event_times,
calculate_inputs
calculate_inputs,
connection_index

using Random:
Random,
Expand Down
55 changes: 36 additions & 19 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,12 @@ end

function (c::BasicConnection)(sys_src::Subsystem{LIFExciNeuron},
sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
w = c.weight
(; S_AMPA, g_AMPA, V, V_E, g_NMDA, Mg) = sys_dst
(; S_NMDA) = sys_src
(; jcn = w * (S_AMPA * g_AMPA * (V - V_E) + S_NMDA * g_NMDA * (V - V_E) / (1 + Mg * exp(-0.062 * V) / 3.57)))
(; jcn = 0.0)
end

function (c::BasicConnection)(sys_src::Subsystem{LIFInhNeuron},
sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
w = c.weight
(; S_GABA, g_GABA, V, V_I) = sys_dst
(;jcn = w * S_GABA * g_GABA * (V - V_I))
function (c::BasicConnection)(::Subsystem{LIFInhNeuron},
::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
(; jcn = 0.0)
end

struct SpikeAffectEventBuilder
Expand All @@ -284,6 +279,7 @@ struct SpikeAffectEvent{i_src, i_LIFInh, i_LIFExci}
j_dsts_inh::Vector{Int}
j_dsts_exci::Vector{Int}
end

function (ev::SpikeAffectEventBuilder)(index_map)
(i_src, j_src) = index_map[ev.idx_src]
i_inh, j_dsts_inh = let v = ev.idx_dsts_inh
Expand Down Expand Up @@ -315,14 +311,18 @@ end




function GraphDynamics.apply_discrete_event!(integrator,
states::NTuple{Len, Any},
params::NTuple{Len, Any},
_,
connection_matrices,
t,
ev::SpikeAffectEvent{i_src, i_dst_inh, i_dst_exci}
) where {i_src, i_dst_inh, i_dst_exci, Len}
(; j_src, j_dsts_inh, j_dsts_exci) = ev

nc = connection_index(BasicConnection, connection_matrices)

params_src = params[i_src][j_src]
@reset params_src.t_refract_end = t + params_src.t_refract_duration
@reset params_src.is_refractory = 1
Expand All @@ -334,22 +334,39 @@ function GraphDynamics.apply_discrete_event!(integrator,
states[i_src][:V, j_src] = params_src.V_reset
if (states_src isa SubsystemStates{LIFExciNeuron}) && (j_src j_dsts_exci)
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
states[i_src][:x, j_src] += 1
w = connection_matrices[nc][i_src, i_src][j_src, j_src].weight
states[i_src][:x, j_src] += w
end

if states_src isa SubsystemStates{LIFExciNeuron}
!isnothing(i_dst_inh) && for j_dst j_dsts_inh
states[i_dst_inh][:S_AMPA, j_dst] += 1
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_AMPA, j_dst] += w
end
end
!isnothing(i_dst_exci) && for j_dst j_dsts_exci
states[i_dst_exci][:S_AMPA, j_dst] += 1
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_AMPA, j_dst] += w
end
end
elseif states_src isa SubsystemStates{LIFInhNeuron}
!isnothing(i_dst_inh) && for j_dst j_dsts_inh
states[i_dst_inh][:S_GABA, j_dst] += 1
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_GABA, j_dst] += w
end
end
!isnothing(i_dst_exci) && for j_dst j_dsts_exci
states[i_dst_exci][:S_GABA, j_dst] += 1
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_GABA, j_dst] += w
end
end
else
error("this should be unreachable")
Expand Down
32 changes: 27 additions & 5 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,39 @@ end
generate_discrete_callbacks(blox, ::BloxConnector; t_block = missing) = []

function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::BloxConnector; t_block = missing)
spike_affect_states = get_spike_affect_states(bc)
spike_affects = get_spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)

states_dest = get(spike_affect_states, name_blox, Num[])
states_affect, params_affect = get(spike_affects, name_blox, (Num[], Num[]))

sys = get_namespaced_sys(blox)
# HACK : MTK will complain if the parameter vector passed to a functional affect
# contains non-unique parameters. Here we sometimes need to pass duplicate parameters that
# affect states in the loop in LIF_spike_affect! .
# Passing parameters with Symbol aliases bypasses this issue and allows for duplicates.
affect_pairs = if unique(params_affect) == length(params_affect)
[p => Symbol(p) for p in params_affect]
else
map(params_affect) do p
if count(pi -> Symbol(pi) == Symbol(p), params_affect) > 1
p => Symbol(p, "_$(rand(1:1000))")
else
p => Symbol(p)
end
end
end

ps = vcat([
sys.V_reset => Symbol(sys.V_reset),
sys.t_refract_duration => Symbol(sys.t_refract_duration),
sys.t_refract_end => Symbol(sys.t_refract_end),
sys.is_refractory => Symbol(sys.is_refractory)
], affect_pairs)

cb = (sys.V > sys.θ) => (
LIF_spike_affect!,
vcat(sys.V, states_dest),
[sys.V_reset, sys.t_refract_duration, sys.t_refract_end, sys.is_refractory],
vcat(sys.V, states_affect),
ps,
[],
nothing
)
Expand Down
6 changes: 3 additions & 3 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks
get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks get_connector)(blox)
get_discrete_callbacks(blox) = []

get_spike_affect_states(bc::BloxConnector) = bc.spike_affect_states
get_spike_affect_states(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affect_states get_connector)(blox)
get_spike_affect_states(blox) = Dict{Symbol, Vector{Num}}()
get_spike_affects(bc::BloxConnector) = bc.spike_affects
get_spike_affects(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affects get_connector)(blox)
get_spike_affects(blox) = Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}()

get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules
get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules get_connector)(blox)
Expand Down
40 changes: 17 additions & 23 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mutable struct BloxConnector
weights::Vector{Num}
delays::Vector{Num}
discrete_callbacks
spike_affect_states::Dict{Symbol, Vector{Num}}
spike_affects::Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}
learning_rules
adjacency

Expand All @@ -14,16 +14,17 @@ mutable struct BloxConnector
weights = mapreduce(get_weight_parameters, vcat, bloxs)
delays = mapreduce(get_delay_parameters, vcat, bloxs)
discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs)
# spike_affect_states holds a Dictionary that maps
# the name of a source Blox to the states of a destination Blox
# that are affected by a continuous callback of the source Blox.
# spike_affects holds a Dictionary that maps
# the name of a source Blox to a Tuple of (states, parameters) of a destination Blox.
# The states are affected by a discrete callback of the source Blox
# and the parameters determine the amount of this affect like `states .+= parameters`.
# Typically this is used when a source Blox spikes, so its Voltage state crosses a threshold,
# and this spike affects synaptic parameters of every destination Blox that it connects to.
spike_affect_states = mapreduce(get_spike_affect_states, merge, bloxs)
spike_affects = mapreduce(get_spike_affects, merge, bloxs)
learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs)
adjacency = mapreduce(get_adjacency, merge, bloxs)

new(eqs, weights, delays, discrete_callbacks, spike_affect_states, learning_rules, adjacency)
new(eqs, weights, delays, discrete_callbacks, spike_affects, learning_rules, adjacency)
end
end

Expand All @@ -33,11 +34,13 @@ function accumulate_equation!(bc::BloxConnector, eq)
bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs
end

function accumulate_spike_affect_states!(bc::BloxConnector, name_blox_src, states_dst)
if haskey(bc.spike_affect_states, name_blox_src)
append!(bc.spike_affect_states[name_blox_src], states_dst)
function accumulate_spike_affects!(bc::BloxConnector, name_blox_src, states_affect, params_affect)
if haskey(bc.spike_affects, name_blox_src)
spike_affects = bc.spike_affects[name_blox_src]
append!(spike_affects[1], states_affect)
append!(spike_affects[2], params_affect)
else
bc.spike_affect_states[name_blox_src] = states_dst
bc.spike_affects[name_blox_src] = (states_affect, params_affect)
end
end

Expand Down Expand Up @@ -891,18 +894,16 @@ function (bc::BloxConnector)(
w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

eq = sys_in.jcn ~ w * sys_in.S_AMPA * sys_in.g_AMPA * (sys_in.V - sys_in.V_E) +
w * sys_out.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) /
eq = sys_in.jcn ~ w * sys_out.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) /
(1 + sys_in.Mg * exp(-0.062 * sys_in.V) / 3.57)

accumulate_equation!(bc, eq)

# Compare the unique namespaced names of both systems
if nameof(sys_out) == nameof(sys_in)
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x])
accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x], [w, w])
else
accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_AMPA])
accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA], [w])
end
end

Expand All @@ -917,11 +918,7 @@ function (bc::BloxConnector)(
w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

eq = sys_in.jcn ~ w * sys_in.S_GABA * sys_in.g_GABA * (sys_in.V - sys_in.V_I)

accumulate_equation!(bc, eq)

accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_GABA])
accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_GABA], [w])
end

function (bc::BloxConnector)(
Expand All @@ -931,9 +928,6 @@ function (bc::BloxConnector)(
)
sys_in = get_namespaced_sys(neuron)

w = generate_weight_param(stim, neuron; kwargs...)
push!(bc.weights, w)

t_spikes = generate_spike_times(stim)

cb = t_spikes => [sys_in.S_AMPA_ext ~ sys_in.S_AMPA_ext + 1]
Expand Down
8 changes: 5 additions & 3 deletions src/blox/neuron_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,10 @@ function LIF_spike_affect!(integ, u, p, ctx)

SciMLBase.add_tstop!(integ, t_refract_end)

c = 1
for i in eachindex(u)[2:end]
integ.u[u[i]] += 1
integ.u[u[i]] += integ.p[p[c + 4]]
c += 1
end
end

Expand Down Expand Up @@ -693,7 +695,7 @@ struct LIFInhNeuron <: AbstractInhNeuronBlox

sts = @variables V(t)=-52 S_AMPA(t)=0 S_GABA(t)=0 S_AMPA_ext(t)=0 jcn(t) [input=true] jcn_external(t) [input=true]
eqs = [
D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - jcn) / C,
D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - S_GABA * g_GABA * (V - V_I) - S_AMPA * g_AMPA * (V - V_E) - jcn) / C,
D(S_AMPA) ~ - S_AMPA / τ_AMPA,
D(S_GABA) ~ - S_GABA / τ_GABA,
D(S_AMPA_ext) ~ - S_AMPA_ext / τ_AMPA
Expand Down Expand Up @@ -761,7 +763,7 @@ struct LIFExciNeuron <: AbstractExciNeuronBlox

sts = @variables V(t)=-52 S_AMPA(t)=0 S_GABA(t)=0 S_NMDA(t)=0 x(t)=0 S_AMPA_ext(t)=0 jcn(t) [input=true]
eqs = [
D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - jcn) / C,
D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - S_GABA * g_GABA * (V - V_I) - S_AMPA * g_AMPA * (V - V_E) - jcn) / C,
D(S_AMPA) ~ - S_AMPA / τ_AMPA,
D(S_GABA) ~ - S_GABA / τ_GABA,
D(S_NMDA) ~ - S_NMDA / τ_NMDA_decay + α * x * (1 - S_NMDA),
Expand Down

0 comments on commit 1eba98e

Please sign in to comment.