Skip to content

Commit

Permalink
cache discrete event conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Nov 13, 2024
1 parent b6ec333 commit 77cd87e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
42 changes: 26 additions & 16 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,28 +277,32 @@ end
#----------------------------------------------------------
# Infra. for discrete events.
#----------------------------------------------------------

function discrete_condition(u, t, integrator)
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
(;params_partitioned, state_types_val, connection_matrices, discrete_event_cache) = integrator.p
states_partitioned = to_vec_o_states(u.x, state_types_val)
_discrete_condition!(states_partitioned, params_partitioned, t, connection_matrices)
_discrete_condition!(states_partitioned, params_partitioned, t, connection_matrices, discrete_event_cache)
end

using GraphDynamics.OhMyThreads: tmapreduce
tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)

@generated function _discrete_condition!(states_partitioned ::NTuple{Len, Any},
params_partitioned ::NTuple{Len, Any},
t,
connection_matrices::ConnectionMatrices{NConn},) where {Len, NConn}
quote
connection_matrices::ConnectionMatrices{NConn},
discrete_event_cache ::NTuple{Len, Any}) where {Len, NConn}
quote
trigger = false
@nexprs $Len i -> begin
if has_discrete_events(eltype(states_partitioned[i]))
for j eachindex(states_partitioned[i])
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
discrete_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t, F) && return true
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
cond = discrete_event_condition(sys, t, F)
trigger |= cond
discrete_event_cache[i][j] = cond
end
end
end
trigger && return true
@nexprs $NConn nc -> begin
@nexprs $Len i -> begin
@nexprs $Len k -> begin
Expand All @@ -318,27 +322,33 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
end

function discrete_affect!(integrator)
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
(;params_partitioned, state_types_val, connection_matrices, discrete_event_cache) = integrator.p
state_data = integrator.u.x
states_partitioned = to_vec_o_states(state_data, state_types_val)
_discrete_affect!(integrator, states_partitioned, params_partitioned, connection_matrices, integrator.t)
_discrete_affect!(integrator,
states_partitioned,
params_partitioned,
connection_matrices,
discrete_event_cache,
integrator.t)
end

@generated function _discrete_affect!(integrator,
states_partitioned ::NTuple{Len, Any},
params_partitioned ::NTuple{Len, Any},
connection_matrices::ConnectionMatrices{NConn},
discrete_event_cache ::NTuple{Len, Any},
t) where {Len, NConn}
quote
@nexprs $Len i -> begin
# First we apply events to the states
if has_discrete_events(eltype(states_partitioned[i]))
@inbounds for j eachindex(states_partitioned[i])
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_event_condition(sys, t, F)
if discrete_event_cache[i][j]
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_discrete_event!(integrator, sview, pview, sys, F, input)
Expand All @@ -347,6 +357,7 @@ end
end
end
end
discrete_event_cache[i] .= false
end
# Then we do the connection events
@nexprs $NConn nc -> begin
Expand All @@ -361,7 +372,6 @@ end
end
end


function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
Expand Down
16 changes: 13 additions & 3 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[];
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt
tstops = vcat(tstops, nt.tstops)
prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(f, u, tspan, p; callback, tstops, kwargs...)
for (k, v) u0map
setu(prob, k)(prob, v)
end
Expand All @@ -35,11 +35,12 @@ function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[];
prob
end

Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV, DEC}
params_partitioned::PP
connection_matrices::CM
scheduler::S
state_types_val::STV
discrete_event_cache::DEC
end

function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map)
Expand Down Expand Up @@ -107,11 +108,20 @@ function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, pa
error(ArgumentError("The provided subsystem states do not have a concrete eltype. All partitions must contain the same eltype. Got `eltype(u) = $(eltype(u))`."))
end

discrete_event_cache = ntuple(length(states_partitioned)) do i
len = has_discrete_events(eltype(states_partitioned[i])) ? length(states_partitioned[i]) : 0
falses(len)
end

ce = nce > 0 ? VectorContinuousCallback(continuous_condition, continuous_affect!, nce) : nothing
de = nde > 0 ? DiscreteCallback(discrete_condition, discrete_affect!) : nothing
callback = CallbackSet(ce, de, composite_discrete_callbacks(composite_discrete_events_partitioned))
f = GraphSystemFunction(graph_ode!, g)
p = GraphSystemParameters(; params_partitioned, connection_matrices, scheduler, state_types_val)
p = GraphSystemParameters(; params_partitioned,
connection_matrices,
scheduler,
state_types_val,
discrete_event_cache)
(; f, u, tspan, p, callback, tstops)
end

Expand Down

0 comments on commit 77cd87e

Please sign in to comment.