From 883d752bd07754d259d0e447a0f9ed52c30cf4c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Dec 2023 02:44:53 -0500 Subject: [PATCH] Properly support scalars --- Project.toml | 2 +- src/common_interface/algorithms.jl | 2 +- src/common_interface/solve.jl | 6 +++++- test/kinsol_nonlinear_solve.jl | 17 +++++++++++++++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index fa63f85..480d39b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Sundials" uuid = "c3572dad-4567-51f8-b174-8c6c989267f4" authors = ["Chris Rackauckas "] -version = "4.23.0" +version = "4.23.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/src/common_interface/algorithms.jl b/src/common_interface/algorithms.jl index e50d9b2..d1f1e60 100644 --- a/src/common_interface/algorithms.jl +++ b/src/common_interface/algorithms.jl @@ -3,7 +3,7 @@ # Abstract Types abstract type SundialsODEAlgorithm{Method, LinearSolver} <: DiffEqBase.AbstractODEAlgorithm end abstract type SundialsDAEAlgorithm{LinearSolver} <: DiffEqBase.AbstractDAEAlgorithm end -abstract type SundialsNonlinearSolveAlgorithm{LinearSolver} end +abstract type SundialsNonlinearSolveAlgorithm{LinearSolver} <: SciMLBase.AbstractNonlinearAlgorithm end SciMLBase.alg_order(alg::Union{SundialsODEAlgorithm, SundialsDAEAlgorithm}) = alg.max_order diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 6f7b086..f5ea33f 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -80,7 +80,11 @@ function DiffEqBase.__solve(prob::Union{ f!(resid, u) retcode = interpret_sundials_retcode(flag) - DiffEqBase.build_solution(prob, alg, u, resid; retcode = retcode) + if prob.u0 isa Number + DiffEqBase.build_solution(prob, alg, u[1], resid[1]; retcode = retcode) + else + DiffEqBase.build_solution(prob, alg, u, resid; retcode = retcode) + end end function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace}, diff --git a/test/kinsol_nonlinear_solve.jl b/test/kinsol_nonlinear_solve.jl index 35cd788..addd5ed 100644 --- a/test/kinsol_nonlinear_solve.jl +++ b/test/kinsol_nonlinear_solve.jl @@ -37,3 +37,20 @@ prob_oop = NonlinearProblem{false}(f_oop, u0) f_oop(sol.u, nothing) @test maximum(abs, du) < 1e-6 end + +# Scalar +f_scalar(u, p) = 2 - 2u +u0 = 0.0 +prob_scalar = NonlinearProblem{false}(f_scalar, u0) + +@testset "linear_solver = $(linear_solver) | globalization_strategy = $(globalization_strategy)" for linear_solver in (:Dense, + :LapackDense, :GMRES, :FGMRES, :PCG, :TFQMR), globalization_strategy in (:LineSearch, :None) + local sol + alg = KINSOL(; linear_solver, globalization_strategy) + sol = solve(prob_scalar, alg; abstol) + @test SciMLBase.successful_retcode(sol.retcode) + @test sol.u isa Number + + resid = f_scalar(sol.u, nothing) + @test abs(resid) < 1e-6 +end