Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
seyuboglu committed Mar 12, 2023
2 parents 75470ff + fa0e7dc commit 49c679b
Show file tree
Hide file tree
Showing 29 changed files with 958 additions and 21 deletions.
430 changes: 430 additions & 0 deletions demo/audio-embed.ipynb

Large diffs are not rendered by default.

128 changes: 128 additions & 0 deletions meerkat/cells/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from meerkat import env
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import requires

torch = LazyLoader("torch")
torchaudio = LazyLoader("torchaudio")


class Audio:
def __init__(
self,
data,
sampling_rate: int,
bits: int = None,
) -> None:
if not env.is_package_installed("torch"):
raise ValueError(
f"{type(self)} requires torch. Follow these instructions "
"to install torch: https://pytorch.org/get-started/locally/."
)
self.data = torch.as_tensor(data)
self.sampling_rate = sampling_rate
self.bits = bits

def duration(self) -> float:
"""Return the duration of the audio in seconds."""
return len(self.data) / self.sampling_rate

@requires("torchaudio")
def resample(self, sampling_rate: int) -> "Audio":
"""Resample the audio with a new sampling rate.
Args:
sampling_rate: The new sampling rate.
Returns:
The resampled audio.
"""
if not env.is_package_installed("torchaudio"):
raise ValueError(
"resample requires torchaudio. Install with `pip install torchaudio`."
)

return Audio(
torchaudio.functional.resample(
self.data, self.sampling_rate, sampling_rate
),
sampling_rate,
)

def normalize(
self, lower: float = 0.0, upper: float = 1.0, eps: float = 1e-6
) -> "Audio":
"""Normalize the audio to a given range.
Args:
lower: The lower bound of the range.
upper: The upper bound of the range.
eps: The epsilon to used to avoid division by zero.
Returns:
The normalized audio.
"""
_min = torch.amin(self.data)
_max = torch.amax(self.data)
data = lower + (upper - lower) * (self.data - _min) / (_max - _min + eps)
return Audio(data=data, sampling_rate=self.sampling_rate)

def quantize(self, bits: int, epsilon: float = 1e-2) -> "Audio":
"""Linearly quantize a signal to a given number of bits.
The signal must be in the range [0, 1].
Args:
bits: The number of bits to quantize to.
epsilon: The epsilon to use for clipping the signal.
Returns:
The quantized audio.
"""
if self.bits is not None:
raise ValueError(
"Audio is already quantized. Use `.dequantize` to dequantize "
"the signal and then requantize."
)

if torch.any(self.data < 0) or torch.any(self.data > 1):
raise ValueError("Audio must be in the range [0, 1] to quantize.")

q_levels = 1 << bits
samples = (q_levels - epsilon) * self.data
samples += epsilon / 2
return Audio(samples.long(), sampling_rate=self.sampling_rate, bits=self.bits)

def dequantize(self) -> "Audio":
"""Dequantize a signal.
Returns:
The dequantized audio.
"""
if self.bits is None:
raise ValueError("Audio is not quantized.")

q_levels = 1 << self.bits
return Audio(
self.data.float() / (q_levels / 2) - 1, sampling_rate=self.sampling_rate
)

def __repr__(self) -> str:
return f"Audio({self.duration()} seconds @ {self.sampling_rate}Hz)"

def __eq__(self, other: "Audio") -> bool:
return (
self.data.shape == other.data.shape
and self.sampling_rate == other.sampling_rate
and torch.allclose(self.data, other.data)
)

def __getitem__(self, key: int) -> "Audio":
return Audio(self.data[key], self.sampling_rate)

def __len__(self) -> int:
return len(self.data)

def _repr_html_(self) -> str:
import IPython.display as ipd

return ipd.Audio(self.data, rate=self.sampling_rate)._repr_html_()
6 changes: 6 additions & 0 deletions meerkat/columns/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def _repr_pandas_(self, max_rows: int = None) -> pd.Series:
else:
col = pd.Series([self._repr_cell(idx) for idx in range(len(self))])

# TODO: if the objects have a _repr_html_ method, we should be able to use
# that instead of explicitly relying on the column having a formatter.
return (
col,
self.formatters["base"]
Expand Down Expand Up @@ -649,6 +651,10 @@ def to_numpy(self) -> np.ndarray:
f"Cannot convert column of type {type(self)} to Numpy array."
)

