diff --git a/pyannote/audio/cli/train.py b/pyannote/audio/cli/train.py index b8547e446..e7ab2225a 100644 --- a/pyannote/audio/cli/train.py +++ b/pyannote/audio/cli/train.py @@ -27,20 +27,19 @@ import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf + +# from pyannote.audio.core.callback import GraduallyUnfreeze +from pyannote.database import FileFinder, get_protocol from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar, ) - from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything from torch_audiomentations.utils.config import from_dict as get_augmentation -# from pyannote.audio.core.callback import GraduallyUnfreeze -from pyannote.database import FileFinder, get_protocol - @hydra.main(config_path="train_config", config_name="config") def train(cfg: DictConfig) -> Optional[float]: @@ -67,15 +66,15 @@ def train(cfg: DictConfig) -> Optional[float]: # instantiate task and validation metric task = instantiate(cfg.task, protocol, augmentation=augmentation) - # validation metric to monitor (and its direction: min or max) - monitor, direction = task.val_monitor - # instantiate model fine_tuning = cfg.model["_target_"] == "pyannote.audio.cli.pretrained" model = instantiate(cfg.model) model.task = task model.setup(stage="fit") + # validation metric to monitor (and its direction: min or max) + monitor, direction = task.val_monitor + # number of batches in one epoch num_batches_per_epoch = model.task.train__len__() // model.task.batch_size @@ -90,6 +89,7 @@ def configure_optimizers(self): num_batches_per_epoch=num_batches_per_epoch, ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + model.configure_optimizers = MethodType(configure_optimizers, model) callbacks = [RichProgressBar(), LearningRateMonitor()] @@ -155,4 +155,3 @@ def configure_optimizers(self): if __name__ == "__main__": train() - diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 4e9bfe092..c36b6927b 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -34,6 +34,7 @@ import torch.nn as nn import torch.optim from huggingface_hub import cached_download, hf_hub_url +from pyannote.core import SlidingWindow from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.model_summary import ModelSummary from semver import VersionInfo @@ -42,7 +43,6 @@ from pyannote.audio import __version__ from pyannote.audio.core.io import Audio from pyannote.audio.core.task import Problem, Resolution, Specifications, Task -from pyannote.core import SlidingWindow CACHE_DIR = os.getenv( "PYANNOTE_CACHE", @@ -385,7 +385,7 @@ def setup(self, stage=None): # setup custom loss function self.task.setup_loss_func() # setup custom validation metrics - validation_metric = self.task.setup_validation_metric() + validation_metric = self.task.metric if validation_metric is not None: self.validation_metric = validation_metric self.validation_metric.to(self.device) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 34903f781..0233a87d8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,13 +23,18 @@ from __future__ import annotations +try: + from functools import cached_property +except ImportError: + from backports.cached_property import cached_property + import multiprocessing import sys import warnings from dataclasses import dataclass from enum import Enum from numbers import Number -from typing import Dict, List, Optional, Sequence, Text, Tuple, Type, Union +from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import pytorch_lightning as pl import torch @@ -152,6 +157,9 @@ class Task(pl.LightningDataModule): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to value returned by `default_metric` method. Attributes ---------- @@ -169,7 +177,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__() @@ -205,20 +213,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.augmentation = augmentation - self.metrics = metrics - - @property - def default_validation_metric( - self, - ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - """Used to get validation metrics when none is provided by the user - - Returns - ------- - Union[Metric, Sequence[Metric], Dict[str, Metric]] - The default validation metric(s) of this task (something that can be plugged in a torchmetrics.MetricCollection). - """ - pass + self._metric = metric def prepare_data(self): """Use this to download and prepare data @@ -249,39 +244,6 @@ def setup(self, stage: Optional[str] = None): def setup_loss_func(self): pass - @property - def val_metric_prefix(self) -> str: - return f"{self.ACRONYM}@val_" - - def get_default_val_metric_name(self, metric: Union[Metric, Type]) -> str: - prefix = self.val_metric_prefix - mn = Task.get_metric_name(metric) - return f"{prefix}{mn}" - - @staticmethod - def get_metric_name(metric: Union[Metric, Type]) -> str: - if isinstance(metric, Metric): - return type(metric).__name__.lower() - elif isinstance(metric, Type): - return metric.__name__.lower() - else: - msg = "get_metric_name only accepts Metric and Type arguments." - raise ValueError(msg) - - def setup_validation_metric(self) -> Metric: - metricsarg = self.metrics - if self.metrics is None: - metricsarg = self.default_validation_metric - - # Convert metricargs to a list, if it is a single Metric - if isinstance(metricsarg, Metric): - metricsarg = [metricsarg] - # Convert metricsarg to a dict, now that it is a list - # If the metrics' names are not given, generate them automatically - if not isinstance(metricsarg, dict): - metricsarg = {Task.get_metric_name(m): m for m in metricsarg} - return MetricCollection(metricsarg, prefix=self.val_metric_prefix) - def train__iter__(self): # will become train_dataset.__iter__ method msg = f"Missing '{self.__class__.__name__}.train__iter__' method." @@ -310,6 +272,18 @@ def train_dataloader(self) -> DataLoader: collate_fn=self.collate_fn, ) + @cached_property + def logging_prefix(self): + + prefix = f"{self.__class__.__name__}-" + if hasattr(self.protocol, "name"): + # "." has a special meaning for pytorch-lightning checkpointing + # so we remove dots from protocol names + name_without_dots = "".join(self.protocol.name.split(".")) + prefix += f"{name_without_dots}-" + + return prefix + def default_loss( self, specifications: Specifications, target, prediction, weight=None ) -> torch.Tensor: @@ -399,7 +373,7 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]): # compute loss loss = self.default_loss(self.specifications, y, y_pred, weight=weight) self.model.log( - f"{self.ACRONYM}@{stage}_loss", + f"{self.logging_prefix}{stage.capitalize()}Loss", loss, on_step=False, on_epoch=True, @@ -443,6 +417,18 @@ def validation_step(self, batch, batch_idx: int): def validation_epoch_end(self, outputs): pass + def default_metric(self) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Default validation metric""" + msg = f"Missing '{self.__class__.__name__}.default_metric' method." + raise NotImplementedError(msg) + + @cached_property + def metric(self) -> MetricCollection: + if self._metric is None: + self._metric = self.default_metric() + + return MetricCollection(self._metric, prefix=self.logging_prefix) + @property def val_monitor(self): """Quantity (and direction) to monitor @@ -461,7 +447,6 @@ def val_monitor(self): pytorch_lightning.callbacks.ModelCheckpoint pytorch_lightning.callbacks.EarlyStopping """ - if self.has_validation: - return f"{self.ACRONYM}@val_loss", "min" - else: - return None, "min" + + name, metric = next(iter(self.metric.items())) + return name, "max" if metric.higher_is_better else "min" diff --git a/pyannote/audio/tasks/embedding/arcface.py b/pyannote/audio/tasks/embedding/arcface.py index bcf9dea31..bb2cb1f6c 100644 --- a/pyannote/audio/tasks/embedding/arcface.py +++ b/pyannote/audio/tasks/embedding/arcface.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,11 +26,11 @@ from typing import Dict, Sequence, Union import pytorch_metric_learning.losses +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Task -from pyannote.database import Protocol from .mixins import SupervisedRepresentationLearningTaskMixin @@ -70,10 +70,11 @@ class SupervisedRepresentationLearningWithArcFace( augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "arcface" - #  TODO: add a ".metric" property that tells how speaker embedding trained with this approach #  should be compared. could be a string like "cosine" or "euclidean" or a pdist/cdist-like #  callable. this ".metric" property should be propagated all the way to Inference (via the model). @@ -90,7 +91,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): self.num_chunks_per_class = num_chunks_per_class @@ -107,7 +108,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) def setup_loss_func(self): diff --git a/pyannote/audio/tasks/embedding/mixins.py b/pyannote/audio/tasks/embedding/mixins.py index 2d4e3315c..13d65edb1 100644 --- a/pyannote/audio/tasks/embedding/mixins.py +++ b/pyannote/audio/tasks/embedding/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,7 +33,7 @@ from torchmetrics import AUROC, Metric from tqdm import tqdm -from pyannote.audio.core.task import Problem, Resolution, Specifications, Task +from pyannote.audio.core.task import Problem, Resolution, Specifications from pyannote.audio.utils.random import create_rng_for_worker @@ -124,8 +124,7 @@ def setup(self, stage: Optional[str] = None): if isinstance(self.protocol, SpeakerVerificationProtocol): self._validation = list(self.protocol.development_trial()) - @property - def default_validation_metric( + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: return AUROC(compute_on_step=False) @@ -222,7 +221,7 @@ def training_step(self, batch, batch_idx: int): loss = self.model.loss_func(self.model(X), y) self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, @@ -290,17 +289,3 @@ def validation_step(self, batch, batch_idx: int): elif isinstance(self.protocol, SpeakerDiarizationProtocol): pass - - @property - def val_monitor(self): - - if self.has_validation and self.metrics is None: - - if isinstance(self.protocol, SpeakerVerificationProtocol): - return Task.get_default_val_metric_name(AUROC), "max" - - elif isinstance(self.protocol, SpeakerDiarizationProtocol): - return None, "min" - - else: - return None, "min" diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index be1632acd..69c1ed182 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -121,14 +121,10 @@ def setup(self, stage: Optional[str] = None): random.shuffle(self._validation) - @property - def default_validation_metric( + def default_metric( self, ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: - """Setup default validation metric - - Use macro-average of area under the ROC curve - """ + """Returns macro-average of the area under the ROC curve""" num_classes = len(self.specifications.classes) return AUROC(num_classes, pos_label=1, average="macro", compute_on_step=False) @@ -212,9 +208,10 @@ def prepare_chunk( duration: float = None, stage: Literal["train", "val"] = "train", use_annotations: bool = True, - ) -> Tuple[np.ndarray, np.ndarray, List[Text]]: + ) -> dict: """Extract audio chunk and corresponding frame-wise labels + Parameters ---------- file : AudioFile @@ -506,6 +503,7 @@ def validation_step(self, batch, batch_idx: int): # y_pred = (batch_size, num_frames, num_classes) # postprocess + # TODO: remove this because metrics should take care of postprocessing y_pred = self.validation_postprocess(y, y_pred) # - remove warm-up frames @@ -525,8 +523,8 @@ def validation_step(self, batch, batch_idx: int): # preds: shape (batch_size, num_frames, 1), type float # torchmetrics expects: - # target: shape (N,), type binary - # preds: shape (N,), type float + # target: shape (batch_size,), type binary + # preds: shape (batch_size,), type float self.model.validation_metric( preds.reshape(-1), @@ -538,8 +536,8 @@ def validation_step(self, batch, batch_idx: int): # preds: shape (batch_size, num_frames, num_classes), type float # torchmetrics expects - # target: shape (N, ), type binary - # preds: shape (N, ), type float + # target: shape (batch_size, num_classes, ...), type binary + # preds: shape (batch_size, num_classes, ...), type float self.model.validation_metric( torch.transpose(preds, 1, 2), @@ -547,13 +545,6 @@ def validation_step(self, batch, batch_idx: int): ) elif self.specifications.problem == Problem.MONO_LABEL_CLASSIFICATION: - # target: shape (batch_size, num_frames, num_classes), type binary - # preds: shape (batch_size, num_frames, num_classes), type float - - # torchmetrics expects: - # target: shape (N, ), type int - # preds: shape (N, num_classes), type float - # TODO: implement when pyannote.audio gets its first mono-label segmentation task raise NotImplementedError() @@ -631,14 +622,7 @@ def validation_step(self, batch, batch_idx: int): plt.tight_layout() self.model.logger.experiment.add_figure( - f"{self.ACRONYM}@val_samples", fig, self.model.current_epoch + f"{self.logging_prefix}ValSamples", fig, self.model.current_epoch ) plt.close(fig) - - @property - def val_monitor(self): - if self.has_validation and self.metrics is None: - return self.get_default_val_metric_name(AUROC), "max" - else: - return None, "min" diff --git a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py index 07423d883..15b7024f4 100644 --- a/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py +++ b/pyannote/audio/tasks/segmentation/overlapped_speech_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,12 +24,12 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class OverlappedSpeechDetection(SegmentationTaskMixin, Task): @@ -84,10 +84,11 @@ class OverlappedSpeechDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "osd" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( @@ -102,7 +103,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -113,7 +114,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.specifications = Specifications( diff --git a/pyannote/audio/tasks/segmentation/segmentation.py b/pyannote/audio/tasks/segmentation/segmentation.py index 460500324..ed540be5b 100644 --- a/pyannote/audio/tasks/segmentation/segmentation.py +++ b/pyannote/audio/tasks/segmentation/segmentation.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,6 +33,12 @@ from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin +from pyannote.audio.torchmetrics import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss from pyannote.audio.utils.permutation import permutate @@ -81,8 +87,13 @@ class Segmentation(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + loss : {"bce", "mse"}, optional + Permutation-invariant segmentation loss. Defaults to "bce". vad_loss : {"bce", "mse"}, optional Add voice activity detection loss. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). Reference ---------- @@ -91,8 +102,6 @@ class Segmentation(SegmentationTaskMixin, Task): Proc. Interspeech 2021 """ - ACRONYM = "seg" - OVERLAP_DEFAULTS = {"probability": 0.5, "snr_min": 0.0, "snr_max": 10.0} def __init__( @@ -110,7 +119,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -121,7 +130,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.max_num_speakers = max_num_speakers @@ -344,7 +353,7 @@ def training_step(self, batch, batch_idx: int): seg_loss = self.segmentation_loss(permutated_prediction, target, weight=weight) self.model.log( - f"{self.ACRONYM}@train_seg_loss", + f"{self.logging_prefix}TrainSegLoss", seg_loss, on_step=False, on_epoch=True, @@ -361,7 +370,7 @@ def training_step(self, batch, batch_idx: int): ) self.model.log( - f"{self.ACRONYM}@train_vad_loss", + f"{self.logging_prefix}TrainVADLoss", vad_loss, on_step=False, on_epoch=True, @@ -372,18 +381,26 @@ def training_step(self, batch, batch_idx: int): loss = seg_loss + vad_loss self.model.log( - f"{self.ACRONYM}@train_loss", + f"{self.logging_prefix}TrainLoss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) + return {"loss": loss} - def validation_postprocess(self, y, y_pred): - permutated_y_pred, _ = permutate(y, y_pred) - return permutated_y_pred + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + return [ + DiarizationErrorRate(), + SpeakerConfusionRate(), + MissedDetectionRate(), + FalseAlarmRate(), + ] def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): diff --git a/pyannote/audio/tasks/segmentation/speaker_change_detection.py b/pyannote/audio/tasks/segmentation/speaker_change_detection.py index 476fc327b..0329d361e 100644 --- a/pyannote/audio/tasks/segmentation/speaker_change_detection.py +++ b/pyannote/audio/tasks/segmentation/speaker_change_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -24,12 +24,12 @@ import numpy as np import scipy.signal +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerChangeDetection(SegmentationTaskMixin, Task): @@ -77,10 +77,11 @@ class SpeakerChangeDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "scd" - def __init__( self, protocol: Protocol, @@ -93,7 +94,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -104,7 +105,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/speaker_tracking.py b/pyannote/audio/tasks/segmentation/speaker_tracking.py index 5b06f2350..9daf48627 100644 --- a/pyannote/audio/tasks/segmentation/speaker_tracking.py +++ b/pyannote/audio/tasks/segmentation/speaker_tracking.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,12 +23,12 @@ from typing import Dict, List, Optional, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class SpeakerTracking(SegmentationTaskMixin, Task): @@ -71,10 +71,11 @@ class SpeakerTracking(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "spk" - def __init__( self, protocol: Protocol, @@ -86,7 +87,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -97,7 +98,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py index 65805e548..8e6669f58 100644 --- a/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py +++ b/pyannote/audio/tasks/segmentation/unsupervised_segmentation.py @@ -9,7 +9,6 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate -from torch.utils.tensorboard import SummaryWriter from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from typing_extensions import Literal @@ -18,7 +17,16 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task, ValDataset from pyannote.audio.tasks import Segmentation -from pyannote.audio.torchmetrics import AUDER +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + diarization_error_rate, +) + + +class PseudoLabelPostprocess: + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() class UnsupervisedSegmentation(Segmentation, Task): @@ -29,6 +37,9 @@ def __init__( fake_in_train=True, # generate fake truth in training mode fake_in_val=True, # generate fake truth in val mode augmentation_model: BaseWaveformTransform = None, + pl_postprocess: Union[ + PseudoLabelPostprocess, List[PseudoLabelPostprocess] + ] = None, # supervised params duration: float = 2.0, max_num_speakers: int = None, @@ -42,7 +53,7 @@ def __init__( augmentation: BaseWaveformTransform = None, loss: Literal["bce", "mse"] = "bce", vad_loss: Literal["bce", "mse"] = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( # Mixin params @@ -60,7 +71,7 @@ def __init__( weight=weight, loss=loss, vad_loss=vad_loss, - metrics=metrics, + metric=metric, ) self.teacher = model @@ -68,6 +79,11 @@ def __init__( self.fake_in_val = fake_in_val self.augmentation_model = augmentation_model + if isinstance(pl_postprocess, PseudoLabelPostprocess): + self.pl_postprocess = [pl_postprocess] + else: + self.pl_postprocess = pl_postprocess + self.teacher.eval() def get_model_output(self, model: Model, waveforms: torch.Tensor): @@ -80,17 +96,33 @@ def get_model_output(self, model: Model, waveforms: torch.Tensor): result = torch.round(result).type(torch.int8) return result + def use_pseudolabels(self, stage: Literal["train", "val"]): + return (stage == "train" and self.fake_in_train) or ( + stage == "val" and self.fake_in_val + ) + def collate_fn(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: - teacher_input = collated_batch["X"] + if self.use_pseudolabels("train"): + x = collated_batch["X"] + teacher_input = x if self.augmentation_model is not None: teacher_input = self.augmentation_model( collated_batch["X"], sample_rate=self.model.hparams.sample_rate ) - collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) + pseudo_y = self.get_model_output(self.teacher, teacher_input) + + y = None + if "y" in collated_batch: + y = collated_batch["y"] + if self.pl_postprocess is not None: + for pp in self.pl_postprocess: + pseudo_y, x = pp.process(pseudo_y, y, x) + + collated_batch["y"] = pseudo_y + collated_batch["X"] = x if self.augmentation is not None: collated_batch["X"] = self.augmentation( @@ -102,7 +134,7 @@ def collate_fn_val(self, batch): collated_batch = default_collate(batch) # Generate annotations y with teacher if they are not provided - if "y" not in collated_batch: + if self.use_pseudolabels("val"): teacher_input = collated_batch["X"] collated_batch["y"] = self.get_model_output(self.teacher, teacher_input) @@ -139,11 +171,8 @@ def prepare_chunk( ... """ - use_annotations = (stage == "train" and not self.fake_in_train) or ( - stage == "val" and not self.fake_in_val - ) sample = super().prepare_chunk( - file, chunk, duration=duration, stage=stage, use_annotations=use_annotations + file, chunk, duration=duration, stage=stage, use_annotations=True ) return sample @@ -160,44 +189,57 @@ def val_dataloader(self) -> Optional[DataLoader]: else: return None - def validation_epoch_end(self, outputs): - super().validation_epoch_end(outputs) - - # TODO : remove (temp debug) - for key, metric in self.model.validation_metric.items(): - # print(key) - if isinstance(metric, AUDER): - ders = ( - metric.false_alarm + metric.missed_detection + metric.confusion - ) / metric.total - # print(ders) - # print(metric.linspace) - SAMPLE = 100 - data = [] - bins = [-0.0001] - linspace = np.linspace( - metric.threshold_min, metric.threshold_max, metric.steps - ) - for i in range(metric.steps): - data += [linspace[i]] * int(ders[i] * SAMPLE) - if i > 0: - bins += [(linspace[i] + linspace[i - 1]) / 2.0] - bins += [1.0001] - values = torch.tensor(data).reshape(-1) - # print(bins) - # print(data) - experiment: SummaryWriter = self.model.logger.experiment - - experiment.add_histogram( - "der_curve", - global_step=self.model.current_epoch, - values=values, - bins=bins, - ) - # fig, ax = plt.subplots() # Create a figure containing a single axes. - # ax.plot(metric.linspace, ders.cpu()) - # experiment.add_figure("testplot", fig, global_step=0) +def _compute_ders( + pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor +) -> Tuple[torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = torch.zeros(batch_size) + + tm_pseudo_y = pseudo_y.swapaxes(1, 2) + tm_true_y = y.swapaxes(1, 2) + for i in range(batch_size): + ders[i] = diarization_error_rate( + tm_pseudo_y[i][None, :, :], tm_true_y[i][None, :, :] + ) + + return ders + + +class DiscardPercentDer(PseudoLabelPostprocess): + def __init__(self, ratio_to_discard: float = 0.1) -> None: + self.ratio_to_discard = ratio_to_discard + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = pseudo_y.shape[0] + ders = _compute_ders(pseudo_y, y, x) + sorted_ders, sorted_indices = torch.sort(ders) + + to_discard_count = min( + batch_size, max(1, round(batch_size * self.ratio_to_discard)) + ) + pseudo_y = pseudo_y[sorted_indices][:-to_discard_count, :, :] + x = x[sorted_indices][:-to_discard_count, :, :] + + return pseudo_y, x + + +class DiscardThresholdDer(PseudoLabelPostprocess): + def __init__(self, threshold: float = 0.5) -> None: + self.threshold = threshold + + def process( + self, pseudo_y: torch.Tensor, y: torch.Tensor, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + ders = _compute_ders(pseudo_y, y, x) + + filter = torch.where(ders < self.threshold) + pseudo_y = pseudo_y[filter] + x = x[filter] + + return pseudo_y, x class TeacherUpdate(Callback): diff --git a/pyannote/audio/tasks/segmentation/voice_activity_detection.py b/pyannote/audio/tasks/segmentation/voice_activity_detection.py index 52743c094..123451968 100644 --- a/pyannote/audio/tasks/segmentation/voice_activity_detection.py +++ b/pyannote/audio/tasks/segmentation/voice_activity_detection.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2021 CNRS +# Copyright (c) 2020- CNRS # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,12 +23,12 @@ from typing import Dict, Sequence, Text, Tuple, Union import numpy as np +from pyannote.database import Protocol from torch_audiomentations.core.transforms_interface import BaseWaveformTransform from torchmetrics import Metric from pyannote.audio.core.task import Problem, Resolution, Specifications, Task from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin -from pyannote.database import Protocol class VoiceActivityDetection(SegmentationTaskMixin, Task): @@ -70,10 +70,11 @@ class VoiceActivityDetection(SegmentationTaskMixin, Task): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). """ - ACRONYM = "vad" - def __init__( self, protocol: Protocol, @@ -85,7 +86,7 @@ def __init__( num_workers: int = None, pin_memory: bool = False, augmentation: BaseWaveformTransform = None, - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, ): super().__init__( @@ -96,7 +97,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, augmentation=augmentation, - metrics=metrics, + metric=metric, ) self.balance = balance diff --git a/pyannote/audio/torchmetrics/__init__.py b/pyannote/audio/torchmetrics/__init__.py index 94efbb4b0..27513524b 100644 --- a/pyannote/audio/torchmetrics/__init__.py +++ b/pyannote/audio/torchmetrics/__init__.py @@ -1,15 +1,36 @@ -from pyannote.audio.torchmetrics.audio.auder import AUDER -from pyannote.audio.torchmetrics.audio.der import ( - DER, - ConfusionMetric, - FalseAlarmMetric, - MissedDetectionMetric, +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .audio.diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, ) __all__ = [ - "AUDER", - "DER", - "FalseAlarmMetric", - "MissedDetectionMetric", - "ConfusionMetric", + "DiarizationErrorRate", + "FalseAlarmRate", + "MissedDetectionRate", + "SpeakerConfusionRate", ] diff --git a/pyannote/audio/torchmetrics/audio/__init__.py b/pyannote/audio/torchmetrics/audio/__init__.py index e69de29bb..e500a2655 100644 --- a/pyannote/audio/torchmetrics/audio/__init__.py +++ b/pyannote/audio/torchmetrics/audio/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "SpeakerConfusionRate", + "MissedDetectionRate", + "FalseAlarmRate", +] diff --git a/pyannote/audio/torchmetrics/audio/auder.py b/pyannote/audio/torchmetrics/audio/auder.py deleted file mode 100644 index 2e6a265c5..000000000 --- a/pyannote/audio/torchmetrics/audio/auder.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch import Tensor -from torchmetrics import Metric - -from pyannote.audio.torchmetrics.functional.audio.auder import ( - _auder_compute, - _auder_update, -) - - -class AUDER(Metric): - """Area Under the Diarization Error Rate. - Approximates the area under the curve of the DER when varying its threshold value. - - Expects preds and target tensors of the shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) in its update. - - Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. - """ - - higher_is_better = True - is_differentiable = False - - def __init__(self, steps=31, threshold_min=0.0, threshold_max=1.0, unit_area=True): - super().__init__() - - if threshold_max < threshold_min: - raise ValueError( - f"Illegal value : threshold_max ({threshold_max}) < threshold_min ({threshold_min})" - ) - - self.threshold_min = threshold_min - self.threshold_max = threshold_max - self.unit_area = unit_area - self.steps = steps - - self.add_state( - "false_alarm", - default=torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "missed_detection", - torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "confusion", - torch.zeros(self.steps, dtype=torch.float), - dist_reduce_fx="sum", - ) - self.add_state( - "total", torch.zeros(self.steps, dtype=torch.float), dist_reduce_fx="sum" - ) - - def update( - self, - preds: Tensor, - target: Tensor, - ): - fa, md, conf, total = _auder_update( - preds, target, self.steps, self.threshold_min, self.threshold_max - ) - self.false_alarm += fa - self.missed_detection += md - self.confusion += conf - self.total += total - - def compute(self): - dx = ( - (self.threshold_max - self.threshold_min) if not self.unit_area else 1.0 - ) / (self.steps - 1) - return _auder_compute( - self.false_alarm, self.missed_detection, self.confusion, self.total, dx - ) diff --git a/pyannote/audio/torchmetrics/audio/der.py b/pyannote/audio/torchmetrics/audio/der.py deleted file mode 100644 index c3fe19c72..000000000 --- a/pyannote/audio/torchmetrics/audio/der.py +++ /dev/null @@ -1,60 +0,0 @@ -from torch import Tensor, tensor -from torchmetrics import Metric - -from pyannote.audio.torchmetrics.functional.audio.der import _der_compute, _der_update - - -class DER(Metric): - """ - Compute Diarization Error Rate on discretized annotations. - - Expects preds and target tensors of the shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) in its update. - - Note that this is only a reliable metric if num_frames == the total number of frames of the diarized audio. - """ - - higher_is_better = True - is_differentiable = False - - def __init__(self, threshold: float = 0.5): - super().__init__() - - self.threshold = threshold - - self.add_state("false_alarm", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("missed_detection", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("confusion", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - - def update( - self, - preds: Tensor, - target: Tensor, - ) -> None: - false_alarm, missed_detection, confusion, total = _der_update( - preds, target, self.threshold - ) - self.false_alarm += false_alarm - self.missed_detection += missed_detection - self.confusion += confusion - self.total += total - - def compute(self): - return _der_compute( - self.false_alarm, self.missed_detection, self.confusion, self.total - ) - - -class ConfusionMetric(DER): - def compute(self): - return self.confusion / self.total - - -class FalseAlarmMetric(DER): - def compute(self): - return self.false_alarm / self.total - - -class MissedDetectionMetric(DER): - def compute(self): - return self.missed_detection / self.total diff --git a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py new file mode 100644 index 000000000..c294eb3a3 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py @@ -0,0 +1,119 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torchmetrics import Metric + +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + _der_compute, + _der_update, +) + + +class DiarizationErrorRate(Metric): + """Diarization error rate + + Parameters + ---------- + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Notes + ----- + While pyannote.audio conventions is to store speaker activations with + (batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric + expects them to be shaped as (batch_size, num_speakers, num_frames) tensors. + """ + + higher_is_better = False + is_differentiable = False + + def __init__(self, threshold: float = 0.5): + super().__init__() + + self.threshold = threshold + + self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state( + "missed_detection", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state( + "speaker_confusion", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("speech_total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + target: torch.Tensor, + ) -> None: + """Compute and accumulate components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=self.threshold + ) + self.false_alarm += false_alarm + self.missed_detection += missed_detection + self.speaker_confusion += speaker_confusion + self.speech_total += speech_total + + def compute(self): + return _der_compute( + self.false_alarm, + self.missed_detection, + self.speaker_confusion, + self.speech_total, + ) + + +class SpeakerConfusionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.speaker_confusion / self.speech_total + + +class FalseAlarmRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.false_alarm / self.speech_total + + +class MissedDetectionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.missed_detection / self.speech_total diff --git a/pyannote/audio/torchmetrics/functional/__init__.py b/pyannote/audio/torchmetrics/functional/__init__.py index e69de29bb..67b544284 100644 --- a/pyannote/audio/torchmetrics/functional/__init__.py +++ b/pyannote/audio/torchmetrics/functional/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/__init__.py b/pyannote/audio/torchmetrics/functional/audio/__init__.py index e69de29bb..67b544284 100644 --- a/pyannote/audio/torchmetrics/functional/audio/__init__.py +++ b/pyannote/audio/torchmetrics/functional/audio/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/auder.py b/pyannote/audio/torchmetrics/functional/audio/auder.py deleted file mode 100644 index 0d4bbc8cf..000000000 --- a/pyannote/audio/torchmetrics/functional/audio/auder.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -from torch import Tensor - -from pyannote.audio.torchmetrics.functional.audio.der import ( - _check_valid_tensors, - _der_update, -) - - -def _auder_update( - preds: Tensor, - target: Tensor, - steps: int, - tmin: float, - tmax: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - _check_valid_tensors(preds, target) - - false_alarm = torch.zeros(steps, dtype=torch.float, device=preds.device) - missed_detection = torch.zeros(steps, dtype=torch.float, device=preds.device) - confusion = torch.zeros(steps, dtype=torch.float, device=preds.device) - total = torch.zeros(steps, dtype=torch.float, device=preds.device) - - linspace = np.linspace(tmin, tmax, steps) - for i in range(steps): - threshold = linspace[i] - - der_fa, der_md, der_conf, der_total = _der_update(preds, target, threshold) - false_alarm[i] += der_fa - missed_detection[i] += der_md - confusion[i] += der_conf - total[i] += der_total - return false_alarm, missed_detection, confusion, total - - -def _auder_compute( - false_alarm: Tensor, - missed_detection: Tensor, - confusion: Tensor, - total: Tensor, - dx: float, -) -> Tensor: - ders = (false_alarm + missed_detection + confusion) / total - return torch.trapezoid(ders, dx=dx) - - -def auder( - preds: Tensor, - target: Tensor, - steps: int, - tmin: float, - tmax: float, - unit_area: bool, -): - fa, md, conf, total = _auder_update(preds, target, steps, tmin, tmax) - dx = ((tmax - tmin) if not unit_area else 1.0) / (steps - 1) - return _auder_update(fa, md, conf, total, dx) diff --git a/pyannote/audio/torchmetrics/functional/audio/der.py b/pyannote/audio/torchmetrics/functional/audio/der.py deleted file mode 100644 index 91fb51a35..000000000 --- a/pyannote/audio/torchmetrics/functional/audio/der.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Tuple - -import torch -from torch import Tensor - -from pyannote.audio.utils.permutation import permutate - - -def _check_valid_tensors(preds: Tensor, target: Tensor): - """Check both tensors have shape (NUM_BATCH, NUM_CLASSES, NUM_FRAMES) with the same NUM_BATCH and NUM_FRAMES.""" - if len(preds.shape) != 3 or len(target.shape) != 3: - msg = f"Wrong shape ({tuple(target.shape)} or {tuple(preds.shape)}), expected (NUM_BATCH, NUM_CLASSES, NUM_FRAMES)." - raise ValueError(msg) - - batch_size, _, num_samples = target.shape - batch_size_, _, num_samples_ = preds.shape - if batch_size != batch_size_ or num_samples != num_samples_: - msg = f"Shape mismatch: {tuple(target.shape)} vs. {tuple(preds.shape)}. Both tensors should have the same NUM_BATCH and NUM_FRAMES." - raise ValueError(msg) - - -def _der_update( - preds: Tensor, target: Tensor, threshold: float -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Compute the false alarm, missed detection, confusion and total values. - - Parameters - ---------- - preds : torch.Tensor - preds torch.tensor of shape (B,C,F) - target : torch.Tensor - preds torch.tensor of shape (B,C,F) (must only contain 0s and 1s) - threshold : float - threshold to discretize preds - - Returns - ------- - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - Tensors with 1 item for false alarm, missed detection, confusion, and total - """ - - _check_valid_tensors(preds, target) - - preds_bin = (preds > threshold).float() - - # convert to/from pyannote's tensor ordering (batch,frames,class) (instead of (batch,class,frames)) - hypothesis, _ = permutate( - torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) - ) - hypothesis = torch.transpose(hypothesis, 1, 2) - - detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) - false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) - missed_detection = torch.maximum( - -detection_error, torch.zeros_like(detection_error) - ) - - confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm - - false_alarm = torch.sum(false_alarm) - missed_detection = torch.sum(missed_detection) - confusion = torch.sum(confusion) - total = 1.0 * torch.sum(target) - - return false_alarm, missed_detection, confusion, total - - -def _der_compute( - false_alarm: Tensor, missed_detection: Tensor, confusion: Tensor, total: Tensor -) -> Tensor: - return (false_alarm + missed_detection + confusion) / total - - -def der(preds: Tensor, target: Tensor, threshold: float = 0.5): - fa, md, conf, total = _der_update(preds, target) - return _der_compute(fa, md, conf, total) diff --git a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py new file mode 100644 index 000000000..77b4f3f3c --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Tuple + +import torch + +from pyannote.audio.utils.permutation import permutate + + +def _der_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + # TODO: consider doing the permutation before the binarization + # in order to improve robustness to mis-calibration. + preds_bin = (preds > threshold).float() + + # convert to/from "permutate" expected shapes + hypothesis, _ = permutate( + torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) + ) + hypothesis = torch.transpose(hypothesis, 1, 2) + + detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) + false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) + missed_detection = torch.maximum( + -detection_error, torch.zeros_like(detection_error) + ) + + speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm + + false_alarm = torch.sum(false_alarm) + missed_detection = torch.sum(missed_detection) + speaker_confusion = torch.sum(speaker_confusion) + speech_total = 1.0 * torch.sum(target) + + return false_alarm, missed_detection, speaker_confusion, speech_total + + +def _der_compute( + false_alarm: torch.Tensor, + missed_detection: torch.Tensor, + speaker_confusion: torch.Tensor, + speech_total: torch.Tensor, +) -> torch.Tensor: + """Compute diarization error rate from its components + + Parameters + ---------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components, in number of frames. + + Returns + ------- + der : torch.Tensor + Diarization error rate. + """ + + # TODO: handle corner case where speech_total == 0 + return (false_alarm + missed_detection + speaker_confusion) / speech_total + + +def diarization_error_rate( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> torch.Tensor: + """Compute diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold to binarize predictions. Defaults to 0.5. + + Returns + ------- + der : torch.Tensor + Aggregated diarization error rate + """ + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=threshold + ) + return _der_compute(false_alarm, missed_detection, speaker_confusion, speech_total) diff --git a/tutorials/add_your_own_task.ipynb b/tutorials/add_your_own_task.ipynb index 8022cbdc4..b2053f459 100644 --- a/tutorials/add_your_own_task.ipynb +++ b/tutorials/add_your_own_task.ipynb @@ -153,8 +153,6 @@ "class SoundEventDetection(Task):\n", " \"\"\"Sound event detection\"\"\"\n", "\n", - " ACRONYM = \"sed\"\n", - "\n", " def __init__(\n", " self,\n", " protocol: Protocol,\n",