diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index 4172601c3..3cd24ad52 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -30,7 +30,7 @@ from typing import Callable, Optional, Text, Tuple, Union import numpy as np -from scipy.ndimage import binary_dilation +from scipy.ndimage import binary_dilation, binary_opening import torch from einops import rearrange from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature @@ -161,12 +161,14 @@ def __init__( if self._segmentation.model.specifications[0].powerset: self.segmentation = ParamDict( + min_duration_on=Uniform(0.0, 1.0), min_duration_off=Uniform(0.0, 1.0), ) else: self.segmentation = ParamDict( threshold=Uniform(0.1, 0.9), + min_duration_on=Uniform(0.0, 1.0), min_duration_off=Uniform(0.0, 1.0), ) @@ -600,6 +602,19 @@ def apply( # shape: (num_speakers, ) discrete_diarization.data = discrete_diarization.data[:, active_speakers] num_frames, num_speakers = discrete_diarization.data.shape + + # filter out too short segments + min_frames_on = int( + self._segmentation.model.num_frames( + self.segmentation.min_duration_on * self._audio.sample_rate + ) + ) + print(min_frames_on) + if min_frames_on > 0: + discrete_diarization.data = binary_opening( + discrete_diarization.data, structure=np.array([[True] * min_frames_on]).T + ) + hook("discrete_diarization", discrete_diarization) clustered_separations = self.reconstruct(separations, hard_clusters, count)