From 36709b94488740bc5066a7da5da3726e9d6a89eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Germ=C3=A1n=20Abrevaya?= Date: Thu, 24 Oct 2024 10:21:26 -0300 Subject: [PATCH] DBS source (#443) * initial working version of DBS source * update square pulses functions * include stimulus function in the `DBS` struct * fix time issue in the smooth square pulse * add functions for computing transition times and values for stimulus signals * change DBS connection rule for Adam's model * fix start_time in square pulse * improve compute_transition_times for smooth pulses * make transitions detections more robust and efficient * add tests * remove offset from u(t) in DBS source * change `detect_transitions` to always return only the transitions indices * separate DBS connection rules by blox type * add DBS connection rules tests --------- Co-authored-by: Mason Protter Co-authored-by: haris organtzidis --- src/Neuroblox.jl | 3 +- src/blox/DBS_sources.jl | 130 +++++++++++++++++++++++++++++++++++++ src/blox/blox_utilities.jl | 2 +- src/blox/connections.jl | 53 +++++++++++++++ src/blox/neuron_models.jl | 6 +- test/dbs.jl | 63 ++++++++++++++++++ test/runtests.jl | 1 + 7 files changed, 254 insertions(+), 4 deletions(-) create mode 100644 src/blox/DBS_sources.jl create mode 100644 test/dbs.jl diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index f98c5cdb..16de702b 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -104,6 +104,7 @@ include("blox/DBS_Model_Blox_Adam_Brown.jl") include("blox/van_der_pol.jl") include("blox/ts_outputs.jl") include("blox/sources.jl") +include("blox/DBS_sources.jl") include("blox/rl_blox.jl") include("blox/winnertakeall.jl") include("blox/subcortical_blox.jl") @@ -222,7 +223,7 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity export Agent, ClassificationEnvironment, GreedyPolicy, reset! export LearningBlox -export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain +export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, detect_transitions, compute_transition_times, compute_transition_values export BandPassFilterBlox export OUBlox, OUCouplingBlox export phase_inter, phase_sin_blox, phase_cos_blox diff --git a/src/blox/DBS_sources.jl b/src/blox/DBS_sources.jl new file mode 100644 index 00000000..014d89b9 --- /dev/null +++ b/src/blox/DBS_sources.jl @@ -0,0 +1,130 @@ +struct DBS <: StimulusBlox + params::Vector{Num} + odesystem::ODESystem + namespace::Union{Symbol, Nothing} + stimulus::Function + + function DBS(; + name, + namespace=nothing, + frequency=130.0, + amplitude=100.0, + pulse_width=0.15, + offset=0.0, + start_time=0.0, + smooth=1e-4 + ) + + if smooth == 0 + stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width) + else + stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width, smooth) + end + + p = paramscoping( + frequency=frequency, + amplitude=amplitude, + pulse_width=pulse_width, + offset=offset, + start_time=start_time + ) + + sts = @variables u(t) [output = true] + + eqs = [u ~ stimulus(t)] + sys = System(eqs, t, sts, p; name=name) + + new(p, sys, namespace, stimulus) + end +end + +function sawtooth(t, f, offset) + f * (t - offset) - floor(f * (t - offset)) +end + +# Smoothed square pulses +function square(t, f, amplitude, offset, start_time, pulse_width, δ) + invδ = 1 / δ + pulse_width_fraction = pulse_width * f + threshold = 1 - 2 * pulse_width_fraction + amp_half = 0.5 * amplitude + start_time = start_time + 0.5 * pulse_width + + saw = sawtooth(t, f, start_time) + triangle_wave = 4 * abs(saw - 0.5) - 1 + y = amp_half * (1 + tanh(invδ * (triangle_wave - threshold))) + offset + + return y +end + +# Non-smoothed square pulses +function square(t, f, amplitude, offset, start_time, pulse_width) + + saw1 = sawtooth(t - start_time, f, pulse_width) + saw2 = sawtooth(t - start_time, f, 0) + saw3 = sawtooth(-start_time, f, pulse_width) + saw4 = sawtooth(-start_time, f, 0) + + y = amplitude * (saw1 - saw2 - saw3 + saw4) + offset + + return y +end + +function detect_transitions(t, signal::Vector{T}; atol=0) where T <: AbstractFloat + low = minimum(signal) + high = maximum(signal) + + # Get indexes when the signal is approximately equal to its low and high values + low_inds = isapprox.(signal, low; atol=atol) + high_inds = isapprox.(signal, high; atol=atol) + + # Detect each type of transitions + trans_inds_1 = diff(low_inds) .== 1 + trans_inds_2 = diff(low_inds) .== -1 + trans_inds_3 = diff(high_inds) .== 1 + trans_inds_4 = diff(high_inds) .== -1 + circshift!(trans_inds_1, -1) + circshift!(trans_inds_3, -1) + + # Combine all transition + transitions_inds = trans_inds_1 .| trans_inds_2 .| trans_inds_3 .| trans_inds_4 + pushfirst!(transitions_inds, false) + + return transitions_inds +end + +function compute_transition_times(stimulus::Function, f , dt, tspan, start_time, pulse_width; atol=0) + period = 1 / f + n_periods = floor((tspan[end] - start_time) / period) + + # Detect single pulse transition points + t = (start_time + 0.5 * period):dt:(start_time + 1.5 * period) + s = stimulus.(t) + transitions_inds = detect_transitions(t, s, atol=atol) + single_pulse = t[transitions_inds] + + # Calculate pulse times across all periods + period_offsets = (-1:n_periods+1) * period + pulses = single_pulse .+ period_offsets' + transition_times = vec(pulses) + + # Filter estimated times within the actual time range + inds = (transition_times .>= tspan[1]) .& (transition_times .<= tspan[end]) + + return transition_times[inds] +end + +function compute_transition_values(transition_times, t, signal) + + # Ensure transition_points are within the range of t, assuming both are ordered + @assert begin + t[1] <= transition_times[1] + transition_times[end] <= t[end] + end "Transition points must be within the range of t" + + # Find the indices of the closest time points + indices = searchsortedfirst.(Ref(t), transition_times) + transition_values = signal[indices] + + return transition_values +end \ No newline at end of file diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index a45e59d5..c385d347 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -67,7 +67,7 @@ get_parts(blox::CompositeBlox) = blox.parts get_components(blox::CompositeBlox) = mapreduce(x -> get_components(x), vcat, get_parts(blox)) get_components(blox::Vector{<:AbstractBlox}) = mapreduce(x -> get_components(x), vcat, blox) -get_components(blox) = [blox] +get_components(blox::Union{NeuralMassBlox, AbstractNeuronBlox}) = [blox] get_neuron_color(n::AbstractExciNeuronBlox) = "blue" get_neuron_color(n::AbstractInhNeuronBlox) = "red" diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 4c2daadc..34358795 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -990,3 +990,56 @@ function (bc::BloxConnector)( accumulate_equation!(bc, eq) end + +function (bc::BloxConnector)( + bloxout::DBS, + bloxin::CompositeBlox; + kwargs... +) + components = get_components(bloxin) + for comp in components + bc(bloxout, comp; kwargs...) + end +end + +function (bc::BloxConnector)( + bloxout::DBS, + bloxin::AbstractNeuronBlox; + kwargs... +) + sys_dbs = get_namespaced_sys(bloxout) + sys_in = get_namespaced_sys(bloxin) + + w = generate_weight_param(bloxout, bloxin; kwargs...) + push!(bc.weights, w) + + eq = sys_in.I_in ~ w * sys_dbs.u + accumulate_equation!(bc, eq) +end + +function (bc::BloxConnector)( + bloxout::DBS, + bloxin::NeuralMassBlox; + kwargs... +) + sys_dbs = get_namespaced_sys(bloxout) + sys_in = get_namespaced_sys(bloxin) + + w = generate_weight_param(bloxout, bloxin; kwargs...) + push!(bc.weights, w) + + eq = sys_in.jcn ~ w * sys_dbs.u + accumulate_equation!(bc, eq) +end + +function (bc::BloxConnector)( + bloxout::DBS, + bloxin::HHNeuronExci_STN_Adam_Blox; + kwargs... +) + sys_dbs = get_namespaced_sys(bloxout) + sys_in = get_namespaced_sys(bloxin) + + eq = sys_in.DBS_in ~ - sys_in.V/sys_in.b + sys_dbs.u + accumulate_equation!(bc, eq) +end \ No newline at end of file diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index dae88d62..b4d67ef3 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -327,7 +327,7 @@ struct HHNeuronInhib_FSI_Adam_Blox <: AbstractInhNeuronBlox hD_inf(v) = 1/(1+exp((v+70)/6)) τₕD(v) = 150 G_asymp(v,a,b) = a*(1+tanh(v/b)) - + eqs = [ D(V)~(1/Cₘ)*(-G_Na*m_inf(V)^3*h*(V-E_Na)-G_K*n^2*(V-E_K)-G_L*(V-E_L)-G_D*mD^3*hD*(V-E_K)+I_bg*(sin(t*freq*2*pi/1000)+1)+I_syn+I_gap+I_asc+I_in+σ*χ), D(n)~(n_inf(V)-n)/τₙ(V), @@ -375,6 +375,8 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox [input=true] I_asc(t) [input=true] + DBS_in(t) + [input=true] G(t)=0.0 [output = true] end @@ -405,7 +407,7 @@ struct HHNeuronExci_STN_Adam_Blox <: AbstractExciNeuronBlox αₕ(v) = 0.128*exp(-(v+50)/18) βₕ(v) = 4/(1+exp(-(v+27)/5)) - G_asymp(v,a,b) = a*(1+tanh(v/b)) + G_asymp(v,a,b) = a*(1+tanh(v/b + DBS_in)) eqs = [ D(V)~(1/Cₘ)*(-G_Na*m^3*h*(V-E_Na)-G_K*n^4*(V-E_K)-G_L*(V-E_L)+I_bg*(sin(t*freq*2*pi/1000)+1)+I_syn+I_asc+I_in+σ*χ), diff --git a/test/dbs.jl b/test/dbs.jl new file mode 100644 index 00000000..ae9150b3 --- /dev/null +++ b/test/dbs.jl @@ -0,0 +1,63 @@ +using Neuroblox +using Test + +@testset "Detection of stimulus transitions" begin + frequency = 0.130 + amplitude = 10.0 + pulse_width = 1 + smooth = 1e-3 + start_time = 5 + offset = -2.0 + dt = 1e-4 + tspan = (0,30) + t = tspan[1]:dt:tspan[2] + + @named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) + stimulus = dbs.stimulus.(t) + transitions_inds = detect_transitions(t, stimulus; atol=0.05) + transition_times1 = t[transitions_inds] + transition_values1 = stimulus[transitions_inds] + transition_times2 = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) + transition_values2 = compute_transition_values(transition_times2, t, stimulus) + @test all(isapprox.(transition_times1, transition_times2, rtol=1e-3)) + @test all(isapprox.(transition_values1, transition_values2, rtol=1e-2)) + + smooth = 1e-10 + @named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) + transition_times_smoothed = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) + smooth = 0 + @named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) + transition_times_non_smooth = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) + @test all(isapprox.(transition_times_smoothed, transition_times_non_smooth)) +end + +@testset "DBS connections" begin + # Test DBS -> single AbstractNeuronBlox + @named dbs = DBS() + @named n1 = HHNeuronExciBlox() + g = MetaDiGraph() + add_edge!(g, dbs => n1, weight = 1.0) + sys = system_from_graph(g; name=:test) + @test sys isa ODESystem + + # Test DBS -> Adam's STN + @named stn = HHNeuronExci_STN_Adam_Blox() + g = MetaDiGraph() + add_edge!(g, dbs => stn, weight = 1.0) + sys = system_from_graph(g; name=:test) + @test sys isa SDESystem + + # Test DBS -> NeuralMassBlox + @named mass = JansenRit() + g = MetaDiGraph() + add_edge!(g, dbs => mass, weight = 1.0) + sys = system_from_graph(g; name=:test) + @test sys isa ODESystem + + # Test DBS -> CompositeBlox + @named cb = CorticalBlox(namespace=:g, N_wta=2, N_exci=2, density=0.1, weight=1.0) + g = MetaDiGraph() + add_edge!(g, dbs => cb, weight = 1.0) + sys = system_from_graph(g; name=:test) + @test sys isa ODESystem +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 59c1ec1f..c205323b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Advanced" @time @safetestset "Source Tests" begin include("source_components.jl") end @time @safetestset "Reinforcement Learning Tests" begin include("reinforcement_learning.jl") end @time @safetestset "Cort-Cort plasticity Tests" begin include("plasticity.jl") end + @time @safetestset "DBS" begin include("dbs.jl") end end