Skip to content

Commit

Permalink
Merge pull request #303 from HazyResearch/feature/audio
Browse files Browse the repository at this point in the history
Add audio demo
  • Loading branch information
seyuboglu authored Mar 12, 2023
2 parents f6fe43f + 812a3fb commit fa0e7dc
Show file tree
Hide file tree
Showing 19 changed files with 896 additions and 6 deletions.
430 changes: 430 additions & 0 deletions demo/audio-embed.ipynb

Large diffs are not rendered by default.

128 changes: 128 additions & 0 deletions meerkat/cells/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from meerkat import env
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import requires

torch = LazyLoader("torch")
torchaudio = LazyLoader("torchaudio")


class Audio:
def __init__(
self,
data,
sampling_rate: int,
bits: int = None,
) -> None:
if not env.is_package_installed("torch"):
raise ValueError(
f"{type(self)} requires torch. Follow these instructions "
"to install torch: https://pytorch.org/get-started/locally/."
)
self.data = torch.as_tensor(data)
self.sampling_rate = sampling_rate
self.bits = bits

def duration(self) -> float:
"""Return the duration of the audio in seconds."""
return len(self.data) / self.sampling_rate

@requires("torchaudio")
def resample(self, sampling_rate: int) -> "Audio":
"""Resample the audio with a new sampling rate.
Args:
sampling_rate: The new sampling rate.
Returns:
The resampled audio.
"""
if not env.is_package_installed("torchaudio"):
raise ValueError(
"resample requires torchaudio. Install with `pip install torchaudio`."
)

return Audio(
torchaudio.functional.resample(
self.data, self.sampling_rate, sampling_rate
),
sampling_rate,
)

def normalize(
self, lower: float = 0.0, upper: float = 1.0, eps: float = 1e-6
) -> "Audio":
"""Normalize the audio to a given range.
Args:
lower: The lower bound of the range.
upper: The upper bound of the range.
eps: The epsilon to used to avoid division by zero.
Returns:
The normalized audio.
"""
_min = torch.amin(self.data)
_max = torch.amax(self.data)
data = lower + (upper - lower) * (self.data - _min) / (_max - _min + eps)
return Audio(data=data, sampling_rate=self.sampling_rate)

def quantize(self, bits: int, epsilon: float = 1e-2) -> "Audio":
"""Linearly quantize a signal to a given number of bits.
The signal must be in the range [0, 1].
Args:
bits: The number of bits to quantize to.
epsilon: The epsilon to use for clipping the signal.
Returns:
The quantized audio.
"""
if self.bits is not None:
raise ValueError(
"Audio is already quantized. Use `.dequantize` to dequantize "
"the signal and then requantize."
)

if torch.any(self.data < 0) or torch.any(self.data > 1):
raise ValueError("Audio must be in the range [0, 1] to quantize.")

q_levels = 1 << bits
samples = (q_levels - epsilon) * self.data
samples += epsilon / 2
return Audio(samples.long(), sampling_rate=self.sampling_rate, bits=self.bits)

def dequantize(self) -> "Audio":
"""Dequantize a signal.
Returns:
The dequantized audio.
"""
if self.bits is None:
raise ValueError("Audio is not quantized.")

q_levels = 1 << self.bits
return Audio(
self.data.float() / (q_levels / 2) - 1, sampling_rate=self.sampling_rate
)

def __repr__(self) -> str:
return f"Audio({self.duration()} seconds @ {self.sampling_rate}Hz)"

def __eq__(self, other: "Audio") -> bool:
return (
self.data.shape == other.data.shape
and self.sampling_rate == other.sampling_rate
and torch.allclose(self.data, other.data)
)

def __getitem__(self, key: int) -> "Audio":
return Audio(self.data[key], self.sampling_rate)

def __len__(self) -> int:
return len(self.data)

def _repr_html_(self) -> str:
import IPython.display as ipd

return ipd.Audio(self.data, rate=self.sampling_rate)._repr_html_()
2 changes: 2 additions & 0 deletions meerkat/columns/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def _repr_pandas_(self, max_rows: int = None) -> pd.Series:
else:
col = pd.Series([self._repr_cell(idx) for idx in range(len(self))])

# TODO: if the objects have a _repr_html_ method, we should be able to use
# that instead of explicitly relying on the column having a formatter.
return (
col,
self.formatters["base"]
Expand Down
15 changes: 15 additions & 0 deletions meerkat/columns/deferred/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import meerkat.tools.docs as docs
from meerkat.block.deferred_block import DeferredOp
from meerkat.cells.audio import Audio
from meerkat.columns.abstract import Column
from meerkat.columns.deferred.base import DeferredCell, DeferredColumn
from meerkat.columns.scalar import ScalarColumn
Expand All @@ -27,6 +28,7 @@
PDFFormatterGroup,
TextFormatterGroup,
)
from meerkat.interactive.formatter.audio import DeferredAudioFormatterGroup
from meerkat.interactive.formatter.base import FormatterGroup
from meerkat.interactive.formatter.image import DeferredImageFormatterGroup

Expand Down Expand Up @@ -485,6 +487,14 @@ def load_text(path: Union[str, io.BytesIO]):
return f.read()


def load_audio(path: str) -> Audio:
import torchaudio

