Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

199 task audio io #202

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion src/senselab/audio/data_structures/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import uuid
import warnings
from typing import Dict, Generator, List, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -151,6 +151,67 @@ def window_generator(self, window_size: int, step_size: int) -> Generator["Audio
yield window_audio
current_position += step_size

def save_to_file(
self,
file_path: Union[str, os.PathLike],
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[float, int]] = None,
) -> None:
"""Save the `Audio` object to a file using `torchaudio.save`.

Args:
file_path (Union[str, os.PathLike]): The path to save the audio file.
format (Optional[str]): Audio format to use. Valid values include "wav", "ogg", and "flac".
If None, the format is inferred from the file extension.
encoding (Optional[str]): Encoding to use. Valid options include "PCM_S", "PCM_U", "PCM_F", "ULAW", "ALAW".
This is effective for formats like "wav" and "flac".
bits_per_sample (Optional[int]): Bit depth for the audio file. Valid values are 8, 16, 24, 32, and 64.
buffer_size (int): Size of the buffer in bytes for processing file-like objects. Default is 4096.
backend (Optional[str]): I/O backend to use. Valid options include "ffmpeg", "sox", and "soundfile".
If None, a backend is automatically selected.
compression (Optional[Union[float, int]]): Compression level for supported formats like "mp3",
"flac", and "ogg".
Refer to `torchaudio.save` documentation for specific compression levels.

Raises:
ValueError: If the `Audio` waveform is not 2D, or if the sampling rate is invalid.
RuntimeError: If there is an error saving the audio file.

Note:
- https://pytorch.org/audio/master/generated/torchaudio.save.html
"""
if self.waveform.ndim != 2:
raise ValueError("Waveform must be a 2D tensor with shape (num_channels, num_samples).")

if self.sampling_rate <= 0:
raise ValueError("Sampling rate must be a positive integer.")

output_dir = os.path.dirname(file_path)
if not os.access(output_dir, os.W_OK):
raise RuntimeError(f"Output directory '{output_dir}' is not writable.")

try:
if not os.path.exists(output_dir):
os.makedirs(os.path.dirname(file_path))
torchaudio.save(
uri=file_path,
src=self.waveform,
sample_rate=self.sampling_rate,
channels_first=True,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
buffer_size=buffer_size,
backend=backend,
compression=compression,
)
except Exception as e:
raise RuntimeError(f"Error saving audio to file: {e}") from e


def batch_audios(audios: List[Audio]) -> Tuple[torch.Tensor, Union[int, List[int]], List[Dict]]:
"""Batches the Audios together into a single Tensor, keeping individual Audio information separate.
Expand Down
3 changes: 3 additions & 0 deletions src/senselab/audio/tasks/input_output/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""This module provides some utilities for audio input and output."""

from .utils import read_audios, save_audios # noqa: F401
119 changes: 119 additions & 0 deletions src/senselab/audio/tasks/input_output/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""This module provides utilities for reading and saving (multiple) audio files using Pydra."""

import os
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import pydra

from senselab.audio.data_structures import Audio


def read_audios(
file_paths: List[str | os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
plugin: str = "serial",
plugin_args: Dict[str, Any] = {},
) -> List[Audio]:
"""Read and process a list of audio files using Pydra workflow.

Args:
file_paths (List[str]): List of paths to audio files.
cache_dir (str, optional): Directory for caching intermediate results. Defaults to None.
plugin (str, optional): Pydra plugin to use for workflow submission. Defaults to "serial".
plugin_args (dict, optional): Additional arguments for the Pydra plugin. Defaults to {}.

Returns:
List[Audio]: A list of Audio objects containing the waveform and sample rate for each processed file.
"""

@pydra.mark.task
def load_audio_file(file_path: str | os.PathLike) -> Any: # noqa: ANN401
"""Load an audio file and return an Audio object.

Args:
file_path (str): Path to the audio file.

Returns:
Audio: An instance of the Audio class containing the waveform and sample rate.
"""
return Audio.from_filepath(file_path)

# Create the workflow
wf = pydra.Workflow(name="read_audio_files_workflow", input_spec=["x"], cache_dir=cache_dir)
wf.split("x", x=file_paths)
wf.add(load_audio_file(name="load_audio_file", file_path=wf.lzin.x))
wf.set_output([("processed_files", wf.load_audio_file.lzout.out)])

# Run the workflow
with pydra.Submitter(plugin=plugin, **plugin_args) as sub:
sub(wf)

# Collect and return the results
outputs = wf.result()
return [output.output.processed_files for output in outputs]


def save_audios(
audio_tuples: Sequence[Tuple[Audio, Union[str, os.PathLike]]],
save_params: Dict[str, Any] = {},
cache_dir: Optional[Union[str, os.PathLike]] = None,
plugin: str = "serial",
plugin_args: Dict[str, Any] = {},
) -> None:
"""Save a list of Audio objects to specified files using Pydra workflow.

