diff --git a/opensoundscape/ml/dataloaders.py b/opensoundscape/ml/dataloaders.py index 32420000..6ae41402 100644 --- a/opensoundscape/ml/dataloaders.py +++ b/opensoundscape/ml/dataloaders.py @@ -63,6 +63,18 @@ def __init__( "(c) (file,start_time,end_time) as MultiIndex" ) + # validate that file paths are correctly placed in the input index or list + if len(samples) > 0: + if isinstance(samples, pd.DataFrame): # samples is a pd.DataFrame + if isinstance(samples.index, pd.core.indexes.multi.MultiIndex): + # index is (file, start_time, end_time) + first_path = samples.index.values[0][0] + else: # index of df is just file path + first_path = samples.index.values[0] + else: # samples is a list of file path + first_path = samples[0] + _check_is_path(first_path) + # set up prediction Dataset, considering three possible cases: # (c1) user provided multi-index df with file,start_time,end_time of clips # (c2) user provided file list and wants clips to be split out automatically @@ -72,15 +84,11 @@ def __init__( and type(samples.index) == pd.core.indexes.multi.MultiIndex ): # c1 user provided multi-index df with file,start_time,end_time of clips # raise AssertionError if first item of multi-index is not str or Path - if len(samples) > 0: - _check_is_path(samples.index.values[0][0]) dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor) elif ( split_files_into_clips ): # c2 user provided file list; split each file into appropriate length clips # raise AssertionError if first item is not str or Path - if len(samples) > 0: - _check_is_path(samples[0]) dataset = AudioSplittingDataset( samples=samples, preprocessor=preprocessor, @@ -90,10 +98,6 @@ def __init__( else: # c3 samples is list of files and # split_files_into_clips=False -> one sample & one prediction per file provided # eg, each file is a 5 second clips and the model expects 5 second clips - - # raise AssertionError if first item is not str or Path - if len(samples) > 0: - _check_is_path(samples[0]) dataset = AudioFileDataset(samples=samples, preprocessor=preprocessor) dataset.bypass_augmentations = bypass_augmentations