Skip to content

Commit

Permalink
add min_duration_on to separation pipeline hparams
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages committed Dec 17, 2024
1 parent 0b7f933 commit 8211074
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing import Callable, Optional, Text, Tuple, Union

import numpy as np
from scipy.ndimage import binary_dilation
from scipy.ndimage import binary_dilation, binary_opening
import torch
from einops import rearrange
from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature
Expand Down Expand Up @@ -161,12 +161,14 @@ def __init__(

if self._segmentation.model.specifications[0].powerset:
self.segmentation = ParamDict(
min_duration_on=Uniform(0.0, 1.0),
min_duration_off=Uniform(0.0, 1.0),
)

else:
self.segmentation = ParamDict(
threshold=Uniform(0.1, 0.9),
min_duration_on=Uniform(0.0, 1.0),
min_duration_off=Uniform(0.0, 1.0),
)

Expand Down Expand Up @@ -600,6 +602,19 @@ def apply(
# shape: (num_speakers, )
discrete_diarization.data = discrete_diarization.data[:, active_speakers]
num_frames, num_speakers = discrete_diarization.data.shape

# filter out too short segments
min_frames_on = int(
self._segmentation.model.num_frames(
self.segmentation.min_duration_on * self._audio.sample_rate
)
)
print(min_frames_on)
if min_frames_on > 0:
discrete_diarization.data = binary_opening(
discrete_diarization.data, structure=np.array([[True] * min_frames_on]).T
)

hook("discrete_diarization", discrete_diarization)

clustered_separations = self.reconstruct(separations, hard_clusters, count)
Expand Down

0 comments on commit 8211074

Please sign in to comment.