Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.9 #217

Merged
merged 10 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
version: 2

build:
os: "ubuntu-22.04"
tools:
python: "3.10"

python:
install:
- requirements: docs/requirements.txt
# Install diart before building the docs
- method: pip
path: .

sphinx:
configuration: docs/conf.py
218 changes: 139 additions & 79 deletions README.md

Large diffs are not rendered by default.

Binary file modified demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Binary file added docs/_static/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = "diart"
copyright = "2023, Juan Manuel Coria"
author = "Juan Manuel Coria"
release = "v0.9"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"autoapi.extension",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx_mdinclude",
]

autoapi_dirs = ["../src/diart"]
autoapi_options = [
"members",
"undoc-members",
"show-inheritance",
"show-module-summary",
"special-members",
"imported-members",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# -- Options for autodoc ----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration

# Automatically extract typehints when specified and place them in
# descriptions of the relevant function/method.
autodoc_typehints = "description"

# Don't show class signature with the class' name.
autodoc_class_signature = "separated"

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = "furo"
html_static_path = ["_static"]
html_logo = "_static/logo.png"
html_title = "diart documentation"


def skip_submodules(app, what, name, obj, skip, options):
return (
name.endswith("__init__")
or name.startswith("diart.console")
or name.startswith("diart.argdoc")
)


def setup(sphinx):
sphinx.connect("autoapi-skip-member", skip_submodules)
11 changes: 11 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Get started with diart
======================

.. mdinclude:: ../README.md


Useful Links
============

.. toctree::
:maxdepth: 1
35 changes: 35 additions & 0 deletions docs/make.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)

if "%1" == "" goto help

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
4 changes: 4 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
sphinx==6.2.1
sphinx-autoapi==3.0.0
sphinx-mdinclude==0.5.3
furo==2023.9.10
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- python=3.10
- portaudio=19.6.*
- pysoundfile=0.12.*
- ffmpeg[version='<4.4']
Expand Down
Binary file modified logo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pipeline.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch>=1.12.1
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[metadata]
name=diart
version=0.8.0
version=0.9.0
author=Juan Manuel Coria
description=Streaming speaker diarization in real-time
description=A python framework to build AI for real-time speech
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
Expand Down Expand Up @@ -32,6 +32,7 @@ install_requires=
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
1 change: 1 addition & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
OUTPUT = "Directory to store the system's output in RTTM format"
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | <token>). If 'true', it will use the token from huggingface-cli login"
SAMPLE_RATE = "Sample rate of the audio stream"
NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like WeSpeaker or ECAPA-TDNN"
juanmc2005 marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 42 additions & 0 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@

@dataclass
class HyperParameter:
"""Represents a pipeline hyper-parameter that can be tuned by diart"""

name: Text
"""Name of the hyper-parameter (e.g. tau_active)"""
low: float
"""Lowest value that this parameter can take"""
high: float
"""Highest value that this parameter can take"""

@staticmethod
def from_name(name: Text) -> "HyperParameter":
"""Create a HyperParameter object given its name.

Parameters
----------
name: str
Name of the hyper-parameter

Returns
-------
HyperParameter
"""
if name == "tau_active":
return TauActive
if name == "rho_update":
Expand All @@ -32,24 +48,34 @@ def from_name(name: Text) -> "HyperParameter":


class PipelineConfig(ABC):
"""Configuration containing the required
parameters to build and run a pipeline"""

@property
@abstractmethod
def duration(self) -> float:
"""The duration of an input audio chunk (in seconds)"""
pass

@property
@abstractmethod
def step(self) -> float:
"""The step between two consecutive input audio chunks (in seconds)"""
pass

@property
@abstractmethod
def latency(self) -> float:
"""The algorithmic latency of the pipeline (in seconds).
At time `t` of the audio stream, the pipeline will
output predictions for time `t - latency`.
"""
pass

@property
@abstractmethod
def sample_rate(self) -> int:
"""The sample rate of the input audio stream"""
pass

def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
Expand All @@ -60,6 +86,8 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:


class Pipeline(ABC):
"""Represents a streaming audio pipeline"""

@staticmethod
@abstractmethod
def get_config_class() -> type:
Expand Down Expand Up @@ -92,4 +120,18 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
"""Runs the next steps of the pipeline
given a list of consecutive audio chunks.

Parameters
----------
waveforms: Sequence[SlidingWindowFeature]
Consecutive chunk waveforms for the pipeline to ingest

Returns
-------
Sequence[Tuple[Any, SlidingWindowFeature]]
For each input waveform, a tuple containing
the pipeline output and its respective audio
"""
pass
4 changes: 4 additions & 0 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def identify(
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
0
]
# Remove speakers that have NaN embeddings
no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)

num_local_speakers = segmentation.data.shape[1]

if self.centers is None:
Expand Down
32 changes: 20 additions & 12 deletions src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
self,
segmentation: m.SegmentationModel | None = None,
embedding: m.EmbeddingModel | None = None,
duration: float | None = None,
duration: float = 5,
step: float = 0.5,
latency: float | Literal["max", "min"] | None = None,
tau_active: float = 0.6,
Expand All @@ -32,7 +32,9 @@ def __init__(
gamma: float = 3,
beta: float = 10,
max_speakers: int = 20,
normalize_embedding_weights: bool = False,
device: torch.device | None = None,
sample_rate: int = 16000,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
Expand All @@ -46,7 +48,7 @@ def __init__(
)

self._duration = duration
self._sample_rate: int | None = None
self._sample_rate = sample_rate

# Latency defaults to the step duration
self._step = step
Expand All @@ -62,16 +64,13 @@ def __init__(
self.gamma = gamma
self.beta = beta
self.max_speakers = max_speakers

self.normalize_embedding_weights = normalize_embedding_weights
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)

@property
def duration(self) -> float:
# Default duration is the one given by the segmentation model
if self._duration is None:
self._duration = self.segmentation.duration
return self._duration

@property
Expand All @@ -84,9 +83,6 @@ def latency(self) -> float:

@property
def sample_rate(self) -> int:
# Expected sample rate is given by the segmentation model
if self._sample_rate is None:
self._sample_rate = self.segmentation.sample_rate
return self._sample_rate


Expand All @@ -105,6 +101,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None):
self._config.gamma,
self._config.beta,
norm=1,
normalize_weights=self._config.normalize_embedding_weights,
device=self._config.device,
)
self.pred_aggregation = DelayedAggregation(
Expand Down Expand Up @@ -160,6 +157,18 @@ def reset(self):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
"""Diarize the next audio chunks of an audio stream.

Parameters
----------
waveforms: Sequence[SlidingWindowFeature]
A sequence of consecutive audio chunks from an audio stream.

Returns
-------
Sequence[tuple[Annotation, SlidingWindowFeature]]
Speaker diarization of each chunk alongside their corresponding audio.
"""
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
Expand All @@ -175,9 +184,8 @@ def __call__(

# Extract segmentation and embeddings
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
embeddings = self.embedding(
batch, segmentations
) # shape (batch, speakers, emb_dim)
# embeddings has shape (batch, speakers, emb_dim)
embeddings = self.embedding(batch, segmentations)

seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]

Expand Down
Loading