Skip to content

Commit

Permalink
fix input validation for SafeAudioDataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
sammlapp committed Jan 24, 2024
1 parent 212e7d4 commit b89d4aa
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions opensoundscape/ml/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def __init__(
"(c) (file,start_time,end_time) as MultiIndex"
)

# validate that file paths are correctly placed in the input index or list
if len(samples) > 0:
if isinstance(samples, pd.DataFrame): # samples is a pd.DataFrame
if isinstance(samples.index, pd.core.indexes.multi.MultiIndex):
# index is (file, start_time, end_time)
first_path = samples.index.values[0][0]
else: # index of df is just file path
first_path = samples.index.values[0]
else: # samples is a list of file path
first_path = samples[0]
_check_is_path(first_path)

# set up prediction Dataset, considering three possible cases:
# (c1) user provided multi-index df with file,start_time,end_time of clips
# (c2) user provided file list and wants clips to be split out automatically
Expand All @@ -72,15 +84,11 @@ def __init__(
and type(samples.index) == pd.core.indexes.multi.MultiIndex
): # c1 user provided multi-index df with file,start_time,end_time of clips
# raise AssertionError if first item of multi-index is not str or Path
if len(samples) > 0:
_check_is_path(samples.index.values[0][0])
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)
elif (
split_files_into_clips
): # c2 user provided file list; split each file into appropriate length clips
# raise AssertionError if first item is not str or Path
if len(samples) > 0:
_check_is_path(samples[0])
dataset = AudioSplittingDataset(
samples=samples,
preprocessor=preprocessor,
Expand All @@ -90,10 +98,6 @@ def __init__(
else: # c3 samples is list of files and
# split_files_into_clips=False -> one sample & one prediction per file provided
# eg, each file is a 5 second clips and the model expects 5 second clips

# raise AssertionError if first item is not str or Path
if len(samples) > 0:
_check_is_path(samples[0])
dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor)

dataset.bypass_augmentations = bypass_augmentations
Expand Down

0 comments on commit b89d4aa

Please sign in to comment.