From 0113ab2d56398224157615a70325802602a3ec91 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Thu, 19 Oct 2023 18:20:57 +0200 Subject: [PATCH] Remove `PipelineConfig.from_dict()` (#189) * Unify hparam naming. Clean some typing annotations * Remove config.from_dict(). Add --duration argument to CLI * Update README.md accordingly --- README.md | 2 +- src/diart/blocks/base.py | 10 +--- src/diart/blocks/diarization.py | 83 +++++++++------------------------ src/diart/blocks/vad.py | 60 +++++++----------------- src/diart/console/benchmark.py | 24 ++++++++-- src/diart/console/client.py | 4 +- src/diart/console/serve.py | 24 ++++++++-- src/diart/console/stream.py | 23 +++++++-- src/diart/console/tune.py | 26 +++++++++-- src/diart/utils.py | 8 +--- 10 files changed, 124 insertions(+), 140 deletions(-) diff --git a/README.md b/README.md index caef5045..3d31bc6d 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: ```shell -diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021 +diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021 ``` or using the inference API: diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 6536a3f7..f6ca3a33 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -1,8 +1,7 @@ -from typing import Any, Tuple, Sequence, Text -from dataclasses import dataclass from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Tuple, Sequence, Text -import numpy as np from pyannote.core import SlidingWindowFeature from pyannote.metrics.base import BaseMetric @@ -53,11 +52,6 @@ def latency(self) -> float: def sample_rate(self) -> int: pass - @staticmethod - @abstractmethod - def from_dict(data: Any) -> "PipelineConfig": - pass - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) right = utils.get_padding_right(self.latency, self.step) diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 3cf4e333..fab83c36 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,4 +1,6 @@ -from typing import Optional, Tuple, Sequence, Union, Any +from __future__ import annotations + +from typing import Sequence import numpy as np import torch @@ -14,40 +16,37 @@ from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m -from .. import utils class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, + segmentation: m.SegmentationModel | None = None, + embedding: m.EmbeddingModel | None = None, + duration: float | None = None, step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, + latency: float | Literal["max", "min"] | None = None, tau_active: float = 0.6, rho_update: float = 0.3, delta_new: float = 1, gamma: float = 3, beta: float = 10, max_speakers: int = 20, - device: Optional[torch.device] = None, + device: torch.device | None = None, **kwargs, ): # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote( - "pyannote/segmentation" - ) - - self._duration = duration - self._sample_rate: Optional[int] = None + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + self.embedding = embedding or m.EmbeddingModel.from_pyannote( + "pyannote/embedding" + ) + + self._duration = duration + self._sample_rate: int | None = None # Latency defaults to the step duration self._step = step @@ -64,48 +63,8 @@ def __init__( self.beta = beta self.max_speakers = max_speakers - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> "SpeakerDiarizationConfig": - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return SpeakerDiarizationConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" ) @property @@ -132,7 +91,7 @@ def sample_rate(self) -> int: class SpeakerDiarization(base.Pipeline): - def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): + def __init__(self, config: SpeakerDiarizationConfig | None = None): self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" @@ -200,7 +159,7 @@ def reset(self): def __call__( self, waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index 04fe5608..0edd3e0b 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Union, Sequence, Tuple +from __future__ import annotations + +from typing import Sequence import numpy as np import torch @@ -13,8 +15,8 @@ from pyannote.metrics.detection import DetectionErrorRate from typing_extensions import Literal -from .aggregation import DelayedAggregation from . import base +from .aggregation import DelayedAggregation from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m @@ -24,24 +26,22 @@ class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, - segmentation: Optional[m.SegmentationModel] = None, - duration: Optional[float] = None, + segmentation: m.SegmentationModel | None = None, + duration: float | None = None, step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, + latency: float | Literal["max", "min"] | None = None, tau_active: float = 0.6, - device: Optional[torch.device] = None, + device: torch.device | None = None, **kwargs, ): # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote( - "pyannote/segmentation" - ) + self.segmentation = segmentation or m.SegmentationModel.from_pyannote( + "pyannote/segmentation" + ) self._duration = duration self._step = step - self._sample_rate: Optional[int] = None + self._sample_rate: int | None = None # Latency defaults to the step duration self._latency = latency @@ -51,9 +51,9 @@ def __init__( self._latency = self._duration self.tau_active = tau_active - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) @property def duration(self) -> float: @@ -77,35 +77,9 @@ def sample_rate(self) -> int: self._sample_rate = self.segmentation.sample_rate return self._sample_rate - @staticmethod - def from_dict(data: Any) -> "VoiceActivityDetectionConfig": - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate segmentation model - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - - # Tau active and its alias - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - - return VoiceActivityDetectionConfig( - segmentation=segmentation, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - device=device, - ) - class VoiceActivityDetection(base.Pipeline): - def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): + def __init__(self, config: VoiceActivityDetectionConfig | None = None): self._config = VoiceActivityDetectionConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" @@ -158,7 +132,7 @@ def set_timestamp_shift(self, shift: float): def __call__( self, waveforms: Sequence[SlidingWindowFeature], - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index 70b4c3d9..b5a296d1 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -2,7 +2,10 @@ from pathlib import Path import pandas as pd +import torch + from diart import argdoc +from diart import models as m from diart import utils from diart.inference import Benchmark, Parallelize @@ -37,6 +40,11 @@ def run(): type=Path, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -44,13 +52,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -93,6 +101,14 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + pipeline_class = utils.get_pipeline_class(args.pipeline) benchmark = Benchmark( @@ -104,7 +120,7 @@ def run(): batch_size=args.batch_size, ) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) diff --git a/src/diart/console/client.py b/src/diart/console/client.py index b656298a..b3de36db 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,12 +3,12 @@ from threading import Thread from typing import Text, Optional -import numpy as np import rx.operators as ops +from websocket import WebSocket + from diart import argdoc from diart import sources as src from diart import utils -from websocket import WebSocket def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index d8c059c3..bc002e42 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,7 +1,10 @@ import argparse from pathlib import Path +import torch + from diart import argdoc +from diart import models as m from diart import sources as src from diart import utils from diart.inference import StreamingInference @@ -30,6 +33,11 @@ def run(): type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -37,13 +45,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -74,9 +82,17 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) pipeline = pipeline_class(config) # Create websocket audio source diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 87f3a0a1..713f3e99 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,7 +1,10 @@ import argparse from pathlib import Path +import torch + from diart import argdoc +from diart import models as m from diart import sources as src from diart import utils from diart.inference import StreamingInference @@ -34,7 +37,9 @@ def run(): help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) parser.add_argument( - "--duration", type=float, help=f"{argdoc.DURATION}. Defaults to training segmentation duration" + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" @@ -43,13 +48,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -88,9 +93,17 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class().from_dict(vars(args)) + config = pipeline_class.get_config_class()(**vars(args)) pipeline = pipeline_class(config) # Manage audio source diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4c969efa..ec243348 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -2,11 +2,14 @@ from pathlib import Path import optuna +import torch +from optuna.samplers import TPESampler + from diart import argdoc +from diart import models as m from diart import utils from diart.blocks.base import HyperParameter from diart.optim import Optimizer -from optuna.samplers import TPESampler def run(): @@ -40,6 +43,11 @@ def run(): type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", ) + parser.add_argument( + "--duration", + type=float, + help=f"{argdoc.DURATION}. Defaults to training segmentation duration", + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -47,13 +55,13 @@ def run(): "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" ) parser.add_argument( - "--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" + "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" ) parser.add_argument( - "--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" + "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" ) parser.add_argument( - "--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" + "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" ) parser.add_argument( "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" @@ -102,11 +110,19 @@ def run(): ) args = parser.parse_args() + # Resolve device + args.device = torch.device("cpu") if args.cpu else None + + # Resolve models + hf_token = utils.parse_hf_token_arg(args.hf_token) + args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token) + args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token) + # Retrieve pipeline class pipeline_class = utils.get_pipeline_class(args.pipeline) # Create the base configuration for each trial - base_config = pipeline_class.get_config_class().from_dict(vars(args)) + base_config = pipeline_class.get_config_class()(**vars(args)) # Create hyper-parameters to optimize possible_hparams = pipeline_class.hyper_parameters() diff --git a/src/diart/utils.py b/src/diart/utils.py index f0eb4751..ca27d022 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -1,13 +1,13 @@ import base64 import time -from typing import Optional, Text, Union, Any, Dict +from typing import Optional, Text, Union import matplotlib.pyplot as plt import numpy as np from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook -from .progress import ProgressBar from . import blocks +from .progress import ProgressBar class Chronometer: @@ -53,10 +53,6 @@ def parse_hf_token_arg(hf_token: Union[bool, Text]) -> Union[bool, Text]: return hf_token -def get(data: Dict[Text, Any], key: Text, default: Any) -> Any: - return data[key] if key in data else default - - def encode_audio(waveform: np.ndarray) -> Text: data = waveform.astype(np.float32).tobytes() return base64.b64encode(data).decode("utf-8")