Skip to content

Commit

Permalink
Merge pull request #1054 from kitzeslab/issue_942_wrong_index
Browse files Browse the repository at this point in the history
Fix index checking
  • Loading branch information
sammlapp authored Sep 12, 2024
2 parents 270dcd4 + 19b9645 commit b4aae3e
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 11 deletions.
11 changes: 5 additions & 6 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
from opensoundscape.preprocess import io
from opensoundscape.ml.datasets import AudioFileDataset
from opensoundscape.ml.cnn_architectures import inception_v3
from opensoundscape.ml.loss import ResampleLoss
from opensoundscape.sample import collate_audio_samples, collate_audio_samples_to_dict
from opensoundscape.sample import collate_audio_samples
from opensoundscape.utils import identity
from opensoundscape.logging import wandb_table

from opensoundscape.ml.cam import CAM
import pytorch_grad_cam
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from lightning.pytorch.callbacks import ModelCheckpoint


from torchmetrics.classification import (
Expand Down Expand Up @@ -1450,6 +1448,7 @@ def train(
"""

### Input Validation ###
# Validation of class list
check_labels(train_df, self.classes)
if validation_df is not None:
check_labels(validation_df, self.classes)
Expand All @@ -1461,6 +1460,8 @@ def train(
"evaluated using the performance on the training set."
)

## Initialization ##

# Initialize attributes
self.log_interval = log_interval
self.save_interval = save_interval
Expand Down Expand Up @@ -2226,9 +2227,7 @@ class BaseClassifier(SpectrogramClassifier):
"""


def use_resample_loss(
model, train_df
): # TODO revisit how this work. Should be able to set loss_cls=ResampleLoss()
def use_resample_loss(model, train_df):
"""Modify a model to use ResampleLoss for multi-target training
ResampleLoss may perform better than BCE Loss for multitarget problems
Expand Down
26 changes: 23 additions & 3 deletions opensoundscape/ml/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import numpy as np
import pandas as pd
import warnings
from pathlib import Path

from opensoundscape.utils import identity
from opensoundscape.utils import identity, _check_is_path
from opensoundscape.ml.safe_dataset import SafeDataset
from opensoundscape.ml.datasets import AudioFileDataset, AudioSplittingDataset
from opensoundscape.annotations import CategoricalLabels
Expand Down Expand Up @@ -97,6 +98,18 @@ def __init__(
), "Cannot specify both overlap_fraction and clip_overlap_fraction"
clip_overlap_fraction = overlap_fraction

# validate that file paths are correctly placed in the input index or list
if len(samples) > 0:
if isinstance(samples, pd.DataFrame): # samples is a pd.DataFrame
if isinstance(samples.index, pd.core.indexes.multi.MultiIndex):
# index is (file, start_time, end_time)
first_path = samples.index.values[0][0]
else: # index of df is just file path
first_path = samples.index.values[0]
else: # samples is a list of file path
first_path = samples[0]
_check_is_path(first_path)

# set up prediction Dataset, considering three possible cases:
# (c1) user provided multi-index df with file,start_time,end_time of clips
# (c2) user provided file list and wants clips to be split out automatically
Expand All @@ -105,8 +118,12 @@ def __init__(
type(samples) == pd.DataFrame
and type(samples.index) == pd.core.indexes.multi.MultiIndex
): # c1 user provided multi-index df with file,start_time,end_time of clips
# raise AssertionError if first item of multi-index is not str or Path
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)
elif split_files_into_clips: # c2 user provided file list; split into
elif (
split_files_into_clips
): # c2 user provided file list; split each file into appropriate length clips
# raise AssertionError if first item is not str or Path
dataset = AudioSplittingDataset(
samples=samples,
preprocessor=preprocessor,
Expand All @@ -115,8 +132,11 @@ def __init__(
clip_step=clip_step,
final_clip=final_clip,
)
else: # c3 split_files_into_clips=False -> one sample & one prediction per file provided
else: # c3 samples is list of files and
# split_files_into_clips=False -> one sample & one prediction per file provided
# eg, each file is a 5 second clips and the model expects 5 second clips
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)

dataset.bypass_augmentations = bypass_augmentations

if len(dataset) < 1:
Expand Down
8 changes: 8 additions & 0 deletions opensoundscape/preprocess/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def forward(
sample (instance of AudioSample class)
"""
# validate input
if break_on_key is not None:
assert (
break_on_key in self.pipeline
Expand Down Expand Up @@ -256,6 +257,13 @@ def _generate_sample(self, sample):
sample = AudioSample(path, start_time=start, duration=self.sample_duration)
elif isinstance(sample, pd.Series):
sample = AudioSample.from_series(sample)
elif isinstance(sample, pd.DataFrame):
raise AssertionError(
"sample must be AudioSample, tuple of (path, start_time), "
"or pd.Series with (path, start_time, end_time) as .name. "
f"was {type(sample)}. "
"Perhaps a dataset was accessed like dataset[[0,1,2]] instead of dataset[0]?"
)
else:
assert isinstance(sample, AudioSample), (
"sample must be AudioSample, tuple of (path, start_time), "
Expand Down
8 changes: 6 additions & 2 deletions opensoundscape/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from pathlib import Path
import numpy as np
import pandas as pd
import pytz
import soundfile
import librosa
from matplotlib.colors import LinearSegmentedColormap
import torch
Expand Down Expand Up @@ -390,3 +388,9 @@ def set_seed(seed, verbose=False):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def _check_is_path(path):
assert isinstance(path, str) or isinstance(
path, Path
), f"Expected str or Path, got {type(path)}. Did you set the index correctly?"
21 changes: 21 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,27 @@ def test_train_on_clip_df(train_df):
)


def test_train_bad_index(train_df):
"""
AssertionError catches case where index is not one of the allowed formats
"""
model = cnn.CNN("resnet18", [0, 1], sample_duration=2)
# reset the index so that train_df index is integers (not an allowed format)
train_df = make_clip_df(train_df.index.values, clip_duration=2).reset_index()
train_df[0] = np.random.choice([0, 1], size=10)
train_df[1] = np.random.choice([0, 1], size=10)
with pytest.raises(AssertionError):
model.train(
train_df,
train_df,
save_path="tests/models/",
epochs=1,
batch_size=2,
save_interval=10,
num_workers=0,
)


def test_predict_without_splitting(test_df):
model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0)
scores = model.predict(test_df, split_files_into_clips=False)
Expand Down
79 changes: 79 additions & 0 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import numpy as np
import pandas as pd
from opensoundscape.preprocess.preprocessors import SpectrogramPreprocessor
from opensoundscape.ml.dataloaders import SafeAudioDataloader


@pytest.fixture()
def dataset_df():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=paths, data=labels, columns=[0, 1])


@pytest.fixture()
def bad_dataset_df():
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1])


@pytest.fixture()
def dataset_df_multiindex():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
start_times = [0, 1]
end_times = [1, 2]
return pd.DataFrame(
{
"file": paths,
"start_time": start_times,
"end_time": end_times,
"A": [0, 1],
"B": [1, 0],
}
).set_index(["file", "start_time", "end_time"])


@pytest.fixture()
def bad_dataset_df_multiindex():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
start_times = [0, 1]
end_times = [1, 2]
return pd.DataFrame(
{
"file": paths,
"start_time": start_times,
"end_time": end_times,
"A": [0, 1],
"B": [1, 0],
}
) # .set_index(["file", "start_time", "end_time"])


@pytest.fixture()
def bad_dataset_df():
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1])


@pytest.fixture()
def pre():
return SpectrogramPreprocessor(sample_duration=1)


def test_helpful_error_if_index_is_integer(bad_dataset_df, pre):
with pytest.raises(AssertionError):
SafeAudioDataloader(bad_dataset_df, pre)


def test_init(dataset_df, pre):
SafeAudioDataloader(dataset_df, pre)


def test_init_multiindex(dataset_df, pre):
SafeAudioDataloader(dataset_df, pre)


def test_catch_index_not_set(bad_dataset_df_multiindex, pre):
with pytest.raises(AssertionError):
SafeAudioDataloader(bad_dataset_df_multiindex, pre)
16 changes: 16 additions & 0 deletions tests/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,22 @@ def test_trace_output(preprocessor, sample):
assert isinstance(sample.trace["load_audio"], Audio)


def test_catch_input_to_forward_is_dataframe(preprocessor):
# raises AssertionError if samples arg in forward is not pd.Series
# which might occur if user accessed a dataset with a list instead of a single index
# see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803
with pytest.raises(AssertionError):
preprocessor.forward(pd.DataFrame({0: ["a"]}))


def test_catch_input_to_forward_is_dataframe(preprocessor):
# raises AssertionError if samples arg in forward is not pd.Series
# which might occur if user accessed a dataset with a list instead of a single index
# see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803
with pytest.raises(AssertionError):
preprocessor.forward(pd.DataFrame({0: ["a"]}))


def test_audiopreprocessor(audiopreprocessor, sample):
"""should retain original sample rate"""
s = audiopreprocessor.forward(sample).data
Expand Down

0 comments on commit b4aae3e

Please sign in to comment.