From 1e8fa1504d8999fc4e0f0cad152b30f64ea28082 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Fri, 11 Mar 2022 13:49:44 +0100 Subject: [PATCH 01/20] fix higher_is_better value in (AU)DER MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hervé BREDIN --- pyannote/audio/torchmetrics/audio/auder.py | 2 +- pyannote/audio/torchmetrics/audio/der.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/torchmetrics/audio/auder.py b/pyannote/audio/torchmetrics/audio/auder.py index 2e6a265c5..4249a5342 100644 --- a/pyannote/audio/torchmetrics/audio/auder.py +++ b/pyannote/audio/torchmetrics/audio/auder.py @@ -17,7 +17,7 @@ class AUDER(Metric): Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. """ - higher_is_better = True + higher_is_better = False is_differentiable = False def __init__(self, steps=31, threshold_min=0.0, threshold_max=1.0, unit_area=True): diff --git a/pyannote/audio/torchmetrics/audio/der.py b/pyannote/audio/torchmetrics/audio/der.py index c3fe19c72..14b944aee 100644 --- a/pyannote/audio/torchmetrics/audio/der.py +++ b/pyannote/audio/torchmetrics/audio/der.py @@ -13,7 +13,7 @@ class DER(Metric): Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. """ - higher_is_better = True + higher_is_better = False is_differentiable = False def __init__(self, threshold: float = 0.5): From 4a618ee6c3973bb13481dfc5107c0672c446f4df Mon Sep 17 00:00:00 2001 From: Masashi <32604391+matakanobu@users.noreply.github.com> Date: Mon, 14 Mar 2022 17:15:38 +0900 Subject: [PATCH 02/20] chore: relax speechbrain embeddings naming convention --- pyannote/audio/pipelines/speaker_verification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyannote/audio/pipelines/speaker_verification.py b/pyannote/audio/pipelines/speaker_verification.py index 84c83ffc2..72d8edec9 100644 --- a/pyannote/audio/pipelines/speaker_verification.py +++ b/pyannote/audio/pipelines/speaker_verification.py @@ -265,7 +265,7 @@ def PretrainedSpeakerEmbedding(embedding: PipelineModel, device: torch.device = >>> embeddings = get_embedding(waveforms, masks=masks) """ - if isinstance(embedding, str) and embedding.split("/")[0] == "speechbrain": + if isinstance(embedding, str) and "speechbrain" in embedding: return SpeechBrainPretrainedSpeakerEmbedding(embedding, device=device) else: return PyannoteAudioPretrainedSpeakerEmbedding(embedding, device=device) From 738d97b819661155f2fb9a49801aa6fddda25834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 14 Mar 2022 11:30:37 +0100 Subject: [PATCH 03/20] wip: trying to simplify things a bit --- pyannote/audio/cli/train.py | 15 ++-- pyannote/audio/core/model.py | 4 +- pyannote/audio/core/task.py | 87 +++++++------------ pyannote/audio/tasks/embedding/arcface.py | 11 ++- pyannote/audio/tasks/embedding/mixins.py | 29 ++----- pyannote/audio/tasks/segmentation/mixins.py | 34 ++------ .../overlapped_speech_detection.py | 11 ++- .../audio/tasks/segmentation/segmentation.py | 17 ++-- .../segmentation/speaker_change_detection.py | 11 ++- .../tasks/segmentation/speaker_tracking.py | 11 ++- .../segmentation/voice_activity_detection.py | 11 ++- 11 files changed, 103 insertions(+), 138 deletions(-) diff --git a/pyannote/audio/cli/train.py b/pyannote/audio/cli/train.py index b8547e446..e7ab2225a 100644 --- a/pyannote/audio/cli/train.py +++ b/pyannote/audio/cli/train.py @@ -27,20 +27,19 @@ import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf + +# from pyannote.audio.core.callback import GraduallyUnfreeze +from pyannote.database import FileFinder, get_protocol from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar, ) - from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything from torch_audiomentations.utils.config import from_dict as get_augmentation -# from pyannote.audio.core.callback import GraduallyUnfreeze -from pyannote.database import FileFinder, get_protocol - @hydra.main(config_path="train_config", config_name="config") def train(cfg: DictConfig) -> Optional[float]: @@ -67,15 +66,15 @@ def train(cfg: DictConfig) -> Optional[float]: # instantiate task and validation metric task = instantiate(cfg.task, protocol, augmentation=augmentation) - # validation metric to monitor (and its direction: min or max) - monitor, direction = task.val_monitor - # instantiate model fine_tuning = cfg.model["_target_"] == "pyannote.audio.cli.pretrained" model = instantiate(cfg.model) model.task = task model.setup(stage="fit") + # validation metric to monitor (and its direction: min or max) + monitor, direction = task.val_monitor + # number of batches in one epoch num_batches_per_epoch = model.task.train__len__() // model.task.batch_size @@ -90,6 +89,7 @@ def configure_optimizers(self): num_batches_per_epoch=num_batches_per_epoch, ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + model.configure_optimizers = MethodType(configure_optimizers, model) callbacks = [RichProgressBar(), LearningRateMonitor()] @@ -155,4 +155,3 @@ def configure_optimizers(self): if __name__ == "__main__": train() - diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 4e9bfe092..c36b6927b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -34,6 +34,7 @@ import torch.nn as nn import torch.optim from huggingface_hub import cached_download, hf_hub_url +from pyannote.core import SlidingWindow from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.model_summary import ModelSummary from semver import VersionInfo @@ -42,7 +43,6 @@ from pyannote.audio import __version__ from pyannote.audio.core.io import Audio from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.core import SlidingWindow CACHE_DIR = os.getenv( "PYANNOTE_CACHE", @@ -385,7 +385,7 @@ def setup(self, stage=None): # setup custom loss function self.task.setup_loss_func() # setup custom validation metrics - validation_metric = self.task.setup_validation_metric() + validation_metric = self.task.metric if validation_metric is not None: self.validation_metric = validation_metric self.validation_metric.to(self.device) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 34903f781..9c40efa7a 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,13 +23,18 @@ from __future__ import annotations +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property + import multiprocessing import sys import warnings from dataclasses import dataclass from enum import Enum from numbers import Number -from typing import Dict, List, Optional, Sequence, Text, Tuple, Type, Union +from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import pytorch_lightning as pl import torch @@ -152,6 +157,9 @@ class Task(pl.LightningDataModule): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to value returned by `default_metric` method. Attributes ---------- @@ -169,7 +177,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__() @@ -205,20 +213,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.augmentation = augmentation - self.metrics = metrics - - @property - def default_validation_metric( - self, - ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - """Used to get validation metrics when none is provided by the user - - Returns - ------- - Union[Metric, Sequence[Metric], Dict[str, Metric]] - The default validation metric(s) of this task (something that can be plugged in a torchmetrics.MetricCollection). - """ - pass + self._metric = metric def prepare_data(self): """Use this to download and prepare data @@ -249,39 +244,6 @@ def setup(self, stage: Optional[str] = None): def setup_loss_func(self): pass - @property - def val_metric_prefix(self) -> str: - return f"{self.ACRONYM}@val_" - - def get_default_val_metric_name(self, metric: Union[Metric, Type]) -> str: - prefix = self.val_metric_prefix - mn = Task.get_metric_name(metric) - return f"{prefix}{mn}" - - @staticmethod - def get_metric_name(metric: Union[Metric, Type]) -> str: - if isinstance(metric, Metric): - return type(metric).__name__.lower() - elif isinstance(metric, Type): - return metric.__name__.lower() - else: - msg = "get_metric_name only accepts Metric and Type arguments." - raise ValueError(msg) - - def setup_validation_metric(self) -> Metric: - metricsarg = self.metrics - if self.metrics is None: - metricsarg = self.default_validation_metric - - # Convert metricargs to a list, if it is a single Metric - if isinstance(metricsarg, Metric): - metricsarg = [metricsarg] - # Convert metricsarg to a dict, now that it is a list - # If the metrics' names are not given, generate them automatically - if not isinstance(metricsarg, dict): - metricsarg = {Task.get_metric_name(m): m for m in metricsarg} - return MetricCollection(metricsarg, prefix=self.val_metric_prefix) - def train__iter__(self): # will become train_dataset.__iter__ method msg = f"Missing '{self.__class__.__name__}.train__iter__' method." @@ -443,6 +405,24 @@ def validation_step(self, batch, batch_idx: int): def validation_epoch_end(self, outputs): pass + def default_metric(self) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Default validation metric""" + msg = f"Missing '{self.__class__.__name__}.default_metric' method." + raise NotImplementedError(msg) + + @cached_property + def metric(self) -> MetricCollection: + if self._metric is None: + self._metric = self.default_metric() + + prefix = f"{self.__class__.__name__}-" + if hasattr(self.protocol, "name"): + # "." has a special meaning for pytorch-lightning checkpointing + # so we replace any encountered "." by "_" in protocol names + prefix += f"{self.protocol.name.replace('.', '_')}-" + + return MetricCollection(self._metric, prefix=prefix) + @property def val_monitor(self): """Quantity (and direction) to monitor @@ -461,7 +441,6 @@ def val_monitor(self): pytorch_lightning.callbacks.ModelCheckpoint pytorch_lightning.callbacks.EarlyStopping """ - if self.has_validation: - return f"{self.ACRONYM}@val_loss", "min" - else: - return None, "min" + + name, metric = next(iter(self.metric.items())) + return name, "max" if metric.higher_is_better else "min" diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index bcf9dea31..74e09e052 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,11 +26,11 @@ from typing import Dict, Sequence, Union import pytorch_metric_learning.losses +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Task -from pyannote.database import Protocol from .mixins import SupervisedRepresentationLearningTaskMixin @@ -70,6 +70,9 @@ class SupervisedRepresentationLearningWithArcFace( augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ ACRONYM = "arcface" @@ -90,7 +93,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): self.num_chunks_per_class = num_chunks_per_class @@ -107,7 +110,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) def setup_loss_func(self): diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index 015137845..170b1af65 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,16 +25,16 @@ import torch import torch.nn.functional as F -from torchmetrics import AUROC, Metric -from tqdm import tqdm - -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.audio.utils.random import create_rng_for_worker from pyannote.core import Segment from pyannote.database.protocol import ( SpeakerDiarizationProtocol, SpeakerVerificationProtocol, ) +from torchmetrics import AUROC, Metric +from tqdm import tqdm + +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.utils.random import create_rng_for_worker class SupervisedRepresentationLearningTaskMixin: @@ -124,8 +124,7 @@ def setup(self, stage: Optional[str] = None): if isinstance(self.protocol, SpeakerVerificationProtocol): self._validation = list(self.protocol.development_trial()) - @property - def default_validation_metric( + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: return AUROC(compute_on_step=False) @@ -290,17 +289,3 @@ def validation_step(self, batch, batch_idx: int): elif isinstance(self.protocol, SpeakerDiarizationProtocol): pass - - @property - def val_monitor(self): - - if self.has_validation and self.metrics is None: - - if isinstance(self.protocol, SpeakerVerificationProtocol): - return Task.get_default_val_metric_name(AUROC), "max" - - elif isinstance(self.protocol, SpeakerDiarizationProtocol): - return None, "min" - - else: - return None, "min" diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index d45267748..0902d9267 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,13 +28,13 @@ import matplotlib.pyplot as plt import numpy as np import torch +from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature from torchmetrics import AUROC, Metric from typing_extensions import Literal from pyannote.audio.core.io import Audio, AudioFile from pyannote.audio.core.task import Problem from pyannote.audio.utils.random import create_rng_for_worker -from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature class SegmentationTaskMixin: @@ -121,14 +121,10 @@ def setup(self, stage: Optional[str] = None): random.shuffle(self._validation) - @property - def default_validation_metric( + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - """Setup default validation metric - - Use macro-average of area under the ROC curve - """ + """Returns macro-average of the area under the ROC curve""" num_classes = len(self.specifications.classes) return AUROC(num_classes, pos_label=1, average="macro", compute_on_step=False) @@ -489,8 +485,8 @@ def validation_step(self, batch, batch_idx: int): # preds: shape (batch_size, num_frames, 1), type float # torchmetrics expects: - # target: shape (N,), type binary - # preds: shape (N,), type float + # target: shape (batch_size,), type binary + # preds: shape (batch_size,), type float self.model.validation_metric( preds.reshape(-1), @@ -502,8 +498,8 @@ def validation_step(self, batch, batch_idx: int): # preds: shape (batch_size, num_frames, num_classes), type float # torchmetrics expects - # target: shape (N, ), type binary - # preds: shape (N, ), type float + # target: shape (batch_size, num_classes, ...), type binary + # preds: shape (batch_size, num_classes, ...), type float self.model.validation_metric( torch.transpose(preds, 1, 2), @@ -511,13 +507,6 @@ def validation_step(self, batch, batch_idx: int): ) elif self.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - # target: shape (batch_size, num_frames, num_classes), type binary - # preds: shape (batch_size, num_frames, num_classes), type float - - # torchmetrics expects: - # target: shape (N, ), type int - # preds: shape (N, num_classes), type float - # TODO: implement when pyannote.audio gets its first mono-label segmentation task raise NotImplementedError() @@ -599,10 +588,3 @@ def validation_step(self, batch, batch_idx: int): ) plt.close(fig) - - @property - def val_monitor(self): - if self.has_validation and self.metrics is None: - return self.get_default_val_metric_name(AUROC), "max" - else: - return None, "min" diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 07423d883..4f01b7480 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,12 +24,12 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class OverlappedSpeechDetection(SegmentationTaskMixin, Task): @@ -84,6 +84,9 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ ACRONYM = "osd" @@ -102,7 +105,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -113,7 +116,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.specifications = Specifications( diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index a1fb8bbf6..ee638946e 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ import numpy as np import torch +from pyannote.core import SlidingWindow +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from typing_extensions import Literal @@ -33,8 +35,6 @@ from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss from pyannote.audio.utils.permutation import permutate -from pyannote.core import SlidingWindow -from pyannote.database import Protocol class Segmentation(SegmentationTaskMixin, Task): @@ -81,8 +81,13 @@ class Segmentation(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + loss : {"bce", "mse"}, optional + Permutation-invariant segmentation loss. Defaults to "bce". vad_loss : {"bce", "mse"}, optional Add voice activity detection loss. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). Reference ---------- @@ -110,7 +115,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -121,7 +126,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.max_num_speakers = max_num_speakers @@ -387,13 +392,13 @@ def validation_postprocess(self, y, y_pred): def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): """Evaluate a segmentation model""" + from pyannote.database import FileFinder, get_protocol from rich.progress import Progress from pyannote.audio import Inference from pyannote.audio.pipelines.utils import get_devices from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate from pyannote.audio.utils.signal import binarize - from pyannote.database import FileFinder, get_protocol (device,) = get_devices(needs=1) metric = DiscreteDiarizationErrorRate() diff --git a/pyannote/audio/tasks/segmentation/speaker_change_detection.py b/pyannote/audio/tasks/segmentation/speaker_change_detection.py index 476fc327b..8e9fdd572 100644 --- a/pyannote/audio/tasks/segmentation/speaker_change_detection.py +++ b/pyannote/audio/tasks/segmentation/speaker_change_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,12 +24,12 @@ import numpy as np import scipy.signal +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerChangeDetection(SegmentationTaskMixin, Task): @@ -77,6 +77,9 @@ class SpeakerChangeDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ ACRONYM = "scd" @@ -93,7 +96,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -104,7 +107,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/speaker_tracking.py b/pyannote/audio/tasks/segmentation/speaker_tracking.py index 5b06f2350..c572fcb26 100644 --- a/pyannote/audio/tasks/segmentation/speaker_tracking.py +++ b/pyannote/audio/tasks/segmentation/speaker_tracking.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerTracking(SegmentationTaskMixin, Task): @@ -71,6 +71,9 @@ class SpeakerTracking(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ ACRONYM = "spk" @@ -86,7 +89,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -97,7 +100,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 52743c094..d2c254ee1 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,12 +23,12 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class VoiceActivityDetection(SegmentationTaskMixin, Task): @@ -70,6 +70,9 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ ACRONYM = "vad" @@ -85,7 +88,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -96,7 +99,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance From 37ad9f606ae36d7a7af76e8b3701c0e5dc8a4d03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 14 Mar 2022 13:26:44 +0100 Subject: [PATCH 04/20] chore: reorganizing/renaming things a bit Also: getting rid of AUDER unless we really need it... --- pyannote/audio/torchmetrics/__init__.py | 43 ++++-- pyannote/audio/torchmetrics/audio/__init__.py | 36 +++++ pyannote/audio/torchmetrics/audio/auder.py | 74 ---------- pyannote/audio/torchmetrics/audio/der.py | 60 -------- .../audio/diarization_error_rate.py | 119 ++++++++++++++++ .../audio/torchmetrics/functional/__init__.py | 21 +++ .../torchmetrics/functional/audio/__init__.py | 21 +++ .../torchmetrics/functional/audio/auder.py | 60 -------- .../torchmetrics/functional/audio/der.py | 76 ----------- .../audio/diarization_error_rate.py | 128 ++++++++++++++++++ 10 files changed, 357 insertions(+), 281 deletions(-) delete mode 100644 pyannote/audio/torchmetrics/audio/auder.py delete mode 100644 pyannote/audio/torchmetrics/audio/der.py create mode 100644 pyannote/audio/torchmetrics/audio/diarization_error_rate.py delete mode 100644 pyannote/audio/torchmetrics/functional/audio/auder.py delete mode 100644 pyannote/audio/torchmetrics/functional/audio/der.py create mode 100644 pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py diff --git a/pyannote/audio/torchmetrics/__init__.py b/pyannote/audio/torchmetrics/__init__.py index 94efbb4b0..27513524b 100644 --- a/pyannote/audio/torchmetrics/__init__.py +++ b/pyannote/audio/torchmetrics/__init__.py @@ -1,15 +1,36 @@ -from pyannote.audio.torchmetrics.audio.auder import AUDER -from pyannote.audio.torchmetrics.audio.der import ( - DER, - ConfusionMetric, - FalseAlarmMetric, - MissedDetectionMetric, +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .audio.diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, ) __all__ = [ - "AUDER", - "DER", - "FalseAlarmMetric", - "MissedDetectionMetric", - "ConfusionMetric", + "DiarizationErrorRate", + "FalseAlarmRate", + "MissedDetectionRate", + "SpeakerConfusionRate", ] diff --git a/pyannote/audio/torchmetrics/audio/__init__.py b/pyannote/audio/torchmetrics/audio/__init__.py index e69de29bb..e500a2655 100644 --- a/pyannote/audio/torchmetrics/audio/__init__.py +++ b/pyannote/audio/torchmetrics/audio/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "SpeakerConfusionRate", + "MissedDetectionRate", + "FalseAlarmRate", +] diff --git a/pyannote/audio/torchmetrics/audio/auder.py b/pyannote/audio/torchmetrics/audio/auder.py deleted file mode 100644 index 4249a5342..000000000 --- a/pyannote/audio/torchmetrics/audio/auder.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch import Tensor -from torchmetrics import Metric - -from pyannote.audio.torchmetrics.functional.audio.auder import ( - _auder_compute, - _auder_update, -) - - -class AUDER(Metric): - """Area Under the Diarization Error Rate. - Approximates the area under the curve of the DER when varying its threshold value. - - Expects preds and target tensors of the shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) in its update. - - Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. - """ - - higher_is_better = False - is_differentiable = False - - def __init__(self, steps=31, threshold_min=0.0, threshold_max=1.0, unit_area=True): - super().__init__() - - if threshold_max < threshold_min: - raise ValueError( - f"Illegal value : threshold_max ({threshold_max}) < threshold_min ({threshold_min})" - ) - - self.threshold_min = threshold_min - self.threshold_max = threshold_max - self.unit_area = unit_area - self.steps = steps - - self.add_state( - "false_alarm", - default=torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "missed_detection", - torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "confusion", - torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "total", torch.zeros(self.steps, dtype=torch.float), dist_reduce_fx="sum" - ) - - def update( - self, - preds: Tensor, - target: Tensor, - ): - fa, md, conf, total = _auder_update( - preds, target, self.steps, self.threshold_min, self.threshold_max - ) - self.false_alarm += fa - self.missed_detection += md - self.confusion += conf - self.total += total - - def compute(self): - dx = ( - (self.threshold_max - self.threshold_min) if not self.unit_area else 1.0 - ) / (self.steps - 1) - return _auder_compute( - self.false_alarm, self.missed_detection, self.confusion, self.total, dx - ) diff --git a/pyannote/audio/torchmetrics/audio/der.py b/pyannote/audio/torchmetrics/audio/der.py deleted file mode 100644 index 14b944aee..000000000 --- a/pyannote/audio/torchmetrics/audio/der.py +++ /dev/null @@ -1,60 +0,0 @@ -from torch import Tensor, tensor -from torchmetrics import Metric - -from pyannote.audio.torchmetrics.functional.audio.der import _der_compute, _der_update - - -class DER(Metric): - """ - Compute Diarization Error Rate on discretized annotations. - - Expects preds and target tensors of the shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) in its update. - - Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. - """ - - higher_is_better = False - is_differentiable = False - - def __init__(self, threshold: float = 0.5): - super().__init__() - - self.threshold = threshold - - self.add_state("false_alarm", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("missed_detection", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("confusion", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - - def update( - self, - preds: Tensor, - target: Tensor, - ) -> None: - false_alarm, missed_detection, confusion, total = _der_update( - preds, target, self.threshold - ) - self.false_alarm += false_alarm - self.missed_detection += missed_detection - self.confusion += confusion - self.total += total - - def compute(self): - return _der_compute( - self.false_alarm, self.missed_detection, self.confusion, self.total - ) - - -class ConfusionMetric(DER): - def compute(self): - return self.confusion / self.total - - -class FalseAlarmMetric(DER): - def compute(self): - return self.false_alarm / self.total - - -class MissedDetectionMetric(DER): - def compute(self): - return self.missed_detection / self.total diff --git a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py new file mode 100644 index 000000000..c294eb3a3 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py @@ -0,0 +1,119 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torchmetrics import Metric + +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + _der_compute, + _der_update, +) + + +class DiarizationErrorRate(Metric): + """Diarization error rate + + Parameters + ---------- + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Notes + ----- + While pyannote.audio conventions is to store speaker activations with + (batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric + expects them to be shaped as (batch_size, num_speakers, num_frames) tensors. + """ + + higher_is_better = False + is_differentiable = False + + def __init__(self, threshold: float = 0.5): + super().__init__() + + self.threshold = threshold + + self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state( + "missed_detection", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state( + "speaker_confusion", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("speech_total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + target: torch.Tensor, + ) -> None: + """Compute and accumulate components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=self.threshold + ) + self.false_alarm += false_alarm + self.missed_detection += missed_detection + self.speaker_confusion += speaker_confusion + self.speech_total += speech_total + + def compute(self): + return _der_compute( + self.false_alarm, + self.missed_detection, + self.speaker_confusion, + self.speech_total, + ) + + +class SpeakerConfusionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.speaker_confusion / self.speech_total + + +class FalseAlarmRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.false_alarm / self.speech_total + + +class MissedDetectionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.missed_detection / self.speech_total diff --git a/pyannote/audio/torchmetrics/functional/__init__.py b/pyannote/audio/torchmetrics/functional/__init__.py index e69de29bb..67b544284 100644 --- a/pyannote/audio/torchmetrics/functional/__init__.py +++ b/pyannote/audio/torchmetrics/functional/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/__init__.py b/pyannote/audio/torchmetrics/functional/audio/__init__.py index e69de29bb..67b544284 100644 --- a/pyannote/audio/torchmetrics/functional/audio/__init__.py +++ b/pyannote/audio/torchmetrics/functional/audio/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/auder.py b/pyannote/audio/torchmetrics/functional/audio/auder.py deleted file mode 100644 index 0d4bbc8cf..000000000 --- a/pyannote/audio/torchmetrics/functional/audio/auder.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -from torch import Tensor - -from pyannote.audio.torchmetrics.functional.audio.der import ( - _check_valid_tensors, - _der_update, -) - - -def _auder_update( - preds: Tensor, - target: Tensor, - steps: int, - tmin: float, - tmax: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - _check_valid_tensors(preds, target) - - false_alarm = torch.zeros(steps, dtype=torch.float, device=preds.device) - missed_detection = torch.zeros(steps, dtype=torch.float, device=preds.device) - confusion = torch.zeros(steps, dtype=torch.float, device=preds.device) - total = torch.zeros(steps, dtype=torch.float, device=preds.device) - - linspace = np.linspace(tmin, tmax, steps) - for i in range(steps): - threshold = linspace[i] - - der_fa, der_md, der_conf, der_total = _der_update(preds, target, threshold) - false_alarm[i] += der_fa - missed_detection[i] += der_md - confusion[i] += der_conf - total[i] += der_total - return false_alarm, missed_detection, confusion, total - - -def _auder_compute( - false_alarm: Tensor, - missed_detection: Tensor, - confusion: Tensor, - total: Tensor, - dx: float, -) -> Tensor: - ders = (false_alarm + missed_detection + confusion) / total - return torch.trapezoid(ders, dx=dx) - - -def auder( - preds: Tensor, - target: Tensor, - steps: int, - tmin: float, - tmax: float, - unit_area: bool, -): - fa, md, conf, total = _auder_update(preds, target, steps, tmin, tmax) - dx = ((tmax - tmin) if not unit_area else 1.0) / (steps - 1) - return _auder_update(fa, md, conf, total, dx) diff --git a/pyannote/audio/torchmetrics/functional/audio/der.py b/pyannote/audio/torchmetrics/functional/audio/der.py deleted file mode 100644 index 91fb51a35..000000000 --- a/pyannote/audio/torchmetrics/functional/audio/der.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Tuple - -import torch -from torch import Tensor - -from pyannote.audio.utils.permutation import permutate - - -def _check_valid_tensors(preds: Tensor, target: Tensor): - """Check both tensors have shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) with the same NUM_BATCH and NUM_FRAMES.""" - if len(preds.shape) != 3 or len(target.shape) != 3: - msg = f"Wrong shape ({tuple(target.shape)} or {tuple(preds.shape)}), expected (NUM_BATCH, NUM_CLASSES, NUM_FRAMES)." - raise ValueError(msg) - - batch_size, _, num_samples = target.shape - batch_size_, _, num_samples_ = preds.shape - if batch_size != batch_size_ or num_samples != num_samples_: - msg = f"Shape mismatch: {tuple(target.shape)} vs. {tuple(preds.shape)}. Both tensors should have the same NUM_BATCH and NUM_FRAMES." - raise ValueError(msg) - - -def _der_update( - preds: Tensor, target: Tensor, threshold: float -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Compute the false alarm, missed detection, confusion and total values. - - Parameters - ---------- - preds : torch.Tensor - preds torch.tensor of shape (B,C,F) - target : torch.Tensor - preds torch.tensor of shape (B,C,F) (must only contain 0s and 1s) - threshold : float - threshold to discretize preds - - Returns - ------- - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - Tensors with 1 item for false alarm, missed detection, confusion, and total - """ - - _check_valid_tensors(preds, target) - - preds_bin = (preds > threshold).float() - - # convert to/from pyannote's tensor ordering (batch,frames,class) (instead of (batch,class,frames)) - hypothesis, _ = permutate( - torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) - ) - hypothesis = torch.transpose(hypothesis, 1, 2) - - detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) - false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) - missed_detection = torch.maximum( - -detection_error, torch.zeros_like(detection_error) - ) - - confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm - - false_alarm = torch.sum(false_alarm) - missed_detection = torch.sum(missed_detection) - confusion = torch.sum(confusion) - total = 1.0 * torch.sum(target) - - return false_alarm, missed_detection, confusion, total - - -def _der_compute( - false_alarm: Tensor, missed_detection: Tensor, confusion: Tensor, total: Tensor -) -> Tensor: - return (false_alarm + missed_detection + confusion) / total - - -def der(preds: Tensor, target: Tensor, threshold: float = 0.5): - fa, md, conf, total = _der_update(preds, target) - return _der_compute(fa, md, conf, total) diff --git a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py new file mode 100644 index 000000000..77b4f3f3c --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Tuple + +import torch + +from pyannote.audio.utils.permutation import permutate + + +def _der_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + # TODO: consider doing the permutation before the binarization + # in order to improve robustness to mis-calibration. + preds_bin = (preds > threshold).float() + + # convert to/from "permutate" expected shapes + hypothesis, _ = permutate( + torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) + ) + hypothesis = torch.transpose(hypothesis, 1, 2) + + detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) + false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) + missed_detection = torch.maximum( + -detection_error, torch.zeros_like(detection_error) + ) + + speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm + + false_alarm = torch.sum(false_alarm) + missed_detection = torch.sum(missed_detection) + speaker_confusion = torch.sum(speaker_confusion) + speech_total = 1.0 * torch.sum(target) + + return false_alarm, missed_detection, speaker_confusion, speech_total + + +def _der_compute( + false_alarm: torch.Tensor, + missed_detection: torch.Tensor, + speaker_confusion: torch.Tensor, + speech_total: torch.Tensor, +) -> torch.Tensor: + """Compute diarization error rate from its components + + Parameters + ---------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components, in number of frames. + + Returns + ------- + der : torch.Tensor + Diarization error rate. + """ + + # TODO: handle corner case where speech_total == 0 + return (false_alarm + missed_detection + speaker_confusion) / speech_total + + +def diarization_error_rate( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> torch.Tensor: + """Compute diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold to binarize predictions. Defaults to 0.5. + + Returns + ------- + der : torch.Tensor + Aggregated diarization error rate + """ + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=threshold + ) + return _der_compute(false_alarm, missed_detection, speaker_confusion, speech_total) From 1dc198dba860602cbc54938ca79b714001f58357 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Mon, 14 Mar 2022 15:59:04 +0100 Subject: [PATCH 05/20] feat: add DER torchmetrics (#909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hervé BREDIN --- pyannote/audio/torchmetrics/__init__.py | 36 +++++ pyannote/audio/torchmetrics/audio/__init__.py | 36 +++++ .../audio/diarization_error_rate.py | 119 ++++++++++++++++ .../audio/torchmetrics/functional/__init__.py | 21 +++ .../torchmetrics/functional/audio/__init__.py | 21 +++ .../audio/diarization_error_rate.py | 128 ++++++++++++++++++ 6 files changed, 361 insertions(+) create mode 100644 pyannote/audio/torchmetrics/__init__.py create mode 100644 pyannote/audio/torchmetrics/audio/__init__.py create mode 100644 pyannote/audio/torchmetrics/audio/diarization_error_rate.py create mode 100644 pyannote/audio/torchmetrics/functional/__init__.py create mode 100644 pyannote/audio/torchmetrics/functional/audio/__init__.py create mode 100644 pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py diff --git a/pyannote/audio/torchmetrics/__init__.py b/pyannote/audio/torchmetrics/__init__.py new file mode 100644 index 000000000..27513524b --- /dev/null +++ b/pyannote/audio/torchmetrics/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .audio.diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "FalseAlarmRate", + "MissedDetectionRate", + "SpeakerConfusionRate", +] diff --git a/pyannote/audio/torchmetrics/audio/__init__.py b/pyannote/audio/torchmetrics/audio/__init__.py new file mode 100644 index 000000000..e500a2655 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "SpeakerConfusionRate", + "MissedDetectionRate", + "FalseAlarmRate", +] diff --git a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py new file mode 100644 index 000000000..c294eb3a3 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py @@ -0,0 +1,119 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torchmetrics import Metric + +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + _der_compute, + _der_update, +) + + +class DiarizationErrorRate(Metric): + """Diarization error rate + + Parameters + ---------- + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Notes + ----- + While pyannote.audio conventions is to store speaker activations with + (batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric + expects them to be shaped as (batch_size, num_speakers, num_frames) tensors. + """ + + higher_is_better = False + is_differentiable = False + + def __init__(self, threshold: float = 0.5): + super().__init__() + + self.threshold = threshold + + self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state( + "missed_detection", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state( + "speaker_confusion", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("speech_total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + target: torch.Tensor, + ) -> None: + """Compute and accumulate components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=self.threshold + ) + self.false_alarm += false_alarm + self.missed_detection += missed_detection + self.speaker_confusion += speaker_confusion + self.speech_total += speech_total + + def compute(self): + return _der_compute( + self.false_alarm, + self.missed_detection, + self.speaker_confusion, + self.speech_total, + ) + + +class SpeakerConfusionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.speaker_confusion / self.speech_total + + +class FalseAlarmRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.false_alarm / self.speech_total + + +class MissedDetectionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.missed_detection / self.speech_total diff --git a/pyannote/audio/torchmetrics/functional/__init__.py b/pyannote/audio/torchmetrics/functional/__init__.py new file mode 100644 index 000000000..67b544284 --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/__init__.py b/pyannote/audio/torchmetrics/functional/audio/__init__.py new file mode 100644 index 000000000..67b544284 --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py new file mode 100644 index 000000000..77b4f3f3c --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Tuple + +import torch + +from pyannote.audio.utils.permutation import permutate + + +def _der_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + # TODO: consider doing the permutation before the binarization + # in order to improve robustness to mis-calibration. + preds_bin = (preds > threshold).float() + + # convert to/from "permutate" expected shapes + hypothesis, _ = permutate( + torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) + ) + hypothesis = torch.transpose(hypothesis, 1, 2) + + detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) + false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) + missed_detection = torch.maximum( + -detection_error, torch.zeros_like(detection_error) + ) + + speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm + + false_alarm = torch.sum(false_alarm) + missed_detection = torch.sum(missed_detection) + speaker_confusion = torch.sum(speaker_confusion) + speech_total = 1.0 * torch.sum(target) + + return false_alarm, missed_detection, speaker_confusion, speech_total + + +def _der_compute( + false_alarm: torch.Tensor, + missed_detection: torch.Tensor, + speaker_confusion: torch.Tensor, + speech_total: torch.Tensor, +) -> torch.Tensor: + """Compute diarization error rate from its components + + Parameters + ---------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components, in number of frames. + + Returns + ------- + der : torch.Tensor + Diarization error rate. + """ + + # TODO: handle corner case where speech_total == 0 + return (false_alarm + missed_detection + speaker_confusion) / speech_total + + +def diarization_error_rate( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> torch.Tensor: + """Compute diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold to binarize predictions. Defaults to 0.5. + + Returns + ------- + der : torch.Tensor + Aggregated diarization error rate + """ + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=threshold + ) + return _der_compute(false_alarm, missed_detection, speaker_confusion, speech_total) From 9c54772f023d89e84a7494fd977c6ff8f4da25a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Mon, 14 Mar 2022 16:25:50 +0100 Subject: [PATCH 06/20] feat: switch to DER metrics for validating segmentation task --- pyannote/audio/tasks/segmentation/mixins.py | 1 + .../audio/tasks/segmentation/segmentation.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 0902d9267..bef518adf 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -466,6 +466,7 @@ def validation_step(self, batch, batch_idx: int): # y_pred = (batch_size, num_frames, num_classes) # postprocess + # TODO: remove this because metrics should take care of postprocessing y_pred = self.validation_postprocess(y, y_pred) # - remove warm-up frames diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index ee638946e..ff89c98c8 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -33,6 +33,12 @@ from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss from pyannote.audio.utils.permutation import permutate @@ -384,9 +390,16 @@ def training_step(self, batch, batch_idx: int): ) return {"loss": loss} - def validation_postprocess(self, y, y_pred): - permutated_y_pred, _ = permutate(y, y_pred) - return permutated_y_pred + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + return [ + DiarizationErrorRate(), + SpeakerConfusionRate(), + MissedDetectionRate(), + FalseAlarmRate(), + ] def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): From 3ab066e2f23fc13bc2364a3d6e9e23ec37f478c9 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 10:15:50 +0100 Subject: [PATCH 07/20] cleaner way to check for pseudolabel use --- .../segmentation/unsupervised_segmentation.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index f8ef2d541..693cf56ea 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -80,11 +80,16 @@ def get_model_output(self, model: Model, waveforms: torch.Tensor): result = torch.round(result).type(torch.int8) return result + def use_pseudolabels(self, stage: Literal["train", "val"]): + return (stage == "train" and self.fake_in_train) or ( + stage == "val" and self.fake_in_val + ) + def collate_fn(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: + if self.use_pseudolabels("train"): teacher_input = collated_batch["X"] if self.augmentation_model is not None: teacher_input = self.augmentation_model( @@ -102,7 +107,7 @@ def collate_fn_val(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: + if self.use_pseudolabels("val"): teacher_input = collated_batch["X"] collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) @@ -139,11 +144,8 @@ def prepare_chunk( ... """ - use_annotations = (stage == "train" and not self.fake_in_train) or ( - stage == "val" and not self.fake_in_val - ) sample = super().prepare_chunk( - file, chunk, duration=duration, stage=stage, use_annotations=use_annotations + file, chunk, duration=duration, stage=stage, use_annotations=True ) return sample From c3d6758dc8a7b2d43e21e032c553733dff58ce2b Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 10:21:03 +0100 Subject: [PATCH 08/20] remove AUDER histogram debug method in UnsupervisedSegmentation --- .../segmentation/unsupervised_segmentation.py | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index 693cf56ea..e6e71fea5 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -9,7 +9,6 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate -from torch.utils.tensorboard import SummaryWriter from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from typing_extensions import Literal @@ -18,7 +17,6 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task, ValDataset from pyannote.audio.tasks import Segmentation -from pyannote.audio.torchmetrics import AUDER class UnsupervisedSegmentation(Segmentation, Task): @@ -162,45 +160,6 @@ def val_dataloader(self) -> Optional[DataLoader]: else: return None - def validation_epoch_end(self, outputs): - super().validation_epoch_end(outputs) - - # TODO : remove (temp debug) - for key, metric in self.model.validation_metric.items(): - # print(key) - if isinstance(metric, AUDER): - ders = ( - metric.false_alarm + metric.missed_detection + metric.confusion - ) / metric.total - # print(ders) - # print(metric.linspace) - SAMPLE = 100 - data = [] - bins = [-0.0001] - linspace = np.linspace( - metric.threshold_min, metric.threshold_max, metric.steps - ) - for i in range(metric.steps): - data += [linspace[i]] * int(ders[i] * SAMPLE) - if i > 0: - bins += [(linspace[i] + linspace[i - 1]) / 2.0] - bins += [1.0001] - values = torch.tensor(data).reshape(-1) - # print(bins) - # print(data) - experiment: SummaryWriter = self.model.logger.experiment - - experiment.add_histogram( - "der_curve", - global_step=self.model.current_epoch, - values=values, - bins=bins, - ) - - # fig, ax = plt.subplots() # Create a figure containing a single axes. - # ax.plot(metric.linspace, ders.cpu()) - # experiment.add_figure("testplot", fig, global_step=0) - class TeacherUpdate(Callback): def __init__( From cf6bd3299db30d073e4d741c49fb5e2e20259036 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 10:23:03 +0100 Subject: [PATCH 09/20] update UnsupervisedSegmention to take a metric arg (instead of metrics) --- .../audio/tasks/segmentation/unsupervised_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index e6e71fea5..aba325e93 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -40,7 +40,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( # Mixin params @@ -58,7 +58,7 @@ def __init__( weight=weight, loss=loss, vad_loss=vad_loss, - metrics=metrics, + metric=metric, ) self.teacher = model From 76c1eda82ba6809c650b76f9aabcc9ffd66492b1 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 13:44:01 +0100 Subject: [PATCH 10/20] fix: fix SegmentationTaskMixin.prepare_chunk typing (#901) --- pyannote/audio/tasks/segmentation/mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 6878e2baf..94c18ca8e 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -23,7 +23,7 @@ import math import random import warnings -from typing import List, Optional, Text, Tuple +from typing import List, Optional, Text import matplotlib.pyplot as plt import numpy as np @@ -158,7 +158,7 @@ def prepare_chunk( chunk: Segment, duration: float = None, stage: Literal["train", "val"] = "train", - ) -> Tuple[np.ndarray, np.ndarray, List[Text]]: + ) -> dict: """Extract audio chunk and corresponding frame-wise labels Parameters From 01904eb4ca41e89f04c81b88d652146e79b1347f Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 16:49:04 +0100 Subject: [PATCH 11/20] * feat: add support for custom validation metrics (#913) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * BREAKING: use task and protocol full names in logs and checkpoint paths * BREAKING: use diarization error rate for validation of segmentation Co-authored-by: Hervé BREDIN --- pyannote/audio/cli/train.py | 15 +++-- pyannote/audio/core/model.py | 4 +- pyannote/audio/core/task.py | 53 ++++++++++++---- pyannote/audio/tasks/embedding/arcface.py | 14 +++-- pyannote/audio/tasks/embedding/mixins.py | 37 ++++------- pyannote/audio/tasks/segmentation/mixins.py | 63 ++++++++----------- .../overlapped_speech_detection.py | 14 +++-- .../audio/tasks/segmentation/segmentation.py | 63 +++++++++++++------ .../segmentation/speaker_change_detection.py | 14 +++-- .../tasks/segmentation/speaker_tracking.py | 14 +++-- .../segmentation/voice_activity_detection.py | 14 +++-- tutorials/add_your_own_task.ipynb | 2 - 12 files changed, 177 insertions(+), 130 deletions(-) diff --git a/pyannote/audio/cli/train.py b/pyannote/audio/cli/train.py index b8547e446..e7ab2225a 100644 --- a/pyannote/audio/cli/train.py +++ b/pyannote/audio/cli/train.py @@ -27,20 +27,19 @@ import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf + +# from pyannote.audio.core.callback import GraduallyUnfreeze +from pyannote.database import FileFinder, get_protocol from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar, ) - from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything from torch_audiomentations.utils.config import from_dict as get_augmentation -# from pyannote.audio.core.callback import GraduallyUnfreeze -from pyannote.database import FileFinder, get_protocol - @hydra.main(config_path="train_config", config_name="config") def train(cfg: DictConfig) -> Optional[float]: @@ -67,15 +66,15 @@ def train(cfg: DictConfig) -> Optional[float]: # instantiate task and validation metric task = instantiate(cfg.task, protocol, augmentation=augmentation) - # validation metric to monitor (and its direction: min or max) - monitor, direction = task.val_monitor - # instantiate model fine_tuning = cfg.model["_target_"] == "pyannote.audio.cli.pretrained" model = instantiate(cfg.model) model.task = task model.setup(stage="fit") + # validation metric to monitor (and its direction: min or max) + monitor, direction = task.val_monitor + # number of batches in one epoch num_batches_per_epoch = model.task.train__len__() // model.task.batch_size @@ -90,6 +89,7 @@ def configure_optimizers(self): num_batches_per_epoch=num_batches_per_epoch, ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + model.configure_optimizers = MethodType(configure_optimizers, model) callbacks = [RichProgressBar(), LearningRateMonitor()] @@ -155,4 +155,3 @@ def configure_optimizers(self): if __name__ == "__main__": train() - diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 4e9bfe092..c36b6927b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -34,6 +34,7 @@ import torch.nn as nn import torch.optim from huggingface_hub import cached_download, hf_hub_url +from pyannote.core import SlidingWindow from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.model_summary import ModelSummary from semver import VersionInfo @@ -42,7 +43,6 @@ from pyannote.audio import __version__ from pyannote.audio.core.io import Audio from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.core import SlidingWindow CACHE_DIR = os.getenv( "PYANNOTE_CACHE", @@ -385,7 +385,7 @@ def setup(self, stage=None): # setup custom loss function self.task.setup_loss_func() # setup custom validation metrics - validation_metric = self.task.setup_validation_metric() + validation_metric = self.task.metric if validation_metric is not None: self.validation_metric = validation_metric self.validation_metric.to(self.device) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index b42c9c90a..0233a87d8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,24 +23,30 @@ from __future__ import annotations +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property + import multiprocessing import sys import warnings from dataclasses import dataclass from enum import Enum from numbers import Number -from typing import List, Optional, Text, Tuple, Union +from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import pytorch_lightning as pl import torch +from pyannote.database import Protocol from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric, MetricCollection from typing_extensions import Literal from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss from pyannote.audio.utils.protocol import check_protocol -from pyannote.database import Protocol # Type of machine learning problem @@ -151,6 +157,9 @@ class Task(pl.LightningDataModule): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to value returned by `default_metric` method. Attributes ---------- @@ -168,6 +177,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__() @@ -203,6 +213,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.augmentation = augmentation + self._metric = metric def prepare_data(self): """Use this to download and prepare data @@ -233,9 +244,6 @@ def setup(self, stage: Optional[str] = None): def setup_loss_func(self): pass - def setup_validation_metric(self): - pass - def train__iter__(self): # will become train_dataset.__iter__ method msg = f"Missing '{self.__class__.__name__}.train__iter__' method." @@ -264,6 +272,18 @@ def train_dataloader(self) -> DataLoader: collate_fn=self.collate_fn, ) + @cached_property + def logging_prefix(self): + + prefix = f"{self.__class__.__name__}-" + if hasattr(self.protocol, "name"): + # "." has a special meaning for pytorch-lightning checkpointing + # so we remove dots from protocol names + name_without_dots = "".join(self.protocol.name.split(".")) + prefix += f"{name_without_dots}-" + + return prefix + def default_loss( self, specifications: Specifications, target, prediction, weight=None ) -> torch.Tensor: @@ -353,7 +373,7 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): # compute loss loss = self.default_loss(self.specifications, y, y_pred, weight=weight) self.model.log( - f"{self.ACRONYM}@{stage}_loss", + f"{self.logging_prefix}{stage.capitalize()}Loss", loss, on_step=False, on_epoch=True, @@ -397,6 +417,18 @@ def validation_step(self, batch, batch_idx: int): def validation_epoch_end(self, outputs): pass + def default_metric(self) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Default validation metric""" + msg = f"Missing '{self.__class__.__name__}.default_metric' method." + raise NotImplementedError(msg) + + @cached_property + def metric(self) -> MetricCollection: + if self._metric is None: + self._metric = self.default_metric() + + return MetricCollection(self._metric, prefix=self.logging_prefix) + @property def val_monitor(self): """Quantity (and direction) to monitor @@ -415,7 +447,6 @@ def val_monitor(self): pytorch_lightning.callbacks.ModelCheckpoint pytorch_lightning.callbacks.EarlyStopping """ - if self.has_validation: - return f"{self.ACRONYM}@val_loss", "min" - else: - return None, "min" + + name, metric = next(iter(self.metric.items())) + return name, "max" if metric.higher_is_better else "min" diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index 33ac394e1..bb2cb1f6c 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,11 +23,14 @@ from __future__ import annotations +from typing import Dict, Sequence, Union + import pytorch_metric_learning.losses +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from pyannote.audio.core.task import Task -from pyannote.database import Protocol from .mixins import SupervisedRepresentationLearningTaskMixin @@ -67,10 +70,11 @@ class SupervisedRepresentationLearningWithArcFace( augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "arcface" - #  TODO: add a ".metric" property that tells how speaker embedding trained with this approach #  should be compared. could be a string like "cosine" or "euclidean" or a pdist/cdist-like #  callable. this ".metric" property should be propagated all the way to Inference (via the model). @@ -87,6 +91,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): self.num_chunks_per_class = num_chunks_per_class @@ -103,6 +108,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) def setup_loss_func(self): diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index a212404de..13d65edb1 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,20 +21,20 @@ # SOFTWARE. import math -from typing import Optional +from typing import Dict, Optional, Sequence, Union import torch import torch.nn.functional as F -from torchmetrics import AUROC -from tqdm import tqdm - -from pyannote.audio.core.task import Problem, Resolution, Specifications -from pyannote.audio.utils.random import create_rng_for_worker from pyannote.core import Segment from pyannote.database.protocol import ( SpeakerDiarizationProtocol, SpeakerVerificationProtocol, ) +from torchmetrics import AUROC, Metric +from tqdm import tqdm + +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.utils.random import create_rng_for_worker class SupervisedRepresentationLearningTaskMixin: @@ -124,7 +124,9 @@ def setup(self, stage: Optional[str] = None): if isinstance(self.protocol, SpeakerVerificationProtocol): self._validation = list(self.protocol.development_trial()) - def setup_validation_metric(self): + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: return AUROC(compute_on_step=False) def train__iter__(self): @@ -219,7 +221,7 @@ def training_step(self, batch, batch_idx: int): loss = self.model.loss_func(self.model(X), y) self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, @@ -277,8 +279,7 @@ def validation_step(self, batch, batch_idx: int): y_true = batch["y"] self.model.validation_metric(y_pred, y_true) - self.model.log( - f"{self.ACRONYM}@val_auroc", + self.model.log_dict( self.model.validation_metric, on_step=False, on_epoch=True, @@ -288,17 +289,3 @@ def validation_step(self, batch, batch_idx: int): elif isinstance(self.protocol, SpeakerDiarizationProtocol): pass - - @property - def val_monitor(self): - - if self.has_validation: - - if isinstance(self.protocol, SpeakerVerificationProtocol): - return f"{self.ACRONYM}@val_auroc", "max" - - elif isinstance(self.protocol, SpeakerDiarizationProtocol): - return None, "min" - - else: - return None, "min" diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 94c18ca8e..e2e6ed60a 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,17 +23,18 @@ import math import random import warnings -from typing import List, Optional, Text +from typing import Dict, List, Optional, Sequence, Text, Union import matplotlib.pyplot as plt import numpy as np -from torchmetrics import AUROC +import torch +from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature +from torchmetrics import AUROC, Metric from typing_extensions import Literal from pyannote.audio.core.io import Audio, AudioFile from pyannote.audio.core.task import Problem from pyannote.audio.utils.random import create_rng_for_worker -from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature class SegmentationTaskMixin: @@ -120,20 +121,12 @@ def setup(self, stage: Optional[str] = None): random.shuffle(self._validation) - def setup_validation_metric(self): - """Setup default validation metric - - Use macro-average of area under the ROC curve - """ - - if self.specifications.problem in [ - Problem.BINARY_CLASSIFICATION, - Problem.MULTI_LABEL_CLASSIFICATION, - ]: - num_classes = 1 - else: - num_classes = len(self.specifications.classes) + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns macro-average of the area under the ROC curve""" + num_classes = len(self.specifications.classes) return AUROC(num_classes, pos_label=1, average="macro", compute_on_step=False) def prepare_y(self, one_hot_y: np.ndarray) -> np.ndarray: @@ -473,6 +466,7 @@ def validation_step(self, batch, batch_idx: int): # y_pred = (batch_size, num_frames, num_classes) # postprocess + # TODO: remove this because metrics should take care of postprocessing y_pred = self.validation_postprocess(y, y_pred) # - remove warm-up frames @@ -492,34 +486,32 @@ def validation_step(self, batch, batch_idx: int): # preds: shape (batch_size, num_frames, 1), type float # torchmetrics expects: - # target: shape (N,), type binary - # preds: shape (N,), type float + # target: shape (batch_size,), type binary + # preds: shape (batch_size,), type float - self.model.validation_metric(preds.reshape(-1), target.reshape(-1)) + self.model.validation_metric( + preds.reshape(-1), + target.reshape(-1), + ) elif self.specifications.problem == Problem.MULTI_LABEL_CLASSIFICATION: # target: shape (batch_size, num_frames, num_classes), type binary # preds: shape (batch_size, num_frames, num_classes), type float # torchmetrics expects - # target: shape (N, ), type binary - # preds: shape (N, ), type float + # target: shape (batch_size, num_classes, ...), type binary + # preds: shape (batch_size, num_classes, ...), type float - self.model.validation_metric(preds.reshape(-1), target.reshape(-1)) + self.model.validation_metric( + torch.transpose(preds, 1, 2), + torch.transpose(target, 1, 2), + ) elif self.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - # target: shape (batch_size, num_frames, num_classes), type binary - # preds: shape (batch_size, num_frames, num_classes), type float - - # torchmetrics expects: - # target: shape (N, ), type int - # preds: shape (N, num_classes), type float - # TODO: implement when pyannote.audio gets its first mono-label segmentation task raise NotImplementedError() - self.model.log( - f"{self.ACRONYM}@val_auroc", + self.model.log_dict( self.model.validation_metric, on_step=False, on_epoch=True, @@ -593,12 +585,7 @@ def validation_step(self, batch, batch_idx: int): plt.tight_layout() self.model.logger.experiment.add_figure( - f"{self.ACRONYM}@val_samples", fig, self.model.current_epoch + f"{self.logging_prefix}ValSamples", fig, self.model.current_epoch ) plt.close(fig) - - @property - def val_monitor(self): - """Maximize validation area under ROC curve""" - return f"{self.ACRONYM}@val_auroc", "max" diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index acad96c87..15b7024f4 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,14 +21,15 @@ # SOFTWARE. -from typing import Text, Tuple, Union +from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class OverlappedSpeechDetection(SegmentationTaskMixin, Task): @@ -83,10 +84,11 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "osd" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( @@ -101,6 +103,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -111,6 +114,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) self.specifications = Specifications( diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index 76f620586..579878247 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,19 +21,26 @@ # SOFTWARE. from collections import Counter -from typing import Optional, Text, Tuple, Union +from typing import Dict, Optional, Sequence, Text, Tuple, Union import numpy as np import torch +from pyannote.core import SlidingWindow +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from typing_extensions import Literal from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss from pyannote.audio.utils.permutation import permutate -from pyannote.core import SlidingWindow -from pyannote.database import Protocol class Segmentation(SegmentationTaskMixin, Task): @@ -80,8 +87,13 @@ class Segmentation(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + loss : {"bce", "mse"}, optional + Permutation-invariant segmentation loss. Defaults to "bce". vad_loss : {"bce", "mse"}, optional Add voice activity detection loss. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). Reference ---------- @@ -90,8 +102,6 @@ class Segmentation(SegmentationTaskMixin, Task): Proc. Interspeech 2021 """ - ACRONYM = "seg" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( @@ -109,6 +119,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -119,6 +130,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) self.max_num_speakers = max_num_speakers @@ -146,7 +158,10 @@ def setup(self, stage: Optional[str] = None): start = file["annotated"][0].start end = file["annotated"][-1].end window = SlidingWindow( - start=start, end=end, duration=self.duration, step=1.0, + start=start, + end=end, + duration=self.duration, + step=1.0, ) for chunk in window: num_speakers.append(len(file["annotation"].crop(chunk).labels())) @@ -322,7 +337,8 @@ def training_step(self, batch, batch_idx: int): # frames weight weight_key = getattr(self, "weight", None) weight = batch.get( - weight_key, torch.ones(batch_size, num_frames, 1, device=self.model.device), + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), ) # (batch_size, num_frames, 1) @@ -335,7 +351,7 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss(permutated_prediction, target, weight=weight) self.model.log( - f"{self.ACRONYM}@train_seg_loss", + f"{self.logging_prefix}TrainSegLoss", seg_loss, on_step=False, on_epoch=True, @@ -352,7 +368,7 @@ def training_step(self, batch, batch_idx: int): ) self.model.log( - f"{self.ACRONYM}@train_vad_loss", + f"{self.logging_prefix}TrainVADLoss", vad_loss, on_step=False, on_epoch=True, @@ -363,29 +379,38 @@ def training_step(self, batch, batch_idx: int): loss = seg_loss + vad_loss self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) + return {"loss": loss} - def validation_postprocess(self, y, y_pred): - permutated_y_pred, _ = permutate(y, y_pred) - return permutated_y_pred + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + return [ + DiarizationErrorRate(), + SpeakerConfusionRate(), + MissedDetectionRate(), + FalseAlarmRate(), + ] def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): """Evaluate a segmentation model""" + from pyannote.database import FileFinder, get_protocol + from rich.progress import Progress + from pyannote.audio import Inference from pyannote.audio.pipelines.utils import get_devices - from pyannote.audio.utils.signal import binarize from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate - from pyannote.database import get_protocol, FileFinder - from rich.progress import Progress + from pyannote.audio.utils.signal import binarize (device,) = get_devices(needs=1) metric = DiscreteDiarizationErrorRate() @@ -400,9 +425,7 @@ def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentatio def progress_hook(completed: int, total: int): progress.update(file_task, completed=completed / total) - inference = Inference( - model, device=device, progress_hook=progress_hook - ) + inference = Inference(model, device=device, progress_hook=progress_hook) for file in files: progress.update(file_task, description=file["uri"]) diff --git a/pyannote/audio/tasks/segmentation/speaker_change_detection.py b/pyannote/audio/tasks/segmentation/speaker_change_detection.py index f11ba1e9d..0329d361e 100644 --- a/pyannote/audio/tasks/segmentation/speaker_change_detection.py +++ b/pyannote/audio/tasks/segmentation/speaker_change_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,15 +20,16 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Text, Tuple, Union +from typing import Dict, Sequence, Text, Tuple, Union import numpy as np import scipy.signal +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerChangeDetection(SegmentationTaskMixin, Task): @@ -76,10 +77,11 @@ class SpeakerChangeDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "scd" - def __init__( self, protocol: Protocol, @@ -92,6 +94,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -102,6 +105,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/speaker_tracking.py b/pyannote/audio/tasks/segmentation/speaker_tracking.py index c726f03d6..9daf48627 100644 --- a/pyannote/audio/tasks/segmentation/speaker_tracking.py +++ b/pyannote/audio/tasks/segmentation/speaker_tracking.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,14 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import List, Optional, Text, Tuple, Union +from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerTracking(SegmentationTaskMixin, Task): @@ -70,10 +71,11 @@ class SpeakerTracking(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "spk" - def __init__( self, protocol: Protocol, @@ -85,6 +87,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -95,6 +98,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 87f9c92f5..123451968 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -20,14 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Text, Tuple, Union +from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class VoiceActivityDetection(SegmentationTaskMixin, Task): @@ -69,10 +70,11 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "vad" - def __init__( self, protocol: Protocol, @@ -84,6 +86,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -94,6 +97,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, + metric=metric, ) self.balance = balance diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index 8022cbdc4..b2053f459 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -153,8 +153,6 @@ "class SoundEventDetection(Task):\n", " \"\"\"Sound event detection\"\"\"\n", "\n", - " ACRONYM = \"sed\"\n", - "\n", " def __init__(\n", " self,\n", " protocol: Protocol,\n", From 68bddc8d0f82c5cd922f13aaae61664b26160a65 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 10:24:12 +0100 Subject: [PATCH 12/20] add oracle pseudolabel filtering postprocessors --- .../segmentation/unsupervised_segmentation.py | 85 ++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index aba325e93..2e6fcfbf3 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -17,6 +17,16 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task, ValDataset from pyannote.audio.tasks import Segmentation +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + diarization_error_rate, +) + + +class PseudoLabelPostprocess: + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() class UnsupervisedSegmentation(Segmentation, Task): @@ -27,6 +37,9 @@ def __init__( fake_in_train=True, # generate fake truth in training mode fake_in_val=True, # generate fake truth in val mode augmentation_model: BaseWaveformTransform = None, + pl_postprocess: Union[ + PseudoLabelPostprocess, List[PseudoLabelPostprocess] + ] = None, # supervised params duration: float = 2.0, max_num_speakers: int = None, @@ -66,6 +79,11 @@ def __init__( self.fake_in_val = fake_in_val self.augmentation_model = augmentation_model + if isinstance(pl_postprocess, PseudoLabelPostprocess): + self.pl_postprocess = [pl_postprocess] + else: + self.pl_postprocess = pl_postprocess + self.teacher.eval() def get_model_output(self, model: Model, waveforms: torch.Tensor): @@ -88,12 +106,23 @@ def collate_fn(self, batch): # Generate annotations y with teacher if they are not provided if self.use_pseudolabels("train"): - teacher_input = collated_batch["X"] + x = collated_batch["X"] + teacher_input = x if self.augmentation_model is not None: teacher_input = self.augmentation_model( collated_batch["X"], sample_rate=self.model.hparams.sample_rate ) - collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) + pseudo_y = self.get_model_output(self.teacher, teacher_input) + + y = None + if "y" in collated_batch: + y = collated_batch["y"] + if self.pl_postprocess is not None: + for pp in self.pl_postprocess: + pseudo_y, x = pp.process(pseudo_y, y, x) + + collated_batch["y"] = pseudo_y + collated_batch["X"] = x if self.augmentation is not None: collated_batch["X"] = self.augmentation( @@ -161,6 +190,58 @@ def val_dataloader(self) -> Optional[DataLoader]: return None +def _compute_ders( + pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor +) -> Tuple[torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = torch.zeros(batch_size) + + tm_pseudo_y = pseudo_y.swapaxes(1, 2) + tm_true_y = y.swapaxes(1, 2) + for i in range(batch_size): + ders[i] = diarization_error_rate( + tm_pseudo_y[i][None, :, :], tm_true_y[i][None, :, :] + ) + + return ders + + +class DiscardPercentDer(PseudoLabelPostprocess): + def __init__(self, ratio_to_discard: float = 0.1) -> None: + self.ratio_to_discard = ratio_to_discard + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = _compute_ders(pseudo_y, y, x) + sorted_ders, sorted_indices = torch.sort(ders) + + to_discard_count = min( + batch_size, max(1, round(batch_size * self.ratio_to_discard)) + ) + pseudo_y = pseudo_y[sorted_indices][:-to_discard_count, :, :] + x = x[sorted_indices][:-to_discard_count, :, :] + + return pseudo_y, x + + +class DiscardThresholdDer(PseudoLabelPostprocess): + def __init__(self, threshold: float = 0.5) -> None: + self.threshold = threshold + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + ders = _compute_ders(pseudo_y, y, x) + + filter = torch.where(ders < self.threshold) + pseudo_y = pseudo_y[filter] + x = x[filter] + + return pseudo_y, x + + class TeacherUpdate(Callback): def __init__( self, From 3e7f2ac85b5016f9be5ddf8dbf8c41873e3d64c7 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Wed, 16 Mar 2022 09:28:19 +0100 Subject: [PATCH 13/20] fix: fix SegmentationTaskMixin.prepare_chunk typing (#901) --- pyannote/audio/tasks/segmentation/mixins.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 98a8fc0ae..bba608fd4 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -208,9 +208,10 @@ def prepare_chunk( duration: float = None, stage: Literal["train", "val"] = "train", use_annotations: bool = True, - ) -> Tuple[np.ndarray, np.ndarray, List[Text]]: + ) -> dict: """Extract audio chunk and corresponding frame-wise labels + Parameters ---------- file : AudioFile From 51d8b7755cbbca5f10040612838af93f12fc8667 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 15 Mar 2022 14:34:39 +0100 Subject: [PATCH 14/20] chore: use same logging conventions for training loss --- pyannote/audio/core/task.py | 21 ++++++++++++------- pyannote/audio/tasks/embedding/arcface.py | 2 -- pyannote/audio/tasks/embedding/mixins.py | 2 +- pyannote/audio/tasks/segmentation/mixins.py | 2 +- .../overlapped_speech_detection.py | 2 -- .../audio/tasks/segmentation/segmentation.py | 9 ++++---- .../segmentation/speaker_change_detection.py | 2 -- .../tasks/segmentation/speaker_tracking.py | 2 -- .../segmentation/voice_activity_detection.py | 2 -- tutorials/add_your_own_task.ipynb | 2 -- 10 files changed, 19 insertions(+), 27 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 9c40efa7a..db6572690 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -272,6 +272,17 @@ def train_dataloader(self) -> DataLoader: collate_fn=self.collate_fn, ) + @cached_property + def logging_prefix(self): + + prefix = f"{self.__class__.__name__}-" + if hasattr(self.protocol, "name"): + # "." has a special meaning for pytorch-lightning checkpointing + # so we replace any encountered "." by "_" in protocol names + prefix += f"{self.protocol.name.replace('.', '_')}-" + + return prefix + def default_loss( self, specifications: Specifications, target, prediction, weight=None ) -> torch.Tensor: @@ -361,7 +372,7 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): # compute loss loss = self.default_loss(self.specifications, y, y_pred, weight=weight) self.model.log( - f"{self.ACRONYM}@{stage}_loss", + f"{self.logging_prefix}{stage}-loss", loss, on_step=False, on_epoch=True, @@ -415,13 +426,7 @@ def metric(self) -> MetricCollection: if self._metric is None: self._metric = self.default_metric() - prefix = f"{self.__class__.__name__}-" - if hasattr(self.protocol, "name"): - # "." has a special meaning for pytorch-lightning checkpointing - # so we replace any encountered "." by "_" in protocol names - prefix += f"{self.protocol.name.replace('.', '_')}-" - - return MetricCollection(self._metric, prefix=prefix) + return MetricCollection(self._metric, prefix=self.logging_prefix) @property def val_monitor(self): diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index 74e09e052..bb2cb1f6c 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -75,8 +75,6 @@ class SupervisedRepresentationLearningWithArcFace( Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "arcface" - #  TODO: add a ".metric" property that tells how speaker embedding trained with this approach #  should be compared. could be a string like "cosine" or "euclidean" or a pdist/cdist-like #  callable. this ".metric" property should be propagated all the way to Inference (via the model). diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index 170b1af65..b5d668c93 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -221,7 +221,7 @@ def training_step(self, batch, batch_idx: int): loss = self.model.loss_func(self.model(X), y) self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}train-loss", loss, on_step=False, on_epoch=True, diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index bba608fd4..53600a0ff 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -622,7 +622,7 @@ def validation_step(self, batch, batch_idx: int): plt.tight_layout() self.model.logger.experiment.add_figure( - f"{self.ACRONYM}@val_samples", fig, self.model.current_epoch + f"{self.logging_prefix}val-samples", fig, self.model.current_epoch ) plt.close(fig) diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 4f01b7480..15b7024f4 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -89,8 +89,6 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task): Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "osd" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index a0fcf0c17..fc23805e6 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -102,8 +102,6 @@ class Segmentation(SegmentationTaskMixin, Task): Proc. Interspeech 2021 """ - ACRONYM = "seg" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( @@ -355,7 +353,7 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss(permutated_prediction, target, weight=weight) self.model.log( - f"{self.ACRONYM}@train_seg_loss", + f"{self.logging_prefix}train-loss-seg", seg_loss, on_step=False, on_epoch=True, @@ -372,7 +370,7 @@ def training_step(self, batch, batch_idx: int): ) self.model.log( - f"{self.ACRONYM}@train_vad_loss", + f"{self.logging_prefix}train-loss-vad", vad_loss, on_step=False, on_epoch=True, @@ -383,13 +381,14 @@ def training_step(self, batch, batch_idx: int): loss = seg_loss + vad_loss self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}train-loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) + return {"loss": loss} def default_metric( diff --git a/pyannote/audio/tasks/segmentation/speaker_change_detection.py b/pyannote/audio/tasks/segmentation/speaker_change_detection.py index 8e9fdd572..0329d361e 100644 --- a/pyannote/audio/tasks/segmentation/speaker_change_detection.py +++ b/pyannote/audio/tasks/segmentation/speaker_change_detection.py @@ -82,8 +82,6 @@ class SpeakerChangeDetection(SegmentationTaskMixin, Task): Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "scd" - def __init__( self, protocol: Protocol, diff --git a/pyannote/audio/tasks/segmentation/speaker_tracking.py b/pyannote/audio/tasks/segmentation/speaker_tracking.py index c572fcb26..9daf48627 100644 --- a/pyannote/audio/tasks/segmentation/speaker_tracking.py +++ b/pyannote/audio/tasks/segmentation/speaker_tracking.py @@ -76,8 +76,6 @@ class SpeakerTracking(SegmentationTaskMixin, Task): Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "spk" - def __init__( self, protocol: Protocol, diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index d2c254ee1..123451968 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -75,8 +75,6 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task): Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "vad" - def __init__( self, protocol: Protocol, diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index 8022cbdc4..b2053f459 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -153,8 +153,6 @@ "class SoundEventDetection(Task):\n", " \"\"\"Sound event detection\"\"\"\n", "\n", - " ACRONYM = \"sed\"\n", - "\n", " def __init__(\n", " self,\n", " protocol: Protocol,\n", From 75e97b9e36988ccbebb9061135cca2bddd3c154c Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Wed, 16 Mar 2022 09:30:54 +0100 Subject: [PATCH 15/20] * feat: add support for custom validation metrics (#913) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * BREAKING: use task and protocol full names in logs and checkpoint paths * BREAKING: use diarization error rate for validation of segmentation Co-authored-by: Hervé BREDIN --- pyannote/audio/core/task.py | 7 ++++--- pyannote/audio/tasks/embedding/mixins.py | 2 +- pyannote/audio/tasks/segmentation/mixins.py | 2 +- pyannote/audio/tasks/segmentation/segmentation.py | 6 +++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index db6572690..0233a87d8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -278,8 +278,9 @@ def logging_prefix(self): prefix = f"{self.__class__.__name__}-" if hasattr(self.protocol, "name"): # "." has a special meaning for pytorch-lightning checkpointing - # so we replace any encountered "." by "_" in protocol names - prefix += f"{self.protocol.name.replace('.', '_')}-" + # so we remove dots from protocol names + name_without_dots = "".join(self.protocol.name.split(".")) + prefix += f"{name_without_dots}-" return prefix @@ -372,7 +373,7 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): # compute loss loss = self.default_loss(self.specifications, y, y_pred, weight=weight) self.model.log( - f"{self.logging_prefix}{stage}-loss", + f"{self.logging_prefix}{stage.capitalize()}Loss", loss, on_step=False, on_epoch=True, diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index b5d668c93..13d65edb1 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -221,7 +221,7 @@ def training_step(self, batch, batch_idx: int): loss = self.model.loss_func(self.model(X), y) self.model.log( - f"{self.logging_prefix}train-loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index 53600a0ff..69c1ed182 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -622,7 +622,7 @@ def validation_step(self, batch, batch_idx: int): plt.tight_layout() self.model.logger.experiment.add_figure( - f"{self.logging_prefix}val-samples", fig, self.model.current_epoch + f"{self.logging_prefix}ValSamples", fig, self.model.current_epoch ) plt.close(fig) diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index fc23805e6..ed540be5b 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -353,7 +353,7 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss(permutated_prediction, target, weight=weight) self.model.log( - f"{self.logging_prefix}train-loss-seg", + f"{self.logging_prefix}TrainSegLoss", seg_loss, on_step=False, on_epoch=True, @@ -370,7 +370,7 @@ def training_step(self, batch, batch_idx: int): ) self.model.log( - f"{self.logging_prefix}train-loss-vad", + f"{self.logging_prefix}TrainVADLoss", vad_loss, on_step=False, on_epoch=True, @@ -381,7 +381,7 @@ def training_step(self, batch, batch_idx: int): loss = seg_loss + vad_loss self.model.log( - f"{self.logging_prefix}train-loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, From 49a45170c0ce48728cb3eb861572e6021fe1f232 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Wed, 16 Mar 2022 13:34:57 +0100 Subject: [PATCH 16/20] cleaner way to check for pseudolabel use --- .../segmentation/unsupervised_segmentation.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index c81e040e6..95ed00ec4 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -75,11 +75,16 @@ def get_model_output(self, model: Model, waveforms: torch.Tensor): result = torch.round(result).type(torch.int8) return result + def use_pseudolabels(self, stage: Literal["train", "val"]): + return (stage == "train" and self.fake_in_train) or ( + stage == "val" and self.fake_in_val + ) + def collate_fn(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: + if self.use_pseudolabels("train"): teacher_input = collated_batch["X"] if self.augmentation_model is not None: teacher_input = self.augmentation_model( @@ -97,7 +102,7 @@ def collate_fn_val(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: + if self.use_pseudolabels("val"): teacher_input = collated_batch["X"] collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) @@ -134,11 +139,8 @@ def prepare_chunk( ... """ - use_annotations = (stage == "train" and not self.fake_in_train) or ( - stage == "val" and not self.fake_in_val - ) sample = super().prepare_chunk( - file, chunk, duration=duration, stage=stage, use_annotations=use_annotations + file, chunk, duration=duration, stage=stage, use_annotations=True ) return sample From 4c3fea13ef14b9785898bcc551eacc0306e20a39 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Tue, 15 Mar 2022 10:23:03 +0100 Subject: [PATCH 17/20] update UnsupervisedSegmention to take a metric arg --- .../audio/tasks/segmentation/unsupervised_segmentation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index 95ed00ec4..84ae2661d 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, OrderedDict, Text, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Text, Tuple, Union import numpy as np import pytorch_lightning as pl @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric from typing_extensions import Literal from pyannote.audio.core.io import AudioFile @@ -39,6 +40,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( # Mixin params @@ -56,6 +58,7 @@ def __init__( weight=weight, loss=loss, vad_loss=vad_loss, + metric=metric, ) self.teacher = model From 82ec5ab22e60e004f2e712ad5e7e33a89ca8a998 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Wed, 16 Mar 2022 13:51:32 +0100 Subject: [PATCH 18/20] add pseudolabel filtering capabilities to UnsupervisedSegmentation --- .../segmentation/unsupervised_segmentation.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index 84ae2661d..d72896c9e 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -19,6 +19,13 @@ from pyannote.audio.tasks import Segmentation +class PseudoLabelPostprocess: + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + class UnsupervisedSegmentation(Segmentation, Task): def __init__( self, @@ -27,6 +34,9 @@ def __init__( fake_in_train=True, # generate fake truth in training mode fake_in_val=True, # generate fake truth in val mode augmentation_model: BaseWaveformTransform = None, + pl_postprocess: Union[ + PseudoLabelPostprocess, List[PseudoLabelPostprocess] + ] = None, # supervised params duration: float = 2.0, max_num_speakers: int = None, @@ -66,6 +76,11 @@ def __init__( self.fake_in_val = fake_in_val self.augmentation_model = augmentation_model + if isinstance(pl_postprocess, PseudoLabelPostprocess): + self.pl_postprocess = [pl_postprocess] + else: + self.pl_postprocess = pl_postprocess + self.teacher.eval() def get_model_output(self, model: Model, waveforms: torch.Tensor): @@ -88,12 +103,23 @@ def collate_fn(self, batch): # Generate annotations y with teacher if they are not provided if self.use_pseudolabels("train"): - teacher_input = collated_batch["X"] + x = collated_batch["X"] + teacher_input = x if self.augmentation_model is not None: teacher_input = self.augmentation_model( collated_batch["X"], sample_rate=self.model.hparams.sample_rate ) - collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) + pseudo_y = self.get_model_output(self.teacher, teacher_input) + + y = None + if "y" in collated_batch: + y = collated_batch["y"] + if self.pl_postprocess is not None: + for pp in self.pl_postprocess: + pseudo_y, x = pp.process(pseudo_y, y, x) + + collated_batch["y"] = pseudo_y + collated_batch["X"] = x if self.augmentation is not None: collated_batch["X"] = self.augmentation( From 4682676d4609175717c17f4cc6c27e53cfdb23bd Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Wed, 16 Mar 2022 13:52:16 +0100 Subject: [PATCH 19/20] add oracle pseudolabel filter postprocessors remove bottom x% of DERs and remove pl with DER > x% --- .../segmentation/unsupervised_segmentation.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index d72896c9e..7cbb4c669 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -17,6 +17,9 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task, ValDataset from pyannote.audio.tasks import Segmentation +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + diarization_error_rate, +) class PseudoLabelPostprocess: @@ -187,6 +190,58 @@ def val_dataloader(self) -> Optional[DataLoader]: return None +def _compute_ders( + pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor +) -> Tuple[torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = torch.zeros(batch_size) + + tm_pseudo_y = pseudo_y.swapaxes(1, 2) + tm_true_y = y.swapaxes(1, 2) + for i in range(batch_size): + ders[i] = diarization_error_rate( + tm_pseudo_y[i][None, :, :], tm_true_y[i][None, :, :] + ) + + return ders + + +class DiscardPercentDer(PseudoLabelPostprocess): + def __init__(self, ratio_to_discard: float = 0.1) -> None: + self.ratio_to_discard = ratio_to_discard + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = _compute_ders(pseudo_y, y, x) + sorted_ders, sorted_indices = torch.sort(ders) + + to_discard_count = min( + batch_size, max(1, round(batch_size * self.ratio_to_discard)) + ) + pseudo_y = pseudo_y[sorted_indices][:-to_discard_count, :, :] + x = x[sorted_indices][:-to_discard_count, :, :] + + return pseudo_y, x + + +class DiscardThresholdDer(PseudoLabelPostprocess): + def __init__(self, threshold: float = 0.5) -> None: + self.threshold = threshold + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + ders = _compute_ders(pseudo_y, y, x) + + filter = torch.where(ders < self.threshold) + pseudo_y = pseudo_y[filter] + x = x[filter] + + return pseudo_y, x + + class TeacherUpdate(Callback): def __init__( self, From 66a873f4fcbbed3a941d8b4f995da733b9eaed7d Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Fri, 18 Mar 2022 09:59:29 +0100 Subject: [PATCH 20/20] fix devastating TeacherUpdate mess up (exclusive to branches develop2/3) --- pyannote/audio/tasks/segmentation/unsupervised_segmentation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index 2e6fcfbf3..8e6669f58 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -260,6 +260,7 @@ def __init__( def enqueue_teacher(self, teacher: OrderedDict[str, torch.Tensor]): if len(self.last_weights) >= self.average_of: + self.last_weights.pop(0) self.last_weights.append(teacher) def get_updated_weights(