Skip to content

Commit

Permalink
Merge pull request #97 from delta-rs/62-fix-unit-tests-in-mnistrs-cif…
Browse files Browse the repository at this point in the history
…ar10rs-and-imagenet_v2rs

resolves #62 add mnist and cifar10 tests
  • Loading branch information
mjovanc authored Dec 11, 2024
2 parents 19f27b4 + 40e94f3 commit 2795707
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 193 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ resolver = "2"

[workspace.dependencies]
tokio = { version = "1.32.0", features = ["full"] }
ndarray = "0.15"
181 changes: 105 additions & 76 deletions delta/src/dataset/image/cifar10.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@
//! OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//! OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
use crate::common::{Tensor};
use crate::common::Tensor;
use crate::dataset::base::{Dataset, ImageDatasetOps};
use crate::get_workspace_dir;
use flate2::read::GzDecoder;
use log::debug;
use ndarray::{IxDyn, Shape};
use std::collections::HashSet;
use std::fs;
use std::fs::File;
use std::future::Future;
use std::io::Read;
use std::path::Path;
use std::pin::Pin;
use ndarray::{IxDyn, Shape};
use tar::Archive;
use crate::dataset::base::{Dataset, ImageDatasetOps};
use crate::get_workspace_dir;

/// A struct representing the CIFAR10 dataset.
pub struct Cifar10Dataset {
Expand Down Expand Up @@ -155,7 +155,11 @@ impl Cifar10Dataset {

for &file in files {
let (img, lbl) = Self::parse_file(
&format!("{}/.cache/dataset/cifar10/cifar-10-batches-bin/{}", env!("CARGO_MANIFEST_DIR"), file),
&format!(
"{}/.cache/dataset/cifar10/{}",
get_workspace_dir().display(),
file
),
total_examples / files.len(),
);
images.extend(img);
Expand All @@ -172,7 +176,10 @@ impl Cifar10Dataset {
3,
])),
),
Tensor::new(labels, Shape::from(IxDyn(&[total_examples, Self::CIFAR10_NUM_CLASSES]))),
Tensor::new(
labels,
Shape::from(IxDyn(&[total_examples, Self::CIFAR10_NUM_CLASSES])),
),
)
}

Expand Down Expand Up @@ -397,77 +404,99 @@ impl ImageDatasetOps for Cifar10Dataset {
Self {
train: self.train.clone(),
test: self.test.clone(),
val: self.val.clone()
val: self.val.clone(),
}
}
}

// #[cfg(test)]
// mod tests {
// use super::*;
// use serial_test::serial;
// use tokio::runtime::Runtime;
//
// fn setup() {
// let workspace_dir = get_workspace_dir();
// let cache_path = format!("{}/.cache/dataset/cifar10", workspace_dir.display());
// if Path::new(&cache_path).exists() {
// fs::remove_dir_all(&cache_path).expect("Failed to delete cache directory");
// }
// }
//
// #[test]
// #[serial]
// fn test_download_and_extract() {
// setup();
// let rt = Runtime::new().unwrap();
// rt.block_on(async {
// Cifar10Dataset::download_and_extract().await;
// let workspace_dir = get_workspace_dir();
// let cache_path = format!("{}/.cache/dataset/cifar10/cifar-10-binary", workspace_dir.display());
// assert!(Path::new(&cache_path).exists(), "CIFAR-10 dataset should be downloaded and extracted");
// });
// }
//
// #[test]
// #[serial]
// fn test_parse_file() {
// // Ensure the dataset is downloaded before parsing
// test_download_and_extract();
//
// let (images, labels) = Cifar10Dataset::parse_file("path/to/data_batch_1.bin", 10000);
// assert_eq!(images.len(), 10000 * 32 * 32 * 3, "Images should have the correct length");
// assert_eq!(labels.len(), 10000 * 10, "Labels should have the correct length");
// }
//
// #[test]
// #[serial]
// fn test_load_data() {
// // Ensure the dataset is downloaded before loading data
// test_download_and_extract();
//
// let dataset = Cifar10Dataset::load_data(&["data_batch_1.bin"], 10000);
// assert_eq!(dataset.inputs.shape(), &[10000, 32, 32, 3], "Dataset inputs should have the correct shape");
// assert_eq!(dataset.labels.shape(), &[10000, 10], "Dataset labels should have the correct shape");
// }
//
// #[test]
// #[serial]
// fn test_load_train() {
// let rt = Runtime::new().unwrap();
// rt.block_on(async {
// let dataset = Cifar10Dataset::load_train().await;
// assert!(dataset.train.is_some(), "Training dataset should be loaded");
// });
// }
//
// #[test]
// #[serial]
// fn test_load_test() {
// let rt = Runtime::new().unwrap();
// rt.block_on(async {
// let dataset = Cifar10Dataset::load_test().await;
// assert!(dataset.test.is_some(), "Test dataset should be loaded");
// });
// }
// }
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Dimension;
use serial_test::serial;

fn setup() {
let workspace_dir = get_workspace_dir();
let cache_path = format!("{}/.cache/dataset/cifar10", workspace_dir.display());
if Path::new(&cache_path).exists() {
fs::remove_dir_all(&cache_path).expect("Failed to delete cache directory");
}
}

#[tokio::test]
#[serial]
async fn test_download_and_extract() {
setup();
Cifar10Dataset::download_and_extract().await;
let workspace_dir = get_workspace_dir();
let cache_path = format!(
"{}/.cache/dataset/cifar10/data_batch_1.bin",
workspace_dir.display()
);
assert!(
Path::new(&cache_path).exists(),
"CIFAR-10 dataset should be downloaded and extracted"
);
}

#[test]
#[serial]
fn test_parse_file() {
// Ensure the dataset is downloaded before parsing
test_download_and_extract();
let workspace_dir = get_workspace_dir();
let cache_path = format!(
"{}/.cache/dataset/cifar10/data_batch_1.bin",
workspace_dir.display()
);

let (images, labels) = Cifar10Dataset::parse_file(&cache_path, 10000);
assert_eq!(
images.len(),
10000 * 32 * 32 * 3,
"Images should have the correct length"
);
assert_eq!(
labels.len(),
10000 * 10,
"Labels should have the correct length"
);
}

#[test]
#[serial]
fn test_load_data() {
// Ensure the dataset is downloaded before loading data
test_download_and_extract();

let dataset = Cifar10Dataset::load_data(&["data_batch_1.bin"], 10000);

// Compare the shape of inputs
assert_eq!(
dataset.inputs.shape().raw_dim().as_array_view().to_vec(),
&[10000, 32, 32, 3],
"Dataset inputs should have the correct shape"
);

// Compare the shape of labels
assert_eq!(
dataset.labels.shape().raw_dim().as_array_view().to_vec(),
&[10000, 10],
"Dataset labels should have the correct shape"
);
}

#[tokio::test]
#[serial]
async fn test_load_train() {
let dataset = Cifar10Dataset::load_train().await;
assert!(dataset.train.is_some(), "Training dataset should be loaded");
}

#[tokio::test]
#[serial]
async fn test_load_test() {
let dataset = Cifar10Dataset::load_test().await;
assert!(dataset.test.is_some(), "Test dataset should be loaded");
}
}
Loading

0 comments on commit 2795707

Please sign in to comment.