diff --git a/opensoundscape/ml/cnn.py b/opensoundscape/ml/cnn.py index 7a25c01c..6780b8cf 100644 --- a/opensoundscape/ml/cnn.py +++ b/opensoundscape/ml/cnn.py @@ -27,15 +27,13 @@ from opensoundscape.preprocess import io from opensoundscape.ml.datasets import AudioFileDataset from opensoundscape.ml.cnn_architectures import inception_v3 -from opensoundscape.ml.loss import ResampleLoss -from opensoundscape.sample import collate_audio_samples, collate_audio_samples_to_dict +from opensoundscape.sample import collate_audio_samples from opensoundscape.utils import identity from opensoundscape.logging import wandb_table from opensoundscape.ml.cam import CAM import pytorch_grad_cam from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget -from lightning.pytorch.callbacks import ModelCheckpoint from torchmetrics.classification import ( @@ -1408,6 +1406,7 @@ def train( """ ### Input Validation ### + # Validation of class list check_labels(train_df, self.classes) if validation_df is not None: check_labels(validation_df, self.classes) @@ -1419,6 +1418,8 @@ def train( "evaluated using the performance on the training set." ) + ## Initialization ## + # Initialize attributes self.log_interval = log_interval self.save_interval = save_interval @@ -2184,9 +2185,7 @@ class BaseClassifier(SpectrogramClassifier): """ -def use_resample_loss( - model, train_df -): # TODO revisit how this work. Should be able to set loss_cls=ResampleLoss() +def use_resample_loss(model, train_df): """Modify a model to use ResampleLoss for multi-target training ResampleLoss may perform better than BCE Loss for multitarget problems diff --git a/opensoundscape/ml/dataloaders.py b/opensoundscape/ml/dataloaders.py index abf318cd..4ed3e164 100644 --- a/opensoundscape/ml/dataloaders.py +++ b/opensoundscape/ml/dataloaders.py @@ -2,8 +2,9 @@ import numpy as np import pandas as pd import warnings +from pathlib import Path -from opensoundscape.utils import identity +from opensoundscape.utils import identity, _check_is_path from opensoundscape.ml.safe_dataset import SafeDataset from opensoundscape.ml.datasets import AudioFileDataset, AudioSplittingDataset from opensoundscape.annotations import CategoricalLabels @@ -97,6 +98,18 @@ def __init__( ), "Cannot specify both overlap_fraction and clip_overlap_fraction" clip_overlap_fraction = overlap_fraction + # validate that file paths are correctly placed in the input index or list + if len(samples) > 0: + if isinstance(samples, pd.DataFrame): # samples is a pd.DataFrame + if isinstance(samples.index, pd.core.indexes.multi.MultiIndex): + # index is (file, start_time, end_time) + first_path = samples.index.values[0][0] + else: # index of df is just file path + first_path = samples.index.values[0] + else: # samples is a list of file path + first_path = samples[0] + _check_is_path(first_path) + # set up prediction Dataset, considering three possible cases: # (c1) user provided multi-index df with file,start_time,end_time of clips # (c2) user provided file list and wants clips to be split out automatically @@ -105,8 +118,12 @@ def __init__( type(samples) == pd.DataFrame and type(samples.index) == pd.core.indexes.multi.MultiIndex ): # c1 user provided multi-index df with file,start_time,end_time of clips + # raise AssertionError if first item of multi-index is not str or Path dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor) - elif split_files_into_clips: # c2 user provided file list; split into + elif ( + split_files_into_clips + ): # c2 user provided file list; split each file into appropriate length clips + # raise AssertionError if first item is not str or Path dataset = AudioSplittingDataset( samples=samples, preprocessor=preprocessor, @@ -115,8 +132,11 @@ def __init__( clip_step=clip_step, final_clip=final_clip, ) - else: # c3 split_files_into_clips=False -> one sample & one prediction per file provided + else: # c3 samples is list of files and + # split_files_into_clips=False -> one sample & one prediction per file provided + # eg, each file is a 5 second clips and the model expects 5 second clips dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor) + dataset.bypass_augmentations = bypass_augmentations if len(dataset) < 1: diff --git a/opensoundscape/preprocess/preprocessors.py b/opensoundscape/preprocess/preprocessors.py index fcfc657b..e477331c 100644 --- a/opensoundscape/preprocess/preprocessors.py +++ b/opensoundscape/preprocess/preprocessors.py @@ -169,6 +169,7 @@ def forward( sample (instance of AudioSample class) """ + # validate input if break_on_key is not None: assert ( break_on_key in self.pipeline @@ -256,6 +257,13 @@ def _generate_sample(self, sample): sample = AudioSample(path, start_time=start, duration=self.sample_duration) elif isinstance(sample, pd.Series): sample = AudioSample.from_series(sample) + elif isinstance(sample, pd.DataFrame): + raise AssertionError( + "sample must be AudioSample, tuple of (path, start_time), " + "or pd.Series with (path, start_time, end_time) as .name. " + f"was {type(sample)}. " + "Perhaps a dataset was accessed like dataset[[0,1,2]] instead of dataset[0]?" + ) else: assert isinstance(sample, AudioSample), ( "sample must be AudioSample, tuple of (path, start_time), " diff --git a/opensoundscape/utils.py b/opensoundscape/utils.py index eb1105e2..cd603b7a 100644 --- a/opensoundscape/utils.py +++ b/opensoundscape/utils.py @@ -6,8 +6,6 @@ from pathlib import Path import numpy as np import pandas as pd -import pytz -import soundfile import librosa from matplotlib.colors import LinearSegmentedColormap import torch @@ -390,3 +388,9 @@ def set_seed(seed, verbose=False): random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +def _check_is_path(path): + assert isinstance(path, str) or isinstance( + path, Path + ), f"Expected str or Path, got {type(path)}. Did you set the index correctly?" diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 6722b11d..8c5cb857 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -472,6 +472,27 @@ def test_train_on_clip_df(train_df): ) +def test_train_bad_index(train_df): + """ + AssertionError catches case where index is not one of the allowed formats + """ + model = cnn.CNN("resnet18", [0, 1], sample_duration=2) + # reset the index so that train_df index is integers (not an allowed format) + train_df = make_clip_df(train_df.index.values, clip_duration=2).reset_index() + train_df[0] = np.random.choice([0, 1], size=10) + train_df[1] = np.random.choice([0, 1], size=10) + with pytest.raises(AssertionError): + model.train( + train_df, + train_df, + save_path="tests/models/", + epochs=1, + batch_size=2, + save_interval=10, + num_workers=0, + ) + + def test_predict_without_splitting(test_df): model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0) scores = model.predict(test_df, split_files_into_clips=False) diff --git a/tests/test_dataloaders.py b/tests/test_dataloaders.py new file mode 100644 index 00000000..62b750d0 --- /dev/null +++ b/tests/test_dataloaders.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +import pandas as pd +from opensoundscape.preprocess.preprocessors import SpectrogramPreprocessor +from opensoundscape.ml.dataloaders import SafeAudioDataloader + + +@pytest.fixture() +def dataset_df(): + paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"] + labels = [[1, 0], [0, 1]] + return pd.DataFrame(index=paths, data=labels, columns=[0, 1]) + + +@pytest.fixture() +def bad_dataset_df(): + labels = [[1, 0], [0, 1]] + return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1]) + + +@pytest.fixture() +def dataset_df_multiindex(): + paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"] + start_times = [0, 1] + end_times = [1, 2] + return pd.DataFrame( + { + "file": paths, + "start_time": start_times, + "end_time": end_times, + "A": [0, 1], + "B": [1, 0], + } + ).set_index(["file", "start_time", "end_time"]) + + +@pytest.fixture() +def bad_dataset_df_multiindex(): + paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"] + start_times = [0, 1] + end_times = [1, 2] + return pd.DataFrame( + { + "file": paths, + "start_time": start_times, + "end_time": end_times, + "A": [0, 1], + "B": [1, 0], + } + ) # .set_index(["file", "start_time", "end_time"]) + + +@pytest.fixture() +def bad_dataset_df(): + labels = [[1, 0], [0, 1]] + return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1]) + + +@pytest.fixture() +def pre(): + return SpectrogramPreprocessor(sample_duration=1) + + +def test_helpful_error_if_index_is_integer(bad_dataset_df, pre): + with pytest.raises(AssertionError): + SafeAudioDataloader(bad_dataset_df, pre) + + +def test_init(dataset_df, pre): + SafeAudioDataloader(dataset_df, pre) + + +def test_init_multiindex(dataset_df, pre): + SafeAudioDataloader(dataset_df, pre) + + +def test_catch_index_not_set(bad_dataset_df_multiindex, pre): + with pytest.raises(AssertionError): + SafeAudioDataloader(bad_dataset_df_multiindex, pre) diff --git a/tests/test_preprocessors.py b/tests/test_preprocessors.py index 2ad3439d..afdfe4da 100644 --- a/tests/test_preprocessors.py +++ b/tests/test_preprocessors.py @@ -186,6 +186,22 @@ def test_trace_output(preprocessor, sample): assert isinstance(sample.trace["load_audio"], Audio) +def test_catch_input_to_forward_is_dataframe(preprocessor): + # raises AssertionError if samples arg in forward is not pd.Series + # which might occur if user accessed a dataset with a list instead of a single index + # see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803 + with pytest.raises(AssertionError): + preprocessor.forward(pd.DataFrame({0: ["a"]})) + + +def test_catch_input_to_forward_is_dataframe(preprocessor): + # raises AssertionError if samples arg in forward is not pd.Series + # which might occur if user accessed a dataset with a list instead of a single index + # see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803 + with pytest.raises(AssertionError): + preprocessor.forward(pd.DataFrame({0: ["a"]})) + + def test_audiopreprocessor(audiopreprocessor, sample): """should retain original sample rate""" s = audiopreprocessor.forward(sample).data