Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index checking #1054

Merged
merged 9 commits into from
Sep 12, 2024
11 changes: 5 additions & 6 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
from opensoundscape.preprocess import io
from opensoundscape.ml.datasets import AudioFileDataset
from opensoundscape.ml.cnn_architectures import inception_v3
from opensoundscape.ml.loss import ResampleLoss
from opensoundscape.sample import collate_audio_samples, collate_audio_samples_to_dict
from opensoundscape.sample import collate_audio_samples
from opensoundscape.utils import identity
from opensoundscape.logging import wandb_table

from opensoundscape.ml.cam import CAM
import pytorch_grad_cam
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from lightning.pytorch.callbacks import ModelCheckpoint


from torchmetrics.classification import (
Expand Down Expand Up @@ -1408,6 +1406,7 @@ def train(
"""

### Input Validation ###
# Validation of class list
check_labels(train_df, self.classes)
if validation_df is not None:
check_labels(validation_df, self.classes)
Expand All @@ -1419,6 +1418,8 @@ def train(
"evaluated using the performance on the training set."
)

## Initialization ##

# Initialize attributes
self.log_interval = log_interval
self.save_interval = save_interval
Expand Down Expand Up @@ -2184,9 +2185,7 @@ class BaseClassifier(SpectrogramClassifier):
"""


def use_resample_loss(
model, train_df
): # TODO revisit how this work. Should be able to set loss_cls=ResampleLoss()
def use_resample_loss(model, train_df):
"""Modify a model to use ResampleLoss for multi-target training

ResampleLoss may perform better than BCE Loss for multitarget problems
Expand Down
26 changes: 23 additions & 3 deletions opensoundscape/ml/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import numpy as np
import pandas as pd
import warnings
from pathlib import Path

from opensoundscape.utils import identity
from opensoundscape.utils import identity, _check_is_path
from opensoundscape.ml.safe_dataset import SafeDataset
from opensoundscape.ml.datasets import AudioFileDataset, AudioSplittingDataset
from opensoundscape.annotations import CategoricalLabels
Expand Down Expand Up @@ -97,6 +98,18 @@ def __init__(
), "Cannot specify both overlap_fraction and clip_overlap_fraction"
clip_overlap_fraction = overlap_fraction

# validate that file paths are correctly placed in the input index or list
if len(samples) > 0:
if isinstance(samples, pd.DataFrame): # samples is a pd.DataFrame
if isinstance(samples.index, pd.core.indexes.multi.MultiIndex):
# index is (file, start_time, end_time)
first_path = samples.index.values[0][0]
else: # index of df is just file path
first_path = samples.index.values[0]
else: # samples is a list of file path
first_path = samples[0]
_check_is_path(first_path)

# set up prediction Dataset, considering three possible cases:
# (c1) user provided multi-index df with file,start_time,end_time of clips
# (c2) user provided file list and wants clips to be split out automatically
Expand All @@ -105,8 +118,12 @@ def __init__(
type(samples) == pd.DataFrame
and type(samples.index) == pd.core.indexes.multi.MultiIndex
): # c1 user provided multi-index df with file,start_time,end_time of clips
# raise AssertionError if first item of multi-index is not str or Path
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)
elif split_files_into_clips: # c2 user provided file list; split into
elif (
split_files_into_clips
): # c2 user provided file list; split each file into appropriate length clips
# raise AssertionError if first item is not str or Path
dataset = AudioSplittingDataset(
samples=samples,
preprocessor=preprocessor,
Expand All @@ -115,8 +132,11 @@ def __init__(
clip_step=clip_step,
final_clip=final_clip,
)
else: # c3 split_files_into_clips=False -> one sample & one prediction per file provided
else: # c3 samples is list of files and
# split_files_into_clips=False -> one sample & one prediction per file provided
# eg, each file is a 5 second clips and the model expects 5 second clips
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)

dataset.bypass_augmentations = bypass_augmentations

if len(dataset) < 1:
Expand Down
8 changes: 8 additions & 0 deletions opensoundscape/preprocess/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def forward(
sample (instance of AudioSample class)

"""
# validate input
if break_on_key is not None:
assert (
break_on_key in self.pipeline
Expand Down Expand Up @@ -256,6 +257,13 @@ def _generate_sample(self, sample):
sample = AudioSample(path, start_time=start, duration=self.sample_duration)
elif isinstance(sample, pd.Series):
sample = AudioSample.from_series(sample)
elif isinstance(sample, pd.DataFrame):
raise AssertionError(
"sample must be AudioSample, tuple of (path, start_time), "
"or pd.Series with (path, start_time, end_time) as .name. "
f"was {type(sample)}. "
"Perhaps a dataset was accessed like dataset[[0,1,2]] instead of dataset[0]?"
)
else:
assert isinstance(sample, AudioSample), (
"sample must be AudioSample, tuple of (path, start_time), "
Expand Down
8 changes: 6 additions & 2 deletions opensoundscape/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from pathlib import Path
import numpy as np
import pandas as pd
import pytz
import soundfile
import librosa
from matplotlib.colors import LinearSegmentedColormap
import torch
Expand Down Expand Up @@ -390,3 +388,9 @@ def set_seed(seed, verbose=False):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def _check_is_path(path):
assert isinstance(path, str) or isinstance(
path, Path
), f"Expected str or Path, got {type(path)}. Did you set the index correctly?"
21 changes: 21 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,27 @@ def test_train_on_clip_df(train_df):
)


def test_train_bad_index(train_df):
"""
AssertionError catches case where index is not one of the allowed formats
"""
model = cnn.CNN("resnet18", [0, 1], sample_duration=2)
# reset the index so that train_df index is integers (not an allowed format)
train_df = make_clip_df(train_df.index.values, clip_duration=2).reset_index()
train_df[0] = np.random.choice([0, 1], size=10)
train_df[1] = np.random.choice([0, 1], size=10)
with pytest.raises(AssertionError):
model.train(
train_df,
train_df,
save_path="tests/models/",
epochs=1,
batch_size=2,
save_interval=10,
num_workers=0,
)


def test_predict_without_splitting(test_df):
model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0)
scores = model.predict(test_df, split_files_into_clips=False)
Expand Down
79 changes: 79 additions & 0 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import numpy as np
import pandas as pd
from opensoundscape.preprocess.preprocessors import SpectrogramPreprocessor
from opensoundscape.ml.dataloaders import SafeAudioDataloader


@pytest.fixture()
def dataset_df():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=paths, data=labels, columns=[0, 1])


@pytest.fixture()
def bad_dataset_df():
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1])


