From 14910e191ce5b8db1dee2154b5c75463302466b8 Mon Sep 17 00:00:00 2001 From: Simon <80467011+sorgfresser@users.noreply.github.com> Date: Thu, 9 Nov 2023 18:25:22 +0100 Subject: [PATCH] Add compatibility with pyannote 3.0 embedding wrappers (#188) * bump pyannote to 3.0 * add wespeaker inference * add weights normalization, cpu for numpy conversion * unify api * remove try catch * always normalize * use PretrainedSpeakerEmbedding in Loader * Fix min-max normalization equation * fix: remove imports * Change embedding model return type to Callable Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com> * fix: remove type checking * remove from active if NaN embeddings * Fix wrong typing of model in `LazyModel` * add docstrings * Simplify EmbeddingModel.__call__() * Add numpy import * add normalize boolean * Update requirements.txt * Update setup.cfg * Apply suggestions from code review * Fix wrong kwarg name * add abstract __call__ * move __call__ to parent class --------- Co-authored-by: Juan Coria --- requirements.txt | 1 + setup.cfg | 1 + src/diart/argdoc.py | 1 + src/diart/blocks/clustering.py | 4 ++ src/diart/blocks/diarization.py | 4 +- src/diart/blocks/embedding.py | 19 ++++++++-- src/diart/console/benchmark.py | 5 +++ src/diart/console/serve.py | 5 +++ src/diart/console/stream.py | 5 +++ src/diart/console/tune.py | 5 +++ src/diart/models.py | 66 +++++++++++++++++---------------- 11 files changed, 81 insertions(+), 35 deletions(-) diff --git a/requirements.txt b/requirements.txt index e0d93213..2d3e4611 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ torch>=1.12.1 torchvision>=0.14.0 torchaudio>=2.0.2 pyannote.audio>=2.1.1 +requests>=2.31.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 diff --git a/setup.cfg b/setup.cfg index f38a612e..314cb1e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires= torchvision>=0.14.0 torchaudio>=2.0.2 pyannote.audio>=2.1.1 + requests>=2.31.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index e89caa28..ecc0206c 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -15,3 +15,4 @@ OUTPUT = "Directory to store the system's output in RTTM format" HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login" SAMPLE_RATE = "Sample rate of the audio stream" +NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like WeSpeaker or ECAPA-TDNN" diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index b7217c0a..860c1395 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -140,6 +140,10 @@ def identify( long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[ 0 ] + # Remove speakers that have NaN embeddings + no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0] + active_speakers = np.intersect1d(active_speakers, no_nan_embeddings) + num_local_speakers = segmentation.data.shape[1] if self.centers is None: diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index fab83c36..579782de 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -32,6 +32,7 @@ def __init__( gamma: float = 3, beta: float = 10, max_speakers: int = 20, + normalize_embedding_weights: bool = False, device: torch.device | None = None, **kwargs, ): @@ -62,7 +63,7 @@ def __init__( self.gamma = gamma self.beta = beta self.max_speakers = max_speakers - + self.normalize_embedding_weights = normalize_embedding_weights self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) @@ -105,6 +106,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None): self._config.gamma, self._config.beta, norm=1, + normalize_weights=self._config.normalize_embedding_weights, device=self._config.device, ) self.pred_aggregation = DelayedAggregation( diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 5cd7c39e..d6905e2d 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -77,12 +77,16 @@ class OverlappedSpeechPenalty: beta: float, optional Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. + normalize: bool, optional + Whether to min-max normalize weights to be in the range [0, 1]. + Defaults to False. """ - def __init__(self, gamma: float = 3, beta: float = 10): + def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False): self.gamma = gamma self.beta = beta self.formatter = TemporalFeatureFormatter() + self.normalize = normalize def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) @@ -90,6 +94,11 @@ def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: probs = torch.softmax(self.beta * weights, dim=-1) weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma) weights[weights < 1e-8] = 1e-8 + if self.normalize: + min_values = weights.min(dim=1, keepdim=True).values + max_values = weights.max(dim=1, keepdim=True).values + weights = (weights - min_values) / (max_values - min_values) + weights.nan_to_num_(1e-8) return self.formatter.restore_type(weights) @@ -134,6 +143,8 @@ class OverlapAwareSpeakerEmbedding: norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional The target norm for the embeddings. It can be different for each speaker. Defaults to 1. + normalize_weights: bool, optional + Whether to min-max normalize embedding weights to be in the range [0, 1]. device: Optional[torch.device] The device on which to run the embedding model. Defaults to GPU if available or CPU if not. @@ -145,10 +156,11 @@ def __init__( gamma: float = 3, beta: float = 10, norm: Union[float, torch.Tensor] = 1, + normalize_weights: bool = False, device: Optional[torch.device] = None, ): self.embedding = SpeakerEmbedding(model, device) - self.osp = OverlappedSpeechPenalty(gamma, beta) + self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights) self.normalize = EmbeddingNormalization(norm) @staticmethod @@ -158,10 +170,11 @@ def from_pyannote( beta: float = 10, norm: Union[float, torch.Tensor] = 1, use_hf_token: Union[Text, bool, None] = True, + normalize_weights: bool = False, device: Optional[torch.device] = None, ): model = EmbeddingModel.from_pyannote(model, use_hf_token) - return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device) + return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize_weights, device) def __call__( self, waveform: TemporalFeatures, segmentation: TemporalFeatures diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b5a296d1..e5808d14 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -99,6 +99,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index bc002e42..1b0645c3 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -80,6 +80,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 713f3e99..527da88e 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -91,6 +91,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index ec243348..534c4b4b 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -108,6 +108,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/models.py b/src/diart/models.py index 5577a097..3e6d7a09 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,11 +1,16 @@ from abc import ABC, abstractmethod from typing import Optional, Text, Union, Callable +import numpy as np import torch import torch.nn as nn +from requests import HTTPError try: - import pyannote.audio.pipelines.utils as pyannote_loader + from pyannote.audio import Inference, Model + from pyannote.audio.pipelines.speaker_verification import ( + PretrainedSpeakerEmbedding, + ) _has_pyannote = True except ImportError: @@ -18,15 +23,20 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): self.model_info = model_info self.hf_token = hf_token - def __call__(self) -> nn.Module: - return pyannote_loader.get_model(self.model_info, self.hf_token) + def __call__(self) -> Callable: + try: + return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token) + except HTTPError: + return PretrainedSpeakerEmbedding( + self.model_info, use_auth_token=self.hf_token + ) -class LazyModel(nn.Module, ABC): - def __init__(self, loader: Callable[[], nn.Module]): +class LazyModel(ABC): + def __init__(self, loader: Callable[[], Callable]): super().__init__() self.get_model = loader - self.model: Optional[nn.Module] = None + self.model: Optional[Callable] = None def is_in_memory(self) -> bool: """Return whether the model has been loaded into memory""" @@ -36,13 +46,20 @@ def load(self): if not self.is_in_memory(): self.model = self.get_model() - def to(self, *args, **kwargs) -> nn.Module: + def to(self, device: torch.device) -> "LazyModel": self.load() - return super().to(*args, **kwargs) + self.model = self.model.to(device) + return self def __call__(self, *args, **kwargs): self.load() - return super().__call__(*args, **kwargs) + return self.model(*args, **kwargs) + + def eval(self) -> "LazyModel": + self.load() + if isinstance(self.model, nn.Module): + self.model.eval() + return self class SegmentationModel(LazyModel): @@ -83,20 +100,17 @@ def sample_rate(self) -> int: def duration(self) -> float: pass - @abstractmethod - def forward(self, waveform: torch.Tensor) -> torch.Tensor: + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: """ - Forward pass of the segmentation model. - + Call the forward pass of the segmentation model. Parameters ---------- waveform: torch.Tensor, shape (batch, channels, samples) - Returns ------- speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) """ - pass + return super().__call__(waveform) class PyannoteSegmentationModel(SegmentationModel): @@ -113,9 +127,6 @@ def duration(self) -> float: self.load() return self.model.specifications.duration - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - return self.model(waveform) - class EmbeddingModel(LazyModel): """Minimal interface for an embedding model.""" @@ -143,33 +154,26 @@ def from_pyannote( assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) - @abstractmethod - def forward( + def __call__( self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """ - Forward pass of an embedding model with optional weights. - + Call the forward pass of an embedding model with optional weights. Parameters ---------- waveform: torch.Tensor, shape (batch, channels, samples) weights: Optional[torch.Tensor], shape (batch, frames) Temporal weights for each sample in the batch. Defaults to no weights. - Returns ------- speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) """ - pass + embeddings = super().__call__(waveform, weights) + if isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings) + return embeddings class PyannoteEmbeddingModel(EmbeddingModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): super().__init__(PyannoteLoader(model_info, hf_token)) - - def forward( - self, - waveform: torch.Tensor, - weights: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.model(waveform, weights=weights)