Skip to content

Commit

Permalink
fix: ensure round number of frames
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Jan 9, 2025
1 parent 0b7f933 commit 3c167f2
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from typing import Callable, Optional, Text, Tuple, Union

import numpy as np
from scipy.ndimage import binary_dilation
import torch
from einops import rearrange
from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature
from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.pipeline.parameter import Categorical, ParamDict, Uniform
from scipy.ndimage import binary_dilation

from pyannote.audio import Audio, Inference, Model, Pipeline
from pyannote.audio.core.io import AudioFile
Expand Down Expand Up @@ -595,7 +595,7 @@ def apply(
count,
)
discrete_diarization = self.to_diarization(discrete_diarization, count)
# remove file-wise inactive speakers from the diarization
# remove file-wise inactive speakers from the diarization
active_speakers = np.sum(discrete_diarization, axis=0) > 0
# shape: (num_speakers, )
discrete_diarization.data = discrete_diarization.data[:, active_speakers]
Expand All @@ -615,10 +615,12 @@ def apply(

_, num_sources = sources.data.shape

# In some cases, maximum num of simultaneous speakers is greater than num of clusters,
# implying a num of speakers in the diarization greater than num of sources after calling
# In some cases, maximum num of simultaneous speakers is greater than num of clusters,
# implying a num of speakers in the diarization greater than num of sources after calling
# to_diarization(). So we add dummy sources to match the number of speakers in diarization.
sources.data = np.pad(sources.data, ((0, 0), (0, max(0, num_speakers - num_sources))))
sources.data = np.pad(
sources.data, ((0, 0), (0, max(0, num_speakers - num_sources)))
)

# remove sources corresponding to file-wise inactive speakers
sources.data = sources.data[:, active_speakers]
Expand All @@ -628,18 +630,24 @@ def apply(
if self.separation.leakage_removal:
asr_collar_frames = int(
self._segmentation.model.num_frames(
self.separation.asr_collar * self._audio.sample_rate
round(self.separation.asr_collar * self._audio.sample_rate)
)
)
if asr_collar_frames > 0:
dilated_speaker_activations = np.zeros_like(discrete_diarization.data)
for i in range(num_speakers):
speaker_activation = discrete_diarization.data.T[i]
non_silent = speaker_activation != 0
dilated_non_silent = binary_dilation(non_silent, [True] * (2 * asr_collar_frames))
dilated_speaker_activations.T[i] = dilated_non_silent.astype(np.int8)

dilated_speaker_activations = SlidingWindowFeature(dilated_speaker_activations, discrete_diarization.sliding_window)
dilated_non_silent = binary_dilation(
non_silent, [True] * (2 * asr_collar_frames)
)
dilated_speaker_activations.T[i] = dilated_non_silent.astype(
np.int8
)

dilated_speaker_activations = SlidingWindowFeature(
dilated_speaker_activations, discrete_diarization.sliding_window
)
sources.data = (
sources.data * dilated_speaker_activations.align(sources).data
)
Expand Down

0 comments on commit 3c167f2

Please sign in to comment.