Skip to content

Commit

Permalink
Refactor cubic spline and hermite interpolation implementations
Browse files Browse the repository at this point in the history
The cubic spline and hermite interpolations in interpolation module have been refactored for better clarity and performance. This included significant changes to the interpolation logic and use of vectors in natural_cubic.rs and hermite.rs, reducing duplication and improving readability. Additional mathematical operations were also implemented in tridiagonal_matrix.rs.
  • Loading branch information
nakashima-hikaru committed Jun 16, 2024
1 parent 747629b commit 78eb4b1
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 68 deletions.
2 changes: 1 addition & 1 deletion crates/qlab-error/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Cow;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::fmt::{Debug, Display, Formatter};
use std::ops::Deref;
use thiserror::Error;

Expand Down
56 changes: 21 additions & 35 deletions crates/qlab-math/src/interpolation/spline/hermite.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
use crate::interpolation::spline::Value;
use crate::interpolation::{find_index_at_left_boundary, Point2DWithSlope};
use nalgebra::Matrix4;
use nalgebra::Vector4;
use num_traits::Zero;
use qlab_error::InterpolationError;

struct Point3<V> {
pub x: V,
pub y: V,
pub dydx: V,
}

pub struct Hermite<V: Value> {
points: Vec<Point3<V>>,
points: Vec<Point2DWithSlope<V>>,
m: Matrix4<V>,
}

