diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 0a2559c1..c0bed63a 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -23,7 +23,7 @@ class HyperParameter: @staticmethod def from_name(name: Text) -> "HyperParameter": """Create a HyperParameter object given its name. - + Parameters ---------- name: str diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 6c865fc8..151b4d36 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -193,7 +193,9 @@ def __call__( for wav, seg, emb in zip(waveforms, segmentations, embeddings): # Add timestamps to segmentation sw = SlidingWindow( - start=wav.extent.start, duration=seg_resolution, step=seg_resolution, + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, ) seg = SlidingWindowFeature(seg.cpu().numpy(), sw) diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index d6925487..f299c94b 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -125,7 +125,8 @@ def set_timestamp_shift(self, shift: float): self.timestamp_shift = shift def __call__( - self, waveforms: Sequence[SlidingWindowFeature], + self, + waveforms: Sequence[SlidingWindowFeature], ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" @@ -152,7 +153,9 @@ def __call__( for wav, vad in zip(waveforms, voice_detection): # Add timestamps to segmentation sw = SlidingWindow( - start=wav.extent.start, duration=seg_resolution, step=seg_resolution, + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, ) vad = SlidingWindowFeature(vad.cpu().numpy(), sw) diff --git a/src/diart/inference.py b/src/diart/inference.py index 74ba01f1..3eb72930 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -123,7 +123,9 @@ def __init__( ) # Form batches - self.stream = self.stream.pipe(ops.buffer_with_count(count=self.batch_size),) + self.stream = self.stream.pipe( + ops.buffer_with_count(count=self.batch_size), + ) if self.do_profile: self.stream = self.stream.pipe( @@ -221,7 +223,8 @@ def __call__(self) -> Annotation: ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( - on_error=self._handle_error, on_completed=self._handle_completion, + on_error=self._handle_error, + on_completed=self._handle_completion, ) # FIXME if read() isn't blocking, the prediction returned is empty self.source.read() @@ -303,7 +306,10 @@ def get_file_paths(self) -> List[Path]: return list(self.speech_path.iterdir()) def run_single( - self, pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, + self, + pipeline: blocks.Pipeline, + filepath: Path, + progress_bar: ProgressBar, ) -> Annotation: """Run a given pipeline on a given file. Note that this method does NOT reset the @@ -325,7 +331,10 @@ def run_single( """ padding = pipeline.config.get_file_padding(filepath) source = src.FileAudioSource( - filepath, pipeline.config.sample_rate, padding, pipeline.config.step, + filepath, + pipeline.config.sample_rate, + padding, + pipeline.config.step, ) pipeline.set_timestamp_shift(-padding[0]) inference = StreamingInference( @@ -348,7 +357,9 @@ def run_single( return pred def evaluate( - self, predictions: List[Annotation], metric: BaseMetric, + self, + predictions: List[Annotation], + metric: BaseMetric, ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -435,7 +446,9 @@ class Parallelize: """ def __init__( - self, benchmark: Benchmark, num_workers: int = 4, + self, + benchmark: Benchmark, + num_workers: int = 4, ): self.benchmark = benchmark self.num_workers = num_workers diff --git a/src/diart/mapping.py b/src/diart/mapping.py index dbfc978e..847d0f78 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -126,12 +126,14 @@ def hard_map( @staticmethod def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap: - score_matrix_per_frame = np.stack( # (local_speakers, num_frames, global_speakers) - [ - scores1[:, speaker : speaker + 1] * scores2 - for speaker in range(scores1.shape[1]) - ], - axis=0, + score_matrix_per_frame = ( + np.stack( # (local_speakers, num_frames, global_speakers) + [ + scores1[:, speaker : speaker + 1] * scores2 + for speaker in range(scores1.shape[1]) + ], + axis=0, + ) ) # Calculate total speech "activations" per local speaker local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1) @@ -213,7 +215,9 @@ def _loose_check_valid(self, src: int, tgt: int) -> bool: return self.is_source_speaker_mapped(src) def valid_assignments( - self, strict: bool = False, as_array: bool = False, + self, + strict: bool = False, + as_array: bool = False, ) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]: source, target = [], [] val_type = "strict" if strict else "loose" diff --git a/src/diart/models.py b/src/diart/models.py index a724d2c0..c23e5774 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -93,7 +93,9 @@ def recreate_session(self): onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL ) self.session = onnxruntime.InferenceSession( - self.path, sess_options=options, providers=[self.execution_provider], + self.path, + sess_options=options, + providers=[self.execution_provider], ) def to(self, device: torch.device) -> ONNXModel: diff --git a/src/diart/operators.py b/src/diart/operators.py index bfd51535..7ce13285 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -134,7 +134,8 @@ def initial() -> "OutputAccumulationState": @property def cropped_waveform(self) -> SlidingWindowFeature: return SlidingWindowFeature( - self.waveform[: self.next_sample], self.waveform.sliding_window, + self.waveform[: self.next_sample], + self.waveform.sliding_window, ) def to_tuple( @@ -144,7 +145,9 @@ def to_tuple( def accumulate_output( - duration: float, step: float, patch_collar: float = 0.05, + duration: float, + step: float, + patch_collar: float = 0.05, ) -> Operator: """Accumulate predictions and audio to infinity: O(N) space complexity. Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations. diff --git a/src/diart/sources.py b/src/diart/sources.py index 59bbe3a5..82051b2e 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -238,7 +238,10 @@ def __init__( self.server.set_fn_message_received(self._on_message_received) def _on_message_received( - self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr, + self, + client: Dict[Text, Any], + server: WebsocketServer, + message: AnyStr, ): # Only one client at a time is allowed if self.client is None or self.client["id"] != client["id"]: