Skip to content

Commit

Permalink
Remove PipelineConfig.from_dict() (#189)
Browse files Browse the repository at this point in the history
* Unify hparam naming. Clean some typing annotations

* Remove config.from_dict(). Add --duration argument to CLI

* Update README.md accordingly
  • Loading branch information
juanmc2005 authored Oct 19, 2023
1 parent 45f8ad9 commit 0113ab2
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 140 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ To obtain the best results, make sure to use the following hyper-parameters:
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:

```shell
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021
```

or using the inference API:
Expand Down
10 changes: 2 additions & 8 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Tuple, Sequence, Text
from dataclasses import dataclass
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Tuple, Sequence, Text

import numpy as np
from pyannote.core import SlidingWindowFeature
from pyannote.metrics.base import BaseMetric

Expand Down Expand Up @@ -53,11 +52,6 @@ def latency(self) -> float:
def sample_rate(self) -> int:
pass

@staticmethod
@abstractmethod
def from_dict(data: Any) -> "PipelineConfig":
pass

def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
right = utils.get_padding_right(self.latency, self.step)
Expand Down
83 changes: 21 additions & 62 deletions src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional, Tuple, Sequence, Union, Any
from __future__ import annotations

from typing import Sequence

import numpy as np
import torch
Expand All @@ -14,40 +16,37 @@
from .segmentation import SpeakerSegmentation
from .utils import Binarize
from .. import models as m
from .. import utils


class SpeakerDiarizationConfig(base.PipelineConfig):
def __init__(
self,
segmentation: Optional[m.SegmentationModel] = None,
embedding: Optional[m.EmbeddingModel] = None,
duration: Optional[float] = None,
segmentation: m.SegmentationModel | None = None,
embedding: m.EmbeddingModel | None = None,
duration: float | None = None,
step: float = 0.5,
latency: Optional[Union[float, Literal["max", "min"]]] = None,
latency: float | Literal["max", "min"] | None = None,
tau_active: float = 0.6,
rho_update: float = 0.3,
delta_new: float = 1,
gamma: float = 3,
beta: float = 10,
max_speakers: int = 20,
device: Optional[torch.device] = None,
device: torch.device | None = None,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
if self.segmentation is None:
self.segmentation = m.SegmentationModel.from_pyannote(
"pyannote/segmentation"
)

self._duration = duration
self._sample_rate: Optional[int] = None
self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
"pyannote/segmentation"
)

# Default embedding model is pyannote/embedding
self.embedding = embedding
if self.embedding is None:
self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
self.embedding = embedding or m.EmbeddingModel.from_pyannote(
"pyannote/embedding"
)

self._duration = duration
self._sample_rate: int | None = None

# Latency defaults to the step duration
self._step = step
Expand All @@ -64,48 +63,8 @@ def __init__(
self.beta = beta
self.max_speakers = max_speakers

self.device = device
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@staticmethod
def from_dict(data: Any) -> "SpeakerDiarizationConfig":
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
device = utils.get(data, "device", None)
if device is None:
device = torch.device("cpu") if utils.get(data, "cpu", False) else None

# Instantiate models
hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
embedding = utils.get(data, "embedding", "pyannote/embedding")
embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)

# Hyper-parameters and their aliases
tau = utils.get(data, "tau_active", None)
if tau is None:
tau = utils.get(data, "tau", 0.6)
rho = utils.get(data, "rho_update", None)
if rho is None:
rho = utils.get(data, "rho", 0.3)
delta = utils.get(data, "delta_new", None)
if delta is None:
delta = utils.get(data, "delta", 1)

return SpeakerDiarizationConfig(
segmentation=segmentation,
embedding=embedding,
duration=utils.get(data, "duration", None),
step=utils.get(data, "step", 0.5),
latency=utils.get(data, "latency", None),
tau_active=tau,
rho_update=rho,
delta_new=delta,
gamma=utils.get(data, "gamma", 3),
beta=utils.get(data, "beta", 10),
max_speakers=utils.get(data, "max_speakers", 20),
device=device,
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)

@property
Expand All @@ -132,7 +91,7 @@ def sample_rate(self) -> int:


class SpeakerDiarization(base.Pipeline):
def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
def __init__(self, config: SpeakerDiarizationConfig | None = None):
self._config = SpeakerDiarizationConfig() if config is None else config

msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
Expand Down Expand Up @@ -200,7 +159,7 @@ def reset(self):

def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
Expand Down
60 changes: 17 additions & 43 deletions src/diart/blocks/vad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Union, Sequence, Tuple
from __future__ import annotations

from typing import Sequence

import numpy as np
import torch
Expand All @@ -13,8 +15,8 @@
from pyannote.metrics.detection import DetectionErrorRate
from typing_extensions import Literal

from .aggregation import DelayedAggregation
from . import base
from .aggregation import DelayedAggregation
from .segmentation import SpeakerSegmentation
from .utils import Binarize
from .. import models as m
Expand All @@ -24,24 +26,22 @@
class VoiceActivityDetectionConfig(base.PipelineConfig):
def __init__(
self,
segmentation: Optional[m.SegmentationModel] = None,
duration: Optional[float] = None,
segmentation: m.SegmentationModel | None = None,
duration: float | None = None,
step: float = 0.5,
latency: Optional[Union[float, Literal["max", "min"]]] = None,
latency: float | Literal["max", "min"] | None = None,
tau_active: float = 0.6,
device: Optional[torch.device] = None,
device: torch.device | None = None,
**kwargs,
):
# Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
if self.segmentation is None:
self.segmentation = m.SegmentationModel.from_pyannote(
"pyannote/segmentation"
)
self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
"pyannote/segmentation"
)