Expand Down Expand Up @@ -41,11 +35,11 @@ impl<V: Value> Hermite<V> {

let mut points = Vec::new();
for &(x, y, dydx) in raw_points {
let point = Point3 { x, y, dydx };
if point.x < temp {
let point = Point2DWithSlope::new(x, y, dydx);
if point.coordinate.x < temp {
return Err(InterpolationError::PointOrderError);
}
temp = point.x;
temp = point.coordinate.x;
points.push(point);
}
let m = Matrix4::new(
Expand Down Expand Up @@ -88,29 +82,21 @@ impl<V: Value> Hermite<V> {
/// # Panics
/// Will panic if partial comparison of points fail.
pub fn try_value(&self, x: V) -> Result<V, InterpolationError<V>> {
match self
.points
.binary_search_by(|point| point.x.partial_cmp(&x).unwrap())
{
Ok(pos) => Ok(self.points[pos].y),
Err(pos) => {
if pos.is_zero() {
return Err(InterpolationError::OutOfLowerBound(x));
}
if pos > self.points.len() {
return Err(InterpolationError::OutOfUpperBound(x));
}
let pos = pos - 1;
let point = &self.points[pos];
let next_point = &self.points[pos + 1];
let h = next_point.x - point.x;
let delta = (x - point.x) / h;
let delta2 = delta * delta;
let delta3 = delta2 * delta;
let d = Vector4::new(delta3, delta2, delta, V::from_i8(1).unwrap());
let f = Vector4::new(point.y, next_point.y, point.dydx * h, next_point.dydx * h);
Ok((d.transpose() * self.m * f).x)
}
}
let pos = find_index_at_left_boundary(&self.points, x)?;

let point = &self.points[pos];
let next_point = &self.points[pos + 1];
let h = next_point.coordinate.x - point.coordinate.x;
let delta = (x - point.coordinate.x) / h;
let delta2 = delta * delta;
let delta3 = delta2 * delta;
let d = Vector4::new(delta3, delta2, delta, V::one());
let f = Vector4::new(
point.coordinate.y,
next_point.coordinate.y,
point.dydx * h,
next_point.dydx * h,
);
Ok((d.transpose() * self.m * f).x)
}
}
76 changes: 45 additions & 31 deletions crates/qlab-math/src/interpolation/spline/natural_cubic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::interpolation::spline::Value;
use crate::interpolation::{find_index_at_left_boundary, Interpolator, Point2DWithSlope};
use crate::linear_algebra::tridiagonal_matrix::TridiagonalMatrix;
use nalgebra::{DVector, Dim, Dyn, VecStorage, U1};
use qlab_error::InterpolationError;

#[derive(Default)]
Expand Down Expand Up @@ -35,41 +36,54 @@ impl<V: Value> Interpolator<NaturalCubic<V>, V> for NaturalCubic<V> {
raw_points.len(),
));
}
let mut du = Vec::with_capacity(raw_points.len() - 1);
let mut d = Vec::with_capacity(raw_points.len());
let mut dl = Vec::with_capacity(raw_points.len() - 1);
for i in 0..raw_points.len() {
if i == 0 {
du.push(V::zero());
d.push(V::one());
} else if i + 1 == raw_points.len() {
d.push(V::one());
dl.push(V::zero());
} else {
let h = raw_points[i].0 - raw_points[i - 1].0;
let h_next = raw_points[i + 1].0 - raw_points[i].0;
du.push(h_next / V::from_i8(6).unwrap());
d.push((h + h_next) / V::from_i8(3).unwrap());
dl.push(h / V::from_i8(6).unwrap());
let mut m = Vec::with_capacity(raw_points.len() - 2);
let mut b_upper_diagonals = Vec::with_capacity(raw_points.len() - 3);
let mut b_diagonals = Vec::with_capacity(raw_points.len() - 2);
let mut b_lower_diagonals = Vec::with_capacity(raw_points.len() - 3);
let mut c_upper_diagonals = Vec::with_capacity(raw_points.len() - 3);
let mut c_diagonals = Vec::with_capacity(raw_points.len() - 2);
let mut c_lower_diagonals = Vec::with_capacity(raw_points.len() - 3);
for i in 1..raw_points.len() - 1 {
let h = raw_points[i].0 - raw_points[i - 1].0;
let h_next = raw_points[i + 1].0 - raw_points[i].0;
if i != 1 {
b_lower_diagonals.push(h / V::from_i8(6).unwrap());
c_lower_diagonals.push(h.recip());
}
}

let mut b = Vec::with_capacity(raw_points.len());
for i in 0..raw_points.len() {
if i == 0 || i + 1 == raw_points.len() {
b.push(V::zero());
if i + 2 != raw_points.len() {
b_upper_diagonals.push(h_next / V::from_i8(6).unwrap());
c_upper_diagonals.push(h_next.recip());
}
if i == 1 {
m.push(raw_points[i - 1].1 / h);
} else if i + 2 == raw_points.len() {
m.push(raw_points[i + 1].1 / h_next);
} else {
b.push(
(raw_points[i + 1].1 - raw_points[i].1)
/ (raw_points[i + 1].0 - raw_points[i].0)
- (raw_points[i].1 - raw_points[i - 1].1)
/ (raw_points[i].0 - raw_points[i - 1].0),
);
m.push(V::zero());
}
b_diagonals.push((h + h_next) / V::from_i8(3).unwrap());
c_diagonals.push(-(h.recip() + h_next.recip()));
}

let matrix = TridiagonalMatrix::try_new(du, d, dl).unwrap();
let derivatives = matrix.solve(&b);
let b =
TridiagonalMatrix::try_new(b_upper_diagonals, b_diagonals, b_lower_diagonals).unwrap();
let c =
TridiagonalMatrix::try_new(c_upper_diagonals, c_diagonals, c_lower_diagonals).unwrap();
let mut y = Vec::with_capacity(raw_points.len() - 2);
for raw_point in raw_points.iter().take(raw_points.len() - 1).skip(1) {
y.push(raw_point.1);
}
let rhs = VecStorage::new(Dyn::from_usize(raw_points.len() - 2), U1, (c * y).unwrap());
let mut rhs = DVector::from_data(rhs);
let m = VecStorage::new(Dyn::from_usize(raw_points.len() - 2), U1, m);
let m = DVector::from_data(m);
rhs += m;
let y2 = b.solve(rhs.as_slice());
let mut derivatives = Vec::with_capacity(raw_points.len());
derivatives.push(V::zero());
for val in y2 {
derivatives.push(val);
}
derivatives.push(V::zero());

let mut temp = raw_points[0].0;
let mut points = Vec::new();
Expand Down
27 changes: 27 additions & 0 deletions crates/qlab-math/src/linear_algebra/tridiagonal_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::value::Value;
use std::ops::Mul;

#[derive(Debug)]
pub enum MatrixValidationError {
Expand Down Expand Up @@ -33,6 +34,9 @@ impl<V: Value> TridiagonalMatrix<V> {

// Solve Ax = b.
pub fn solve(self, b: &[V]) -> Vec<V> {
if self.size == 1 {
return vec![b[0] / self.diagonal[0]];
}
// shape validation is already done at construction phase
solve_with_thomas_algorithm_unchecked(
self.size,
Expand All @@ -44,6 +48,29 @@ impl<V: Value> TridiagonalMatrix<V> {
}
}

impl<V: Value> Mul<Vec<V>> for TridiagonalMatrix<V> {
type Output = Option<Vec<V>>;

fn mul(self, rhs: Vec<V>) -> Self::Output {
if rhs.len() != self.size {
return None;
}
let mut ret = Vec::with_capacity(self.size);
for i in 0..self.size {
let mut temp = self.diagonal[i] * rhs[i];
if i + 1 < self.size {
temp += self.upper_diagonal[i] * rhs[i + 1];
}
if i > 0 {
temp += self.lower_diagonal[i - 1] * rhs[i - 1];
}

ret.push(temp);
}
Some(ret)
}
}

fn solve_with_thomas_algorithm_unchecked<V: Value>(
matrix_size: usize,
lower_diagonal: &[V],
Expand Down
2 changes: 2 additions & 0 deletions crates/qlab-termstructure/src/yield_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ impl<V: Value, D: DayCount, I: Interpolator<I, V>> YieldCurve<D, V, I> {
let y1 = self.yield_curve(t1)?;
Ok((t1 * y1 - t2 * y2).exp())
}

// Calculates continuous yield at the specified time.
fn yield_curve(&self, t: V) -> QLabResult<V> {
Ok(self.interpolator.try_value(t)?)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/qlab/tests/calculate_bond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ fn main() {
.discounted_value(spot_settle_date, &yield_curve)
.unwrap();
println!("{}", bond_20_yr.bond_id());
println!("{val}"); // 1314.5664389486494
println!("{val}"); // 1314.5577192000126
}

0 comments on commit 78eb4b1

Please sign in to comment.