diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 00000000..9465ee3e
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,16 @@
+version: 2
+
+build:
+ os: "ubuntu-22.04"
+ tools:
+ python: "3.10"
+
+python:
+ install:
+ - requirements: docs/requirements.txt
+ # Install diart before building the docs
+ - method: pip
+ path: .
+
+sphinx:
+ configuration: docs/conf.py
\ No newline at end of file
diff --git a/README.md b/README.md
index f9d87b91..f9a89dbc 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,15 @@
-
+
+
+
-
+πΏ Build AI-powered real-time audio applications in a breeze πΏ
-
+
@@ -23,16 +25,16 @@
ποΈ Stream audio
|
-
- π€ Add your model
+
+ π§ Models
- |
+
- π Tune hyper-parameters
+ π Tuning
-
+ |
- π§ π Build pipelines
+ π§ π Pipelines
|
@@ -42,45 +44,69 @@
π¬ Research
- |
-
- π Citation
-
- |
-
- π¨βπ» Reproducibility
-
-
+
+
+
+## β‘ Quick introduction
+
+Diart is a python framework to build AI-powered real-time audio applications.
+Its key feature is the ability to recognize different speakers in real time with state-of-the-art performance,
+a task commonly known as "speaker diarization".
+
+The pipeline `diart.SpeakerDiarization` combines a speaker segmentation and a speaker embedding model
+to power an incremental clustering algorithm that gets more accurate as the conversation progresses:
+
+
+
+With diart you can also create your own custom AI pipeline, benchmark it,
+tune its hyper-parameters, and even serve it on the web using websockets.
+
+**We provide pre-trained pipelines for:**
+
+- Speaker Diarization
+- Voice Activity Detection
+- Transcription ([coming soon](https://github.com/juanmc2005/diart/pull/144))
+- [Speaker-Aware Transcription](https://betterprogramming.pub/color-your-captions-streamlining-live-transcriptions-with-diart-and-openais-whisper-6203350234ef) ([coming soon](https://github.com/juanmc2005/diart/pull/147))
+
## πΎ Installation
-1) Create environment:
+**1) Make sure your system has the following dependencies:**
+
+```
+ffmpeg < 4.4
+portaudio == 19.6.X
+libsndfile >= 1.2.2
+```
+
+Alternatively, we provide an `environment.yml` file for a pre-configured conda environment:
```shell
conda env create -f diart/environment.yml
conda activate diart
```
-2) Install the package:
+**2) Install the package:**
```shell
pip install diart
```
### Get access to πΉ pyannote models
-By default, diart is based on [pyannote.audio](https://github.com/pyannote/pyannote-audio) models stored in the [huggingface](https://huggingface.co/) hub.
-To allow diart to use them, you need to follow these steps:
+By default, diart is based on [pyannote.audio](https://github.com/pyannote/pyannote-audio) models from the [huggingface](https://huggingface.co/) hub.
+In order to use them, please follow these steps:
1) [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
-2) [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
-3) Install [huggingface-cli](https://huggingface.co/docs/huggingface_hub/quick-start#install-the-hub-library) and [log in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with your user access token (or provide it manually in diart CLI or API).
+2) [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the newest `pyannote/segmentation-3.0` model
+3) [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
+4) Install [huggingface-cli](https://huggingface.co/docs/huggingface_hub/quick-start#install-the-hub-library) and [log in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with your user access token (or provide it manually in diart CLI or API).
## ποΈ Stream audio
@@ -100,7 +126,8 @@ A live conversation:
diart.stream microphone
```
-See `diart.stream -h` for more options.
+By default, diart runs a speaker diarization pipeline, equivalent to setting `--pipeline SpeakerDiarization`,
+but you can also set it to `--pipeline VoiceActivityDetection`. See `diart.stream -h` for more options.
### From python
@@ -119,57 +146,78 @@ inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```
-For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#-reproducibility)).
+For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
-## π€ Add your model
+## π§ Models
-Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
+You can use other models with the `--segmentation` and `--embedding` arguments.
+Or in python:
+
+```python
+import diart.models as m
+
+segmentation = m.SegmentationModel.from_pretrained("model_name")
+embedding = m.EmbeddingModel.from_pretrained("model_name")
+```
+
+### Pre-trained models
+
+Below is a list of all the models currently supported by diart:
+
+| Model Name | Model Type | CPU Time* | GPU Time* |
+|---------------------------------------------------------------------------------------------------------------------------|--------------|-----------|-----------|
+| [π€](https://huggingface.co/pyannote/segmentation) `pyannote/segmentation` (default) | segmentation | 12ms | 8ms |
+| [π€](https://huggingface.co/pyannote/segmentation-3.0) `pyannote/segmentation-3.0` | segmentation | 11ms | 8ms |
+| [π€](https://huggingface.co/pyannote/embedding) `pyannote/embedding` (default) | embedding | 26ms | 12ms |
+| [π€](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM) `hbredin/wespeaker-voxceleb-resnet34-LM` (ONNX) | embedding | 48ms | 15ms |
+| [π€](https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM) `pyannote/wespeaker-voxceleb-resnet34-LM` (PyTorch) | embedding | 150ms | 29ms |
+| [π€](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) `speechbrain/spkrec-xvect-voxceleb` | embedding | 41ms | 15ms |
+| [π€](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb) `speechbrain/spkrec-ecapa-voxceleb` | embedding | 41ms | 14ms |
+| [π€](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb-mel-spec) `speechbrain/spkrec-ecapa-voxceleb-mel-spec` | embedding | 42ms | 14ms |
+| [π€](https://huggingface.co/speechbrain/spkrec-resnet-voxceleb) `speechbrain/spkrec-resnet-voxceleb` | embedding | 41ms | 16ms |
+| [π€](https://huggingface.co/nvidia/speakerverification_en_titanet_large) `nvidia/speakerverification_en_titanet_large` | embedding | 91ms | 16ms |
+
+The latency of segmentation models is measured in a VAD pipeline (5s chunks).
+
+The latency of embedding models is measured in a diarization pipeline using `pyannote/segmentation` (also 5s chunks).
+
+\* CPU: AMD Ryzen 9 - GPU: RTX 4060 Max-Q
+
+### Custom models
+
+Third-party models can be integrated by providing a loader function:
```python
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
-from diart.sources import MicrophoneAudioSource
-from diart.inference import StreamingInference
-
-def model_loader():
+def segmentation_loader():
+ # It should take a waveform and return a segmentation tensor
return load_pretrained_model("my_model.ckpt")
+def embedding_loader():
+ # It should take (waveform, weights) and return per-speaker embeddings
+ return load_pretrained_model("my_other_model.ckpt")
-class MySegmentationModel(SegmentationModel):
- def __init__(self):
- super().__init__(model_loader)
-
- @property
- def sample_rate(self) -> int:
- return 16000
-
- @property
- def duration(self) -> float:
- return 2 # seconds
-
- def forward(self, waveform):
- # self.model is created lazily
- return self.model(waveform)
-
-
-class MyEmbeddingModel(EmbeddingModel):
- def __init__(self):
- super().__init__(model_loader)
-
- def forward(self, waveform, weights):
- # self.model is created lazily
- return self.model(waveform, weights)
-
-
+segmentation = SegmentationModel(segmentation_loader)
+embedding = EmbeddingModel(embedding_loader)
config = SpeakerDiarizationConfig(
- segmentation=MySegmentationModel(),
- embedding=MyEmbeddingModel()
+ segmentation=segmentation,
+ embedding=embedding,
)
pipeline = SpeakerDiarization(config)
-mic = MicrophoneAudioSource()
-inference = StreamingInference(pipeline, mic)
-prediction = inference()
+```
+
+If you have an ONNX model, you can use `from_onnx()`:
+
+```python
+from diart.models import EmbeddingModel
+
+embedding = EmbeddingModel.from_onnx(
+ model_path="my_model.ckpt",
+ input_names=["x", "w"], # defaults to ["waveform", "weights"]
+ output_name="output", # defaults to "embedding"
+)
```
## π Tune hyper-parameters
@@ -195,7 +243,7 @@ optimizer(num_iter=100)
This will write results to an sqlite database in `/output/dir`.
-### Distributed optimization
+### Distributed tuning
For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match:
@@ -239,8 +287,8 @@ import diart.operators as dops
from diart.sources import MicrophoneAudioSource
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
-segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
-embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
+segmentation = SpeakerSegmentation.from_pretrained("pyannote/segmentation")
+embedding = OverlapAwareSpeakerEmbedding.from_pretrained("pyannote/embedding")
mic = MicrophoneAudioSource()
stream = mic.stream.pipe(
@@ -269,7 +317,7 @@ Diart is also compatible with the WebSocket protocol to serve pipelines on the w
### From the command line
-```commandline
+```shell
diart.serve --host 0.0.0.0 --port 7007
diart.client microphone --host --port 7007
```
@@ -296,16 +344,21 @@ prediction = inference()
## π¬ Powered by research
-Diart is the official implementation of the paper *[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](/paper.pdf)* by [Juan Manuel Coria](https://juanmc2005.github.io/), [HervΓ© Bredin](https://herve.niderb.fr), [Sahar Ghannay](https://saharghannay.github.io/) and [Sophie Rosset](https://perso.limsi.fr/rosset/).
+Diart is the official implementation of the paper
+[Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](https://github.com/juanmc2005/diart/blob/main/paper.pdf)
+by [Juan Manuel Coria](https://juanmc2005.github.io/),
+[HervΓ© Bredin](https://herve.niderb.fr),
+[Sahar Ghannay](https://saharghannay.github.io/)
+and [Sophie Rosset](https://perso.limsi.fr/rosset/).
> We propose to address online speaker diarization as a combination of incremental clustering and local diarization applied to a rolling buffer updated every 500ms. Every single step of the proposed pipeline is designed to take full advantage of the strong ability of a recently proposed end-to-end overlap-aware segmentation to detect and separate overlapping speakers. In particular, we propose a modified version of the statistics pooling layer (initially introduced in the x-vector architecture) to give less weight to frames where the segmentation model predicts simultaneous speakers. Furthermore, we derive cannot-link constraints from the initial segmentation step to prevent two local speakers from being wrongfully merged during the incremental clustering step. Finally, we show how the latency of the proposed approach can be adjusted between 500ms and 5s to match the requirements of a particular use case, and we provide a systematic analysis of the influence of latency on the overall performance (on AMI, DIHARD and VoxConverse).
-
+
-## π Citation
+### Citation
If you found diart useful, please make sure to cite our paper:
@@ -320,9 +373,12 @@ If you found diart useful, please make sure to cite our paper:
}
```
-## π¨βπ» Reproducibility
+### Reproducibility
+
+![Results table](https://github.com/juanmc2005/diart/blob/main/table1.png?raw=true)
-![Results table](/table1.png)
+**Important:** We highly recommend installing `pyannote.audio<3.1` to reproduce these results.
+For more information, see [this issue](https://github.com/juanmc2005/diart/issues/214).
Diart aims to be lightweight and capable of real-time streaming in practical scenarios.
Its performance is very close to what is reported in the paper (and sometimes even a bit better).
@@ -352,11 +408,11 @@ from diart.models import SegmentationModel
benchmark = Benchmark("/wav/dir", "/rttm/dir")
-name = "pyannote/segmentation@Interspeech2021"
-segmentation = SegmentationModel.from_pyannote(name)
+model_name = "pyannote/segmentation@Interspeech2021"
+model = SegmentationModel.from_pretrained(model_name)
config = SpeakerDiarizationConfig(
- # Set the model used in the paper
- segmentation=segmentation,
+ # Set the segmentation model used in the paper
+ segmentation=model,
step=0.5,
latency=0.5,
tau_active=0.555,
@@ -374,9 +430,13 @@ if __name__ == "__main__": # Needed for multiprocessing
This pre-calculates model outputs in batches, so it runs a lot faster.
See `diart.benchmark -h` for more options.
-For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
+For convenience and to facilitate future comparisons, we also provide the
+expected outputs
+of the paper implementation in RTTM format for every entry of Table 1 and Figure 5.
+This includes the VBx offline topline as well as our proposed online approach with
+latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
-![Figure 5](/figure5.png)
+![Figure 5](https://github.com/juanmc2005/diart/blob/main/figure5.png?raw=true)
## π License
@@ -405,4 +465,4 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
-Logo generated by DesignEvo free logo designer
+Logo generated by DesignEvo free logo designer
diff --git a/demo.gif b/demo.gif
index 1d3496a6..a82ab116 100644
Binary files a/demo.gif and b/demo.gif differ
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..d4bb2cbb
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/_static/logo.png b/docs/_static/logo.png
new file mode 100644
index 00000000..2cba5e7b
Binary files /dev/null and b/docs/_static/logo.png differ
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 00000000..1c291773
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,65 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+
+project = "diart"
+copyright = "2023, Juan Manuel Coria"
+author = "Juan Manuel Coria"
+release = "v0.9"
+
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+
+extensions = [
+ "autoapi.extension",
+ "sphinx.ext.coverage",
+ "sphinx.ext.napoleon",
+ "sphinx_mdinclude",
+]
+
+autoapi_dirs = ["../src/diart"]
+autoapi_options = [
+ "members",
+ "undoc-members",
+ "show-inheritance",
+ "show-module-summary",
+ "special-members",
+ "imported-members",
+]
+
+templates_path = ["_templates"]
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+# -- Options for autodoc ----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration
+
+# Automatically extract typehints when specified and place them in
+# descriptions of the relevant function/method.
+autodoc_typehints = "description"
+
+# Don't show class signature with the class' name.
+autodoc_class_signature = "separated"
+
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
+
+html_theme = "furo"
+html_static_path = ["_static"]
+html_logo = "_static/logo.png"
+html_title = "diart documentation"
+
+
+def skip_submodules(app, what, name, obj, skip, options):
+ return (
+ name.endswith("__init__")
+ or name.startswith("diart.console")
+ or name.startswith("diart.argdoc")
+ )
+
+
+def setup(sphinx):
+ sphinx.connect("autoapi-skip-member", skip_submodules)
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 00000000..c1ff218b
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,11 @@
+Get started with diart
+======================
+
+.. mdinclude:: ../README.md
+
+
+Useful Links
+============
+
+.. toctree::
+ :maxdepth: 1
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..32bb2452
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 00000000..44f4f63f
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,4 @@
+sphinx==6.2.1
+sphinx-autoapi==3.0.0
+sphinx-mdinclude==0.5.3
+furo==2023.9.10
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
index f62b4274..5ca22f0b 100644
--- a/environment.yml
+++ b/environment.yml
@@ -3,7 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- - python=3.8
+ - python=3.10
- portaudio=19.6.*
- pysoundfile=0.12.*
- ffmpeg[version='<4.4']
diff --git a/logo.jpg b/logo.jpg
index b115d96a..317c129e 100644
Binary files a/logo.jpg and b/logo.jpg differ
diff --git a/pipeline.gif b/pipeline.gif
new file mode 100644
index 00000000..44107df0
Binary files /dev/null and b/pipeline.gif differ
diff --git a/requirements.txt b/requirements.txt
index e0d93213..2d3e4611 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,6 +10,7 @@ torch>=1.12.1
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
+requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
diff --git a/setup.cfg b/setup.cfg
index f38a612e..66b255f2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,8 +1,8 @@
[metadata]
name=diart
-version=0.8.0
+version=0.9.0
author=Juan Manuel Coria
-description=Streaming speaker diarization in real-time
+description=A python framework to build AI for real-time speech
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
@@ -32,6 +32,7 @@ install_requires=
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
+ requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py
index e89caa28..d9d92b9b 100644
--- a/src/diart/argdoc.py
+++ b/src/diart/argdoc.py
@@ -15,3 +15,4 @@
OUTPUT = "Directory to store the system's output in RTTM format"
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login"
SAMPLE_RATE = "Sample rate of the audio stream"
+NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like Nvidia's NeMo or ECAPA-TDNN"
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index f6ca3a33..c0bed63a 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -11,12 +11,28 @@
@dataclass
class HyperParameter:
+ """Represents a pipeline hyper-parameter that can be tuned by diart"""
+
name: Text
+ """Name of the hyper-parameter (e.g. tau_active)"""
low: float
+ """Lowest value that this parameter can take"""
high: float
+ """Highest value that this parameter can take"""
@staticmethod
def from_name(name: Text) -> "HyperParameter":
+ """Create a HyperParameter object given its name.
+
+ Parameters
+ ----------
+ name: str
+ Name of the hyper-parameter
+
+ Returns
+ -------
+ HyperParameter
+ """
if name == "tau_active":
return TauActive
if name == "rho_update":
@@ -32,24 +48,34 @@ def from_name(name: Text) -> "HyperParameter":
class PipelineConfig(ABC):
+ """Configuration containing the required
+ parameters to build and run a pipeline"""
+
@property
@abstractmethod
def duration(self) -> float:
+ """The duration of an input audio chunk (in seconds)"""
pass
@property
@abstractmethod
def step(self) -> float:
+ """The step between two consecutive input audio chunks (in seconds)"""
pass
@property
@abstractmethod
def latency(self) -> float:
+ """The algorithmic latency of the pipeline (in seconds).
+ At time `t` of the audio stream, the pipeline will
+ output predictions for time `t - latency`.
+ """
pass
@property
@abstractmethod
def sample_rate(self) -> int:
+ """The sample rate of the input audio stream"""
pass
def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
@@ -60,6 +86,8 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
class Pipeline(ABC):
+ """Represents a streaming audio pipeline"""
+
@staticmethod
@abstractmethod
def get_config_class() -> type:
@@ -92,4 +120,18 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
+ """Runs the next steps of the pipeline
+ given a list of consecutive audio chunks.
+
+ Parameters
+ ----------
+ waveforms: Sequence[SlidingWindowFeature]
+ Consecutive chunk waveforms for the pipeline to ingest
+
+ Returns
+ -------
+ Sequence[Tuple[Any, SlidingWindowFeature]]
+ For each input waveform, a tuple containing
+ the pipeline output and its respective audio
+ """
pass
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index b7217c0a..860c1395 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -140,6 +140,10 @@ def identify(
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
0
]
+ # Remove speakers that have NaN embeddings
+ no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
+ active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)
+
num_local_speakers = segmentation.data.shape[1]
if self.centers is None:
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index fab83c36..151b4d36 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -23,7 +23,7 @@ def __init__(
self,
segmentation: m.SegmentationModel | None = None,
embedding: m.EmbeddingModel | None = None,
- duration: float | None = None,
+ duration: float = 5,
step: float = 0.5,
latency: float | Literal["max", "min"] | None = None,
tau_active: float = 0.6,
@@ -32,7 +32,9 @@ def __init__(
gamma: float = 3,
beta: float = 10,
max_speakers: int = 20,
+ normalize_embedding_weights: bool = False,
device: torch.device | None = None,
+ sample_rate: int = 16000,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
@@ -46,7 +48,7 @@ def __init__(
)
self._duration = duration
- self._sample_rate: int | None = None
+ self._sample_rate = sample_rate
# Latency defaults to the step duration
self._step = step
@@ -62,16 +64,13 @@ def __init__(
self.gamma = gamma
self.beta = beta
self.max_speakers = max_speakers
-
+ self.normalize_embedding_weights = normalize_embedding_weights
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
@property
def duration(self) -> float:
- # Default duration is the one given by the segmentation model
- if self._duration is None:
- self._duration = self.segmentation.duration
return self._duration
@property
@@ -84,9 +83,6 @@ def latency(self) -> float:
@property
def sample_rate(self) -> int:
- # Expected sample rate is given by the segmentation model
- if self._sample_rate is None:
- self._sample_rate = self.segmentation.sample_rate
return self._sample_rate
@@ -105,6 +101,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None):
self._config.gamma,
self._config.beta,
norm=1,
+ normalize_weights=self._config.normalize_embedding_weights,
device=self._config.device,
)
self.pred_aggregation = DelayedAggregation(
@@ -160,6 +157,18 @@ def reset(self):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
+ """Diarize the next audio chunks of an audio stream.
+
+ Parameters
+ ----------
+ waveforms: Sequence[SlidingWindowFeature]
+ A sequence of consecutive audio chunks from an audio stream.
+
+ Returns
+ -------
+ Sequence[tuple[Annotation, SlidingWindowFeature]]
+ Speaker diarization of each chunk alongside their corresponding audio.
+ """
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
@@ -175,9 +184,8 @@ def __call__(
# Extract segmentation and embeddings
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
- embeddings = self.embedding(
- batch, segmentations
- ) # shape (batch, speakers, emb_dim)
+ # embeddings has shape (batch, speakers, emb_dim)
+ embeddings = self.embedding(batch, segmentations)
seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py
index 5cd7c39e..c6f72cd8 100644
--- a/src/diart/blocks/embedding.py
+++ b/src/diart/blocks/embedding.py
@@ -3,6 +3,7 @@
import torch
from einops import rearrange
+from .. import functional as F
from ..features import TemporalFeatures, TemporalFeatureFormatter
from ..models import EmbeddingModel
@@ -19,12 +20,12 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None)
self.weights_formatter = TemporalFeatureFormatter()
@staticmethod
- def from_pyannote(
+ def from_pretrained(
model,
use_hf_token: Union[Text, bool, None] = True,
device: Optional[torch.device] = None,
) -> "SpeakerEmbedding":
- emb_model = EmbeddingModel.from_pyannote(model, use_hf_token)
+ emb_model = EmbeddingModel.from_pretrained(model, use_hf_token)
return SpeakerEmbedding(emb_model, device)
def __call__(
@@ -68,7 +69,13 @@ def __call__(
class OverlappedSpeechPenalty:
- """
+ """Applies a penalty on overlapping speech and low-confidence regions to speaker segmentation scores.
+
+ .. note::
+ For more information, see `"Overlap-Aware Low-Latency Online Speaker Diarization
+ based on End-to-End Local Segmentation" `_
+ (Section 2.2.1 Segmentation-driven speaker embedding). This block implements Equation 2.
+
Parameters
----------
gamma: float, optional
@@ -77,19 +84,26 @@ class OverlappedSpeechPenalty:
beta: float, optional
Temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
+ normalize: bool, optional
+ Whether to min-max normalize weights to be in the range [0, 1].
+ Defaults to False.
"""
- def __init__(self, gamma: float = 3, beta: float = 10):
+ def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False):
self.gamma = gamma
self.beta = beta
self.formatter = TemporalFeatureFormatter()
+ self.normalize = normalize
def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
- with torch.no_grad():
- probs = torch.softmax(self.beta * weights, dim=-1)
- weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
- weights[weights < 1e-8] = 1e-8
+ with torch.inference_mode():
+ weights = F.overlapped_speech_penalty(weights, self.gamma, self.beta)
+ if self.normalize:
+ min_values = weights.min(dim=1, keepdim=True).values
+ max_values = weights.max(dim=1, keepdim=True).values
+ weights = (weights - min_values) / (max_values - min_values)
+ weights.nan_to_num_(1e-8)
return self.formatter.restore_type(weights)
@@ -101,19 +115,8 @@ def __init__(self, norm: Union[float, torch.Tensor] = 1):
self.norm = self.norm.unsqueeze(0)
def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
- # Add batch dimension if missing
- if embeddings.ndim == 2:
- embeddings = embeddings.unsqueeze(0)
- if isinstance(self.norm, torch.Tensor):
- batch_size1, num_speakers1, _ = self.norm.shape
- batch_size2, num_speakers2, _ = embeddings.shape
- assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
- with torch.no_grad():
- norm_embs = (
- self.norm
- * embeddings
- / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
- )
+ with torch.inference_mode():
+ norm_embs = F.normalize_embeddings(embeddings, self.norm)
return norm_embs
@@ -134,6 +137,8 @@ class OverlapAwareSpeakerEmbedding:
norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
The target norm for the embeddings. It can be different for each speaker.
Defaults to 1.
+ normalize_weights: bool, optional
+ Whether to min-max normalize embedding weights to be in the range [0, 1].
device: Optional[torch.device]
The device on which to run the embedding model.
Defaults to GPU if available or CPU if not.
@@ -145,23 +150,27 @@ def __init__(
gamma: float = 3,
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
+ normalize_weights: bool = False,
device: Optional[torch.device] = None,
):
self.embedding = SpeakerEmbedding(model, device)
- self.osp = OverlappedSpeechPenalty(gamma, beta)
+ self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights)
self.normalize = EmbeddingNormalization(norm)
@staticmethod
- def from_pyannote(
+ def from_pretrained(
model,
gamma: float = 3,
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
use_hf_token: Union[Text, bool, None] = True,
+ normalize_weights: bool = False,
device: Optional[torch.device] = None,
):
- model = EmbeddingModel.from_pyannote(model, use_hf_token)
- return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)
+ model = EmbeddingModel.from_pretrained(model, use_hf_token)
+ return OverlapAwareSpeakerEmbedding(
+ model, gamma, beta, norm, normalize_weights, device
+ )
def __call__(
self, waveform: TemporalFeatures, segmentation: TemporalFeatures
diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py
index e946c748..16915639 100644
--- a/src/diart/blocks/segmentation.py
+++ b/src/diart/blocks/segmentation.py
@@ -18,12 +18,12 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No
self.formatter = TemporalFeatureFormatter()
@staticmethod
- def from_pyannote(
+ def from_pretrained(
model,
use_hf_token: Union[Text, bool, None] = True,
device: Optional[torch.device] = None,
) -> "SpeakerSegmentation":
- seg_model = SegmentationModel.from_pyannote(model, use_hf_token)
+ seg_model = SegmentationModel.from_pretrained(model, use_hf_token)
return SpeakerSegmentation(seg_model, device)
def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
index 0edd3e0b..f299c94b 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/blocks/vad.py
@@ -27,11 +27,12 @@ class VoiceActivityDetectionConfig(base.PipelineConfig):
def __init__(
self,
segmentation: m.SegmentationModel | None = None,
- duration: float | None = None,
+ duration: float = 5,
step: float = 0.5,
latency: float | Literal["max", "min"] | None = None,
tau_active: float = 0.6,
device: torch.device | None = None,
+ sample_rate: int = 16000,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
@@ -41,7 +42,7 @@ def __init__(
self._duration = duration
self._step = step
- self._sample_rate: int | None = None
+ self._sample_rate = sample_rate
# Latency defaults to the step duration
self._latency = latency
@@ -57,9 +58,6 @@ def __init__(
@property
def duration(self) -> float:
- # Default duration is the one given by the segmentation model
- if self._duration is None:
- self._duration = self.segmentation.duration
return self._duration
@property
@@ -72,9 +70,6 @@ def latency(self) -> float:
@property
def sample_rate(self) -> int:
- # Expected sample rate is given by the segmentation model
- if self._sample_rate is None:
- self._sample_rate = self.segmentation.sample_rate
return self._sample_rate
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index b5a296d1..f5dccd0a 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -43,6 +43,7 @@ def run():
parser.add_argument(
"--duration",
type=float,
+ default=5,
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
)
parser.add_argument(
@@ -99,6 +100,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
+ parser.add_argument(
+ "--normalize-embedding-weights",
+ action="store_true",
+ help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
+ )
args = parser.parse_args()
# Resolve device
@@ -106,8 +112,8 @@ def run():
# Resolve models
hf_token = utils.parse_hf_token_arg(args.hf_token)
- args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
- args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+ args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
pipeline_class = utils.get_pipeline_class(args.pipeline)
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index bc002e42..e52980dd 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -36,6 +36,7 @@ def run():
parser.add_argument(
"--duration",
type=float,
+ default=5,
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
)
parser.add_argument(
@@ -80,6 +81,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
+ parser.add_argument(
+ "--normalize-embedding-weights",
+ action="store_true",
+ help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
+ )
args = parser.parse_args()
# Resolve device
@@ -87,8 +93,8 @@ def run():
# Resolve models
hf_token = utils.parse_hf_token_arg(args.hf_token)
- args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
- args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+ args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
# Resolve pipeline
pipeline_class = utils.get_pipeline_class(args.pipeline)
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index 713f3e99..32ea8761 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -39,6 +39,7 @@ def run():
parser.add_argument(
"--duration",
type=float,
+ default=5,
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
)
parser.add_argument(
@@ -91,6 +92,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
+ parser.add_argument(
+ "--normalize-embedding-weights",
+ action="store_true",
+ help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
+ )
args = parser.parse_args()
# Resolve device
@@ -98,8 +104,8 @@ def run():
# Resolve models
hf_token = utils.parse_hf_token_arg(args.hf_token)
- args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
- args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+ args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
# Resolve pipeline
pipeline_class = utils.get_pipeline_class(args.pipeline)
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index ec243348..ba9ac7e9 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -46,6 +46,7 @@ def run():
parser.add_argument(
"--duration",
type=float,
+ default=5,
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
)
parser.add_argument(
@@ -108,6 +109,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
+ parser.add_argument(
+ "--normalize-embedding-weights",
+ action="store_true",
+ help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
+ )
args = parser.parse_args()
# Resolve device
@@ -115,8 +121,8 @@ def run():
# Resolve models
hf_token = utils.parse_hf_token_arg(args.hf_token)
- args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
- args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
+ args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
+ args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
# Retrieve pipeline class
pipeline_class = utils.get_pipeline_class(args.pipeline)
diff --git a/src/diart/functional.py b/src/diart/functional.py
new file mode 100644
index 00000000..af15f8d1
--- /dev/null
+++ b/src/diart/functional.py
@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+import torch
+
+
+def overlapped_speech_penalty(
+ segmentation: torch.Tensor, gamma: float = 3, beta: float = 10
+):
+ # segmentation has shape (batch, frames, speakers)
+ probs = torch.softmax(beta * segmentation, dim=-1)
+ weights = torch.pow(segmentation, gamma) * torch.pow(probs, gamma)
+ weights[weights < 1e-8] = 1e-8
+ return weights
+
+
+def normalize_embeddings(
+ embeddings: torch.Tensor, norm: float | torch.Tensor = 1
+) -> torch.Tensor:
+ # embeddings has shape (batch, speakers, feat) or (speakers, feat)
+ if embeddings.ndim == 2:
+ embeddings = embeddings.unsqueeze(0)
+ if isinstance(norm, torch.Tensor):
+ batch_size1, num_speakers1, _ = norm.shape
+ batch_size2, num_speakers2, _ = embeddings.shape
+ assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
+ emb_norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
+ return norm * embeddings / emb_norm
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 3eb72930..a09593d3 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -524,7 +524,12 @@ def __call__(
num_audio_files = len(audio_file_paths)
# Workaround for multiprocessing with GPU
- torch.multiprocessing.set_start_method("spawn")
+ try:
+ torch.multiprocessing.set_start_method("spawn")
+ except RuntimeError:
+ # This may fail if the start method was set before
+ pass
+
# For Windows support
freeze_support()
diff --git a/src/diart/models.py b/src/diart/models.py
index 5577a097..3a479e4c 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -1,15 +1,42 @@
-from abc import ABC, abstractmethod
-from typing import Optional, Text, Union, Callable
+from __future__ import annotations
+from abc import ABC
+from pathlib import Path
+from typing import Optional, Text, Union, Callable, List
+
+import numpy as np
import torch
import torch.nn as nn
+from requests import HTTPError
try:
- import pyannote.audio.pipelines.utils as pyannote_loader
+ from pyannote.audio import Model
+ from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
+ from pyannote.audio.utils.powerset import Powerset
- _has_pyannote = True
+ IS_PYANNOTE_AVAILABLE = True
except ImportError:
- _has_pyannote = False
+ IS_PYANNOTE_AVAILABLE = False
+
+try:
+ import onnxruntime as ort
+
+ IS_ONNX_AVAILABLE = True
+except ImportError:
+ IS_ONNX_AVAILABLE = False
+
+
+class PowersetAdapter(nn.Module):
+ def __init__(self, segmentation_model: nn.Module):
+ super().__init__()
+ self.model = segmentation_model
+ specs = self.model.specifications
+ max_speakers_per_frame = specs.powerset_max_classes
+ max_speakers_per_chunk = len(specs.classes)
+ self.powerset = Powerset(max_speakers_per_chunk, max_speakers_per_frame)
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ return self.powerset.to_multilabel(self.model(waveform))
class PyannoteLoader:
@@ -18,15 +45,75 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
self.model_info = model_info
self.hf_token = hf_token
- def __call__(self) -> nn.Module:
- return pyannote_loader.get_model(self.model_info, self.hf_token)
+ def __call__(self) -> Callable:
+ try:
+ model = Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
+ specs = getattr(model, "specifications", None)
+ if specs is not None and specs.powerset:
+ model = PowersetAdapter(model)
+ return model
+ except HTTPError:
+ pass
+ except ModuleNotFoundError:
+ pass
+ return PretrainedSpeakerEmbedding(self.model_info, use_auth_token=self.hf_token)
+
+
+class ONNXLoader:
+ def __init__(self, path: str | Path, input_names: List[str], output_name: str):
+ super().__init__()
+ self.path = Path(path)
+ self.input_names = input_names
+ self.output_name = output_name
+
+ def __call__(self) -> ONNXModel:
+ return ONNXModel(self.path, self.input_names, self.output_name)
+
+class ONNXModel:
+ def __init__(self, path: Path, input_names: List[str], output_name: str):
+ super().__init__()
+ self.path = path
+ self.input_names = input_names
+ self.output_name = output_name
+ self.device = torch.device("cpu")
+ self.session = None
+ self.recreate_session()
-class LazyModel(nn.Module, ABC):
- def __init__(self, loader: Callable[[], nn.Module]):
+ @property
+ def execution_provider(self) -> str:
+ device = "CUDA" if self.device.type == "cuda" else "CPU"
+ return f"{device}ExecutionProvider"
+
+ def recreate_session(self):
+ options = ort.SessionOptions()
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+ self.session = ort.InferenceSession(
+ self.path,
+ sess_options=options,
+ providers=[self.execution_provider],
+ )
+
+ def to(self, device: torch.device) -> ONNXModel:
+ if device.type != self.device.type:
+ self.device = device
+ self.recreate_session()
+ return self
+
+ def __call__(self, *args) -> torch.Tensor:
+ inputs = {
+ name: arg.cpu().numpy().astype(np.float32)
+ for name, arg in zip(self.input_names, args)
+ }
+ output = self.session.run([self.output_name], inputs)[0]
+ return torch.from_numpy(output).float().to(args[0].device)
+
+
+class LazyModel(ABC):
+ def __init__(self, loader: Callable[[], Callable]):
super().__init__()
self.get_model = loader
- self.model: Optional[nn.Module] = None
+ self.model: Optional[Callable] = None
def is_in_memory(self) -> bool:
"""Return whether the model has been loaded into memory"""
@@ -36,13 +123,20 @@ def load(self):
if not self.is_in_memory():
self.model = self.get_model()
- def to(self, *args, **kwargs) -> nn.Module:
+ def to(self, device: torch.device) -> LazyModel:
self.load()
- return super().to(*args, **kwargs)
+ self.model = self.model.to(device)
+ return self
def __call__(self, *args, **kwargs):
self.load()
- return super().__call__(*args, **kwargs)
+ return self.model(*args, **kwargs)
+
+ def eval(self) -> LazyModel:
+ self.load()
+ if isinstance(self.model, nn.Module):
+ self.model.eval()
+ return self
class SegmentationModel(LazyModel):
@@ -70,51 +164,38 @@ def from_pyannote(
-------
wrapper: SegmentationModel
"""
- assert _has_pyannote, "No pyannote.audio installation found"
- return PyannoteSegmentationModel(model, use_hf_token)
+ assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
+ return SegmentationModel(PyannoteLoader(model, use_hf_token))
- @property
- @abstractmethod
- def sample_rate(self) -> int:
- pass
+ @staticmethod
+ def from_onnx(
+ model_path: Union[str, Path],
+ input_name: str = "waveform",
+ output_name: str = "segmentation",
+ ) -> "SegmentationModel":
+ assert IS_ONNX_AVAILABLE, "No ONNX installation found"
+ return SegmentationModel(ONNXLoader(model_path, [input_name], output_name))
- @property
- @abstractmethod
- def duration(self) -> float:
- pass
+ @staticmethod
+ def from_pretrained(
+ model, use_hf_token: Union[Text, bool, None] = True
+ ) -> "SegmentationModel":
+ if isinstance(model, str) or isinstance(model, Path):
+ if Path(model).name.endswith(".onnx"):
+ return SegmentationModel.from_onnx(model)
+ return SegmentationModel.from_pyannote(model, use_hf_token)
- @abstractmethod
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
"""
- Forward pass of the segmentation model.
-
+ Call the forward pass of the segmentation model.
Parameters
----------
waveform: torch.Tensor, shape (batch, channels, samples)
-
Returns
-------
speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
"""
- pass
-
-
-class PyannoteSegmentationModel(SegmentationModel):
- def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
- super().__init__(PyannoteLoader(model_info, hf_token))
-
- @property
- def sample_rate(self) -> int:
- self.load()
- return self.model.audio.sample_rate
-
- @property
- def duration(self) -> float:
- self.load()
- return self.model.specifications.duration
-
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
- return self.model(waveform)
+ return super().__call__(waveform)
class EmbeddingModel(LazyModel):
@@ -140,36 +221,45 @@ def from_pyannote(
-------
wrapper: EmbeddingModel
"""
- assert _has_pyannote, "No pyannote.audio installation found"
- return PyannoteEmbeddingModel(model, use_hf_token)
+ assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
+ loader = PyannoteLoader(model, use_hf_token)
+ return EmbeddingModel(loader)
+
+ @staticmethod
+ def from_onnx(
+ model_path: Union[str, Path],
+ input_names: List[str] | None = None,
+ output_name: str = "embedding",
+ ) -> "EmbeddingModel":
+ assert IS_ONNX_AVAILABLE, "No ONNX installation found"
+ input_names = input_names or ["waveform", "weights"]
+ loader = ONNXLoader(model_path, input_names, output_name)
+ return EmbeddingModel(loader)
- @abstractmethod
- def forward(
+ @staticmethod
+ def from_pretrained(
+ model, use_hf_token: Union[Text, bool, None] = True
+ ) -> "EmbeddingModel":
+ if isinstance(model, str) or isinstance(model, Path):
+ if Path(model).name.endswith(".onnx"):
+ return EmbeddingModel.from_onnx(model)
+ return EmbeddingModel.from_pyannote(model, use_hf_token)
+
+ def __call__(
self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
- Forward pass of an embedding model with optional weights.
-
+ Call the forward pass of an embedding model with optional weights.
Parameters
----------
waveform: torch.Tensor, shape (batch, channels, samples)
weights: Optional[torch.Tensor], shape (batch, frames)
Temporal weights for each sample in the batch. Defaults to no weights.
-
Returns
-------
speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
"""
- pass
-
-
-class PyannoteEmbeddingModel(EmbeddingModel):
- def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
- super().__init__(PyannoteLoader(model_info, hf_token))
-
- def forward(
- self,
- waveform: torch.Tensor,
- weights: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- return self.model(waveform, weights=weights)
+ embeddings = super().__call__(waveform, weights)
+ if isinstance(embeddings, np.ndarray):
+ embeddings = torch.from_numpy(embeddings)
+ return embeddings
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index ed4e2ea0..1ae3adf9 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -8,8 +8,6 @@
from rx.core import Observer
from typing_extensions import Literal
-from . import utils
-
class WindowClosedException(Exception):
pass