Skip to content

Commit

Permalink
Fixes to support solvers which do intermediate ForwardDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Oct 10, 2024
1 parent 2fe7ec8 commit e1dd0f0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.1.3"
version = "0.1.4"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
9 changes: 6 additions & 3 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module GraphDynamics

macro public(ex)
if VERSION >= v"1.11.0-DEV.469"
args = ex isa Symbol ? (ex,) : Base.isexpr(ex, :tuple) ? ex.args : error("malformed input to `@public`: $ex")
args = ex isa Symbol ? (ex,) : Base.isexpr(ex, :tuple) ? ex.args :
error("malformed input to `@public`: $ex")
esc(Expr(:public, args...))
else
nothing
Expand Down Expand Up @@ -189,7 +190,9 @@ function initialize_input end
When a `Subsystem` is connected to multiple other subsystems, all of the inputs sent to that `Subsystem` via the connections must be `combine`'d together into one input representing the accumulation of all of the inputs. `combine` is the function used to accumulate these inputs together at each step. Defaults to addition, but can have methods added to it for more exotic input types.
"""
combine(x::Number, y::Number) = x + y
combine(x::NamedTuple, y::NamedTuple) = typeof(x)(combine.(Tuple(x), Tuple(y)))
function combine(x::NamedTuple{names}, y::NamedTuple{names}) where {names}
NamedTuple{names}(combine.(Tuple(x), Tuple(y)))
end


"""
Expand Down Expand Up @@ -223,7 +226,7 @@ end
"""
event_times(::T) = ()
add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically.
add methods to this function if a subsystem or connection type has a discrete event that triggers at pre-defined times. This will be used to add `tstops` to the `ODEProblem` or `SDEProblem` automatically during `GraphSystem` construction. This is vital for discrete events which only trigger at a specific time.
"""
event_times(::Any) = ()

Expand Down
19 changes: 15 additions & 4 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,30 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T
struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States}
data::Mat
end
VectorOfSubsystemStates{States}(v::Mat) where {States, Mat} = VectorOfSubsystemStates{States, Mat}(v)
function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(
v::AbstractMatrix{U}
) where {Name, T, U, snames, Tup}
V = promote_type(T,U)
States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}}
VectorOfSubsystemStates{States, typeof(v)}(v)
end

Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2),)

@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, idx::Integer) where {States <: SubsystemStates}
l = length(States)
#@boundscheck checkbounds(v.data, 1:l, idx)
@boundscheck checkbounds(v.data, 1:l, idx)
@inbounds States(view(v.data, 1:l, idx))
end

@noinline function sym_not_found_error(::Type{SubsystemStates{Name, T, NamedTuple{names}}}, s::Symbol) where {Name, T, names}
error("SubsystemStates{$Name} does not have a field $s, valid fields are $names")
end

@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
error("Something helpful")
sym_not_found_error(States, s)
end
v.data[i, idx]
end
Expand All @@ -231,7 +242,7 @@ end
idx::Integer) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
error("Something helpful")
sym_not_found_error(States, s)
end
v.data[i, idx] = val
end
Expand Down

0 comments on commit e1dd0f0

Please sign in to comment.