@pytest.fixture()
def dataset_df_multiindex():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
start_times = [0, 1]
end_times = [1, 2]
return pd.DataFrame(
{
"file": paths,
"start_time": start_times,
"end_time": end_times,
"A": [0, 1],
"B": [1, 0],
}
).set_index(["file", "start_time", "end_time"])


@pytest.fixture()
def bad_dataset_df_multiindex():
paths = ["tests/audio/silence_10s.mp3", "tests/audio/silence_10s.mp3"]
start_times = [0, 1]
end_times = [1, 2]
return pd.DataFrame(
{
"file": paths,
"start_time": start_times,
"end_time": end_times,
"A": [0, 1],
"B": [1, 0],
}
) # .set_index(["file", "start_time", "end_time"])


@pytest.fixture()
def bad_dataset_df():
labels = [[1, 0], [0, 1]]
return pd.DataFrame(index=range(len(labels)), data=labels, columns=[0, 1])


@pytest.fixture()
def pre():
return SpectrogramPreprocessor(sample_duration=1)


def test_helpful_error_if_index_is_integer(bad_dataset_df, pre):
with pytest.raises(AssertionError):
SafeAudioDataloader(bad_dataset_df, pre)


def test_init(dataset_df, pre):
SafeAudioDataloader(dataset_df, pre)


def test_init_multiindex(dataset_df, pre):
SafeAudioDataloader(dataset_df, pre)


def test_catch_index_not_set(bad_dataset_df_multiindex, pre):
with pytest.raises(AssertionError):
SafeAudioDataloader(bad_dataset_df_multiindex, pre)
16 changes: 16 additions & 0 deletions tests/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,22 @@ def test_trace_output(preprocessor, sample):
assert isinstance(sample.trace["load_audio"], Audio)


def test_catch_input_to_forward_is_dataframe(preprocessor):
# raises AssertionError if samples arg in forward is not pd.Series
# which might occur if user accessed a dataset with a list instead of a single index
# see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803
with pytest.raises(AssertionError):
preprocessor.forward(pd.DataFrame({0: ["a"]}))


def test_catch_input_to_forward_is_dataframe(preprocessor):
# raises AssertionError if samples arg in forward is not pd.Series
# which might occur if user accessed a dataset with a list instead of a single index
# see issue 803: https://github.com/kitzeslab/opensoundscape/issues/803
with pytest.raises(AssertionError):
preprocessor.forward(pd.DataFrame({0: ["a"]}))


def test_audiopreprocessor(audiopreprocessor, sample):
"""should retain original sample rate"""
s = audiopreprocessor.forward(sample).data
Expand Down
Loading