def __array__(self) -> np.ndarray:
"""Convert the data to a numpy array."""
return self.to_numpy()

def to_json(self) -> dict:
"""Convert the column to a JSON object.
Expand Down
15 changes: 15 additions & 0 deletions meerkat/columns/deferred/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import meerkat.tools.docs as docs
from meerkat.block.deferred_block import DeferredOp
from meerkat.cells.audio import Audio
from meerkat.columns.abstract import Column
from meerkat.columns.deferred.base import DeferredCell, DeferredColumn
from meerkat.columns.scalar import ScalarColumn
Expand All @@ -27,6 +28,7 @@
PDFFormatterGroup,
TextFormatterGroup,
)
from meerkat.interactive.formatter.audio import DeferredAudioFormatterGroup
from meerkat.interactive.formatter.base import FormatterGroup
from meerkat.interactive.formatter.image import DeferredImageFormatterGroup

Expand Down Expand Up @@ -485,6 +487,14 @@ def load_text(path: Union[str, io.BytesIO]):
return f.read()


def load_audio(path: str) -> Audio:
import torchaudio

data, sampling_rate = torchaudio.load(path)
data = data.squeeze()
return Audio(data, sampling_rate=sampling_rate)


FILE_TYPES = {
"image": {
"loader": load_image,
Expand Down Expand Up @@ -512,6 +522,11 @@ def load_text(path: Union[str, io.BytesIO]):
"formatters": CodeFormatterGroup,
"exts": [".py", ".js", ".css", ".json", ".java", ".cpp", ".c", ".h", ".hpp"],
},
"audio": {
"loader": load_audio,
"formatters": DeferredAudioFormatterGroup,
"exts": [".wav", ".mp3"],
},
}


Expand Down
4 changes: 4 additions & 0 deletions meerkat/columns/object/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence

import cytoolz as tz
import numpy as np
import pandas as pd
from PIL.Image import Image
from yaml.representer import Representer
Expand Down Expand Up @@ -73,3 +74,6 @@ def _get_default_formatters(self):

def to_pandas(self, allow_objects: bool = False) -> pd.Series:
return pd.Series([self[int(idx)] for idx in range(len(self))])

def to_numpy(self):
return np.array(self.data)
14 changes: 11 additions & 3 deletions meerkat/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from meerkat.errors import ConversionError
from meerkat.interactive.graph.marking import is_unmarked_context, unmarked
from meerkat.interactive.graph.reactivity import reactive
from meerkat.interactive.graph.store import Store
from meerkat.interactive.modification import DataFrameModification
from meerkat.interactive.node import NodeMixin
from meerkat.mixins.cloneable import CloneableMixin
Expand Down Expand Up @@ -542,8 +543,12 @@ def __setitem__(self, posidx, value):
# but those modifications will be cleared by the endpoint before it is
# run.
if self.has_inode():
columns = self.columns
if isinstance(columns, Store):
# df modifications expects a list, not a Store[list]
columns = columns.value
# Add a modification if it's on the graph
mod = DataFrameModification(id=self.id, scope=self.columns)
mod = DataFrameModification(id=self.id, scope=columns)
mod.add_to_queue()

def __call__(self):
Expand All @@ -560,11 +565,14 @@ def set(self, value: DataFrame):
self._set_state(value._get_state())

if self.has_inode():
columns = self.columns
if isinstance(columns, Store):
columns = columns.value
# Add a modification if it's on the graph
# TODO: think about what the scope should be.
# How does `scope` relate to the the skip_fn mechanism
# in Operation?
mod = DataFrameModification(id=self.inode.id, scope=self.columns)
mod = DataFrameModification(id=self.inode.id, scope=columns)
mod.add_to_queue()

def consolidate(self):
Expand Down Expand Up @@ -675,7 +683,7 @@ def from_huggingface(cls, *args, **kwargs):
if isinstance(dataset, dict):
return dict(
map(
lambda t: (t[0], cls.from_arrow(t[1]._data)),
lambda t: (t[0], cls.from_arrow(t[1]._data.table)),
dataset.items(),
)
)
Expand Down
2 changes: 2 additions & 0 deletions meerkat/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .pascal import pascal
from .registry import datasets
from .rfw import rfw
from .torchaudio import yesno

__all__ = [
"celeba",
Expand All @@ -29,6 +30,7 @@
"rfw",
"ngoa",
"coco",
"yesno",
]

DOWNLOAD_MODES = ["force", "extract", "reuse", "skip"]
Expand Down
9 changes: 4 additions & 5 deletions meerkat/datasets/coco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ def build(self):
"rb",
)
)
breakpoint()

df = mk.DataFrame(dct["images"])
df["split"] = [split] * len(df)
df["split"] = split
dfs.append(df)

df = mk.concat(dfs, axis=0)

path = df["split"] + "2014/" + df["file_name"]
df["image"] = mk.ImageColumn.from_filepaths(path, base_dir=self.var_dataset_dir)
df["image"] = mk.files(path, base_dir=self.var_dataset_dir)

df.data.reorder(
["id", "image"] + [c for c in df.columns if c not in ["id", "image"]]
Expand Down Expand Up @@ -130,13 +129,13 @@ def build_coco_2014_df(dataset_dir: str, download: bool = False):
)

df = mk.DataFrame(dct["images"])
df["split"] = [split] * len(df)
df["split"] = split
dfs.append(df)

df = mk.concat(dfs, axis=0)

path = df["split"] + "2014/" + df["file_name"]
df["image"] = mk.ImageColumn.from_filepaths(path, base_dir=dataset_dir)
df["image"] = mk.files(path, base_dir=dataset_dir)

df.data.reorder(
["id", "image"] + [c for c in df.columns if c not in ["id", "image"]]
Expand Down
7 changes: 5 additions & 2 deletions meerkat/datasets/lvis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,16 @@ def build(self):
)

image_df = mk.DataFrame(dct["images"])
image_df["split"] = [split] * len(image_df)
image_df["split"] = split
image_dfs.append(image_df)

annot_df = mk.DataFrame(dct["annotations"])
if not self.include_segmentations:
annot_df.remove_column("segmentation")
annot_df["bbox"] = np.array(annot_df["bbox"])
# Creating a numpy array from the raw data is much faster than
# iterating through the column.
# TODO: Consider adding the __array__ protocol for the abstract column.
annot_df["bbox"] = np.array(annot_df["bbox"].data)
annot_dfs.append(annot_df)

cat_dfs.append(mk.DataFrame(dct["categories"]))
Expand Down
59 changes: 59 additions & 0 deletions meerkat/datasets/torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,71 @@
import os

import pandas as pd

import meerkat as mk
from meerkat.tools.lazy_loader import LazyLoader
from meerkat.tools.utils import deprecated

from ..abstract import DatasetBuilder
from ..info import DatasetInfo
from ..registry import datasets

torch = LazyLoader("torch")
torchaudio = LazyLoader("torchaudio")


@datasets.register()
class yesno(DatasetBuilder):
"""YESNO dataset.
Reference:
https://www.openslr.org/1/
"""

info = DatasetInfo(
name="yesno",
full_name="YesNo",
description=(
"This dataset contains 60 .wav files, sampled at 8 kHz. "
"All were recorded by the same male speaker, in Hebrew. "
"In each file, the individual says 8 words; each word is either the "
"Hebrew for 'yes' or 'no', so each file is a random sequence of 8 yes-es "
"or noes. There is no separate transcription provided; the sequence is "
"encoded in the filename, with 1 for yes and 0 for no."
),
homepage="https://www.openslr.org/1/",
tags=["audio", "classification"],
)

VERSIONS = ["release1"]

def download(self):
os.makedirs(self.dataset_dir, exist_ok=True)
torchaudio.datasets.YESNO(root=self.dataset_dir, download=True)

def is_downloaded(self) -> bool:
return super().is_downloaded() and os.path.exists(
os.path.join(self.dataset_dir, "waves_yesno")
)

def build(self):
dataset = torchaudio.datasets.YESNO(root=self.dataset_dir, download=False)
df = mk.DataFrame(
{
"id": dataset._walker,
"audio": mk.files(
pd.Series(dataset._walker) + ".wav", base_dir=dataset._path
),
"labels": torch.tensor(
[[int(c) for c in fileid.split("_")] for fileid in dataset._walker]
),
}
)

return df


@deprecated("mk.get('yesno')")
def get_yesno(dataset_dir: str, download: bool = True):
"""Load YESNO as a Meerkat DataFrame.
Expand Down
Loading

0 comments on commit 49c679b

Please sign in to comment.