From e1dd0f07a0c6e5d51d4b82c4808d61ba2f7370a3 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 10 Oct 2024 23:16:08 +0200 Subject: [PATCH] Fixes to support solvers which do intermediate ForwardDiff --- Project.toml | 2 +- src/GraphDynamics.jl | 9 ++++++--- src/subsystems.jl | 19 +++++++++++++++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 3549b8e..5a0e0b3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GraphDynamics" uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" -version = "0.1.3" +version = "0.1.4" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index 07bd3c3..976e7f3 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -2,7 +2,8 @@ module GraphDynamics macro public(ex) if VERSION >= v"1.11.0-DEV.469" - args = ex isa Symbol ? (ex,) : Base.isexpr(ex, :tuple) ? ex.args : error("malformed input to `@public`: $ex") + args = ex isa Symbol ? (ex,) : Base.isexpr(ex, :tuple) ? ex.args : + error("malformed input to `@public`: $ex") esc(Expr(:public, args...)) else nothing @@ -189,7 +190,9 @@ function initialize_input end When a `Subsystem` is connected to multiple other subsystems, all of the inputs sent to that `Subsystem` via the connections must be `combine`'d together into one input representing the accumulation of all of the inputs. `combine` is the function used to accumulate these inputs together at each step. Defaults to addition, but can have methods added to it for more exotic input types. """ combine(x::Number, y::Number) = x + y -combine(x::NamedTuple, y::NamedTuple) = typeof(x)(combine.(Tuple(x), Tuple(y))) +function combine(x::NamedTuple{names}, y::NamedTuple{names}) where {names} + NamedTuple{names}(combine.(Tuple(x), Tuple(y))) +end """ @@ -223,7 +226,7 @@ end """ event_times(::T) = () -add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically. +add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time. """ event_times(::Any) = () diff --git a/src/subsystems.jl b/src/subsystems.jl index 0b5c3cb..2a67466 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -201,19 +201,30 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States} data::Mat end -VectorOfSubsystemStates{States}(v::Mat) where {States, Mat} = VectorOfSubsystemStates{States, Mat}(v) +function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}( + v::AbstractMatrix{U} + ) where {Name, T, U, snames, Tup} + V = promote_type(T,U) + States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}} + VectorOfSubsystemStates{States, typeof(v)}(v) +end Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2),) @propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, idx::Integer) where {States <: SubsystemStates} l = length(States) - #@boundscheck checkbounds(v.data, 1:l, idx) + @boundscheck checkbounds(v.data, 1:l, idx) @inbounds States(view(v.data, 1:l, idx)) end + +@noinline function sym_not_found_error(::Type{SubsystemStates{Name, T, NamedTuple{names}}}, s::Symbol) where {Name, T, names} + error("SubsystemStates{$Name} does not have a field $s, valid fields are $names") +end + @propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates} i = state_ind(States, s) if isnothing(i) - error("Something helpful") + sym_not_found_error(States, s) end v.data[i, idx] end @@ -231,7 +242,7 @@ end idx::Integer) where {States <: SubsystemStates} i = state_ind(States, s) if isnothing(i) - error("Something helpful") + sym_not_found_error(States, s) end v.data[i, idx] = val end