From 78eb4b13b009c173cb75ffeeea6d0645486c7b22 Mon Sep 17 00:00:00 2001 From: hnakashima Date: Sun, 16 Jun 2024 21:42:10 +0900 Subject: [PATCH] Refactor cubic spline and hermite interpolation implementations The cubic spline and hermite interpolations in interpolation module have been refactored for better clarity and performance. This included significant changes to the interpolation logic and use of vectors in natural_cubic.rs and hermite.rs, reducing duplication and improving readability. Additional mathematical operations were also implemented in tridiagonal_matrix.rs. --- crates/qlab-error/src/lib.rs | 2 +- .../src/interpolation/spline/hermite.rs | 56 +++++--------- .../src/interpolation/spline/natural_cubic.rs | 76 +++++++++++-------- .../src/linear_algebra/tridiagonal_matrix.rs | 27 +++++++ crates/qlab-termstructure/src/yield_curve.rs | 2 + crates/qlab/tests/calculate_bond.rs | 2 +- 6 files changed, 97 insertions(+), 68 deletions(-) diff --git a/crates/qlab-error/src/lib.rs b/crates/qlab-error/src/lib.rs index 62e7500..f340605 100644 --- a/crates/qlab-error/src/lib.rs +++ b/crates/qlab-error/src/lib.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::fmt; -use std::fmt::{Display, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::ops::Deref; use thiserror::Error; diff --git a/crates/qlab-math/src/interpolation/spline/hermite.rs b/crates/qlab-math/src/interpolation/spline/hermite.rs index 7edc080..035818d 100644 --- a/crates/qlab-math/src/interpolation/spline/hermite.rs +++ b/crates/qlab-math/src/interpolation/spline/hermite.rs @@ -1,17 +1,11 @@ use crate::interpolation::spline::Value; +use crate::interpolation::{find_index_at_left_boundary, Point2DWithSlope}; use nalgebra::Matrix4; use nalgebra::Vector4; -use num_traits::Zero; use qlab_error::InterpolationError; -struct Point3 { - pub x: V, - pub y: V, - pub dydx: V, -} - pub struct Hermite { - points: Vec>, + points: Vec>, m: Matrix4, } @@ -41,11 +35,11 @@ impl Hermite { let mut points = Vec::new(); for &(x, y, dydx) in raw_points { - let point = Point3 { x, y, dydx }; - if point.x < temp { + let point = Point2DWithSlope::new(x, y, dydx); + if point.coordinate.x < temp { return Err(InterpolationError::PointOrderError); } - temp = point.x; + temp = point.coordinate.x; points.push(point); } let m = Matrix4::new( @@ -88,29 +82,21 @@ impl Hermite { /// # Panics /// Will panic if partial comparison of points fail. pub fn try_value(&self, x: V) -> Result> { - match self - .points - .binary_search_by(|point| point.x.partial_cmp(&x).unwrap()) - { - Ok(pos) => Ok(self.points[pos].y), - Err(pos) => { - if pos.is_zero() { - return Err(InterpolationError::OutOfLowerBound(x)); - } - if pos > self.points.len() { - return Err(InterpolationError::OutOfUpperBound(x)); - } - let pos = pos - 1; - let point = &self.points[pos]; - let next_point = &self.points[pos + 1]; - let h = next_point.x - point.x; - let delta = (x - point.x) / h; - let delta2 = delta * delta; - let delta3 = delta2 * delta; - let d = Vector4::new(delta3, delta2, delta, V::from_i8(1).unwrap()); - let f = Vector4::new(point.y, next_point.y, point.dydx * h, next_point.dydx * h); - Ok((d.transpose() * self.m * f).x) - } - } + let pos = find_index_at_left_boundary(&self.points, x)?; + + let point = &self.points[pos]; + let next_point = &self.points[pos + 1]; + let h = next_point.coordinate.x - point.coordinate.x; + let delta = (x - point.coordinate.x) / h; + let delta2 = delta * delta; + let delta3 = delta2 * delta; + let d = Vector4::new(delta3, delta2, delta, V::one()); + let f = Vector4::new( + point.coordinate.y, + next_point.coordinate.y, + point.dydx * h, + next_point.dydx * h, + ); + Ok((d.transpose() * self.m * f).x) } } diff --git a/crates/qlab-math/src/interpolation/spline/natural_cubic.rs b/crates/qlab-math/src/interpolation/spline/natural_cubic.rs index 822353f..7888398 100644 --- a/crates/qlab-math/src/interpolation/spline/natural_cubic.rs +++ b/crates/qlab-math/src/interpolation/spline/natural_cubic.rs @@ -1,6 +1,7 @@ use crate::interpolation::spline::Value; use crate::interpolation::{find_index_at_left_boundary, Interpolator, Point2DWithSlope}; use crate::linear_algebra::tridiagonal_matrix::TridiagonalMatrix; +use nalgebra::{DVector, Dim, Dyn, VecStorage, U1}; use qlab_error::InterpolationError; #[derive(Default)] @@ -35,41 +36,54 @@ impl Interpolator, V> for NaturalCubic { raw_points.len(), )); } - let mut du = Vec::with_capacity(raw_points.len() - 1); - let mut d = Vec::with_capacity(raw_points.len()); - let mut dl = Vec::with_capacity(raw_points.len() - 1); - for i in 0..raw_points.len() { - if i == 0 { - du.push(V::zero()); - d.push(V::one()); - } else if i + 1 == raw_points.len() { - d.push(V::one()); - dl.push(V::zero()); - } else { - let h = raw_points[i].0 - raw_points[i - 1].0; - let h_next = raw_points[i + 1].0 - raw_points[i].0; - du.push(h_next / V::from_i8(6).unwrap()); - d.push((h + h_next) / V::from_i8(3).unwrap()); - dl.push(h / V::from_i8(6).unwrap()); + let mut m = Vec::with_capacity(raw_points.len() - 2); + let mut b_upper_diagonals = Vec::with_capacity(raw_points.len() - 3); + let mut b_diagonals = Vec::with_capacity(raw_points.len() - 2); + let mut b_lower_diagonals = Vec::with_capacity(raw_points.len() - 3); + let mut c_upper_diagonals = Vec::with_capacity(raw_points.len() - 3); + let mut c_diagonals = Vec::with_capacity(raw_points.len() - 2); + let mut c_lower_diagonals = Vec::with_capacity(raw_points.len() - 3); + for i in 1..raw_points.len() - 1 { + let h = raw_points[i].0 - raw_points[i - 1].0; + let h_next = raw_points[i + 1].0 - raw_points[i].0; + if i != 1 { + b_lower_diagonals.push(h / V::from_i8(6).unwrap()); + c_lower_diagonals.push(h.recip()); } - } - - let mut b = Vec::with_capacity(raw_points.len()); - for i in 0..raw_points.len() { - if i == 0 || i + 1 == raw_points.len() { - b.push(V::zero()); + if i + 2 != raw_points.len() { + b_upper_diagonals.push(h_next / V::from_i8(6).unwrap()); + c_upper_diagonals.push(h_next.recip()); + } + if i == 1 { + m.push(raw_points[i - 1].1 / h); + } else if i + 2 == raw_points.len() { + m.push(raw_points[i + 1].1 / h_next); } else { - b.push( - (raw_points[i + 1].1 - raw_points[i].1) - / (raw_points[i + 1].0 - raw_points[i].0) - - (raw_points[i].1 - raw_points[i - 1].1) - / (raw_points[i].0 - raw_points[i - 1].0), - ); + m.push(V::zero()); } + b_diagonals.push((h + h_next) / V::from_i8(3).unwrap()); + c_diagonals.push(-(h.recip() + h_next.recip())); } - - let matrix = TridiagonalMatrix::try_new(du, d, dl).unwrap(); - let derivatives = matrix.solve(&b); + let b = + TridiagonalMatrix::try_new(b_upper_diagonals, b_diagonals, b_lower_diagonals).unwrap(); + let c = + TridiagonalMatrix::try_new(c_upper_diagonals, c_diagonals, c_lower_diagonals).unwrap(); + let mut y = Vec::with_capacity(raw_points.len() - 2); + for raw_point in raw_points.iter().take(raw_points.len() - 1).skip(1) { + y.push(raw_point.1); + } + let rhs = VecStorage::new(Dyn::from_usize(raw_points.len() - 2), U1, (c * y).unwrap()); + let mut rhs = DVector::from_data(rhs); + let m = VecStorage::new(Dyn::from_usize(raw_points.len() - 2), U1, m); + let m = DVector::from_data(m); + rhs += m; + let y2 = b.solve(rhs.as_slice()); + let mut derivatives = Vec::with_capacity(raw_points.len()); + derivatives.push(V::zero()); + for val in y2 { + derivatives.push(val); + } + derivatives.push(V::zero()); let mut temp = raw_points[0].0; let mut points = Vec::new(); diff --git a/crates/qlab-math/src/linear_algebra/tridiagonal_matrix.rs b/crates/qlab-math/src/linear_algebra/tridiagonal_matrix.rs index 308d1fb..81b613b 100644 --- a/crates/qlab-math/src/linear_algebra/tridiagonal_matrix.rs +++ b/crates/qlab-math/src/linear_algebra/tridiagonal_matrix.rs @@ -1,4 +1,5 @@ use crate::value::Value; +use std::ops::Mul; #[derive(Debug)] pub enum MatrixValidationError { @@ -33,6 +34,9 @@ impl TridiagonalMatrix { // Solve Ax = b. pub fn solve(self, b: &[V]) -> Vec { + if self.size == 1 { + return vec![b[0] / self.diagonal[0]]; + } // shape validation is already done at construction phase solve_with_thomas_algorithm_unchecked( self.size, @@ -44,6 +48,29 @@ impl TridiagonalMatrix { } } +impl Mul> for TridiagonalMatrix { + type Output = Option>; + + fn mul(self, rhs: Vec) -> Self::Output { + if rhs.len() != self.size { + return None; + } + let mut ret = Vec::with_capacity(self.size); + for i in 0..self.size { + let mut temp = self.diagonal[i] * rhs[i]; + if i + 1 < self.size { + temp += self.upper_diagonal[i] * rhs[i + 1]; + } + if i > 0 { + temp += self.lower_diagonal[i - 1] * rhs[i - 1]; + } + + ret.push(temp); + } + Some(ret) + } +} + fn solve_with_thomas_algorithm_unchecked( matrix_size: usize, lower_diagonal: &[V], diff --git a/crates/qlab-termstructure/src/yield_curve.rs b/crates/qlab-termstructure/src/yield_curve.rs index 5fbcfd4..92d303c 100644 --- a/crates/qlab-termstructure/src/yield_curve.rs +++ b/crates/qlab-termstructure/src/yield_curve.rs @@ -103,6 +103,8 @@ impl> YieldCurve { let y1 = self.yield_curve(t1)?; Ok((t1 * y1 - t2 * y2).exp()) } + + // Calculates continuous yield at the specified time. fn yield_curve(&self, t: V) -> QLabResult { Ok(self.interpolator.try_value(t)?) } diff --git a/crates/qlab/tests/calculate_bond.rs b/crates/qlab/tests/calculate_bond.rs index 890bce0..636ea78 100644 --- a/crates/qlab/tests/calculate_bond.rs +++ b/crates/qlab/tests/calculate_bond.rs @@ -50,5 +50,5 @@ fn main() { .discounted_value(spot_settle_date, &yield_curve) .unwrap(); println!("{}", bond_20_yr.bond_id()); - println!("{val}"); // 1314.5664389486494 + println!("{val}"); // 1314.5577192000126 }