Args:
audio_tuples (Sequence[Tuple[Audio, Union[str, os.PathLike]]): Sequence of tuples where each tuple contains
an Audio object, its output path (str or os.PathLike).
save_params (dict, optional): Additional parameters for saving audio files.
Defaults to {}
cache_dir (str, optional): Directory for caching intermediate results.
Defaults to None.
plugin (str, optional): Pydra plugin to use for workflow submission.
Defaults to "serial".
plugin_args (dict, optional): Additional arguments for the Pydra plugin.
Defaults to {}.

Raises:
RuntimeError: If any output directory in the provided paths does not exist or is not writable.
"""

@pydra.mark.task
def _extract_audio(audio_tuple: Tuple[Audio, Union[str, os.PathLike]]) -> Any: # noqa: ANN401
"""Extract the Audio object from the tuple."""
return audio_tuple[0]

@pydra.mark.task
def _extract_output_path(audio_tuple: Tuple[Audio, Union[str, os.PathLike]]) -> Union[str, os.PathLike]:
"""Extract the output path from the tuple."""
return audio_tuple[1]

@pydra.mark.task
def _save_audio(audio: Audio, file_path: str, save_params: Dict[str, Any]) -> None:
"""Save an Audio object to a file.

Args:
audio (Audio): The Audio object to save.
file_path (str): Path to save the audio file.
save_params (dict): Additional parameters for saving audio files.
"""
audio.save_to_file(file_path, **save_params)

# Create the workflow
wf = pydra.Workflow(name="save_audio_files_workflow", input_spec=["audio_tuples"], cache_dir=cache_dir)
wf.split("audio_tuples", audio_tuples=audio_tuples)
wf.add(_extract_audio(name="_extract_audio", audio_tuple=wf.lzin.audio_tuples))
wf.add(_extract_output_path(name="_extract_output_path", audio_tuple=wf.lzin.audio_tuples))
wf.add(
_save_audio(
name="_save_audio",
audio=wf._extract_audio.lzout.out,
file_path=wf._extract_output_path.lzout.out,
save_params=save_params,
)
)
wf.set_output([])

# Run the workflow
with pydra.Submitter(plugin=plugin, **plugin_args) as sub:
sub(wf)
27 changes: 27 additions & 0 deletions src/tests/audio/data_structures/audio_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module for testing Audio data structures."""

import tempfile
from pathlib import Path
from typing import List, Tuple

import pytest
Expand Down Expand Up @@ -34,6 +36,31 @@ def test_audio_creation(audio_fixture: str, audio_path: str, request: pytest.Fix
assert audio == audio_sample, "Audios are not exactly equivalent"


@pytest.mark.parametrize(
"audio_fixture",
["mono_audio_sample", "stereo_audio_sample"],
)
def test_audio_save_to_file(audio_fixture: str, request: pytest.FixtureRequest) -> None:
"""Tests saving audio to file."""
# Get the audio sample from the fixture
audio_sample = request.getfixturevalue(audio_fixture)

# Use a temporary file for the test
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = Path(temp_dir) / "test_audio.wav"

# Call save_to_file to save the audio
audio_sample.save_to_file(file_path=temp_file_path, format="wav", bits_per_sample=16)

# Check if the file was created
assert temp_file_path.exists(), "The audio file was not saved."

# Load the saved file and verify its content
loaded_waveform, loaded_sampling_rate = torchaudio.load(temp_file_path)
assert torch.allclose(audio_sample.waveform, loaded_waveform, atol=1e-5), "Waveform data does not match."
assert audio_sample.sampling_rate == loaded_sampling_rate, "Sampling rate does not match."


@pytest.mark.parametrize(
"audio_fixture, audio_path",
[
Expand Down
68 changes: 68 additions & 0 deletions src/tests/audio/tasks/input_output_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Module for testing Audio data structures."""

import os
import tempfile
from typing import List

import numpy as np
import pytest
import torch

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.input_output import read_audios, save_audios
from tests.audio.conftest import MONO_AUDIO_PATH, STEREO_AUDIO_PATH


@pytest.mark.parametrize(
"audio_paths",
[
([MONO_AUDIO_PATH]),
([STEREO_AUDIO_PATH]),
([MONO_AUDIO_PATH, STEREO_AUDIO_PATH]), # Test multiple files
],
)
def test_read_audios(audio_paths: List[str | os.PathLike]) -> None:
"""Tests the read_audios function with actual mono and stereo audio files."""
# Run the function with real audio file paths
processed_audios = read_audios(file_paths=audio_paths, plugin="serial")

# Validate results
assert len(processed_audios) == len(audio_paths), "Incorrect number of processed files."

for idx, processed_audio in enumerate(processed_audios):
# Load the same file directly using the Audio class for comparison
reference_audio = Audio.from_filepath(audio_paths[idx])

# Verify the processed Audio matches the reference
assert torch.equal(
processed_audio.waveform, reference_audio.waveform
), f"Waveform for file {audio_paths[idx]} does not match the expected."
assert (
processed_audio.sampling_rate == reference_audio.sampling_rate
), f"Sampling rate for file {audio_paths[idx]} does not match the expected."


def test_save_audios() -> None:
"""Test the `save_audios` function."""
# Create temporary directory for saving audio files
with tempfile.TemporaryDirectory() as temp_dir:
# Mock audio data with correct fields
audio_mock_1 = Audio(
waveform=np.array([0.1, 0.2, 0.3]), # Replace with actual waveform data if needed
sampling_rate=44100,
)
audio_mock_2 = Audio(waveform=np.array([0.4, 0.5, 0.6]), sampling_rate=44100)

# Prepare tuples of Audio objects and target file paths
audio_tuples = [
(audio_mock_1, os.path.join(temp_dir, "audio1.wav")),
(audio_mock_2, os.path.join(temp_dir, "audio2.wav")),
]

# Call the `save_audios` function
save_audios(audio_tuples=audio_tuples)

# Assertions to verify files are saved
for _, file_path in audio_tuples:
assert os.path.exists(file_path), f"File {file_path} was not created."
assert os.path.getsize(file_path) > 0, f"File {file_path} is empty."
Loading
Loading