Skip to content

Commit

Permalink
DBS source (#443)
Browse files Browse the repository at this point in the history
* initial working version of DBS source

* update square pulses functions

* include stimulus function in the `DBS` struct

* fix time issue in the smooth square pulse

* add functions for computing transition times and values for stimulus signals

* change DBS connection rule for Adam's model

* fix start_time in square pulse

* improve compute_transition_times for smooth pulses

* make transitions detections more robust and efficient

* add tests

* remove offset from u(t) in DBS source

* change `detect_transitions` to always return only the transitions indices

* separate DBS connection rules by blox type

* add DBS connection rules tests

---------

Co-authored-by: Mason Protter <mason.protter@icloud.com>
Co-authored-by: haris organtzidis <organtzh@gmail.com>
  • Loading branch information
3 people authored Oct 24, 2024
1 parent b7eaf5d commit 36709b9
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ include("blox/DBS_Model_Blox_Adam_Brown.jl")
include("blox/van_der_pol.jl")
include("blox/ts_outputs.jl")
include("blox/sources.jl")
include("blox/DBS_sources.jl")
include("blox/rl_blox.jl")
include("blox/winnertakeall.jl")
include("blox/subcortical_blox.jl")
Expand Down Expand Up @@ -222,7 +223,7 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, detect_transitions, compute_transition_times, compute_transition_values
export BandPassFilterBlox
export OUBlox, OUCouplingBlox
export phase_inter, phase_sin_blox, phase_cos_blox
Expand Down
130 changes: 130 additions & 0 deletions src/blox/DBS_sources.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
struct DBS <: StimulusBlox
params::Vector{Num}
odesystem::ODESystem
namespace::Union{Symbol, Nothing}
stimulus::Function

function DBS(;
name,
namespace=nothing,
frequency=130.0,
amplitude=100.0,
pulse_width=0.15,
offset=0.0,
start_time=0.0,
smooth=1e-4
)

if smooth == 0
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width)
else
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width, smooth)
end

p = paramscoping(
frequency=frequency,
amplitude=amplitude,
pulse_width=pulse_width,
offset=offset,
start_time=start_time
)

sts = @variables u(t) [output = true]

eqs = [u ~ stimulus(t)]
sys = System(eqs, t, sts, p; name=name)

new(p, sys, namespace, stimulus)
end
end

function sawtooth(t, f, offset)
f * (t - offset) - floor(f * (t - offset))
end

# Smoothed square pulses
function square(t, f, amplitude, offset, start_time, pulse_width, δ)
invδ = 1 / δ
pulse_width_fraction = pulse_width * f
threshold = 1 - 2 * pulse_width_fraction
amp_half = 0.5 * amplitude
start_time = start_time + 0.5 * pulse_width

saw = sawtooth(t, f, start_time)
triangle_wave = 4 * abs(saw - 0.5) - 1
y = amp_half * (1 + tanh(invδ * (triangle_wave - threshold))) + offset

return y
end

# Non-smoothed square pulses
function square(t, f, amplitude, offset, start_time, pulse_width)

saw1 = sawtooth(t - start_time, f, pulse_width)
saw2 = sawtooth(t - start_time, f, 0)
saw3 = sawtooth(-start_time, f, pulse_width)
saw4 = sawtooth(-start_time, f, 0)

y = amplitude * (saw1 - saw2 - saw3 + saw4) + offset

return y
end

function detect_transitions(t, signal::Vector{T}; atol=0) where T <: AbstractFloat
low = minimum(signal)
high = maximum(signal)

# Get indexes when the signal is approximately equal to its low and high values
low_inds = isapprox.(signal, low; atol=atol)
high_inds = isapprox.(signal, high; atol=atol)

# Detect each type of transitions
trans_inds_1 = diff(low_inds) .== 1
trans_inds_2 = diff(low_inds) .== -1
trans_inds_3 = diff(high_inds) .== 1
trans_inds_4 = diff(high_inds) .== -1
circshift!(trans_inds_1, -1)
circshift!(trans_inds_3, -1)

# Combine all transition
transitions_inds = trans_inds_1 .| trans_inds_2 .| trans_inds_3 .| trans_inds_4
pushfirst!(transitions_inds, false)

return transitions_inds
end

function compute_transition_times(stimulus::Function, f , dt, tspan, start_time, pulse_width; atol=0)
period = 1 / f
n_periods = floor((tspan[end] - start_time) / period)

# Detect single pulse transition points
t = (start_time + 0.5 * period):dt:(start_time + 1.5 * period)
s = stimulus.(t)
transitions_inds = detect_transitions(t, s, atol=atol)
single_pulse = t[transitions_inds]

# Calculate pulse times across all periods
period_offsets = (-1:n_periods+1) * period
pulses = single_pulse .+ period_offsets'
transition_times = vec(pulses)

# Filter estimated times within the actual time range
inds = (transition_times .>= tspan[1]) .& (transition_times .<= tspan[end])

return transition_times[inds]
end

function compute_transition_values(transition_times, t, signal)

# Ensure transition_points are within the range of t, assuming both are ordered
@assert begin
t[1] <= transition_times[1]
transition_times[end] <= t[end]
end "Transition points must be within the range of t"

# Find the indices of the closest time points
indices = searchsortedfirst.(Ref(t), transition_times)
transition_values = signal[indices]

return transition_values
end
2 changes: 1 addition & 1 deletion src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ get_parts(blox::CompositeBlox) = blox.parts

get_components(blox::CompositeBlox) = mapreduce(x -> get_components(x), vcat, get_parts(blox))
get_components(blox::Vector{<:AbstractBlox}) = mapreduce(x -> get_components(x), vcat, blox)
get_components(blox) = [blox]
get_components(blox::Union{NeuralMassBlox, AbstractNeuronBlox}) = [blox]

