-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial working version of DBS source * update square pulses functions * include stimulus function in the `DBS` struct * fix time issue in the smooth square pulse * add functions for computing transition times and values for stimulus signals * change DBS connection rule for Adam's model * fix start_time in square pulse * improve compute_transition_times for smooth pulses * make transitions detections more robust and efficient * add tests * remove offset from u(t) in DBS source * change `detect_transitions` to always return only the transitions indices * separate DBS connection rules by blox type * add DBS connection rules tests --------- Co-authored-by: Mason Protter <mason.protter@icloud.com> Co-authored-by: haris organtzidis <organtzh@gmail.com>
- Loading branch information
1 parent
b7eaf5d
commit 36709b9
Showing
7 changed files
with
254 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
struct DBS <: StimulusBlox | ||
params::Vector{Num} | ||
odesystem::ODESystem | ||
namespace::Union{Symbol, Nothing} | ||
stimulus::Function | ||
|
||
function DBS(; | ||
name, | ||
namespace=nothing, | ||
frequency=130.0, | ||
amplitude=100.0, | ||
pulse_width=0.15, | ||
offset=0.0, | ||
start_time=0.0, | ||
smooth=1e-4 | ||
) | ||
|
||
if smooth == 0 | ||
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width) | ||
else | ||
stimulus = t -> square(t, frequency, amplitude, offset, start_time, pulse_width, smooth) | ||
end | ||
|
||
p = paramscoping( | ||
frequency=frequency, | ||
amplitude=amplitude, | ||
pulse_width=pulse_width, | ||
offset=offset, | ||
start_time=start_time | ||
) | ||
|
||
sts = @variables u(t) [output = true] | ||
|
||
eqs = [u ~ stimulus(t)] | ||
sys = System(eqs, t, sts, p; name=name) | ||
|
||
new(p, sys, namespace, stimulus) | ||
end | ||
end | ||
|
||
function sawtooth(t, f, offset) | ||
f * (t - offset) - floor(f * (t - offset)) | ||
end | ||
|
||
# Smoothed square pulses | ||
function square(t, f, amplitude, offset, start_time, pulse_width, δ) | ||
invδ = 1 / δ | ||
pulse_width_fraction = pulse_width * f | ||
threshold = 1 - 2 * pulse_width_fraction | ||
amp_half = 0.5 * amplitude | ||
start_time = start_time + 0.5 * pulse_width | ||
|
||
saw = sawtooth(t, f, start_time) | ||
triangle_wave = 4 * abs(saw - 0.5) - 1 | ||
y = amp_half * (1 + tanh(invδ * (triangle_wave - threshold))) + offset | ||
|
||
return y | ||
end | ||
|
||
# Non-smoothed square pulses | ||
function square(t, f, amplitude, offset, start_time, pulse_width) | ||
|
||
saw1 = sawtooth(t - start_time, f, pulse_width) | ||
saw2 = sawtooth(t - start_time, f, 0) | ||
saw3 = sawtooth(-start_time, f, pulse_width) | ||
saw4 = sawtooth(-start_time, f, 0) | ||
|
||
y = amplitude * (saw1 - saw2 - saw3 + saw4) + offset | ||
|
||
return y | ||
end | ||
|
||
function detect_transitions(t, signal::Vector{T}; atol=0) where T <: AbstractFloat | ||
low = minimum(signal) | ||
high = maximum(signal) | ||
|
||
# Get indexes when the signal is approximately equal to its low and high values | ||
low_inds = isapprox.(signal, low; atol=atol) | ||
high_inds = isapprox.(signal, high; atol=atol) | ||
|
||
# Detect each type of transitions | ||
trans_inds_1 = diff(low_inds) .== 1 | ||
trans_inds_2 = diff(low_inds) .== -1 | ||
trans_inds_3 = diff(high_inds) .== 1 | ||
trans_inds_4 = diff(high_inds) .== -1 | ||
circshift!(trans_inds_1, -1) | ||
circshift!(trans_inds_3, -1) | ||
|
||
# Combine all transition | ||
transitions_inds = trans_inds_1 .| trans_inds_2 .| trans_inds_3 .| trans_inds_4 | ||
pushfirst!(transitions_inds, false) | ||
|
||
return transitions_inds | ||
end | ||
|
||
function compute_transition_times(stimulus::Function, f , dt, tspan, start_time, pulse_width; atol=0) | ||
period = 1 / f | ||
n_periods = floor((tspan[end] - start_time) / period) | ||
|
||
# Detect single pulse transition points | ||
t = (start_time + 0.5 * period):dt:(start_time + 1.5 * period) | ||
s = stimulus.(t) | ||
transitions_inds = detect_transitions(t, s, atol=atol) | ||
single_pulse = t[transitions_inds] | ||
|
||
# Calculate pulse times across all periods | ||
period_offsets = (-1:n_periods+1) * period | ||
pulses = single_pulse .+ period_offsets' | ||
transition_times = vec(pulses) | ||
|
||
# Filter estimated times within the actual time range | ||
inds = (transition_times .>= tspan[1]) .& (transition_times .<= tspan[end]) | ||
|
||
return transition_times[inds] | ||
end | ||
|
||
function compute_transition_values(transition_times, t, signal) | ||
|
||
# Ensure transition_points are within the range of t, assuming both are ordered | ||
@assert begin | ||
t[1] <= transition_times[1] | ||
transition_times[end] <= t[end] | ||
end "Transition points must be within the range of t" | ||
|
||
# Find the indices of the closest time points | ||
indices = searchsortedfirst.(Ref(t), transition_times) | ||
transition_values = signal[indices] | ||
|
||
return transition_values | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
using Neuroblox | ||
using Test | ||
|
||
@testset "Detection of stimulus transitions" begin | ||
frequency = 0.130 | ||
amplitude = 10.0 | ||
pulse_width = 1 | ||
smooth = 1e-3 | ||
start_time = 5 | ||
offset = -2.0 | ||
dt = 1e-4 | ||
tspan = (0,30) | ||
t = tspan[1]:dt:tspan[2] | ||
|
||
@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) | ||
stimulus = dbs.stimulus.(t) | ||
transitions_inds = detect_transitions(t, stimulus; atol=0.05) | ||
transition_times1 = t[transitions_inds] | ||
transition_values1 = stimulus[transitions_inds] | ||
transition_times2 = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) | ||
transition_values2 = compute_transition_values(transition_times2, t, stimulus) | ||
@test all(isapprox.(transition_times1, transition_times2, rtol=1e-3)) | ||
@test all(isapprox.(transition_values1, transition_values2, rtol=1e-2)) | ||
|
||
smooth = 1e-10 | ||
@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) | ||
transition_times_smoothed = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) | ||
smooth = 0 | ||
@named dbs = DBS(namespace=:g, frequency=frequency, amplitude=amplitude, pulse_width=pulse_width, start_time=start_time, smooth=smooth, offset=offset) | ||
transition_times_non_smooth = compute_transition_times(dbs.stimulus, frequency, dt, tspan, start_time, pulse_width; atol=0.05) | ||
@test all(isapprox.(transition_times_smoothed, transition_times_non_smooth)) | ||
end | ||
|
||
@testset "DBS connections" begin | ||
# Test DBS -> single AbstractNeuronBlox | ||
@named dbs = DBS() | ||
@named n1 = HHNeuronExciBlox() | ||
g = MetaDiGraph() | ||
add_edge!(g, dbs => n1, weight = 1.0) | ||
sys = system_from_graph(g; name=:test) | ||
@test sys isa ODESystem | ||
|
||
# Test DBS -> Adam's STN | ||
@named stn = HHNeuronExci_STN_Adam_Blox() | ||
g = MetaDiGraph() | ||
add_edge!(g, dbs => stn, weight = 1.0) | ||
sys = system_from_graph(g; name=:test) | ||
@test sys isa SDESystem | ||
|
||
# Test DBS -> NeuralMassBlox | ||
@named mass = JansenRit() | ||
g = MetaDiGraph() | ||
add_edge!(g, dbs => mass, weight = 1.0) | ||
sys = system_from_graph(g; name=:test) | ||
@test sys isa ODESystem | ||
|
||
# Test DBS -> CompositeBlox | ||
@named cb = CorticalBlox(namespace=:g, N_wta=2, N_exci=2, density=0.1, weight=1.0) | ||
g = MetaDiGraph() | ||
add_edge!(g, dbs => cb, weight = 1.0) | ||
sys = system_from_graph(g; name=:test) | ||
@test sys isa ODESystem | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters