From 626307addc06bfb9f56e843bd0a7ae9e41de65f6 Mon Sep 17 00:00:00 2001 From: Tom Pollard Date: Tue, 9 Jul 2024 15:58:43 -0400 Subject: [PATCH] Add test for csv_to_wfdb(). --- tests/io/test_convert.py | 71 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/tests/io/test_convert.py b/tests/io/test_convert.py index aa7ba78a..cac9a7a2 100644 --- a/tests/io/test_convert.py +++ b/tests/io/test_convert.py @@ -1,14 +1,21 @@ +import os +import shutil +import unittest + import numpy as np from wfdb.io.record import rdrecord from wfdb.io.convert.edf import read_edf +from wfdb.io.convert.csv import csv_to_wfdb -class TestConvert: +class TestEdfToWfdb: + """ + Tests for the io.convert.edf module. + """ def test_edf_uniform(self): """ EDF format conversion to MIT for uniform sample rates. - """ # Uniform sample rates record_MIT = rdrecord("sample-data/n16").__dict__ @@ -60,7 +67,6 @@ def test_edf_uniform(self): def test_edf_non_uniform(self): """ EDF format conversion to MIT for non-uniform sample rates. - """ # Non-uniform sample rates record_MIT = rdrecord("sample-data/wave_4").__dict__ @@ -108,3 +114,62 @@ def test_edf_non_uniform(self): target_results = len(fields) * [True] assert np.array_equal(test_results, target_results) + + +class TestCsvToWfdb(unittest.TestCase): + """ + Tests for the io.convert.csv module. + """ + def setUp(self): + """ + Create a temporary directory containing data for testing. + + Load 100.dat file for comparison to 100.csv file. + """ + self.test_dir = 'test_output' + os.makedirs(self.test_dir, exist_ok=True) + + self.record_100_csv = 'sample-data/100.csv' + self.record_100_dat = rdrecord('sample-data/100', physical=True) + + def tearDown(self): + """ + Remove the temporary directory after the test. + """ + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_write_dir(self): + """ + Call the function with the write_dir argument. + """ + csv_to_wfdb( + file_name=self.record_100_csv, + fs=360, + units='mV', + write_dir=self.test_dir + ) + + # Check if the output files are created in the specified directory + base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0] + expected_dat_file = os.path.join(self.test_dir, f'{base_name}.dat') + expected_hea_file = os.path.join(self.test_dir, f'{base_name}.hea') + + self.assertTrue(os.path.exists(expected_dat_file)) + self.assertTrue(os.path.exists(expected_hea_file)) + + # Check that newly written file matches the 100.dat file + record_write = rdrecord(os.path.join(self.test_dir, base_name)) + + self.assertEqual(record_write.fs, 360) + self.assertEqual(record_write.fs, self.record_100_dat.fs) + self.assertEqual(record_write.units, ['mV', 'mV']) + self.assertEqual(record_write.units, self.record_100_dat.units) + self.assertEqual(record_write.sig_name, ['MLII', 'V5']) + self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name) + self.assertEqual(record_write.p_signal.size, 1300000) + self.assertEqual(record_write.p_signal.size, self.record_100_dat.p_signal.size) + + +if __name__ == '__main__': + unittest.main()