From d2f154d201d8265a427395dc8ca7ff4acb2e6ba2 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 13 Nov 2023 20:52:29 +0100 Subject: [PATCH] Blacken code --- src/diart/blocks/base.py | 12 ++++++++---- src/diart/blocks/diarization.py | 4 +--- src/diart/blocks/vad.py | 7 ++----- src/diart/inference.py | 25 ++++++------------------- src/diart/mapping.py | 18 +++++++----------- src/diart/models.py | 8 ++------ src/diart/operators.py | 7 ++----- src/diart/sources.py | 5 +---- 8 files changed, 29 insertions(+), 57 deletions(-) diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index edcaf66d..0a2559c1 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -48,7 +48,8 @@ def from_name(name: Text) -> "HyperParameter": class PipelineConfig(ABC): - """Configuration containing the required parameters to build and run a pipeline""" + """Configuration containing the required + parameters to build and run a pipeline""" @property @abstractmethod @@ -66,7 +67,8 @@ def step(self) -> float: @abstractmethod def latency(self) -> float: """The algorithmic latency of the pipeline (in seconds). - At time `t` of the audio stream, the pipeline will output predictions for time `t - latency`. + At time `t` of the audio stream, the pipeline will + output predictions for time `t - latency`. """ pass @@ -118,7 +120,8 @@ def set_timestamp_shift(self, shift: float): def __call__( self, waveforms: Sequence[SlidingWindowFeature] ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: - """Runs the next steps of the pipeline given a list of consecutive audio chunks. + """Runs the next steps of the pipeline + given a list of consecutive audio chunks. Parameters ---------- @@ -128,6 +131,7 @@ def __call__( Returns ------- Sequence[Tuple[Any, SlidingWindowFeature]] - For each input waveform, a tuple containing the pipeline output and its respective audio + For each input waveform, a tuple containing + the pipeline output and its respective audio """ pass diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 151b4d36..6c865fc8 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -193,9 +193,7 @@ def __call__( for wav, seg, emb in zip(waveforms, segmentations, embeddings): # Add timestamps to segmentation sw = SlidingWindow( - start=wav.extent.start, - duration=seg_resolution, - step=seg_resolution, + start=wav.extent.start, duration=seg_resolution, step=seg_resolution, ) seg = SlidingWindowFeature(seg.cpu().numpy(), sw) diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index f299c94b..d6925487 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -125,8 +125,7 @@ def set_timestamp_shift(self, shift: float): self.timestamp_shift = shift def __call__( - self, - waveforms: Sequence[SlidingWindowFeature], + self, waveforms: Sequence[SlidingWindowFeature], ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" @@ -153,9 +152,7 @@ def __call__( for wav, vad in zip(waveforms, voice_detection): # Add timestamps to segmentation sw = SlidingWindow( - start=wav.extent.start, - duration=seg_resolution, - step=seg_resolution, + start=wav.extent.start, duration=seg_resolution, step=seg_resolution, ) vad = SlidingWindowFeature(vad.cpu().numpy(), sw) diff --git a/src/diart/inference.py b/src/diart/inference.py index 3eb72930..74ba01f1 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -123,9 +123,7 @@ def __init__( ) # Form batches - self.stream = self.stream.pipe( - ops.buffer_with_count(count=self.batch_size), - ) + self.stream = self.stream.pipe(ops.buffer_with_count(count=self.batch_size),) if self.do_profile: self.stream = self.stream.pipe( @@ -223,8 +221,7 @@ def __call__(self) -> Annotation: ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( - on_error=self._handle_error, - on_completed=self._handle_completion, + on_error=self._handle_error, on_completed=self._handle_completion, ) # FIXME if read() isn't blocking, the prediction returned is empty self.source.read() @@ -306,10 +303,7 @@ def get_file_paths(self) -> List[Path]: return list(self.speech_path.iterdir()) def run_single( - self, - pipeline: blocks.Pipeline, - filepath: Path, - progress_bar: ProgressBar, + self, pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: """Run a given pipeline on a given file. Note that this method does NOT reset the @@ -331,10 +325,7 @@ def run_single( """ padding = pipeline.config.get_file_padding(filepath) source = src.FileAudioSource( - filepath, - pipeline.config.sample_rate, - padding, - pipeline.config.step, + filepath, pipeline.config.sample_rate, padding, pipeline.config.step, ) pipeline.set_timestamp_shift(-padding[0]) inference = StreamingInference( @@ -357,9 +348,7 @@ def run_single( return pred def evaluate( - self, - predictions: List[Annotation], - metric: BaseMetric, + self, predictions: List[Annotation], metric: BaseMetric, ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -446,9 +435,7 @@ class Parallelize: """ def __init__( - self, - benchmark: Benchmark, - num_workers: int = 4, + self, benchmark: Benchmark, num_workers: int = 4, ): self.benchmark = benchmark self.num_workers = num_workers diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 847d0f78..dbfc978e 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -126,14 +126,12 @@ def hard_map( @staticmethod def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap: - score_matrix_per_frame = ( - np.stack( # (local_speakers, num_frames, global_speakers) - [ - scores1[:, speaker : speaker + 1] * scores2 - for speaker in range(scores1.shape[1]) - ], - axis=0, - ) + score_matrix_per_frame = np.stack( # (local_speakers, num_frames, global_speakers) + [ + scores1[:, speaker : speaker + 1] * scores2 + for speaker in range(scores1.shape[1]) + ], + axis=0, ) # Calculate total speech "activations" per local speaker local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1) @@ -215,9 +213,7 @@ def _loose_check_valid(self, src: int, tgt: int) -> bool: return self.is_source_speaker_mapped(src) def valid_assignments( - self, - strict: bool = False, - as_array: bool = False, + self, strict: bool = False, as_array: bool = False, ) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]: source, target = [], [] val_type = "strict" if strict else "loose" diff --git a/src/diart/models.py b/src/diart/models.py index 49bfd17a..a724d2c0 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -12,9 +12,7 @@ try: from pyannote.audio import Model - from pyannote.audio.pipelines.speaker_verification import ( - PretrainedSpeakerEmbedding, - ) + from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.audio.utils.powerset import Powerset _has_pyannote = True @@ -95,9 +93,7 @@ def recreate_session(self): onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL ) self.session = onnxruntime.InferenceSession( - self.path, - sess_options=options, - providers=[self.execution_provider], + self.path, sess_options=options, providers=[self.execution_provider], ) def to(self, device: torch.device) -> ONNXModel: diff --git a/src/diart/operators.py b/src/diart/operators.py index 7ce13285..bfd51535 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -134,8 +134,7 @@ def initial() -> "OutputAccumulationState": @property def cropped_waveform(self) -> SlidingWindowFeature: return SlidingWindowFeature( - self.waveform[: self.next_sample], - self.waveform.sliding_window, + self.waveform[: self.next_sample], self.waveform.sliding_window, ) def to_tuple( @@ -145,9 +144,7 @@ def to_tuple( def accumulate_output( - duration: float, - step: float, - patch_collar: float = 0.05, + duration: float, step: float, patch_collar: float = 0.05, ) -> Operator: """Accumulate predictions and audio to infinity: O(N) space complexity. Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations. diff --git a/src/diart/sources.py b/src/diart/sources.py index 82051b2e..59bbe3a5 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -238,10 +238,7 @@ def __init__( self.server.set_fn_message_received(self._on_message_received) def _on_message_received( - self, - client: Dict[Text, Any], - server: WebsocketServer, - message: AnyStr, + self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr, ): # Only one client at a time is allowed if self.client is None or self.client["id"] != client["id"]: