Skip to content

Commit

Permalink
add tests, make outlier filter properly start and end sections of acq…
Browse files Browse the repository at this point in the history
…uisition
  • Loading branch information
mat-kie committed Dec 5, 2024
1 parent 781d8a4 commit 3d67281
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/model/acquisition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::bluetooth::HeartrateMessage;
use crate::model::hrv::{HrvSessionData, HrvStatistics};
use anyhow::Result;
use log::{trace, warn};
#[cfg(test)]
use mockall::automock;
use serde::{Deserialize, Deserializer, Serialize};
use std::fmt::Debug;
Expand All @@ -16,7 +17,7 @@ use time::{Duration, OffsetDateTime};
///
/// Defines the interface for managing acquisition-related data, including runtime measurements,
/// HRV statistics, and stored acquisitions.
#[automock]
#[cfg_attr(test, automock)]
pub trait AcquisitionModelApi: Debug + Send + Sync {
/// Retrieves the start time of the current acquisition.
///
Expand Down
96 changes: 82 additions & 14 deletions src/model/hrv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,32 +245,54 @@ impl HrvSessionData {
outlier_filter: f64,
window_size: usize,
) -> (Vec<f64>, Vec<Duration>) {
let predicate = |rr_window: &[f64], window_size: usize| {
let median_rr = rr_window[window_size / 2];
let mean_rr = rr_window.iter().sum::<f64>() / window_size as f64;
let deviation = (median_rr - mean_rr).abs() * 0.5;
let half_window = window_size / 2;

deviation < outlier_filter
// Helper function to check if a value is an outlier
let is_outlier = |idx: usize, values: &[f64]| {
let mut start = idx.saturating_sub(half_window);
let mut end = start + window_size;
if end >= values.len() {
end = values.len();
start = end.saturating_sub(window_size);
}

let window = &values[start..end];
let mean = window
.iter()
.enumerate()
.filter(|(i, _)| start + i != idx)
.map(|(_, &v)| v)
.sum::<f64>()
/ (window.len() - 1) as f64;

let deviation = (values[idx] - mean).abs();

deviation > outlier_filter
};

if let Some(rr_time) = opt_rr_time {
// Process both RR intervals and timestamps
rr_intervals
.windows(window_size)
.zip(rr_time.windows(window_size))
.filter_map(|(rr_window, time_window)| {
if predicate(rr_window, window_size) {
Some((rr_window[window_size / 2], time_window[window_size / 2]))
.iter()
.zip(rr_time)
.enumerate()
.filter_map(|(i, (&rr, &time))| {
if !is_outlier(i, rr_intervals) {
Some((rr, time))
} else {
None
}
})
.unzip()
} else {
// Process only RR intervals
(
rr_intervals
.windows(window_size)
.filter_map(|rr_window| {
if predicate(rr_window, window_size) {
Some(rr_window[window_size / 2])
.iter()
.enumerate()
.filter_map(|(i, &rr)| {
if !is_outlier(i, rr_intervals) {
Some(rr)
} else {
None
}
Expand Down Expand Up @@ -337,4 +359,50 @@ mod tests {
runtime.add_measurement(&hr_msg, &Duration::milliseconds(500));
assert!(runtime.has_sufficient_data());
}

#[test]
fn test_hrv_statistics_new() {
let rr_intervals = vec![800.0, 810.0, 790.0, 805.0];
let hr_values = vec![75.0, 76.0, 74.0, 75.5];
let hrv_stats = HrvStatistics::new(&rr_intervals, &hr_values).unwrap();
assert!(hrv_stats.rmssd > 0.0);
assert!(hrv_stats.sdrr > 0.0);
assert!(hrv_stats.sd1 > 0.0);
assert!(hrv_stats.sd2 > 0.0);
assert!(hrv_stats.avg_hr > 0.0);
}

#[test]
fn test_hrv_session_data_from_acquisition() {
let hr_msg = HeartrateMessage::new(&[0b10000, 80, 255, 0]);
let data = vec![
(Duration::milliseconds(0), hr_msg),
(Duration::milliseconds(1000), hr_msg),
(Duration::milliseconds(2000), hr_msg),
(Duration::milliseconds(3000), hr_msg),
];
let session_data = HrvSessionData::from_acquisition(&data, None, 50.0).unwrap();
assert!(session_data.has_sufficient_data());
assert!(session_data.hrv_stats.is_some());
}

#[test]
fn test_apply_outlier_filter() {
let rr_intervals = vec![800.0, 810.0, 790.0, 805.0, 900.0, 805.0, 810.0];
let (filtered_rr, _) = HrvSessionData::apply_outlier_filter(&rr_intervals, None, 50.0, 5);
assert_eq!(filtered_rr.len(), 6); // The outlier (900.0) should be filtered out
}

#[test]
fn test_get_poincare() {
let session_data = HrvSessionData {
rr_intervals: vec![800.0, 810.0, 790.0, 805.0],
..Default::default()
};
let poincare_points = session_data.get_poincare();
assert_eq!(poincare_points.len(), 3);
assert_eq!(poincare_points[0], [800.0, 810.0]);
assert_eq!(poincare_points[1], [810.0, 790.0]);
assert_eq!(poincare_points[2], [790.0, 805.0]);
}
}
79 changes: 77 additions & 2 deletions src/model/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@
use std::sync::Arc;

use mockall::automock;
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::{RwLock, RwLockReadGuard};

use super::acquisition::AcquisitionModelApi;

#[cfg(test)]
use super::acquisition::MockAcquisitionModelApi;

#[cfg(test)]
use mockall::automock;
/// Trait defining the interface for storage models.
///
/// This trait allows for managing a collection of acquisition models,
/// providing methods to access, store, and delete acquisitions.
#[automock(type AcqModelType = MockAcquisitionModelApi;)]
#[cfg_attr(test, automock(type AcqModelType = MockAcquisitionModelApi;))]
pub trait StorageModelApi: Sync + Send {
/// The type of acquisition model being stored, which must implement `AcquisitionModelApi`,
/// `Serialize`, and `DeserializeOwned`.
Expand Down Expand Up @@ -176,3 +179,75 @@ impl<T: ?Sized> ModelHandle<T> {
self.data.blocking_read()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::model::acquisition::MockAcquisitionModelApi as MockAcquisitionModel;
impl Serialize for MockAcquisitionModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let s = serde_json::to_string(self).unwrap();
serializer.serialize_str(&s)
}
}

impl<'a> Deserialize<'a> for MockAcquisitionModel {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
Ok(MockAcquisitionModel::default())
}
}

#[tokio::test]
async fn test_store_acquisition() {
let mut storage = StorageModel::<MockAcquisitionModel>::default();
let acq = Arc::new(RwLock::new(MockAcquisitionModel::default()));

storage.store_acquisition(acq.clone());

assert_eq!(storage.acquisitions.len(), 1);
assert_eq!(storage.handles.len(), 1);
}

#[tokio::test]
async fn test_delete_acquisition() {
let mut storage = StorageModel::<MockAcquisitionModel>::default();
let acq1 = Arc::new(RwLock::new(MockAcquisitionModel::default()));
let acq2 = Arc::new(RwLock::new(MockAcquisitionModel::default()));

storage.store_acquisition(acq1.clone());
storage.store_acquisition(acq2.clone());

storage.delete_acquisition(0);

assert_eq!(storage.acquisitions.len(), 1);
assert_eq!(storage.handles.len(), 1);
}

#[tokio::test]
async fn test_get_acquisitions() {
let mut storage = StorageModel::<MockAcquisitionModel>::default();
let acq = Arc::new(RwLock::new(MockAcquisitionModel::default()));

storage.store_acquisition(acq.clone());

let acquisitions = storage.get_acquisitions();
assert_eq!(acquisitions.len(), 1);
}

#[tokio::test]
async fn test_get_mut_acquisitions() {
let mut storage = StorageModel::<MockAcquisitionModel>::default();
let acq = Arc::new(RwLock::new(MockAcquisitionModel::default()));

storage.store_acquisition(acq.clone());

let acquisitions = storage.get_mut_acquisitions();
assert_eq!(acquisitions.len(), 1);
}
}

0 comments on commit 3d67281

Please sign in to comment.