Skip to content

Commit

Permalink
adapt GraphDynamics code to Kuramoto cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Jan 16, 2025
1 parent ce333e6 commit b7da64c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 58 deletions.
2 changes: 2 additions & 0 deletions src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ using ..Neuroblox:
Matrisome,
TAN,
SNc,
AbstractKuramotoOscillator,
KuramotoOscillator,
KuramotoOscillatorNoise,
CorticalBlox,
STN,
Thalamus,
Expand Down
7 changes: 4 additions & 3 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,17 @@ end

#----------------------------------------------
# Kuramoto
function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs)
function get_connection(src::AbstractKuramotoOscillator, dst::AbstractKuramotoOscillator, kwargs)
(;w_val, name) = generate_weight_param(src, dst, kwargs)
(;conn=BasicConnection(w_val), names=[name])
end

function (c::BasicConnection)(src::Subsystem{KuramotoOscillator}, dst::Subsystem{KuramotoOscillator})
function (c::BasicConnection)(src::Subsystem{<:AbstractKuramotoOscillator},
dst::Subsystem{<:AbstractKuramotoOscillator})
w = c.weight
x₀ = src.θ
xᵢ = dst.θ
w * sin(x₀ - xᵢ)
(;jcn = w * sin(x₀ - xᵢ))
end

#----------------------------------------------
Expand Down
61 changes: 32 additions & 29 deletions src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ function define_neurons()
(:pinhib, :PINGNeuronInhib)
(:VdP, :VanDerPol)
(:VdPN, :VanDerPolNoise)
(:ko, :KuramotoOscillator)
(:kon, :KuramotoOscillatorNoise)
]
sys = getproperty(Neuroblox, T)(;name)
system = structural_simplify(sys.system; fully_determined=false)
Expand Down Expand Up @@ -165,18 +167,6 @@ function define_neurons()
end
define_neurons() # it's useful when developing this module to have these in a function

#Maybe should just encorporate this into define_neurons()
# for T ∈ [:LIFExciNeuron, :LIFInhNeuron]
# @eval begin
# GraphDynamics.has_discrete_events(::Type{$T}) = true
# GraphDynamics.discrete_event_condition((; t_refract_end)::Subsystem{$T}, t) = t_refract_end == t
# function GraphDynamics.apply_discrete_event!(integrator, _, pview, neuron::Subsystem{$T}, _)
# params = get_params(neuron)
# pview[] = @set params.is_refractory = 0
# end
# end
# end

issupported(::PoissonSpikeTrain) = true
components(p::PoissonSpikeTrain) = (p,)
function to_subsystem(s::PoissonSpikeTrain)
Expand All @@ -189,24 +179,37 @@ GraphDynamics.apply_subsystem_differential!(_, ::Subsystem{PoissonSpikeTrain}, _
GraphDynamics.subsystem_differential_requires_inputs(::Type{PoissonSpikeTrain}) = false


#-------------------------
# Kuramoto
issupported(::KuramotoOscillator) = true
function to_subsystem(o::KuramotoOscillator)
states = SubsystemStates{KuramotoOscillator}((;θ=0.0,))
params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ)))
Subsystem(states, params)
end
# #-------------------------
# # Kuramoto
# issupported(::KuramotoOscillator) = true
# function to_subsystem(o::KuramotoOscillator)
# states = SubsystemStates{KuramotoOscillator}((;θ=0.0,))
# params = SubsystemParams{KuramotoOscillator}((;ω=getdefault(o.system.ω)))
# Subsystem(states, params)
# end

GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = 0.0
function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, jcn, t)
SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn))
end
GraphDynamics.isstochastic(::Type{KuramotoOscillator}) = true
function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillator}, t)
vstates[1] = ζ
nothing
end
# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillator}) = (;jcn = 0.0)
# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillator}, (; jcn), t)
# SubsystemStates{KuramotoOscillator}((; #=D=#θ= s.ω + jcn))
# end


# issupported(::KuramotoOscillatorNoise) = true
# function to_subsystem(o::KuramotoOscillatorNoise)
# states = SubsystemStates{KuramotoOscillatorNoise}((;θ=0.0,))
# params = SubsystemParams{KuramotoOscillatorNoise}((;ω=getdefault(o.system.ω), ζ=getdefault(o.system.ζ)))
# Subsystem(states, params)
# end

