Skip to content

Commit

Permalink
Add compatibility with pyannote 3.0 embedding wrappers (#188)
Browse files Browse the repository at this point in the history
* bump pyannote to 3.0

* add wespeaker inference

* add weights normalization, cpu for numpy conversion

* unify api

* remove try catch

* always normalize

* use PretrainedSpeakerEmbedding in Loader

* Fix min-max normalization equation

* fix: remove imports

* Change embedding model return type to Callable

Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com>

* fix: remove type checking

* remove from active if NaN embeddings

* Fix wrong typing of model in `LazyModel`

* add docstrings

* Simplify EmbeddingModel.__call__()

* Add numpy import

* add normalize boolean

* Update requirements.txt

* Update setup.cfg

* Apply suggestions from code review

* Fix wrong kwarg name

* add abstract __call__

* move __call__ to parent class

---------

Co-authored-by: Juan Coria <juanmc2005@hotmail.com>
  • Loading branch information
sorgfresser and juanmc2005 authored Nov 9, 2023
1 parent 0279f21 commit 14910e1
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 35 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch>=1.12.1
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires=
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
1 change: 1 addition & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
OUTPUT = "Directory to store the system's output in RTTM format"
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | <token>). If 'true', it will use the token from huggingface-cli login"
SAMPLE_RATE = "Sample rate of the audio stream"
NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like WeSpeaker or ECAPA-TDNN"
4 changes: 4 additions & 0 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def identify(
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
0
]
# Remove speakers that have NaN embeddings
no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)

num_local_speakers = segmentation.data.shape[1]

if self.centers is None:
Expand Down
4 changes: 3 additions & 1 deletion src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
gamma: float = 3,
beta: float = 10,
max_speakers: int = 20,
normalize_embedding_weights: bool = False,
device: torch.device | None = None,
**kwargs,
):
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
self.gamma = gamma
self.beta = beta
self.max_speakers = max_speakers