get_neuron_color(n::AbstractExciNeuronBlox) = "blue"
get_neuron_color(n::AbstractInhNeuronBlox) = "red"
Expand Down
53 changes: 53 additions & 0 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -990,3 +990,56 @@ function (bc::BloxConnector)(

accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
bloxout::DBS,
bloxin::CompositeBlox;
kwargs...
)
components = get_components(bloxin)
for comp in components
bc(bloxout, comp; kwargs...)
end
end

function (bc::BloxConnector)(
bloxout::DBS,
bloxin::AbstractNeuronBlox;
kwargs...
)
sys_dbs = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

eq = sys_in.I_in ~ w * sys_dbs.u
accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
bloxout::DBS,
bloxin::NeuralMassBlox;
kwargs...
)
sys_dbs = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

w = generate_weight_param(bloxout, bloxin; kwargs...)
push!(bc.weights, w)

eq = sys_in.jcn ~ w * sys_dbs.u
accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
bloxout::DBS,
bloxin::HHNeuronExci_STN_Adam_Blox;
kwargs...
)
sys_dbs = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

eq = sys_in.DBS_in ~ - sys_in.V/sys_in.b + sys_dbs.u
accumulate_equation!(bc, eq)
end
6 changes: 4 additions & 2 deletions src/blox/neuron_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ struct HHNeuronInhib_FSI_Adam_Blox <: AbstractInhNeuronBlox
hD_inf(v) = 1/(1+exp((v+70)/6))
τₕD(v) = 150
G_asymp(v,a,b) = a*(1+tanh(v/b))

eqs = [
D(V)~(1/Cₘ)*(-G_Na*m_inf(V)^3*h*(V-E_Na)-G_K*n^2*(V-E_K)-G_L*(V-E_L)-G_D*mD^3*hD*(V-E_K)+I_bg*(sin(t*freq*2*pi/1000)+1)+I_syn+I_gap+I_asc+I_in+σ*χ),
D(n)~(n_inf(V)-n)/τₙ(V),
Expand Down Expand Up @@ -375,6 +375,8 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox
[input=true]
I_asc(t)
[input=true]
DBS_in(t)
[input=true]
G(t)=0.0
[output = true]
end
Expand Down Expand Up @@ -405,7 +407,7 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox
αₕ(v) = 0.128*exp(-(v+50)/18)
βₕ(v) = 4/(1+exp(-(v+27)/5))

G_asymp(v,a,b) = a*(1+tanh(v/b))
G_asymp(v,a,b) = a*(1+tanh(v/b + DBS_in))

eqs = [
D(V)~(1/Cₘ)*(-G_Na*m^3*h*(V-E_Na)-G_K*n^4*(V-E_K)-G_L*(V-E_L)+I_bg*(sin(t*freq*2*pi/1000)+1)+I_syn+I_asc+I_in+σ*χ),
Expand Down
63 changes: 63 additions & 0 deletions test/dbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Neuroblox
using Test

@testset "Detection of stimulus transitions" begin
frequency = 0.130
amplitude = 10.0
pulse_width = 1
smooth = 1e-3
start_time = 5
offset = -2.0
dt = 1e-4
tspan = (0,30)
t = tspan[1]:dt:tspan[2]

@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset)
stimulus = dbs.stimulus.(t)
transitions_inds = detect_transitions(t, stimulus; atol=0.05)
transition_times1 = t[transitions_inds]
transition_values1 = stimulus[transitions_inds]
transition_times2 = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05)
transition_values2 = compute_transition_values(transition_times2, t, stimulus)
@test all(isapprox.(transition_times1, transition_times2, rtol=1e-3))
@test all(isapprox.(transition_values1, transition_values2, rtol=1e-2))

smooth = 1e-10
@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset)
transition_times_smoothed = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05)
smooth = 0
@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset)
transition_times_non_smooth = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05)
@test all(isapprox.(transition_times_smoothed, transition_times_non_smooth))
end

@testset "DBS connections" begin
# Test DBS -> single AbstractNeuronBlox
@named dbs = DBS()
@named n1 = HHNeuronExciBlox()
g = MetaDiGraph()
add_edge!(g, dbs => n1, weight = 1.0)
sys = system_from_graph(g; name=:test)
@test sys isa ODESystem

# Test DBS -> Adam's STN
@named stn = HHNeuronExci_STN_Adam_Blox()
g = MetaDiGraph()
add_edge!(g, dbs => stn, weight = 1.0)
sys = system_from_graph(g; name=:test)
@test sys isa SDESystem

# Test DBS -> NeuralMassBlox
@named mass = JansenRit()
g = MetaDiGraph()
add_edge!(g, dbs => mass, weight = 1.0)
sys = system_from_graph(g; name=:test)
@test sys isa ODESystem

# Test DBS -> CompositeBlox
@named cb = CorticalBlox(namespace=:g, N_wta=2, N_exci=2, density=0.1, weight=1.0)
g = MetaDiGraph()
add_edge!(g, dbs => cb, weight = 1.0)
sys = system_from_graph(g; name=:test)
@test sys isa ODESystem
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Advanced"
@time @safetestset "Source Tests" begin include("source_components.jl") end
@time @safetestset "Reinforcement Learning Tests" begin include("reinforcement_learning.jl") end
@time @safetestset "Cort-Cort plasticity Tests" begin include("plasticity.jl") end
@time @safetestset "DBS" begin include("dbs.jl") end
end


Expand Down

0 comments on commit 36709b9

Please sign in to comment.