# GraphDynamics.initialize_input(s::Subsystem{KuramotoOscillatorNoise}) = (;jcn=0.0)
# function GraphDynamics.subsystem_differential(s::Subsystem{KuramotoOscillatorNoise}, (;jcn), t)
# SubsystemStates{KuramotoOscillatorNoise}((; #=D=#θ= s.ω + jcn))
# end
# GraphDynamics.isstochastic(::Type{KuramotoOscillatorNoise}) = true
# function GraphDynamics.apply_subsystem_noise!(vstates, (;ζ,)::Subsystem{KuramotoOscillatorNoise}, t)
# vstates[1] = ζ
# nothing
# end


#-------------------------
Expand Down
2 changes: 1 addition & 1 deletion test/GraphDynamicsTests/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end

if GROUP == "All" || GROUP == "GraphDynamics2"
vdp_test()
kuramato_test()
kuramoto_test()
wta_tests()
dbs_circuit_components()
dbs_circuit()
Expand Down
52 changes: 27 additions & 25 deletions test/GraphDynamicsTests/test_suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ end



function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, trajectories=50_000)
Random.seed!(1234)
function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rtol, mtk=true, alg=nothing, trajectories=100_000)
# Random.seed!(1234)
if graph isa Tuple
(graph_l, graph_r) = graph
else
graph_l = g
graph_r = g
graph_l = graph
graph_r = graph
end

@named gsys = system_from_graph(graph_l; graphdynamics=true)
Expand All @@ -267,8 +267,6 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto
dnoise = zero(u0)
g(dnoise, u0, p, 1.1)

@test solve(prob, ImplicitEM(), saveat = 0.01,reltol=1e-4,abstol=1e-4).retcode == ReturnCode.Success

ens_prob = EnsembleProblem(prob)
sols = solve(ens_prob, alg, EnsembleThreads(); trajectories)

Expand Down Expand Up @@ -313,6 +311,7 @@ function test_compare_du_and_sols_ensemble(::Type{SDEProblem}, graph, tspan; rto
end
@test sort(du_grp) sort(du_mtk) #due to the MTK getu bug, we'll compare the sorted versions
@test sort(dnoise_grp) sort(dnoise_mtk) #due to the MTK getu bug, we'll compare the sorted versions
@debug "" norm(mean(sol_grp_ens) .- mean(sol_mtk_ens)) / norm(mean(sol_grp_ens))
@test mean(sol_grp_ens) mean(sol_mtk_ens) rtol=rtol
@test std(sol_grp_ens) std(sol_mtk_ens) rtol=rtol
end
Expand Down Expand Up @@ -438,27 +437,30 @@ function ngei_test()
end
end

function kuramato_test()
@testset "Kuramoto" begin
N = 2
# Define the natural distribution of oscillator frequencies
Ω = 249
σ = 26.317
ks_blocks = [KuramotoOscillator(name=Symbol("KO$i"),
ω=rand(Normal(Ω, σ)),
ζ=5.920,
include_noise=true) for i in 1:N]
# Create a graph and add all the oscillators to it
g = MetaDiGraph()
add_blox!.(Ref(g), ks_blocks)
function kuramoto_test()
@testset "Kuramoto Oscillator" begin
@testset "Non-noisy" begin
@named K01 = kuramoto_oscillator=2.0)
@named K02 = kuramoto_oscillator=5.0)

# Connect all oscillators to each other
for i in 1:N
for j in 1:N
add_edge!(g, i, j, Dict(:weight => 1.0))
end
adj = [0 1; 1 0]
g = MetaDiGraph()
add_blox!.(Ref(g), [K01, K02])
create_adjacency_edges!(g, adj)

test_compare_du_and_sols(ODEProblem, g, (0.0, 2.0); rtol=1e-10, alg=AutoVern7(Rodas4()))
end
@testset "Noisy" begin
@named K01 = kuramoto_oscillator=2.0, noise=true)
@named K02 = kuramoto_oscillator=5.0, noise=true)

adj = [0 1; 1 0]
g = MetaDiGraph()
add_blox!.(Ref(g), [K01, K02])
create_adjacency_edges!(g, adj)

test_compare_du_and_sols(SDEProblem, g, (0.0, 2.0); rtol=1e-10, alg=RKMil())
end
test_compare_du_and_sols(SDEProblem, g, (0.0, 2.0), rtol=0.05, alg=RKMil())
end
end

Expand Down

0 comments on commit b7da64c

Please sign in to comment.