From 3d6330eb2ebb30da73b123210145b05dba245041 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 16 Oct 2023 20:42:33 +0200 Subject: [PATCH 01/23] bump pyannote to 3.0 --- requirements.txt | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index e0d93213..316a5b20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 torchaudio>=2.0.2 -pyannote.audio>=2.1.1 +pyannote.audio>=3.0.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 diff --git a/setup.cfg b/setup.cfg index f38a612e..1b17badf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ install_requires= torch>=1.12.1 torchvision>=0.14.0 torchaudio>=2.0.2 - pyannote.audio>=2.1.1 + pyannote.audio>=3.0.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 From 4e96df3e7fdf321fb02f029238b36e4e0278c597 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 16 Oct 2023 20:42:50 +0200 Subject: [PATCH 02/23] add wespeaker inference --- src/diart/blocks/diarization.py | 2 +- src/diart/models.py | 107 +++++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 8 deletions(-) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index fab83c36..b4c0ce1f 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -22,7 +22,7 @@ class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: m.SegmentationModel | None = None, - embedding: m.EmbeddingModel | None = None, + embedding: m.EmbeddingModel | m.WeSpeakerSpeakerEmbeddingInference | None = None, duration: float | None = None, step: float = 0.5, latency: float | Literal["max", "min"] | None = None, diff --git a/src/diart/models.py b/src/diart/models.py index 5577a097..4489d815 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,16 +1,21 @@ from abc import ABC, abstractmethod -from typing import Optional, Text, Union, Callable +from typing import Optional, Text, Union, Callable, Mapping, TYPE_CHECKING import torch import torch.nn as nn try: import pyannote.audio.pipelines.utils as pyannote_loader + from pyannote.audio import Inference, Model + from pyannote.audio.pipelines.speaker_verification import WeSpeakerPretrainedSpeakerEmbedding _has_pyannote = True except ImportError: _has_pyannote = False +if TYPE_CHECKING: + PipelineInference = Union[WeSpeakerPretrainedSpeakerEmbedding, Model, Text, Mapping] + class PyannoteLoader: def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): @@ -22,6 +27,15 @@ def __call__(self) -> nn.Module: return pyannote_loader.get_model(self.model_info, self.hf_token) +class PyannoteWeSpeakerSpeakerEmbeddingLoader: + def __init__(self, inference_info): + super().__init__() + self.inference_info = inference_info + + def __call__(self) -> WeSpeakerPretrainedSpeakerEmbedding: + return WeSpeakerPretrainedSpeakerEmbedding(self.inference_info) + + class LazyModel(nn.Module, ABC): def __init__(self, loader: Callable[[], nn.Module]): super().__init__() @@ -45,6 +59,30 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) +class LazyWeSpeakerSpeakerEmbedding(WeSpeakerPretrainedSpeakerEmbedding, ABC): + def __init__(self, loader: Callable[[], WeSpeakerPretrainedSpeakerEmbedding]): + self.get_inference = loader + self.inference: Optional[WeSpeakerPretrainedSpeakerEmbedding] = None + # __init__ at end because in there we call .to() which requires .load() + super().__init__() + + def is_in_memory(self) -> bool: + """Return whether the model has been loaded into memory""" + return self.inference is not None + + def load(self): + if not self.is_in_memory(): + self.inference = self.get_inference() + + def to(self, *args, **kwargs) -> WeSpeakerPretrainedSpeakerEmbedding: + self.load() + return super().to(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self.load() + return super().__call__(*args, **kwargs) + + class SegmentationModel(LazyModel): """ Minimal interface for a segmentation model. @@ -52,7 +90,7 @@ class SegmentationModel(LazyModel): @staticmethod def from_pyannote( - model, use_hf_token: Union[Text, bool, None] = True + model, use_hf_token: Union[Text, bool, None] = True ) -> "SegmentationModel": """ Returns a `SegmentationModel` wrapping a pyannote model. @@ -122,7 +160,7 @@ class EmbeddingModel(LazyModel): @staticmethod def from_pyannote( - model, use_hf_token: Union[Text, bool, None] = True + model, use_hf_token: Union[Text, bool, None] = True ) -> "EmbeddingModel": """ Returns an `EmbeddingModel` wrapping a pyannote model. @@ -145,7 +183,7 @@ def from_pyannote( @abstractmethod def forward( - self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Forward pass of an embedding model with optional weights. @@ -168,8 +206,63 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): super().__init__(PyannoteLoader(model_info, hf_token)) def forward( - self, - waveform: torch.Tensor, - weights: Optional[torch.Tensor] = None, + self, + waveform: torch.Tensor, + weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(waveform, weights=weights) + + +class WeSpeakerSpeakerEmbeddingInference(LazyWeSpeakerSpeakerEmbedding): + """Minimal interface for a we speaker embedding inference.""" + + def __init__(self, loader: Callable[[], WeSpeakerPretrainedSpeakerEmbedding]): + super().__init__(loader) + + @staticmethod + def from_pyannote(inference, + ) -> "WeSpeakerSpeakerEmbeddingInference": + """ + Returns an `EmbeddingModel` wrapping a pyannote model. + + Parameters + ---------- + inference: pyannote.audio.pipelines.speaker_verification.WeSpeakerPretrainedSpeakerEmbedding + The pyannote.audio inference to fetch. + + Returns + ------- + wrapper: EmbeddingModel + """ + assert _has_pyannote, "No pyannote.audio installation found" + return PyannoteWeSpeakerSpeakerEmbeddingInference(inference) + + @abstractmethod + def forward( + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Forward pass of an embedding model with optional weights. + + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + weights: Optional[torch.Tensor], shape (batch, frames) + Temporal weights for each sample in the batch. Defaults to no weights. + + Returns + ------- + speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) + """ + pass + + +class PyannoteWeSpeakerSpeakerEmbeddingInference(WeSpeakerSpeakerEmbeddingInference): + def __init__(self, wespeaker_info): + super().__init__(PyannoteWeSpeakerSpeakerEmbeddingLoader(wespeaker_info)) + + def forward( + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: + self.load() + return torch.from_numpy(self.inference(waveform, weights)) From 4c5a66e455ecd81810220379a06e452934578ce4 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Wed, 18 Oct 2023 10:20:05 +0200 Subject: [PATCH 03/23] add weights normalization, cpu for numpy conversion --- src/diart/models.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 4489d815..9b613c64 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -237,32 +237,24 @@ def from_pyannote(inference, assert _has_pyannote, "No pyannote.audio installation found" return PyannoteWeSpeakerSpeakerEmbeddingInference(inference) - @abstractmethod - def forward( - self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Forward pass of an embedding model with optional weights. - - Parameters - ---------- - waveform: torch.Tensor, shape (batch, channels, samples) - weights: Optional[torch.Tensor], shape (batch, frames) - Temporal weights for each sample in the batch. Defaults to no weights. - - Returns - ------- - speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) - """ - pass + def eval(self): + """Mock method to match pytorch api""" + return class PyannoteWeSpeakerSpeakerEmbeddingInference(WeSpeakerSpeakerEmbeddingInference): def __init__(self, wespeaker_info): super().__init__(PyannoteWeSpeakerSpeakerEmbeddingLoader(wespeaker_info)) - def forward( + def __call__( self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: self.load() + # Normalize weights + weights -= weights.min(dim=1, keepdim=True).values + weights /= weights.max(dim=1, keepdim=True).values + weights.nan_to_num_(0.0) + # Move to cpu for numpy conversion + weights = weights.to("cpu") + waveform = waveform.to("cpu") return torch.from_numpy(self.inference(waveform, weights)) From c89152ff9a499ccdebb214e12f676f4729cb1b21 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Fri, 27 Oct 2023 13:26:01 +0200 Subject: [PATCH 04/23] unify api --- src/diart/blocks/diarization.py | 2 +- src/diart/mapping.py | 5 +- src/diart/models.py | 146 ++++++++------------------------ 3 files changed, 41 insertions(+), 112 deletions(-) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index b4c0ce1f..fab83c36 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -22,7 +22,7 @@ class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: m.SegmentationModel | None = None, - embedding: m.EmbeddingModel | m.WeSpeakerSpeakerEmbeddingInference | None = None, + embedding: m.EmbeddingModel | None = None, duration: float | None = None, step: float = 0.5, latency: float | Literal["max", "min"] | None = None, diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 847d0f78..ddf0b0d1 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -13,7 +13,10 @@ def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: return np.ones(shape) * self.invalid_value def optimal_assignments(self, matrix: np.ndarray) -> List[int]: - return list(lsap(matrix, self.maximize)[1]) + try: + return list(lsap(matrix, self.maximize)[1]) + except ValueError: + print(matrix) def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: # Entries full of invalid_value are not mapped diff --git a/src/diart/models.py b/src/diart/models.py index 9b613c64..97f73e6d 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -3,18 +3,23 @@ import torch import torch.nn as nn +from requests import HTTPError try: import pyannote.audio.pipelines.utils as pyannote_loader from pyannote.audio import Inference, Model - from pyannote.audio.pipelines.speaker_verification import WeSpeakerPretrainedSpeakerEmbedding + from pyannote.audio.pipelines.speaker_verification import ( + WeSpeakerPretrainedSpeakerEmbedding, + ) _has_pyannote = True except ImportError: _has_pyannote = False if TYPE_CHECKING: - PipelineInference = Union[WeSpeakerPretrainedSpeakerEmbedding, Model, Text, Mapping] + from pyannote.audio.pipelines.speaker_verification import ( + WeSpeakerPretrainedSpeakerEmbedding, + ) class PyannoteLoader: @@ -23,20 +28,14 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): self.model_info = model_info self.hf_token = hf_token - def __call__(self) -> nn.Module: - return pyannote_loader.get_model(self.model_info, self.hf_token) + def __call__(self) -> Union[nn.Module, WeSpeakerPretrainedSpeakerEmbedding]: + try: + return pyannote_loader.get_model(self.model_info) + except HTTPError: + return WeSpeakerPretrainedSpeakerEmbedding(self.model_info) -class PyannoteWeSpeakerSpeakerEmbeddingLoader: - def __init__(self, inference_info): - super().__init__() - self.inference_info = inference_info - - def __call__(self) -> WeSpeakerPretrainedSpeakerEmbedding: - return WeSpeakerPretrainedSpeakerEmbedding(self.inference_info) - - -class LazyModel(nn.Module, ABC): +class LazyModel(ABC): def __init__(self, loader: Callable[[], nn.Module]): super().__init__() self.get_model = loader @@ -52,35 +51,17 @@ def load(self): def to(self, *args, **kwargs) -> nn.Module: self.load() - return super().to(*args, **kwargs) + return self.model.to(*args, **kwargs) def __call__(self, *args, **kwargs): self.load() - return super().__call__(*args, **kwargs) - - -class LazyWeSpeakerSpeakerEmbedding(WeSpeakerPretrainedSpeakerEmbedding, ABC): - def __init__(self, loader: Callable[[], WeSpeakerPretrainedSpeakerEmbedding]): - self.get_inference = loader - self.inference: Optional[WeSpeakerPretrainedSpeakerEmbedding] = None - # __init__ at end because in there we call .to() which requires .load() - super().__init__() - - def is_in_memory(self) -> bool: - """Return whether the model has been loaded into memory""" - return self.inference is not None - - def load(self): - if not self.is_in_memory(): - self.inference = self.get_inference() - - def to(self, *args, **kwargs) -> WeSpeakerPretrainedSpeakerEmbedding: - self.load() - return super().to(*args, **kwargs) + return self.model(*args, **kwargs) - def __call__(self, *args, **kwargs): + def eval(self) -> "LazyModel": self.load() - return super().__call__(*args, **kwargs) + if not isinstance(self.model, WeSpeakerPretrainedSpeakerEmbedding): + self.model.eval() + return self class SegmentationModel(LazyModel): @@ -90,7 +71,7 @@ class SegmentationModel(LazyModel): @staticmethod def from_pyannote( - model, use_hf_token: Union[Text, bool, None] = True + model, use_hf_token: Union[Text, bool, None] = True ) -> "SegmentationModel": """ Returns a `SegmentationModel` wrapping a pyannote model. @@ -160,7 +141,7 @@ class EmbeddingModel(LazyModel): @staticmethod def from_pyannote( - model, use_hf_token: Union[Text, bool, None] = True + model, use_hf_token: Union[Text, bool, None] = True ) -> "EmbeddingModel": """ Returns an `EmbeddingModel` wrapping a pyannote model. @@ -181,80 +162,25 @@ def from_pyannote( assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) - @abstractmethod - def forward( - self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Forward pass of an embedding model with optional weights. - - Parameters - ---------- - waveform: torch.Tensor, shape (batch, channels, samples) - weights: Optional[torch.Tensor], shape (batch, frames) - Temporal weights for each sample in the batch. Defaults to no weights. - - Returns - ------- - speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) - """ - pass - class PyannoteEmbeddingModel(EmbeddingModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): super().__init__(PyannoteLoader(model_info, hf_token)) - def forward( - self, - waveform: torch.Tensor, - weights: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.model(waveform, weights=weights) - - -class WeSpeakerSpeakerEmbeddingInference(LazyWeSpeakerSpeakerEmbedding): - """Minimal interface for a we speaker embedding inference.""" - - def __init__(self, loader: Callable[[], WeSpeakerPretrainedSpeakerEmbedding]): - super().__init__(loader) - - @staticmethod - def from_pyannote(inference, - ) -> "WeSpeakerSpeakerEmbeddingInference": - """ - Returns an `EmbeddingModel` wrapping a pyannote model. - - Parameters - ---------- - inference: pyannote.audio.pipelines.speaker_verification.WeSpeakerPretrainedSpeakerEmbedding - The pyannote.audio inference to fetch. - - Returns - ------- - wrapper: EmbeddingModel - """ - assert _has_pyannote, "No pyannote.audio installation found" - return PyannoteWeSpeakerSpeakerEmbeddingInference(inference) - - def eval(self): - """Mock method to match pytorch api""" - return - - -class PyannoteWeSpeakerSpeakerEmbeddingInference(WeSpeakerSpeakerEmbeddingInference): - def __init__(self, wespeaker_info): - super().__init__(PyannoteWeSpeakerSpeakerEmbeddingLoader(wespeaker_info)) - def __call__( - self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: - self.load() - # Normalize weights - weights -= weights.min(dim=1, keepdim=True).values - weights /= weights.max(dim=1, keepdim=True).values - weights.nan_to_num_(0.0) - # Move to cpu for numpy conversion - weights = weights.to("cpu") - waveform = waveform.to("cpu") - return torch.from_numpy(self.inference(waveform, weights)) + if not isinstance(self.model, WeSpeakerPretrainedSpeakerEmbedding): + return super().__call__(waveform, weights) + else: + self.load() + # Normalize weights + if weights is not None: + weights -= weights.min(dim=1, keepdim=True).values + weights /= weights.max(dim=1, keepdim=True).values + weights.nan_to_num_(0.0) + # Move to cpu for numpy conversion + weights = weights.to("cpu") + # Move to cpu for numpy conversion + waveform = waveform.to("cpu") + return torch.from_numpy(self.model(waveform, weights)) From ccd7d12a45dd46721b1be45be69e176897a523a1 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Fri, 27 Oct 2023 18:43:30 +0200 Subject: [PATCH 05/23] remove try catch --- src/diart/mapping.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diart/mapping.py b/src/diart/mapping.py index ddf0b0d1..847d0f78 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -13,10 +13,7 @@ def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: return np.ones(shape) * self.invalid_value def optimal_assignments(self, matrix: np.ndarray) -> List[int]: - try: - return list(lsap(matrix, self.maximize)[1]) - except ValueError: - print(matrix) + return list(lsap(matrix, self.maximize)[1]) def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: # Entries full of invalid_value are not mapped From aa2e03b56792b1a7db72ee9db002ee2011e455c0 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Fri, 27 Oct 2023 18:59:36 +0200 Subject: [PATCH 06/23] always normalize --- src/diart/models.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 97f73e6d..01875a44 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -59,7 +59,7 @@ def __call__(self, *args, **kwargs): def eval(self) -> "LazyModel": self.load() - if not isinstance(self.model, WeSpeakerPretrainedSpeakerEmbedding): + if isinstance(self.model, nn.Module): self.model.eval() return self @@ -170,17 +170,19 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): def __call__( self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: - if not isinstance(self.model, WeSpeakerPretrainedSpeakerEmbedding): + # Normalize weights + if weights is not None: + weights -= weights.min(dim=1, keepdim=True).values + weights /= weights.max(dim=1, keepdim=True).values + weights.nan_to_num_(0.0) + + if isinstance(self.model, nn.Module): return super().__call__(waveform, weights) + else: - self.load() - # Normalize weights if weights is not None: - weights -= weights.min(dim=1, keepdim=True).values - weights /= weights.max(dim=1, keepdim=True).values - weights.nan_to_num_(0.0) # Move to cpu for numpy conversion weights = weights.to("cpu") # Move to cpu for numpy conversion waveform = waveform.to("cpu") - return torch.from_numpy(self.model(waveform, weights)) + return torch.from_numpy(super().__call__(waveform, weights)) From d0788bbc043e8c1f6431b8d20fc39e4ff67d88d9 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Fri, 27 Oct 2023 19:14:29 +0200 Subject: [PATCH 07/23] use PretrainedSpeakerEmbedding in Loader --- src/diart/models.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 01875a44..78e3d249 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -10,6 +10,7 @@ from pyannote.audio import Inference, Model from pyannote.audio.pipelines.speaker_verification import ( WeSpeakerPretrainedSpeakerEmbedding, + PretrainedSpeakerEmbedding, ) _has_pyannote = True @@ -28,15 +29,17 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): self.model_info = model_info self.hf_token = hf_token - def __call__(self) -> Union[nn.Module, WeSpeakerPretrainedSpeakerEmbedding]: + def __call__(self) -> Union[Model, PretrainedSpeakerEmbedding]: try: - return pyannote_loader.get_model(self.model_info) + return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token) except HTTPError: - return WeSpeakerPretrainedSpeakerEmbedding(self.model_info) + return PretrainedSpeakerEmbedding( + self.model_info, use_auth_token=self.hf_token + ) class LazyModel(ABC): - def __init__(self, loader: Callable[[], nn.Module]): + def __init__(self, loader: Callable[[], Callable]): super().__init__() self.get_model = loader self.model: Optional[nn.Module] = None @@ -102,21 +105,6 @@ def sample_rate(self) -> int: def duration(self) -> float: pass - @abstractmethod - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the segmentation model. - - Parameters - ---------- - waveform: torch.Tensor, shape (batch, channels, samples) - - Returns - ------- - speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) - """ - pass - class PyannoteSegmentationModel(SegmentationModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): @@ -132,8 +120,8 @@ def duration(self) -> float: self.load() return self.model.specifications.duration - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - return self.model(waveform) + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + return super().__call__(waveform) class EmbeddingModel(LazyModel): From 714eae366269965e092ad6d54f81c5438d50dbe3 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Sat, 28 Oct 2023 14:06:13 +0200 Subject: [PATCH 08/23] Fix min-max normalization equation --- src/diart/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 78e3d249..be774b44 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -160,8 +160,9 @@ def __call__( ) -> torch.Tensor: # Normalize weights if weights is not None: - weights -= weights.min(dim=1, keepdim=True).values - weights /= weights.max(dim=1, keepdim=True).values + min_values = weights.min(dim=1, keepdim=True).values + max_values = weights.max(dim=1, keepdim=True).values + weights = (weights - min_values) / (max_values - min_values) weights.nan_to_num_(0.0) if isinstance(self.model, nn.Module): From 58293a8292c9cec84af0d466c0b2d8fcfcbae66c Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 30 Oct 2023 14:51:40 +0100 Subject: [PATCH 09/23] fix: remove imports --- src/diart/models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index be774b44..7df9ac09 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,15 +1,13 @@ from abc import ABC, abstractmethod -from typing import Optional, Text, Union, Callable, Mapping, TYPE_CHECKING +from typing import Optional, Text, Union, Callable, TYPE_CHECKING import torch import torch.nn as nn from requests import HTTPError try: - import pyannote.audio.pipelines.utils as pyannote_loader from pyannote.audio import Inference, Model from pyannote.audio.pipelines.speaker_verification import ( - WeSpeakerPretrainedSpeakerEmbedding, PretrainedSpeakerEmbedding, ) @@ -18,8 +16,9 @@ _has_pyannote = False if TYPE_CHECKING: + from pyannote.audio import Model from pyannote.audio.pipelines.speaker_verification import ( - WeSpeakerPretrainedSpeakerEmbedding, + PretrainedSpeakerEmbedding, ) From d8b2c4ff55c50aacf0ff0bef73103e16b2425d2f Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 30 Oct 2023 15:42:14 +0100 Subject: [PATCH 10/23] Change embedding model return type to Callable Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com> --- src/diart/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diart/models.py b/src/diart/models.py index 7df9ac09..fec892d4 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -28,7 +28,7 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): self.model_info = model_info self.hf_token = hf_token - def __call__(self) -> Union[Model, PretrainedSpeakerEmbedding]: + def __call__(self) -> Callable: try: return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token) except HTTPError: From 1ea830969aa36787fa516949b1429f9ccef3632b Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 30 Oct 2023 15:53:16 +0100 Subject: [PATCH 11/23] fix: remove type checking --- src/diart/models.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index fec892d4..56fcdeed 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Text, Union, Callable, TYPE_CHECKING +from typing import Optional, Text, Union, Callable import torch import torch.nn as nn @@ -15,11 +15,6 @@ except ImportError: _has_pyannote = False -if TYPE_CHECKING: - from pyannote.audio import Model - from pyannote.audio.pipelines.speaker_verification import ( - PretrainedSpeakerEmbedding, - ) class PyannoteLoader: From 653984cb7a099412d00b1c6607fcd83cd073f704 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 30 Oct 2023 17:06:35 +0100 Subject: [PATCH 12/23] remove from active if NaN embeddings --- src/diart/blocks/clustering.py | 4 ++++ src/diart/models.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index b7217c0a..860c1395 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -140,6 +140,10 @@ def identify( long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[ 0 ] + # Remove speakers that have NaN embeddings + no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0] + active_speakers = np.intersect1d(active_speakers, no_nan_embeddings) + num_local_speakers = segmentation.data.shape[1] if self.centers is None: diff --git a/src/diart/models.py b/src/diart/models.py index 56fcdeed..91b8b8b1 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -16,7 +16,6 @@ _has_pyannote = False - class PyannoteLoader: def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): super().__init__() From 539e401ee428a5164de4f5a60996c8a6a9b7def6 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 30 Oct 2023 17:29:01 +0100 Subject: [PATCH 13/23] Fix wrong typing of model in `LazyModel` --- src/diart/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diart/models.py b/src/diart/models.py index 91b8b8b1..bc994c59 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -35,7 +35,7 @@ class LazyModel(ABC): def __init__(self, loader: Callable[[], Callable]): super().__init__() self.get_model = loader - self.model: Optional[nn.Module] = None + self.model: Optional[Callable] = None def is_in_memory(self) -> bool: """Return whether the model has been loaded into memory""" From e3c6eef2059f2b41e9a5ee0001cd7fcd2c19f032 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Fri, 3 Nov 2023 17:18:09 +0100 Subject: [PATCH 14/23] add docstrings --- src/diart/models.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index bc994c59..4a670ff8 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -114,6 +114,15 @@ def duration(self) -> float: return self.model.specifications.duration def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + """ + Call the forward pass of the segmentation model. + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + Returns + ------- + speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) + """ return super().__call__(waveform) @@ -151,6 +160,17 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): def __call__( self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: + """ + Call the forward pass of an embedding model with optional weights. + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + weights: Optional[torch.Tensor], shape (batch, frames) + Temporal weights for each sample in the batch. Defaults to no weights. + Returns + ------- + speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) + """ # Normalize weights if weights is not None: min_values = weights.min(dim=1, keepdim=True).values @@ -162,9 +182,4 @@ def __call__( return super().__call__(waveform, weights) else: - if weights is not None: - # Move to cpu for numpy conversion - weights = weights.to("cpu") - # Move to cpu for numpy conversion - waveform = waveform.to("cpu") return torch.from_numpy(super().__call__(waveform, weights)) From f57f9d2a45c4dc8d32003f6a8b3718bb8be8e946 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Fri, 3 Nov 2023 21:02:01 +0100 Subject: [PATCH 15/23] Simplify EmbeddingModel.__call__() --- src/diart/models.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 4a670ff8..6542d33f 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -178,8 +178,7 @@ def __call__( weights = (weights - min_values) / (max_values - min_values) weights.nan_to_num_(0.0) - if isinstance(self.model, nn.Module): - return super().__call__(waveform, weights) - - else: - return torch.from_numpy(super().__call__(waveform, weights)) + embeddings = super().__call__(waveform, weights) + if isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings) + return embeddings From fd9f8e64676538c12f21ac6064a2eaf9820defcb Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Fri, 3 Nov 2023 21:02:58 +0100 Subject: [PATCH 16/23] Add numpy import --- src/diart/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diart/models.py b/src/diart/models.py index 6542d33f..d11c0c08 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Text, Union, Callable +import numpy as np import torch import torch.nn as nn from requests import HTTPError From e5d31e0f478320978c953c570e048919f4fe1888 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Mon, 6 Nov 2023 15:53:01 +0100 Subject: [PATCH 17/23] add normalize boolean --- src/diart/argdoc.py | 1 + src/diart/blocks/diarization.py | 4 +++- src/diart/blocks/embedding.py | 19 ++++++++++++++++--- src/diart/console/benchmark.py | 5 +++++ src/diart/console/serve.py | 5 +++++ src/diart/console/stream.py | 5 +++++ src/diart/console/tune.py | 5 +++++ src/diart/models.py | 7 ------- 8 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index e89caa28..cab7ecf4 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -15,3 +15,4 @@ OUTPUT = "Directory to store the system's output in RTTM format" HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login" SAMPLE_RATE = "Sample rate of the audio stream" +NORMALIZE_EMBEDDING_WEIGHTS = "Normalize embedding weights to be in the range [0, 1]. Useful for WeSpeaker embeddings" diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index fab83c36..aec7da65 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -32,6 +32,7 @@ def __init__( gamma: float = 3, beta: float = 10, max_speakers: int = 20, + normalize_embedding_weights: bool = False, device: torch.device | None = None, **kwargs, ): @@ -62,7 +63,7 @@ def __init__( self.gamma = gamma self.beta = beta self.max_speakers = max_speakers - + self.normalize_embedding_weights = normalize_embedding_weights self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) @@ -105,6 +106,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None): self._config.gamma, self._config.beta, norm=1, + normalize=self._config.normalize_embedding_weights, device=self._config.device, ) self.pred_aggregation = DelayedAggregation( diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 5cd7c39e..b95dcb37 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -77,12 +77,16 @@ class OverlappedSpeechPenalty: beta: float, optional Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. + normalize: bool, optional + Whether to normalize the weights to be in the range [0, 1]. + Defaults to False. """ - def __init__(self, gamma: float = 3, beta: float = 10): + def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False): self.gamma = gamma self.beta = beta self.formatter = TemporalFeatureFormatter() + self.normalize = normalize def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) @@ -90,6 +94,11 @@ def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: probs = torch.softmax(self.beta * weights, dim=-1) weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma) weights[weights < 1e-8] = 1e-8 + if self.normalize: + min_values = weights.min(dim=1, keepdim=True).values + max_values = weights.max(dim=1, keepdim=True).values + weights = (weights - min_values) / (max_values - min_values) + weights.nan_to_num_(1e-8) return self.formatter.restore_type(weights) @@ -134,6 +143,8 @@ class OverlapAwareSpeakerEmbedding: norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional The target norm for the embeddings. It can be different for each speaker. Defaults to 1. + normalize: bool, optional + Whether to normalize the embeddings to be in the range [0, 1]. device: Optional[torch.device] The device on which to run the embedding model. Defaults to GPU if available or CPU if not. @@ -145,10 +156,11 @@ def __init__( gamma: float = 3, beta: float = 10, norm: Union[float, torch.Tensor] = 1, + normalize: bool = False, device: Optional[torch.device] = None, ): self.embedding = SpeakerEmbedding(model, device) - self.osp = OverlappedSpeechPenalty(gamma, beta) + self.osp = OverlappedSpeechPenalty(gamma, beta, normalize) self.normalize = EmbeddingNormalization(norm) @staticmethod @@ -158,10 +170,11 @@ def from_pyannote( beta: float = 10, norm: Union[float, torch.Tensor] = 1, use_hf_token: Union[Text, bool, None] = True, + normalize: bool = False, device: Optional[torch.device] = None, ): model = EmbeddingModel.from_pyannote(model, use_hf_token) - return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device) + return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize, device) def __call__( self, waveform: TemporalFeatures, segmentation: TemporalFeatures diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b5a296d1..e5808d14 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -99,6 +99,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index bc002e42..1b0645c3 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -80,6 +80,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 713f3e99..527da88e 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -91,6 +91,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index ec243348..534c4b4b 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -108,6 +108,11 @@ def run(): type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", ) + parser.add_argument( + "--normalize-embedding-weights", + action="store_true", + help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", + ) args = parser.parse_args() # Resolve device diff --git a/src/diart/models.py b/src/diart/models.py index d11c0c08..92dc9972 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -172,13 +172,6 @@ def __call__( ------- speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) """ - # Normalize weights - if weights is not None: - min_values = weights.min(dim=1, keepdim=True).values - max_values = weights.max(dim=1, keepdim=True).values - weights = (weights - min_values) / (max_values - min_values) - weights.nan_to_num_(0.0) - embeddings = super().__call__(waveform, weights) if isinstance(embeddings, np.ndarray): embeddings = torch.from_numpy(embeddings) From 661588c1dd22b2a42281f65fada65da9e38cbdf7 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 6 Nov 2023 20:06:36 +0100 Subject: [PATCH 18/23] Update requirements.txt --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 316a5b20..2d3e4611 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,8 @@ pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 torchaudio>=2.0.2 -pyannote.audio>=3.0.0 +pyannote.audio>=2.1.1 +requests>=2.31.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 From c708aa063a7eee42b4d0386abad14ea1fd2fe0b4 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 6 Nov 2023 20:06:48 +0100 Subject: [PATCH 19/23] Update setup.cfg --- setup.cfg | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 1b17badf..314cb1e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,8 @@ install_requires= torch>=1.12.1 torchvision>=0.14.0 torchaudio>=2.0.2 - pyannote.audio>=3.0.0 + pyannote.audio>=2.1.1 + requests>=2.31.0 pyannote.core>=4.5 pyannote.database>=4.1.1 pyannote.metrics>=3.2 From 22aa9f1d281a55879bbc91f689020d2885b150d7 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 6 Nov 2023 20:08:56 +0100 Subject: [PATCH 20/23] Apply suggestions from code review --- src/diart/argdoc.py | 2 +- src/diart/blocks/embedding.py | 14 +++++++------- src/diart/models.py | 5 +++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index cab7ecf4..ecc0206c 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -15,4 +15,4 @@ OUTPUT = "Directory to store the system's output in RTTM format" HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login" SAMPLE_RATE = "Sample rate of the audio stream" -NORMALIZE_EMBEDDING_WEIGHTS = "Normalize embedding weights to be in the range [0, 1]. Useful for WeSpeaker embeddings" +NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like WeSpeaker or ECAPA-TDNN" diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index b95dcb37..d6905e2d 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -78,7 +78,7 @@ class OverlappedSpeechPenalty: Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. normalize: bool, optional - Whether to normalize the weights to be in the range [0, 1]. + Whether to min-max normalize weights to be in the range [0, 1]. Defaults to False. """ @@ -143,8 +143,8 @@ class OverlapAwareSpeakerEmbedding: norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional The target norm for the embeddings. It can be different for each speaker. Defaults to 1. - normalize: bool, optional - Whether to normalize the embeddings to be in the range [0, 1]. + normalize_weights: bool, optional + Whether to min-max normalize embedding weights to be in the range [0, 1]. device: Optional[torch.device] The device on which to run the embedding model. Defaults to GPU if available or CPU if not. @@ -156,11 +156,11 @@ def __init__( gamma: float = 3, beta: float = 10, norm: Union[float, torch.Tensor] = 1, - normalize: bool = False, + normalize_weights: bool = False, device: Optional[torch.device] = None, ): self.embedding = SpeakerEmbedding(model, device) - self.osp = OverlappedSpeechPenalty(gamma, beta, normalize) + self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights) self.normalize = EmbeddingNormalization(norm) @staticmethod @@ -170,11 +170,11 @@ def from_pyannote( beta: float = 10, norm: Union[float, torch.Tensor] = 1, use_hf_token: Union[Text, bool, None] = True, - normalize: bool = False, + normalize_weights: bool = False, device: Optional[torch.device] = None, ): model = EmbeddingModel.from_pyannote(model, use_hf_token) - return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize, device) + return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize_weights, device) def __call__( self, waveform: TemporalFeatures, segmentation: TemporalFeatures diff --git a/src/diart/models.py b/src/diart/models.py index 92dc9972..399fd782 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -46,9 +46,10 @@ def load(self): if not self.is_in_memory(): self.model = self.get_model() - def to(self, *args, **kwargs) -> nn.Module: + def to(self, device: torch.device) -> "LazyModel": self.load() - return self.model.to(*args, **kwargs) + self.model = self.model.to(device) + return self def __call__(self, *args, **kwargs): self.load() From 576992bf9839cc5be41ac31200739eb7f77823df Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Mon, 6 Nov 2023 20:16:30 +0100 Subject: [PATCH 21/23] Fix wrong kwarg name --- src/diart/blocks/diarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index aec7da65..579782de 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -106,7 +106,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None): self._config.gamma, self._config.beta, norm=1, - normalize=self._config.normalize_embedding_weights, + normalize_weights=self._config.normalize_embedding_weights, device=self._config.device, ) self.pred_aggregation = DelayedAggregation( From 0644c70e83a17c2984dd2f6ff4a48f043127dcdc Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Thu, 9 Nov 2023 08:09:59 +0100 Subject: [PATCH 22/23] add abstract __call__ --- src/diart/models.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diart/models.py b/src/diart/models.py index 399fd782..f6422075 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -100,6 +100,19 @@ def sample_rate(self) -> int: def duration(self) -> float: pass + @abstractmethod + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + """ + Call the forward pass of the segmentation model. + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + Returns + ------- + speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) + """ + pass + class PyannoteSegmentationModel(SegmentationModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): @@ -154,6 +167,23 @@ def from_pyannote( assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) + @abstractmethod + def __call__( + self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Call the forward pass of an embedding model with optional weights. + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + weights: Optional[torch.Tensor], shape (batch, frames) + Temporal weights for each sample in the batch. Defaults to no weights. + Returns + ------- + speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) + """ + pass + class PyannoteEmbeddingModel(EmbeddingModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): From 791d68831fb47ba9f051d0dfd54e17a9a5edf768 Mon Sep 17 00:00:00 2001 From: Simon Sorg Date: Thu, 9 Nov 2023 08:11:36 +0100 Subject: [PATCH 23/23] move __call__ to parent class --- src/diart/models.py | 40 +++++----------------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index f6422075..3e6d7a09 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -100,7 +100,6 @@ def sample_rate(self) -> int: def duration(self) -> float: pass - @abstractmethod def __call__(self, waveform: torch.Tensor) -> torch.Tensor: """ Call the forward pass of the segmentation model. @@ -111,7 +110,7 @@ def __call__(self, waveform: torch.Tensor) -> torch.Tensor: ------- speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) """ - pass + return super().__call__(waveform) class PyannoteSegmentationModel(SegmentationModel): @@ -128,18 +127,6 @@ def duration(self) -> float: self.load() return self.model.specifications.duration - def __call__(self, waveform: torch.Tensor) -> torch.Tensor: - """ - Call the forward pass of the segmentation model. - Parameters - ---------- - waveform: torch.Tensor, shape (batch, channels, samples) - Returns - ------- - speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) - """ - return super().__call__(waveform) - class EmbeddingModel(LazyModel): """Minimal interface for an embedding model.""" @@ -167,7 +154,6 @@ def from_pyannote( assert _has_pyannote, "No pyannote.audio installation found" return PyannoteEmbeddingModel(model, use_hf_token) - @abstractmethod def __call__( self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -182,28 +168,12 @@ def __call__( ------- speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) """ - pass + embeddings = super().__call__(waveform, weights) + if isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings) + return embeddings class PyannoteEmbeddingModel(EmbeddingModel): def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): super().__init__(PyannoteLoader(model_info, hf_token)) - - def __call__( - self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Call the forward pass of an embedding model with optional weights. - Parameters - ---------- - waveform: torch.Tensor, shape (batch, channels, samples) - weights: Optional[torch.Tensor], shape (batch, frames) - Temporal weights for each sample in the batch. Defaults to no weights. - Returns - ------- - speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) - """ - embeddings = super().__call__(waveform, weights) - if isinstance(embeddings, np.ndarray): - embeddings = torch.from_numpy(embeddings) - return embeddings