self._duration = duration
self._step = step
self._sample_rate: Optional[int] = None
self._sample_rate: int | None = None

# Latency defaults to the step duration
self._latency = latency
Expand All @@ -51,9 +51,9 @@ def __init__(
self._latency = self._duration

self.tau_active = tau_active
self.device = device
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)

@property
def duration(self) -> float:
Expand All @@ -77,35 +77,9 @@ def sample_rate(self) -> int:
self._sample_rate = self.segmentation.sample_rate
return self._sample_rate

@staticmethod
def from_dict(data: Any) -> "VoiceActivityDetectionConfig":
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
device = utils.get(data, "device", None)
if device is None:
device = torch.device("cpu") if utils.get(data, "cpu", False) else None

# Instantiate segmentation model
hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)

# Tau active and its alias
tau = utils.get(data, "tau_active", None)
if tau is None:
tau = utils.get(data, "tau", 0.6)

return VoiceActivityDetectionConfig(
segmentation=segmentation,
duration=utils.get(data, "duration", None),
step=utils.get(data, "step", 0.5),
latency=utils.get(data, "latency", None),
tau_active=tau,
device=device,
)


class VoiceActivityDetection(base.Pipeline):
def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
def __init__(self, config: VoiceActivityDetectionConfig | None = None):
self._config = VoiceActivityDetectionConfig() if config is None else config

msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
Expand Down Expand Up @@ -158,7 +132,7 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self,
waveforms: Sequence[SlidingWindowFeature],
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
Expand Down
24 changes: 20 additions & 4 deletions src/diart/console/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from pathlib import Path

import pandas as pd
import torch

from diart import argdoc
from diart import models as m
from diart import utils
from diart.inference import Benchmark, Parallelize

Expand Down Expand Up @@ -37,20 +40,25 @@ def run():
type=Path,
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
)
parser.add_argument(
"--duration",
type=float,
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
)
parser.add_argument(
"--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
)
parser.add_argument(
"--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
)
parser.add_argument(
"--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
"--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
)
parser.add_argument(
"--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
"--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
)
parser.add_argument(
"--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
"--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
)
parser.add_argument(
"--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
Expand Down Expand Up @@ -93,6 +101,14 @@ def run():
)
args = parser.parse_args()

# Resolve device
args.device = torch.device("cpu") if args.cpu else None

# Resolve models
hf_token = utils.parse_hf_token_arg(args.hf_token)
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)

pipeline_class = utils.get_pipeline_class(args.pipeline)

benchmark = Benchmark(
Expand All @@ -104,7 +120,7 @@ def run():
batch_size=args.batch_size,
)

config = pipeline_class.get_config_class().from_dict(vars(args))
config = pipeline_class.get_config_class()(**vars(args))
if args.num_workers > 0:
benchmark = Parallelize(benchmark, args.num_workers)

Expand Down
4 changes: 2 additions & 2 deletions src/diart/console/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from threading import Thread
from typing import Text, Optional

import numpy as np
import rx.operators as ops
from websocket import WebSocket

from diart import argdoc
from diart import sources as src
from diart import utils
from websocket import WebSocket


def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):
Expand Down
Loading

0 comments on commit 0113ab2

Please sign in to comment.