data, sampling_rate = torchaudio.load(path)
data = data.squeeze()
return Audio(data, sampling_rate=sampling_rate)


FILE_TYPES = {
"image": {
"loader": load_image,
Expand Down Expand Up @@ -512,6 +522,11 @@ def load_text(path: Union[str, io.BytesIO]):
"formatters": CodeFormatterGroup,
"exts": [".py", ".js", ".css", ".json", ".java", ".cpp", ".c", ".h", ".hpp"],
},
"audio": {
"loader": load_audio,
"formatters": DeferredAudioFormatterGroup,
"exts": [".wav", ".mp3"],
},
}


Expand Down
2 changes: 1 addition & 1 deletion meerkat/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def from_huggingface(cls, *args, **kwargs):
if isinstance(dataset, dict):
return dict(
map(
lambda t: (t[0], cls.from_arrow(t[1]._data)),
lambda t: (t[0], cls.from_arrow(t[1]._data.table)),
dataset.items(),
)
)
Expand Down
2 changes: 2 additions & 0 deletions meerkat/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .pascal import pascal
from .registry import datasets
from .rfw import rfw
from .torchaudio import yesno

__all__ = [
"celeba",
Expand All @@ -29,6 +30,7 @@
"rfw",
"ngoa",
"coco",
"yesno",
]

DOWNLOAD_MODES = ["force", "extract", "reuse", "skip"]
Expand Down
59 changes: 59 additions & 0 deletions meerkat/datasets/torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,71 @@
import os

import pandas as pd

import meerkat as mk
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import deprecated

from ..abstract import DatasetBuilder
from ..info import DatasetInfo
from ..registry import datasets

torch = LazyLoader("torch")
torchaudio = LazyLoader("torchaudio")


@datasets.register()
class yesno(DatasetBuilder):
"""YESNO dataset.
Reference:
https://www.openslr.org/1/
"""

info = DatasetInfo(
name="yesno",
full_name="YesNo",
description=(
"This dataset contains 60 .wav files, sampled at 8 kHz. "
"All were recorded by the same male speaker, in Hebrew. "
"In each file, the individual says 8 words; each word is either the "
"Hebrew for 'yes' or 'no', so each file is a random sequence of 8 yes-es "
"or noes. There is no separate transcription provided; the sequence is "
"encoded in the filename, with 1 for yes and 0 for no."
),
homepage="https://www.openslr.org/1/",
tags=["audio", "classification"],
)

VERSIONS = ["release1"]

def download(self):
os.makedirs(self.dataset_dir, exist_ok=True)
torchaudio.datasets.YESNO(root=self.dataset_dir, download=True)

def is_downloaded(self) -> bool:
return super().is_downloaded() and os.path.exists(
os.path.join(self.dataset_dir, "waves_yesno")
)

def build(self):
dataset = torchaudio.datasets.YESNO(root=self.dataset_dir, download=False)
df = mk.DataFrame(
{
"id": dataset._walker,
"audio": mk.files(
pd.Series(dataset._walker) + ".wav", base_dir=dataset._path
),
"labels": torch.tensor(
[[int(c) for c in fileid.split("_")] for fileid in dataset._walker]
),
}
)

return df


@deprecated("mk.get('yesno')")
def get_yesno(dataset_dir: str, download: bool = True):
"""Load YESNO as a Meerkat DataFrame.
Expand Down
19 changes: 16 additions & 3 deletions meerkat/interactive/app/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions meerkat/interactive/app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"d3-selection": "^3.0.0",
"flowbite": "^1.5.2",
"flowbite-svelte": "^0.29.11",
"howler": "^2.2.3",
"layercake": "^7.0.0",
"marked": "^4.2.12",
"monaco-editor": "^0.34.1",
Expand Down
40 changes: 40 additions & 0 deletions meerkat/interactive/app/src/lib/component/core/audio/Audio.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<script lang="ts">
import { PlayFill, Pause } from 'svelte-bootstrap-icons';
import { setCurrentMedia, currentMedia } from '$lib/shared/media/current';
export let data: string;
export let classes: string = '';
const startPlay = () => {
setCurrentMedia(data);
if ($currentMedia) {
$currentMedia.paused = false;
$currentMedia.currentTime = 0;
}
};
// FIXME: This should be based of the primary key of the cell.
$: playing = $currentMedia && $currentMedia.data === data;
</script>

<div
class={'h-full w-10 bg-slate-100 rounded-sm flex items-center justify-center self-center border' +
classes}
class:border-violet-600={playing}
>
{#if playing}
{#if $currentMedia.paused}
<button on:click={() => ($currentMedia.paused = false)}>
<PlayFill class="text-slate-600" />
</button>
{:else if $currentMedia.ended}
<button on:click={startPlay}> <PlayFill class="text-slate-600" /> </button>
{:else}
<button on:click={() => ($currentMedia.paused = true)}>
<Pause class="text-slate-600" />
</button>
{/if}
{:else}
<button on:click={startPlay}> <PlayFill class="text-slate-600" /> </button>
{/if}
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from meerkat.interactive.app.src.lib.component.abstract import Component


class Audio(Component):
data: str
classes: str = ""
Loading

0 comments on commit fa0e7dc

Please sign in to comment.