diff --git a/src/diart/__init__.py b/src/diart/__init__.py index acbcc037..4540f4cc 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1 +1 @@ -from .blocks import OnlineSpeakerDiarization, PipelineConfig +from .blocks import OnlineSpeakerDiarization, PipelineConfig, OnlineSpeakerDiarizationHook diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index a02d52dd..f9cd8723 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,5 +13,5 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import OnlineSpeakerDiarization, PipelineConfig +from .diarization import OnlineSpeakerDiarization, PipelineConfig, OnlineSpeakerDiarizationHook from .utils import Binarize, Resample, AdjustVolume diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 882001b9..2f08b7cc 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -59,11 +59,20 @@ def num_blocked_speakers(self) -> int: @property def inactive_centers(self) -> List[int]: return [ - c - for c in range(self.max_speakers) + c for c in range(self.max_speakers) if c not in self.active_centers or c in self.blocked_centers ] + @property + def center_matrix(self) -> Optional[np.ndarray]: + if self.centers is None: + return None + active = np.array([ + c for c in range(self.max_speakers) + if c in self.active_centers and c not in self.blocked_centers + ], dtype=np.int) + return self.centers[active] + def get_next_center_position(self) -> Optional[int]: for center in range(self.max_speakers): if center not in self.active_centers and center not in self.blocked_centers: diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index c889f24d..5f76c149 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional, Any, Union, Tuple, Sequence import numpy as np @@ -11,6 +13,7 @@ from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m +from ..features import TemporalFeatures class PipelineConfig: @@ -85,9 +88,86 @@ def from_namespace(args: Any) -> 'PipelineConfig': ) +class OnlineSpeakerDiarizationHook: + def on_local_segmentation_batch( + self, + pipeline: OnlineSpeakerDiarization, + audio_batch: torch.Tensor, + segmentation_batch: TemporalFeatures + ): + pass + + def on_embedding_batch( + self, + pipeline: OnlineSpeakerDiarization, + audio_batch: torch.Tensor, + embedding_batch: torch.Tensor + ): + pass + + def on_local_segmentation( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature, + segmentation: SlidingWindowFeature + ): + pass + + def on_embeddings( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature, + embeddings: torch.Tensor + ): + pass + + def on_before_clustering( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature + ): + pass + + def on_after_clustering( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature, + clustering: OnlineSpeakerClustering, + segmentation: SlidingWindowFeature + ): + pass + + def on_soft_prediction( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature, + segmentation: SlidingWindowFeature + ): + pass + + def on_binary_prediction( + self, + pipeline: OnlineSpeakerDiarization, + waveform: SlidingWindowFeature, + diarization: Annotation + ): + pass + + def on_before_reset(self, pipeline: OnlineSpeakerDiarization): + pass + + def on_after_reset(self, pipeline: OnlineSpeakerDiarization,): + pass + + class OnlineSpeakerDiarization: - def __init__(self, config: Optional[PipelineConfig] = None): + def __init__( + self, + config: Optional[PipelineConfig] = None, + hooks: Optional[Sequence[OnlineSpeakerDiarizationHook]] = None, + ): self.config = PipelineConfig() if config is None else config + self.hooks = [] if hooks is None else hooks msg = f"Latency should be in the range [{self.config.step}, {self.config.duration}]" assert self.config.step <= self.config.latency <= self.config.duration, msg @@ -111,11 +191,14 @@ def __init__(self, config: Optional[PipelineConfig] = None): self.binarize = Binarize(self.config.tau_active) # Internal state, handle with care - self.clustering = None + self.clustering: Optional[OnlineSpeakerClustering] = None self.chunk_buffer, self.pred_buffer = [], [] self.reset() def reset(self): + for hook in self.hooks: + hook.on_before_reset(self) + self.clustering = OnlineSpeakerClustering( self.config.tau_active, self.config.rho_update, @@ -125,10 +208,13 @@ def reset(self): ) self.chunk_buffer, self.pred_buffer = [], [] + for hook in self.hooks: + hook.on_after_reset(self) + def __call__( self, waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Optional[Tuple[Annotation, SlidingWindowFeature]]]: + ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" assert batch_size >= 1, msg @@ -142,7 +228,12 @@ def __call__( # Extract segmentation and embeddings segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + for hook in self.hooks: + hook.on_local_segmentation_batch(self, batch, segmentations) + embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim) + for hook in self.hooks: + hook.on_embedding_batch(self, batch, embeddings) seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] @@ -156,8 +247,17 @@ def __call__( ) seg = SlidingWindowFeature(seg.cpu().numpy(), sw) + for hook in self.hooks: + hook.on_local_segmentation(self, wav, seg) + for hook in self.hooks: + hook.on_embeddings(self, wav, emb) + for hook in self.hooks: + hook.on_before_clustering(self, wav) + # Update clustering state and permute segmentation permuted_seg = self.clustering(seg, emb) + for hook in self.hooks: + hook.on_after_clustering(self, wav, self.clustering, permuted_seg) # Update sliding buffer self.chunk_buffer.append(wav) @@ -166,7 +266,13 @@ def __call__( # Aggregate buffer outputs for this time step agg_waveform = self.audio_aggregation(self.chunk_buffer) agg_prediction = self.pred_aggregation(self.pred_buffer) - outputs.append((self.binarize(agg_prediction), agg_waveform)) + for hook in self.hooks: + hook.on_soft_prediction(self, agg_waveform, agg_prediction) + + bin_prediction = self.binarize(agg_prediction) + outputs.append((bin_prediction, agg_waveform)) + for hook in self.hooks: + hook.on_binary_prediction(self, agg_waveform, bin_prediction) # Make place for new chunks in buffer if required if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: diff --git a/src/diart/inference.py b/src/diart/inference.py index 747f7f17..2f60ce2a 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -1,7 +1,7 @@ import logging from pathlib import Path from traceback import print_exc -from typing import Union, Text, Optional, Callable, Tuple, List +from typing import Union, Text, Optional, Any, Tuple, List, Sequence import diart.operators as dops import diart.sources as src @@ -10,7 +10,7 @@ import rx import rx.operators as ops from diart.blocks import OnlineSpeakerDiarization, Resample -from diart.sinks import DiarizationPredictionAccumulator, RTTMWriter, RealTimePlot, WindowClosedException +from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException from diart.utils import Chronometer from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm @@ -19,6 +19,26 @@ from tqdm import tqdm +class RealTimeInferenceHook: + def on_audio_chunk(self, waveform: SlidingWindowFeature): + pass + + def on_audio_batch(self, batch: Sequence[SlidingWindowFeature]): + pass + + def on_prediction_batch(self, batch: Sequence[Tuple[Annotation, SlidingWindowFeature]]): + pass + + def on_prediction(self, prediction: Annotation, waveform: SlidingWindowFeature): + pass + + def on_completion(self, prediction: Annotation): + pass + + def on_error(self, error: BaseException): + pass + + class RealTimeInference: """ Performs inference in real time given a pipeline and an audio source. @@ -60,12 +80,15 @@ def __init__( show_progress: bool = True, progress_desc: Optional[Text] = None, leave_progress_bar: bool = False, + hooks: Optional[Sequence[RealTimeInferenceHook]] = None, ): self.pipeline = pipeline self.source = source self.batch_size = batch_size self.do_profile = do_profile self.do_plot = do_plot + # TODO replace profiling and progress bar with hooks + self.hooks = [] if hooks is None else hooks self.accumulator = DiarizationPredictionAccumulator(source.uri) self._chrono = Chronometer("chunk" if self.batch_size == 1 else "batch") self._observers = [] @@ -94,7 +117,9 @@ def __init__( # Add rx operators to manage the inputs and outputs of the pipeline self.stream = self.stream.pipe( dops.rearrange_audio_stream(chunk_duration, step_duration, sample_rate), + ops.do_action(lambda wav: self._call_hooks("on_audio_chunk", wav)), ops.buffer_with_count(count=self.batch_size), + ops.do_action(lambda batch: self._call_hooks("on_audio_batch", batch)), ) if self.do_profile: @@ -107,7 +132,9 @@ def __init__( self.stream = self.stream.pipe(ops.map(pipeline)) self.stream = self.stream.pipe( + ops.do_action(lambda pred_batch: self._call_hooks("on_prediction_batch", pred_batch)), ops.flat_map(lambda results: rx.from_iterable(results)), + ops.do_action(lambda pred: self._call_hooks("on_prediction", *pred)), ops.do(self.accumulator), ) @@ -120,6 +147,11 @@ def __init__( ops.do_action(on_next=lambda _: self._pbar.update()) ) + def _call_hooks(self, method: Text, *args: Any): + """Call ``method`` on every attached hook with parameters ``args``""" + for hook in self.hooks: + getattr(hook, method)(*args) + def _close_pbar(self): if self._pbar is not None: self._pbar.close() @@ -130,16 +162,6 @@ def _close_chronometer(self): self._chrono.stop(do_count=False) self._chrono.report() - def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]): - """Attach hooks to the pipeline. - - Parameters - ---------- - *hooks: (Tuple[Annotation, SlidingWindowFeature]) -> None - Hook functions to consume emitted annotations and audio. - """ - self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks]) - def attach_observers(self, *observers: Observer): """Attach rx observers to the pipeline. @@ -155,6 +177,7 @@ def _handle_error(self, error: BaseException): # Compensate for Rx not always calling on_error for sink in self._observers: sink.on_error(error) + self._call_hooks("on_error", error) # Always close the source in case of bad termination self.source.close() # Special treatment for a user interruption (counted as normal termination) @@ -167,6 +190,7 @@ def _handle_error(self, error: BaseException): self._close_chronometer() def _handle_completion(self): + self._call_hooks("on_completion", self.accumulator.get_prediction()) # Close progress and chronometer states self._close_pbar() self._close_chronometer() @@ -201,6 +225,34 @@ def __call__(self) -> Annotation: return self.accumulator.get_prediction() +class BenchmarkHook: + def on_before_dataset(self, num_files: int): + pass + + def on_before_file( + self, + source: src.FileAudioSource, + pipeline: OnlineSpeakerDiarization + ): + pass + + def on_after_file( + self, + source: src.FileAudioSource, + pipeline: OnlineSpeakerDiarization, + prediction: Annotation, + reference: Optional[Annotation], + ): + pass + + def on_after_dataset( + self, + predictions: Sequence[Annotation], + references: Optional[Sequence[Annotation]], + ): + pass + + class Benchmark: """ Run an online speaker diarization pipeline on a set of audio files in batches. @@ -242,6 +294,8 @@ def __init__( show_progress: bool = True, show_report: bool = True, batch_size: int = 32, + hooks: Optional[Sequence[BenchmarkHook]] = None, + inference_hooks: Optional[Sequence[RealTimeInferenceHook]] = None, ): self.speech_path = Path(speech_path).expanduser() assert self.speech_path.is_dir(), "Speech path must be a directory" @@ -263,6 +317,9 @@ def __init__( self.show_progress = show_progress self.show_report = show_report self.batch_size = batch_size + # TODO implement writing to disk and evaluation/reporting as hooks + self.hooks = [] if hooks is None else hooks + self.inference_hooks = inference_hooks def __call__(self, pipeline: OnlineSpeakerDiarization) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. @@ -285,7 +342,12 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Union[pd.DataFrame, Li pipeline.reset() audio_file_paths = list(self.speech_path.iterdir()) num_audio_files = len(audio_file_paths) - predictions = [] + predictions, references = [], [] + + # Hook calls + for hook in self.hooks: + hook.on_before_dataset(num_audio_files) + for i, filepath in enumerate(audio_file_paths): stream_padding = pipeline.config.latency - pipeline.config.step block_size = int(np.rint(pipeline.config.step * pipeline.config.sample_rate)) @@ -299,9 +361,26 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Union[pd.DataFrame, Li show_progress=self.show_progress, progress_desc=f"Streaming {source.uri} ({i + 1}/{num_audio_files})", leave_progress_bar=False, + hooks=self.inference_hooks, ) + + # Hook calls + for hook in self.hooks: + hook.on_before_file(source, pipeline) + pred = inference() pred.uri = source.uri + + # Load reference file for hook calls if possible + ref = None + if self.reference_path is not None: + ref = load_rttm(self.reference_path / f"{source.uri}.rttm").popitem()[1] + references.append(ref) + + # Hook calls + for hook in self.hooks: + hook.on_after_file(source, pipeline, pred, ref) + predictions.append(pred) if self.output_path is not None: @@ -311,6 +390,10 @@ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Union[pd.DataFrame, Li # Reset internal state for the next file pipeline.reset() + # Hook calls + for hook in self.hooks: + hook.on_after_dataset(predictions, None if references else references) + # Run evaluation if self.reference_path is not None: metric = DiarizationErrorRate(collar=0, skip_overlap=False) diff --git a/src/diart/optim.py b/src/diart/optim.py index ede516b7..f2981de5 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,7 +1,7 @@ from collections import OrderedDict from dataclasses import dataclass from pathlib import Path -from typing import Sequence, Text, Optional, Union +from typing import Sequence, Text, Optional, Union, Any, Dict from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler @@ -75,13 +75,20 @@ def __init__( raise ValueError(msg) @property - def best_performance(self): + def best_performance(self) -> float: return self.study.best_value @property - def best_hparams(self): + def best_hparams(self) -> Dict[Text, float]: return self.study.best_params + @property + def best_config(self) -> PipelineConfig: + config = vars(self.base_config) + for name, value in self.best_hparams.items(): + config[name] = value + return PipelineConfig(**config) + def _callback(self, study: Study, trial: FrozenTrial): if self._progress is None: return diff --git a/src/diart/sinks.py b/src/diart/sinks.py index f0797030..2c99ec1b 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -9,6 +9,9 @@ from typing_extensions import Literal +# TODO sinks could also implement inference hooks + + class WindowClosedException(Exception): pass