self.normalize_embedding_weights = normalize_embedding_weights
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None):
self._config.gamma,
self._config.beta,
norm=1,
normalize_weights=self._config.normalize_embedding_weights,
device=self._config.device,
)
self.pred_aggregation = DelayedAggregation(
Expand Down
19 changes: 16 additions & 3 deletions src/diart/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,28 @@ class OverlappedSpeechPenalty:
beta: float, optional
Temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
normalize: bool, optional
Whether to min-max normalize weights to be in the range [0, 1].
Defaults to False.
"""

def __init__(self, gamma: float = 3, beta: float = 10):
def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False):
self.gamma = gamma
self.beta = beta
self.formatter = TemporalFeatureFormatter()
self.normalize = normalize

def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
with torch.no_grad():
probs = torch.softmax(self.beta * weights, dim=-1)
weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
weights[weights < 1e-8] = 1e-8
if self.normalize:
min_values = weights.min(dim=1, keepdim=True).values
max_values = weights.max(dim=1, keepdim=True).values
weights = (weights - min_values) / (max_values - min_values)
weights.nan_to_num_(1e-8)
return self.formatter.restore_type(weights)


Expand Down Expand Up @@ -134,6 +143,8 @@ class OverlapAwareSpeakerEmbedding:
norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
The target norm for the embeddings. It can be different for each speaker.
Defaults to 1.
normalize_weights: bool, optional
Whether to min-max normalize embedding weights to be in the range [0, 1].
device: Optional[torch.device]
The device on which to run the embedding model.
Defaults to GPU if available or CPU if not.
Expand All @@ -145,10 +156,11 @@ def __init__(
gamma: float = 3,
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
normalize_weights: bool = False,
device: Optional[torch.device] = None,
):
self.embedding = SpeakerEmbedding(model, device)
self.osp = OverlappedSpeechPenalty(gamma, beta)
self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights)
self.normalize = EmbeddingNormalization(norm)

@staticmethod
Expand All @@ -158,10 +170,11 @@ def from_pyannote(
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
use_hf_token: Union[Text, bool, None] = True,
normalize_weights: bool = False,
device: Optional[torch.device] = None,
):
model = EmbeddingModel.from_pyannote(model, use_hf_token)
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize_weights, device)

def __call__(
self, waveform: TemporalFeatures, segmentation: TemporalFeatures
Expand Down
5 changes: 5 additions & 0 deletions src/diart/console/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
parser.add_argument(
"--normalize-embedding-weights",
action="store_true",
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
)
args = parser.parse_args()

# Resolve device
Expand Down
5 changes: 5 additions & 0 deletions src/diart/console/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
parser.add_argument(
"--normalize-embedding-weights",
action="store_true",
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
)
args = parser.parse_args()

# Resolve device
Expand Down
5 changes: 5 additions & 0 deletions src/diart/console/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
parser.add_argument(
"--normalize-embedding-weights",
action="store_true",
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
)
args = parser.parse_args()

# Resolve device
Expand Down
5 changes: 5 additions & 0 deletions src/diart/console/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def run():
type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
)
parser.add_argument(
"--normalize-embedding-weights",
action="store_true",
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
)
args = parser.parse_args()

# Resolve device
Expand Down
66 changes: 35 additions & 31 deletions src/diart/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from abc import ABC, abstractmethod
from typing import Optional, Text, Union, Callable

import numpy as np
import torch
import torch.nn as nn
from requests import HTTPError

try:
import pyannote.audio.pipelines.utils as pyannote_loader
from pyannote.audio import Inference, Model
from pyannote.audio.pipelines.speaker_verification import (
PretrainedSpeakerEmbedding,
)

_has_pyannote = True
except ImportError:
Expand All @@ -18,15 +23,20 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
self.model_info = model_info
self.hf_token = hf_token

def __call__(self) -> nn.Module:
return pyannote_loader.get_model(self.model_info, self.hf_token)
def __call__(self) -> Callable:
try:
return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
except HTTPError:
return PretrainedSpeakerEmbedding(
self.model_info, use_auth_token=self.hf_token
)


class LazyModel(nn.Module, ABC):
def __init__(self, loader: Callable[[], nn.Module]):
class LazyModel(ABC):
def __init__(self, loader: Callable[[], Callable]):
super().__init__()
self.get_model = loader
self.model: Optional[nn.Module] = None
self.model: Optional[Callable] = None

def is_in_memory(self) -> bool:
"""Return whether the model has been loaded into memory"""
Expand All @@ -36,13 +46,20 @@ def load(self):
if not self.is_in_memory():
self.model = self.get_model()

def to(self, *args, **kwargs) -> nn.Module:
def to(self, device: torch.device) -> "LazyModel":
self.load()
return super().to(*args, **kwargs)
self.model = self.model.to(device)
return self

def __call__(self, *args, **kwargs):
self.load()
return super().__call__(*args, **kwargs)
return self.model(*args, **kwargs)

def eval(self) -> "LazyModel":
self.load()
if isinstance(self.model, nn.Module):
self.model.eval()
return self


class SegmentationModel(LazyModel):
Expand Down Expand Up @@ -83,20 +100,17 @@ def sample_rate(self) -> int:
def duration(self) -> float:
pass

@abstractmethod
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the segmentation model.
Call the forward pass of the segmentation model.
Parameters
----------
waveform: torch.Tensor, shape (batch, channels, samples)
Returns
-------
speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
"""
pass
return super().__call__(waveform)


class PyannoteSegmentationModel(SegmentationModel):
Expand All @@ -113,9 +127,6 @@ def duration(self) -> float:
self.load()
return self.model.specifications.duration

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
return self.model(waveform)


class EmbeddingModel(LazyModel):
"""Minimal interface for an embedding model."""
Expand Down Expand Up @@ -143,33 +154,26 @@ def from_pyannote(
assert _has_pyannote, "No pyannote.audio installation found"
return PyannoteEmbeddingModel(model, use_hf_token)

@abstractmethod
def forward(
def __call__(
self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass of an embedding model with optional weights.
Call the forward pass of an embedding model with optional weights.
Parameters
----------
waveform: torch.Tensor, shape (batch, channels, samples)
weights: Optional[torch.Tensor], shape (batch, frames)
Temporal weights for each sample in the batch. Defaults to no weights.
Returns
-------
speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
"""
pass
embeddings = super().__call__(waveform, weights)
if isinstance(embeddings, np.ndarray):
embeddings = torch.from_numpy(embeddings)
return embeddings


class PyannoteEmbeddingModel(EmbeddingModel):
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
super().__init__(PyannoteLoader(model_info, hf_token))

def forward(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(waveform, weights=weights)

0 comments on commit 14910e1

Please sign in to comment.