diff --git a/demo/audio-embed.ipynb b/demo/audio-embed.ipynb new file mode 100644 index 000000000..3947aa5f2 --- /dev/null +++ b/demo/audio-embed.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTORCH_ENABLE_MPS_FALLBACK=1\n" + ] + } + ], + "source": [ + "\"\"\"A notebook to visualize audio embeddings of the HuggingFace Speech Commands dataset.\"\"\"\n", + "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1\n", + "\n", + "import meerkat as mk\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# Set your device here\n", + "device = \"mps\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building (may take up to a minute)... 48.28s\r" + ] + }, + { + "data": { + "text/plain": [ + "(APIInfo(api=, port=5010, server=, name='127.0.0.1', shared=False, process=None, _url=None),\n", + " FrontendInfo(package_manager='npm', port=8003, name='localhost', shared=False, process=, _url=None))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mk.gui.start(api_port=5010, frontend_port=8001, dev=False, skip_build=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the Dataset\n", + "In this demo, we will be working with [`music_genres_small`](https://huggingface.co/datasets/lewtun/music_genres_small) dataset on HuggingFace." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sabrieyuboglu/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/huggingface_hub/utils/_deprecation.py:97: FutureWarning: Deprecated argument(s) used in 'dataset_info': token. Will not be supported from version '0.12'.\n", + " warnings.warn(message, FutureWarning)\n" + ] + }, + { + "data": { + "text/html": [ + "
[03/11/23 20:45:39] WARNING  [_create_builder_config()] [datasets.builder: 463] :: Using custom data builder.py:463\n",
+       "                             configuration lewtun--music_genres_small-2686d03f87ff3ace                             \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[03/11/23 20:45:39]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m \u001b[1m[\u001b[0m\u001b[1;35m_create_builder_config\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m \u001b[1m[\u001b[0mdatasets.builder: \u001b[1;36m463\u001b[0m\u001b[1m]\u001b[0m :: Using custom data \u001b]8;id=593159;file:///Users/sabrieyuboglu/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/builder.py\u001b\\\u001b[2mbuilder.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=759472;file:///Users/sabrieyuboglu/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/builder.py#463\u001b\\\u001b[2m463\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m configuration lewtun--music_genres_small-2686d03f87ff3ace \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
                    WARNING  [download_and_prepare()] [datasets.builder: 641] :: Reusing dataset     builder.py:641\n",
+       "                             parquet                                                                               \n",
+       "                             (/Users/sabrieyuboglu/.cache/huggingface/datasets/lewtun___parquet/lewt               \n",
+       "                             un--music_genres_small-2686d03f87ff3ace/0.0.0/2a3b91fbd88a2c90d1dbbb32b               \n",
+       "                             460cf621d31bd5b05b934492fdef7d8d6f236ec)                                              \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m \u001b[1m[\u001b[0m\u001b[1;35mdownload_and_prepare\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m \u001b[1m[\u001b[0mdatasets.builder: \u001b[1;36m641\u001b[0m\u001b[1m]\u001b[0m :: Reusing dataset \u001b]8;id=774455;file:///Users/sabrieyuboglu/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/builder.py\u001b\\\u001b[2mbuilder.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=811242;file:///Users/sabrieyuboglu/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/builder.py#641\u001b\\\u001b[2m641\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m parquet \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1m(\u001b[0m\u001b[35m/Users/sabrieyuboglu/.cache/huggingface/datasets/lewtun___parquet/lewt\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[35mun--music_genres_small-2686d03f87ff3ace/0.0.0/\u001b[0m\u001b[95m2a3b91fbd88a2c90d1dbbb32b\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[95m460cf621d31bd5b05b934492fdef7d8d6f236ec\u001b[0m\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.010290145874023438, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "", + "rate": null, + "total": 1, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "82c89d26ed254771a82b00dbae167e75", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00)\n", + " warnings.warn(f\"Unable to check if column is a valid primary key: {e}\")\n" + ] + } + ], + "source": [ + "dataset = mk.get(name=\"lewtun/music_genres_small\", registry=\"huggingface\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Dict\n", + "import datasets as hf_datasets\n", + "\n", + "from meerkat.interactive.formatter import AudioFormatterGroup\n", + "from meerkat.cells.audio import Audio\n", + "\n", + "df = dataset[\"train\"].view()\n", + "\n", + "# The audio column is a dictionary containing the bytes.\n", + "# Extract the bytes lazily.\n", + "# The byte string is actually the fastest way to display the audio,\n", + "# because the encoding is already done.\n", + "df[\"audio\"] = df[\"audio\"].defer(lambda x: x[\"bytes\"])\n", + "\n", + "# Set the formatter for this column.\n", + "df[\"audio\"].formatters = AudioFormatterGroup()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Encode the dataset with Wav2Vec2\n", + "Encode the dataset with Wav2Vec2. This will take a few minutes.\n", + "\n", + "You can also optionally download the embeddings from huggingface. See the code for how to do this" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "MemoryError", + "evalue": "Cannot allocate write+execute memory for ffi.callback(). You might be running on a system that prevents this. For more information, see https://cffi.readthedocs.io/en/latest/using.html#callbacks", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mMemoryError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [7], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[39mreturn\u001b[39;00m audio\u001b[39m.\u001b[39mresample(sampling_rate)\u001b[39m.\u001b[39mdata\n\u001b[1;32m 14\u001b[0m df_embed \u001b[39m=\u001b[39m df[[\u001b[39m\"\u001b[39m\u001b[39msong_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39maudio\u001b[39m\u001b[39m\"\u001b[39m]]\n\u001b[0;32m---> 15\u001b[0m df_embed[\u001b[39m\"\u001b[39m\u001b[39maudio\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m df_embed[\u001b[39m\"\u001b[39;49m\u001b[39maudio\u001b[39;49m\u001b[39m\"\u001b[39;49m]\u001b[39m.\u001b[39;49mdefer(to_mk_audio)\n\u001b[1;32m 16\u001b[0m df_embed[\u001b[39m\"\u001b[39m\u001b[39maudio_tensor\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m df_embed[\u001b[39m\"\u001b[39m\u001b[39maudio\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mdefer(to_array)\n", + "File \u001b[0;32m~/code/meerkat/meerkat/interactive/graph/marking.py:74\u001b[0m, in \u001b[0;36munmarked.__call__..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[39m@wraps\u001b[39m(func)\n\u001b[1;32m 72\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 73\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclone():\n\u001b[0;32m---> 74\u001b[0m \u001b[39mreturn\u001b[39;00m reactive(func, nested_return\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/code/meerkat/meerkat/interactive/graph/reactivity.py:204\u001b[0m, in \u001b[0;36mreactive..__reactive..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[39m# Call the function on the args and kwargs\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[39mwith\u001b[39;00m unmarked():\n\u001b[0;32m--> 204\u001b[0m result \u001b[39m=\u001b[39m fn(\u001b[39m*\u001b[39;49munpacked_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49munpacked_kwargs)\n\u001b[1;32m 206\u001b[0m \u001b[39m# TODO: Check if result is equal to one of the inputs.\u001b[39;00m\n\u001b[1;32m 207\u001b[0m \u001b[39m# If it is, we need to copy it.\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[39mif\u001b[39;00m _is_unmarked_context \u001b[39mor\u001b[39;00m _force_no_react \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m any_inputs_marked:\n\u001b[1;32m 210\u001b[0m \u001b[39m# If we are in an unmarked context, then we don't need to create\u001b[39;00m\n\u001b[1;32m 211\u001b[0m \u001b[39m# any nodes in the graph.\u001b[39;00m\n\u001b[1;32m 212\u001b[0m \u001b[39m# `fn` should be run as normal.\u001b[39;00m\n", + "File \u001b[0;32m~/code/meerkat/meerkat/interactive/graph/reactivity.py:137\u001b[0m, in \u001b[0;36mreactive..__reactive..wrapper.._fn_outer_wrapper.._fn_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[39m@wraps\u001b[39m(_fn)\n\u001b[1;32m 136\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_fn_wrapper\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 137\u001b[0m \u001b[39mreturn\u001b[39;00m _fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/code/meerkat/meerkat/mixins/deferable.py:31\u001b[0m, in \u001b[0;36mDeferrableMixin.defer\u001b[0;34m(self, function, is_batched_fn, batch_size, inputs, outputs, output_type, materialize)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[39m@doc\u001b[39m(defer, data\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mself\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 21\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdefer\u001b[39m(\n\u001b[1;32m 22\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 29\u001b[0m materialize: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m,\n\u001b[1;32m 30\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Union[\u001b[39m\"\u001b[39m\u001b[39mDataFrame\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDeferredColumn\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[0;32m---> 31\u001b[0m \u001b[39mreturn\u001b[39;00m defer(\n\u001b[1;32m 32\u001b[0m data\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m,\n\u001b[1;32m 33\u001b[0m function\u001b[39m=\u001b[39;49mfunction,\n\u001b[1;32m 34\u001b[0m is_batched_fn\u001b[39m=\u001b[39;49mis_batched_fn,\n\u001b[1;32m 35\u001b[0m batch_size\u001b[39m=\u001b[39;49mbatch_size,\n\u001b[1;32m 36\u001b[0m inputs\u001b[39m=\u001b[39;49minputs,\n\u001b[1;32m 37\u001b[0m outputs\u001b[39m=\u001b[39;49moutputs,\n\u001b[1;32m 38\u001b[0m output_type\u001b[39m=\u001b[39;49moutput_type,\n\u001b[1;32m 39\u001b[0m materialize\u001b[39m=\u001b[39;49mmaterialize,\n\u001b[1;32m 40\u001b[0m )\n", + "File \u001b[0;32m~/code/meerkat/meerkat/ops/map.py:298\u001b[0m, in \u001b[0;36mdefer\u001b[0;34m(data, function, is_batched_fn, batch_size, inputs, outputs, output_type, materialize)\u001b[0m\n\u001b[1;32m 286\u001b[0m op \u001b[39m=\u001b[39m DeferredOp(\n\u001b[1;32m 287\u001b[0m fn\u001b[39m=\u001b[39mfunction,\n\u001b[1;32m 288\u001b[0m args\u001b[39m=\u001b[39margs,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 293\u001b[0m materialize_inputs\u001b[39m=\u001b[39mmaterialize,\n\u001b[1;32m 294\u001b[0m )\n\u001b[1;32m 296\u001b[0m block \u001b[39m=\u001b[39m DeferredBlock\u001b[39m.\u001b[39mfrom_block_data(data\u001b[39m=\u001b[39mop)\n\u001b[0;32m--> 298\u001b[0m first_row \u001b[39m=\u001b[39m op\u001b[39m.\u001b[39;49m_get(\u001b[39m0\u001b[39;49m) \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(op) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 300\u001b[0m \u001b[39mif\u001b[39;00m outputs \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(first_row, Dict):\n\u001b[1;32m 301\u001b[0m \u001b[39m# support for splitting a dict into multiple columns without specifying outputs\u001b[39;00m\n\u001b[1;32m 302\u001b[0m outputs \u001b[39m=\u001b[39m {output_key: output_key \u001b[39mfor\u001b[39;00m output_key \u001b[39min\u001b[39;00m first_row}\n", + "File \u001b[0;32m~/code/meerkat/meerkat/block/deferred_block.py:248\u001b[0m, in \u001b[0;36mDeferredOp._get\u001b[0;34m(self, index, indexed_inputs, materialize)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(index, \u001b[39mint\u001b[39m):\n\u001b[1;32m 247\u001b[0m \u001b[39mif\u001b[39;00m materialize:\n\u001b[0;32m--> 248\u001b[0m output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 249\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_index \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 250\u001b[0m output \u001b[39m=\u001b[39m output[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreturn_index]\n", + "Cell \u001b[0;32mIn [7], line 8\u001b[0m, in \u001b[0;36mto_mk_audio\u001b[0;34m(audio)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mto_mk_audio\u001b[39m(audio: \u001b[39mbytes\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Audio:\n\u001b[0;32m----> 8\u001b[0m audio_dict \u001b[39m=\u001b[39m hf_datasets\u001b[39m.\u001b[39;49mAudio()\u001b[39m.\u001b[39;49mdecode_example({\u001b[39m\"\u001b[39;49m\u001b[39mpath\u001b[39;49m\u001b[39m\"\u001b[39;49m: \u001b[39mNone\u001b[39;49;00m, \u001b[39m\"\u001b[39;49m\u001b[39mbytes\u001b[39;49m\u001b[39m\"\u001b[39;49m: audio})\n\u001b[1;32m 9\u001b[0m \u001b[39mreturn\u001b[39;00m Audio(data\u001b[39m=\u001b[39maudio_dict[\u001b[39m\"\u001b[39m\u001b[39marray\u001b[39m\u001b[39m\"\u001b[39m], sampling_rate\u001b[39m=\u001b[39maudio_dict[\u001b[39m\"\u001b[39m\u001b[39msampling_rate\u001b[39m\u001b[39m\"\u001b[39m])\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/features/audio.py:154\u001b[0m, in \u001b[0;36mAudio.decode_example\u001b[0;34m(self, value, token_per_repo_id)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 153\u001b[0m \u001b[39mif\u001b[39;00m file:\n\u001b[0;32m--> 154\u001b[0m array, sampling_rate \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_decode_non_mp3_file_like(file)\n\u001b[1;32m 155\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 156\u001b[0m array, sampling_rate \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_decode_non_mp3_path_like(path, token_per_repo_id\u001b[39m=\u001b[39mtoken_per_repo_id)\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/datasets/features/audio.py:273\u001b[0m, in \u001b[0;36mAudio._decode_non_mp3_file_like\u001b[0;34m(self, file, format)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[39mif\u001b[39;00m version\u001b[39m.\u001b[39mparse(sf\u001b[39m.\u001b[39m__libsndfile_version__) \u001b[39m<\u001b[39m version\u001b[39m.\u001b[39mparse(\u001b[39m\"\u001b[39m\u001b[39m1.0.30\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[1;32m 269\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 270\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mDecoding .opus files requires \u001b[39m\u001b[39m'\u001b[39m\u001b[39mlibsndfile\u001b[39m\u001b[39m'\u001b[39m\u001b[39m>=1.0.30, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 271\u001b[0m \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mit can be installed via conda: `conda install -c conda-forge libsndfile>=1.0.30`\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 272\u001b[0m )\n\u001b[0;32m--> 273\u001b[0m array, sampling_rate \u001b[39m=\u001b[39m sf\u001b[39m.\u001b[39;49mread(file)\n\u001b[1;32m 274\u001b[0m array \u001b[39m=\u001b[39m array\u001b[39m.\u001b[39mT\n\u001b[1;32m 275\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmono:\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/soundfile.py:285\u001b[0m, in \u001b[0;36mread\u001b[0;34m(file, frames, start, stop, dtype, always_2d, fill_value, out, samplerate, channels, format, subtype, endian, closefd)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mread\u001b[39m(file, frames\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, start\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m, stop\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, dtype\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mfloat64\u001b[39m\u001b[39m'\u001b[39m, always_2d\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[1;32m 200\u001b[0m fill_value\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, out\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, samplerate\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, channels\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 201\u001b[0m \u001b[39mformat\u001b[39m\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, subtype\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, endian\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m, closefd\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 202\u001b[0m \u001b[39m\"\"\"Provide audio data from a sound file as NumPy array.\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \n\u001b[1;32m 204\u001b[0m \u001b[39m By default, the whole file is read from the beginning, but the\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 283\u001b[0m \n\u001b[1;32m 284\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 285\u001b[0m \u001b[39mwith\u001b[39;00m SoundFile(file, \u001b[39m'\u001b[39;49m\u001b[39mr\u001b[39;49m\u001b[39m'\u001b[39;49m, samplerate, channels,\n\u001b[1;32m 286\u001b[0m subtype, endian, \u001b[39mformat\u001b[39;49m, closefd) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m 287\u001b[0m frames \u001b[39m=\u001b[39m f\u001b[39m.\u001b[39m_prepare_read(start, stop, frames)\n\u001b[1;32m 288\u001b[0m data \u001b[39m=\u001b[39m f\u001b[39m.\u001b[39mread(frames, dtype, always_2d, fill_value, out)\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/soundfile.py:658\u001b[0m, in \u001b[0;36mSoundFile.__init__\u001b[0;34m(self, file, mode, samplerate, channels, subtype, endian, format, closefd)\u001b[0m\n\u001b[1;32m 655\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mode \u001b[39m=\u001b[39m mode\n\u001b[1;32m 656\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_info \u001b[39m=\u001b[39m _create_info_struct(file, mode, samplerate, channels,\n\u001b[1;32m 657\u001b[0m \u001b[39mformat\u001b[39m, subtype, endian)\n\u001b[0;32m--> 658\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_file \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_open(file, mode_int, closefd)\n\u001b[1;32m 659\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mset\u001b[39m(mode)\u001b[39m.\u001b[39missuperset(\u001b[39m'\u001b[39m\u001b[39mr+\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseekable():\n\u001b[1;32m 660\u001b[0m \u001b[39m# Move write position to 0 (like in Python file objects)\u001b[39;00m\n\u001b[1;32m 661\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseek(\u001b[39m0\u001b[39m)\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/soundfile.py:1209\u001b[0m, in \u001b[0;36mSoundFile._open\u001b[0;34m(self, file, mode_int, closefd)\u001b[0m\n\u001b[1;32m 1207\u001b[0m file_ptr \u001b[39m=\u001b[39m _snd\u001b[39m.\u001b[39msf_open_fd(file, mode_int, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_info, closefd)\n\u001b[1;32m 1208\u001b[0m \u001b[39melif\u001b[39;00m _has_virtual_io_attrs(file, mode_int):\n\u001b[0;32m-> 1209\u001b[0m file_ptr \u001b[39m=\u001b[39m _snd\u001b[39m.\u001b[39msf_open_virtual(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_init_virtual_io(file),\n\u001b[1;32m 1210\u001b[0m mode_int, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_info, _ffi\u001b[39m.\u001b[39mNULL)\n\u001b[1;32m 1211\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1212\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInvalid file: \u001b[39m\u001b[39m{0!r}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mname))\n", + "File \u001b[0;32m~/opt/miniconda3/envs/meerkat/lib/python3.9/site-packages/soundfile.py:1229\u001b[0m, in \u001b[0;36mSoundFile._init_virtual_io\u001b[0;34m(self, file)\u001b[0m\n\u001b[1;32m 1226\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_init_virtual_io\u001b[39m(\u001b[39mself\u001b[39m, file):\n\u001b[1;32m 1227\u001b[0m \u001b[39m\"\"\"Initialize callback functions for sf_open_virtual().\"\"\"\u001b[39;00m\n\u001b[1;32m 1228\u001b[0m \u001b[39m@_ffi\u001b[39;49m\u001b[39m.\u001b[39;49mcallback(\u001b[39m\"\u001b[39;49m\u001b[39msf_vio_get_filelen\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[0;32m-> 1229\u001b[0m \u001b[39mdef\u001b[39;49;00m \u001b[39mvio_get_filelen\u001b[39;49m(user_data):\n\u001b[1;32m 1230\u001b[0m curr \u001b[39m=\u001b[39;49m file\u001b[39m.\u001b[39;49mtell()\n\u001b[1;32m 1231\u001b[0m file\u001b[39m.\u001b[39;49mseek(\u001b[39m0\u001b[39;49m, SEEK_END)\n", + "\u001b[0;31mMemoryError\u001b[0m: Cannot allocate write+execute memory for ffi.callback(). You might be running on a system that prevents this. For more information, see https://cffi.readthedocs.io/en/latest/using.html#callbacks" + ] + } + ], + "source": [ + "# Make a column that returns the audio tensors @ sampling rage 16kHz.\n", + "from meerkat.cells.audio import Audio\n", + "import datasets as hf_datasets\n", + "\n", + "sampling_rate = 16000\n", + "\n", + "def to_mk_audio(audio: bytes) -> Audio:\n", + " audio_dict = hf_datasets.Audio().decode_example({\"path\": None, \"bytes\": audio})\n", + " return Audio(data=audio_dict[\"array\"], sampling_rate=audio_dict[\"sampling_rate\"])\n", + "\n", + "def to_array(audio: Audio):\n", + " return audio.resample(sampling_rate).data\n", + "\n", + "df_embed = df[[\"song_id\", \"audio\"]]\n", + "df_embed[\"audio\"] = df_embed[\"audio\"].defer(to_mk_audio)\n", + "df_embed[\"audio_tensor\"] = df_embed[\"audio\"].defer(to_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoProcessor, Wav2Vec2Model\n", + "\n", + "\n", + "processor = AutoProcessor.from_pretrained(\"facebook/wav2vec2-base-960h\", device=device)\n", + "model = Wav2Vec2Model.from_pretrained(\"facebook/wav2vec2-base-960h\").to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def embed(audio_tensor: torch.Tensor):\n", + " audio_tensor = audio_tensor.type(torch.float32).to(device)\n", + " inputs = processor(audio_tensor, sampling_rate=sampling_rate, return_tensors=\"pt\", device=device)\n", + " inputs[\"input_values\"] = inputs[\"input_values\"].to(device)\n", + " with torch.no_grad():\n", + " outputs = model(**inputs)\n", + " last_hidden_states = outputs.last_hidden_state\n", + " return last_hidden_states.mean(dim=1).squeeze().cpu()\n", + "\n", + "df_embed[\"embeddings\"] = df_embed[\"audio_tensor\"].map(embed, use_ray=False, pbar=True)\n", + "df_embed[\"embeddings\"] = df_embed[\"embeddings\"].to(\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally download embeddings from huggingface\n", + "df_embed = mk.DataFrame.read(\"https://huggingface.co/datasets/meerkat-ml/meerkat-dataframes/resolve/main/music_genres_small-wav2vec2-embedded.mk.tar.gz\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make the interface" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "plot_df = df.merge(df_embed, on=\"song_id\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# compute umap of embeddings\n", + "from umap import UMAP\n", + "\n", + "umap = UMAP(n_components=2)\n", + "umap = umap.fit_transform(plot_df[\"embeddings\"])\n", + "plot_df[\"umap_1\"] = umap[:, 0]\n", + "plot_df[\"umap_2\"] = umap[:, 1]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_df = plot_df.mark()\n", + "plot = mk.gui.plotly.ScatterPlot(df=plot_df, x=\"umap_1\", y=\"umap_2\",)\n", + "\n", + "# Because we're using the reactive decorator, the filter function will re-run whenever\n", + "# plot.selected changes. This will update the gallery to only show the selected points.\n", + "@mk.gui.reactive\n", + "def filter(selected: list, df: mk.DataFrame):\n", + " return df[df.primary_key.isin(selected)]\n", + "\n", + "filtered_df = filter(plot.selected, plot_df)\n", + "table = mk.gui.Table(filtered_df)\n", + "\n", + "mk.gui.html.flexcol(\n", + " [plot, table],\n", + " classes=\"h-[1200px]\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.12 ('meerkat')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fab435020311317453f49d9f6bde54424ac28707d23828314f0465c42622dc69" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/meerkat/cells/audio.py b/meerkat/cells/audio.py new file mode 100644 index 000000000..74ade88bd --- /dev/null +++ b/meerkat/cells/audio.py @@ -0,0 +1,128 @@ +from meerkat import env +from meerkat.tools.lazy_loader import LazyLoader +from meerkat.tools.utils import requires + +torch = LazyLoader("torch") +torchaudio = LazyLoader("torchaudio") + + +class Audio: + def __init__( + self, + data, + sampling_rate: int, + bits: int = None, + ) -> None: + if not env.is_package_installed("torch"): + raise ValueError( + f"{type(self)} requires torch. Follow these instructions " + "to install torch: https://pytorch.org/get-started/locally/." + ) + self.data = torch.as_tensor(data) + self.sampling_rate = sampling_rate + self.bits = bits + + def duration(self) -> float: + """Return the duration of the audio in seconds.""" + return len(self.data) / self.sampling_rate + + @requires("torchaudio") + def resample(self, sampling_rate: int) -> "Audio": + """Resample the audio with a new sampling rate. + + Args: + sampling_rate: The new sampling rate. + + Returns: + The resampled audio. + """ + if not env.is_package_installed("torchaudio"): + raise ValueError( + "resample requires torchaudio. Install with `pip install torchaudio`." + ) + + return Audio( + torchaudio.functional.resample( + self.data, self.sampling_rate, sampling_rate + ), + sampling_rate, + ) + + def normalize( + self, lower: float = 0.0, upper: float = 1.0, eps: float = 1e-6 + ) -> "Audio": + """Normalize the audio to a given range. + + Args: + lower: The lower bound of the range. + upper: The upper bound of the range. + eps: The epsilon to used to avoid division by zero. + + Returns: + The normalized audio. + """ + _min = torch.amin(self.data) + _max = torch.amax(self.data) + data = lower + (upper - lower) * (self.data - _min) / (_max - _min + eps) + return Audio(data=data, sampling_rate=self.sampling_rate) + + def quantize(self, bits: int, epsilon: float = 1e-2) -> "Audio": + """Linearly quantize a signal to a given number of bits. + + The signal must be in the range [0, 1]. + + Args: + bits: The number of bits to quantize to. + epsilon: The epsilon to use for clipping the signal. + + Returns: + The quantized audio. + """ + if self.bits is not None: + raise ValueError( + "Audio is already quantized. Use `.dequantize` to dequantize " + "the signal and then requantize." + ) + + if torch.any(self.data < 0) or torch.any(self.data > 1): + raise ValueError("Audio must be in the range [0, 1] to quantize.") + + q_levels = 1 << bits + samples = (q_levels - epsilon) * self.data + samples += epsilon / 2 + return Audio(samples.long(), sampling_rate=self.sampling_rate, bits=self.bits) + + def dequantize(self) -> "Audio": + """Dequantize a signal. + + Returns: + The dequantized audio. + """ + if self.bits is None: + raise ValueError("Audio is not quantized.") + + q_levels = 1 << self.bits + return Audio( + self.data.float() / (q_levels / 2) - 1, sampling_rate=self.sampling_rate + ) + + def __repr__(self) -> str: + return f"Audio({self.duration()} seconds @ {self.sampling_rate}Hz)" + + def __eq__(self, other: "Audio") -> bool: + return ( + self.data.shape == other.data.shape + and self.sampling_rate == other.sampling_rate + and torch.allclose(self.data, other.data) + ) + + def __getitem__(self, key: int) -> "Audio": + return Audio(self.data[key], self.sampling_rate) + + def __len__(self) -> int: + return len(self.data) + + def _repr_html_(self) -> str: + import IPython.display as ipd + + return ipd.Audio(self.data, rate=self.sampling_rate)._repr_html_() diff --git a/meerkat/columns/abstract.py b/meerkat/columns/abstract.py index 606261406..c06b8f85f 100644 --- a/meerkat/columns/abstract.py +++ b/meerkat/columns/abstract.py @@ -323,6 +323,8 @@ def _repr_pandas_(self, max_rows: int = None) -> pd.Series: else: col = pd.Series([self._repr_cell(idx) for idx in range(len(self))]) + # TODO: if the objects have a _repr_html_ method, we should be able to use + # that instead of explicitly relying on the column having a formatter. return ( col, self.formatters["base"] diff --git a/meerkat/columns/deferred/file.py b/meerkat/columns/deferred/file.py index b1550c27c..4b6337b3f 100644 --- a/meerkat/columns/deferred/file.py +++ b/meerkat/columns/deferred/file.py @@ -18,6 +18,7 @@ import meerkat.tools.docs as docs from meerkat.block.deferred_block import DeferredOp +from meerkat.cells.audio import Audio from meerkat.columns.abstract import Column from meerkat.columns.deferred.base import DeferredCell, DeferredColumn from meerkat.columns.scalar import ScalarColumn @@ -27,6 +28,7 @@ PDFFormatterGroup, TextFormatterGroup, ) +from meerkat.interactive.formatter.audio import DeferredAudioFormatterGroup from meerkat.interactive.formatter.base import FormatterGroup from meerkat.interactive.formatter.image import DeferredImageFormatterGroup @@ -485,6 +487,14 @@ def load_text(path: Union[str, io.BytesIO]): return f.read() +def load_audio(path: str) -> Audio: + import torchaudio + + data, sampling_rate = torchaudio.load(path) + data = data.squeeze() + return Audio(data, sampling_rate=sampling_rate) + + FILE_TYPES = { "image": { "loader": load_image, @@ -512,6 +522,11 @@ def load_text(path: Union[str, io.BytesIO]): "formatters": CodeFormatterGroup, "exts": [".py", ".js", ".css", ".json", ".java", ".cpp", ".c", ".h", ".hpp"], }, + "audio": { + "loader": load_audio, + "formatters": DeferredAudioFormatterGroup, + "exts": [".wav", ".mp3"], + }, } diff --git a/meerkat/dataframe.py b/meerkat/dataframe.py index 20b76611c..01643fa7d 100644 --- a/meerkat/dataframe.py +++ b/meerkat/dataframe.py @@ -683,7 +683,7 @@ def from_huggingface(cls, *args, **kwargs): if isinstance(dataset, dict): return dict( map( - lambda t: (t[0], cls.from_arrow(t[1]._data)), + lambda t: (t[0], cls.from_arrow(t[1]._data.table)), dataset.items(), ) ) diff --git a/meerkat/datasets/__init__.py b/meerkat/datasets/__init__.py index 4da09bb7f..dc8fec7d3 100644 --- a/meerkat/datasets/__init__.py +++ b/meerkat/datasets/__init__.py @@ -15,6 +15,7 @@ from .pascal import pascal from .registry import datasets from .rfw import rfw +from .torchaudio import yesno __all__ = [ "celeba", @@ -29,6 +30,7 @@ "rfw", "ngoa", "coco", + "yesno", ] DOWNLOAD_MODES = ["force", "extract", "reuse", "skip"] diff --git a/meerkat/datasets/torchaudio/__init__.py b/meerkat/datasets/torchaudio/__init__.py index 66b62e28d..07ca9ec4b 100644 --- a/meerkat/datasets/torchaudio/__init__.py +++ b/meerkat/datasets/torchaudio/__init__.py @@ -1,12 +1,71 @@ +import os + import pandas as pd import meerkat as mk from meerkat.tools.lazy_loader import LazyLoader +from meerkat.tools.utils import deprecated + +from ..abstract import DatasetBuilder +from ..info import DatasetInfo +from ..registry import datasets torch = LazyLoader("torch") torchaudio = LazyLoader("torchaudio") +@datasets.register() +class yesno(DatasetBuilder): + """YESNO dataset. + + Reference: + https://www.openslr.org/1/ + """ + + info = DatasetInfo( + name="yesno", + full_name="YesNo", + description=( + "This dataset contains 60 .wav files, sampled at 8 kHz. " + "All were recorded by the same male speaker, in Hebrew. " + "In each file, the individual says 8 words; each word is either the " + "Hebrew for 'yes' or 'no', so each file is a random sequence of 8 yes-es " + "or noes. There is no separate transcription provided; the sequence is " + "encoded in the filename, with 1 for yes and 0 for no." + ), + homepage="https://www.openslr.org/1/", + tags=["audio", "classification"], + ) + + VERSIONS = ["release1"] + + def download(self): + os.makedirs(self.dataset_dir, exist_ok=True) + torchaudio.datasets.YESNO(root=self.dataset_dir, download=True) + + def is_downloaded(self) -> bool: + return super().is_downloaded() and os.path.exists( + os.path.join(self.dataset_dir, "waves_yesno") + ) + + def build(self): + dataset = torchaudio.datasets.YESNO(root=self.dataset_dir, download=False) + df = mk.DataFrame( + { + "id": dataset._walker, + "audio": mk.files( + pd.Series(dataset._walker) + ".wav", base_dir=dataset._path + ), + "labels": torch.tensor( + [[int(c) for c in fileid.split("_")] for fileid in dataset._walker] + ), + } + ) + + return df + + +@deprecated("mk.get('yesno')") def get_yesno(dataset_dir: str, download: bool = True): """Load YESNO as a Meerkat DataFrame. diff --git a/meerkat/interactive/app/package-lock.json b/meerkat/interactive/app/package-lock.json index 46dc4c0e0..c7875c796 100644 --- a/meerkat/interactive/app/package-lock.json +++ b/meerkat/interactive/app/package-lock.json @@ -1,12 +1,12 @@ { "name": "@meerkat-ml/meerkat", - "version": "0.4.0", + "version": "0.4.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@meerkat-ml/meerkat", - "version": "0.4.0", + "version": "0.4.2", "dependencies": { "@datastructures-js/priority-queue": "^6.1.2", "@magidoc/plugin-svelte-marked": "^3.2.1", @@ -21,6 +21,7 @@ "d3-selection": "^3.0.0", "flowbite": "^1.5.2", "flowbite-svelte": "^0.29.11", + "howler": "^2.2.3", "layercake": "^7.0.0", "marked": "^4.2.12", "monaco-editor": "^0.34.1", @@ -6975,6 +6976,11 @@ "node": ">=12.0.0" } }, + "node_modules/howler": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/howler/-/howler-2.2.3.tgz", + "integrity": "sha512-QM0FFkw0LRX1PR8pNzJVAY25JhIWvbKMBFM4gqk+QdV+kPXOhleWGCB6AiAF/goGjIHK2e/nIElplvjQwhr0jg==" + }, "node_modules/html-encoding-sniffer": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz", @@ -11684,6 +11690,7 @@ }, "node_modules/vega-embed/node_modules/yallist": { "version": "4.0.0", + "extraneous": true, "inBundle": true, "license": "ISC" }, @@ -17976,6 +17983,11 @@ "resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.6.0.tgz", "integrity": "sha512-ig1eqDzJaB0pqEvlPVIpSSyMaO92bH1N2rJpLMN/nX396wTpDA4Eq0uK+7I/2XG17pFaaKE0kjV/XPeGt7Evjw==" }, + "howler": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/howler/-/howler-2.2.3.tgz", + "integrity": "sha512-QM0FFkw0LRX1PR8pNzJVAY25JhIWvbKMBFM4gqk+QdV+kPXOhleWGCB6AiAF/goGjIHK2e/nIElplvjQwhr0jg==" + }, "html-encoding-sniffer": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz", @@ -21407,7 +21419,8 @@ "dependencies": { "yallist": { "version": "4.0.0", - "bundled": true + "bundled": true, + "extraneous": true } } }, diff --git a/meerkat/interactive/app/package.json b/meerkat/interactive/app/package.json index df3d540b5..644bd3f99 100644 --- a/meerkat/interactive/app/package.json +++ b/meerkat/interactive/app/package.json @@ -77,6 +77,7 @@ "d3-selection": "^3.0.0", "flowbite": "^1.5.2", "flowbite-svelte": "^0.29.11", + "howler": "^2.2.3", "layercake": "^7.0.0", "marked": "^4.2.12", "monaco-editor": "^0.34.1", diff --git a/meerkat/interactive/app/src/lib/component/core/audio/Audio.svelte b/meerkat/interactive/app/src/lib/component/core/audio/Audio.svelte new file mode 100644 index 000000000..3bfb26486 --- /dev/null +++ b/meerkat/interactive/app/src/lib/component/core/audio/Audio.svelte @@ -0,0 +1,40 @@ + + +
+ {#if playing} + {#if $currentMedia.paused} + + {:else if $currentMedia.ended} + + {:else} + + {/if} + {:else} + + {/if} +
diff --git a/meerkat/interactive/app/src/lib/component/core/audio/__init__.py b/meerkat/interactive/app/src/lib/component/core/audio/__init__.py new file mode 100644 index 000000000..340f02b7d --- /dev/null +++ b/meerkat/interactive/app/src/lib/component/core/audio/__init__.py @@ -0,0 +1,6 @@ +from meerkat.interactive.app.src.lib.component.abstract import Component + + +class Audio(Component): + data: str + classes: str = "" diff --git a/meerkat/interactive/app/src/lib/component/core/icon/Icon.svelte b/meerkat/interactive/app/src/lib/component/core/icon/Icon.svelte index 641ea7b6f..e52fae623 100644 --- a/meerkat/interactive/app/src/lib/component/core/icon/Icon.svelte +++ b/meerkat/interactive/app/src/lib/component/core/icon/Icon.svelte @@ -12,7 +12,8 @@ Magic, BoxFill, CheckSquare, - TrashFill + TrashFill, + Soundwave } from 'svelte-bootstrap-icons'; export let name: string = 'Globe2'; @@ -29,7 +30,8 @@ FileEarmarkPdf: FileEarmarkPdf, BoxFill: BoxFill, CheckSquare: CheckSquare, - TrashFill: TrashFill + TrashFill: TrashFill, + Soundwave: Soundwave }; diff --git a/meerkat/interactive/app/src/lib/index.js b/meerkat/interactive/app/src/lib/index.js index 41157358c..c3aa3f1ac 100644 --- a/meerkat/interactive/app/src/lib/index.js +++ b/meerkat/interactive/app/src/lib/index.js @@ -16,6 +16,7 @@ export { default as Row } from './component/contrib/row/Row.svelte'; export { default as FMFilter } from "./component/contrib/fm_filter/FMFilter.svelte"; export { default as GlobalStats } from './component/contrib/global_stats/GlobalStats.svelte'; /** Core Components */ +export { default as Audio } from './component/core/audio/Audio.svelte'; export { default as Button } from './component/core/button/Button.svelte'; export { default as Carousel } from './component/core/carousel/Carousel.svelte'; export { default as Chat } from './component/core/chat/Chat.svelte'; diff --git a/meerkat/interactive/app/src/lib/shared/media/MediaPlayer.svelte b/meerkat/interactive/app/src/lib/shared/media/MediaPlayer.svelte new file mode 100644 index 000000000..1f55af333 --- /dev/null +++ b/meerkat/interactive/app/src/lib/shared/media/MediaPlayer.svelte @@ -0,0 +1,27 @@ + + +{#if $currentMedia} +
(showAudio = true)} + on:mouseleave={() => (showAudio = false)} + on:focus={() => (showAudio = true)} + > + +
+ +{/if} diff --git a/meerkat/interactive/app/src/lib/shared/media/current.ts b/meerkat/interactive/app/src/lib/shared/media/current.ts new file mode 100644 index 000000000..c83f8e9eb --- /dev/null +++ b/meerkat/interactive/app/src/lib/shared/media/current.ts @@ -0,0 +1,23 @@ +import { writable, type Writable } from 'svelte/store'; + +export class PlayingMedia { + data: any; + currentTime: number = 0; + duration: number = 0; + paused: boolean = true; + ended: boolean = false; + + + constructor(data: any) { + this.data = data; + } +} + +export const currentMedia: Writable = writable(null); + + +export const setCurrentMedia = (id: string) => { + let media = new PlayingMedia(id); + currentMedia.set(media); + return media; +}; \ No newline at end of file diff --git a/meerkat/interactive/app/src/routes/+layout.svelte b/meerkat/interactive/app/src/routes/+layout.svelte index 6987ff8f9..85fbf7400 100644 --- a/meerkat/interactive/app/src/routes/+layout.svelte +++ b/meerkat/interactive/app/src/routes/+layout.svelte @@ -2,12 +2,16 @@ import '$lib/app.css'; import ComponentContext from '$lib/ComponentContext.svelte'; import Modals from '$lib/shared/common/Modals.svelte'; + import MediaPlayer from '$lib/shared/media/MediaPlayer.svelte'; + + + diff --git a/meerkat/interactive/formatter/__init__.py b/meerkat/interactive/formatter/__init__.py index a1411d57f..10229f6d7 100644 --- a/meerkat/interactive/formatter/__init__.py +++ b/meerkat/interactive/formatter/__init__.py @@ -1,3 +1,4 @@ +from .audio import AudioFormatter, AudioFormatterGroup from .base import Formatter, deferred_formatter_group from .boolean import BooleanFormatter, BooleanFormatterGroup from .code import CodeFormatter, CodeFormatterGroup @@ -27,6 +28,8 @@ "TensorFormatterGroup", "BooleanFormatter", "BooleanFormatterGroup", + "AudioFormatter", + "AudioFormatterGroup", ] diff --git a/meerkat/interactive/formatter/audio.py b/meerkat/interactive/formatter/audio.py new file mode 100644 index 000000000..a16236c49 --- /dev/null +++ b/meerkat/interactive/formatter/audio.py @@ -0,0 +1,110 @@ +import base64 +import os +from io import BytesIO +from typing import Any, Dict, Union + +from scipy.io.wavfile import write + +from meerkat.cells.audio import Audio as AudioCell +from meerkat.columns.deferred.base import DeferredCell +from meerkat.interactive.formatter.icon import IconFormatter + +from ..app.src.lib.component.core.audio import Audio +from .base import BaseFormatter, FormatterGroup + + +class AudioFormatter(BaseFormatter): + component_class = Audio + data_prop: str = "data" + + def __init__(self, downsampling_factor: int = 1, classes: str = ""): + """ + Args: + sampling_rate_factor: The factor by which to multiply the sampling rate. + classes: The CSS classes to apply to the audio. + """ + self.downsampling_factor = downsampling_factor + self.classes = classes + + def encode(self, cell: Union[AudioCell, bytes]) -> str: + """Encodes audio as a base64 string. + + Args: + cell: The image to encode. + skip_copy: If True, the image may be modified in place. + Set to ``True`` if the image is already a copy + or is loaded dynamically (e.g. DeferredColumn). + This may save time for large images. + """ + if isinstance(cell, DeferredCell): + cell = cell() + + # If the audio is already encoded, return it. + if isinstance(cell, bytes): + return "data:audio/mpeg;base64,{im_base_64}".format( + im_base_64=base64.b64encode(cell).decode() + ) + + arr = cell.data + sampling_rate = cell.sampling_rate + + import torch + + if isinstance(arr, torch.Tensor): + arr = arr.cpu().numpy() + + arr = arr[:: self.downsampling_factor] + sampling_rate = sampling_rate // self.downsampling_factor + + with BytesIO() as buffer: + write(buffer, sampling_rate, arr) + return "data:audio/mpeg;base64,{im_base_64}".format( + im_base_64=base64.b64encode(buffer.getvalue()).decode() + ) + + @property + def props(self) -> Dict[str, Any]: + return {"classes": self.classes} + + def html(self, cell: AudioCell) -> str: + encoded = self.encode(cell) + return ( + "" + ) + + def _get_state(self) -> Dict[str, Any]: + return {"classes": self.classes} + + def _set_state(self, state: Dict[str, Any]): + self.classes = state["classes"] + + +class AudioFormatterGroup(FormatterGroup): + formatter_class: type = AudioFormatter + + def __init__(self, classes: str = ""): + super().__init__( + icon=IconFormatter(name="Soundwave"), + base=self.formatter_class(), + ) + + +class DeferredAudioFormatter(AudioFormatter): + component_class: type = Audio + data_prop: str = "data" + + def encode(self, audio: DeferredCell) -> str: + if hasattr(audio, "absolute_path"): + absolute_path = audio.absolute_path + if isinstance(absolute_path, os.PathLike): + absolute_path = str(absolute_path) + if isinstance(absolute_path, str) and absolute_path.startswith("http"): + return audio.absolute_path + + audio = audio() + return super().encode(audio) + + +class DeferredAudioFormatterGroup(AudioFormatterGroup): + formatter_class: type = DeferredAudioFormatter diff --git a/meerkat/tools/utils.py b/meerkat/tools/utils.py index 4683d2c97..56752e5d4 100644 --- a/meerkat/tools/utils.py +++ b/meerkat/tools/utils.py @@ -15,6 +15,7 @@ import yaml from yaml.constructor import ConstructorError +from meerkat import env from meerkat.tools.lazy_loader import LazyLoader torch = LazyLoader("torch") @@ -114,6 +115,29 @@ def new_func(*args, **kwargs): return _decorator +def requires(*packages): + """Use this decorator to identify which packages must be installed. + + It will raise an error if the function is called and these packages + are not available. + """ + + def _decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + for package in packages: + fn_str = f"{func.__qualname__}()" + if not env.package_available(package): + raise ImportError( + f"Missing package `{package}` which is required for {fn_str}." + ) + return func(*args, **kwargs) + + return wrapped + + return _decorator + + class WeakMapping(Mapping): def __init__(self): self.refs: Dict[Any, weakref.ReferenceType] = {}