Skip to content

Commit

Permalink
Blacken code
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 13, 2023
1 parent b2547d6 commit d2f154d
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 57 deletions.
12 changes: 8 additions & 4 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def from_name(name: Text) -> "HyperParameter":


class PipelineConfig(ABC):
"""Configuration containing the required parameters to build and run a pipeline"""
"""Configuration containing the required
parameters to build and run a pipeline"""

@property
@abstractmethod
Expand All @@ -66,7 +67,8 @@ def step(self) -> float:
@abstractmethod
def latency(self) -> float:
"""The algorithmic latency of the pipeline (in seconds).
At time `t` of the audio stream, the pipeline will output predictions for time `t - latency`.
At time `t` of the audio stream, the pipeline will
output predictions for time `t - latency`.
"""
pass

Expand Down Expand Up @@ -118,7 +120,8 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
"""Runs the next steps of the pipeline given a list of consecutive audio chunks.
"""Runs the next steps of the pipeline
given a list of consecutive audio chunks.
Parameters
----------
Expand All @@ -128,6 +131,7 @@ def __call__(
Returns
-------
Sequence[Tuple[Any, SlidingWindowFeature]]
For each input waveform, a tuple containing the pipeline output and its respective audio
For each input waveform, a tuple containing
the pipeline output and its respective audio
"""
pass
4 changes: 1 addition & 3 deletions src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ def __call__(
for wav, seg, emb in zip(waveforms, segmentations, embeddings):
# Add timestamps to segmentation
sw = SlidingWindow(
start=wav.extent.start,
duration=seg_resolution,
step=seg_resolution,
start=wav.extent.start, duration=seg_resolution, step=seg_resolution,
)
seg = SlidingWindowFeature(seg.cpu().numpy(), sw)

Expand Down
7 changes: 2 additions & 5 deletions src/diart/blocks/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def set_timestamp_shift(self, shift: float):
self.timestamp_shift = shift

def __call__(
self,
waveforms: Sequence[SlidingWindowFeature],
self, waveforms: Sequence[SlidingWindowFeature],
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
Expand All @@ -153,9 +152,7 @@ def __call__(
for wav, vad in zip(waveforms, voice_detection):
# Add timestamps to segmentation
sw = SlidingWindow(
start=wav.extent.start,
duration=seg_resolution,
step=seg_resolution,
start=wav.extent.start, duration=seg_resolution, step=seg_resolution,
)
vad = SlidingWindowFeature(vad.cpu().numpy(), sw)

Expand Down
25 changes: 6 additions & 19 deletions src/diart/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def __init__(
)

# Form batches
self.stream = self.stream.pipe(
ops.buffer_with_count(count=self.batch_size),
)
self.stream = self.stream.pipe(ops.buffer_with_count(count=self.batch_size),)

if self.do_profile:
self.stream = self.stream.pipe(
Expand Down Expand Up @@ -223,8 +221,7 @@ def __call__(self) -> Annotation:
ops.do(StreamingPlot(config.duration, config.latency)),
)
observable.subscribe(
on_error=self._handle_error,
on_completed=self._handle_completion,
on_error=self._handle_error, on_completed=self._handle_completion,
)
# FIXME if read() isn't blocking, the prediction returned is empty
self.source.read()
Expand Down Expand Up @@ -306,10 +303,7 @@ def get_file_paths(self) -> List[Path]:
return list(self.speech_path.iterdir())

def run_single(
self,
pipeline: blocks.Pipeline,
filepath: Path,
progress_bar: ProgressBar,
self, pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar,
) -> Annotation:
"""Run a given pipeline on a given file.
Note that this method does NOT reset the
Expand All @@ -331,10 +325,7 @@ def run_single(
"""
padding = pipeline.config.get_file_padding(filepath)
source = src.FileAudioSource(
filepath,
pipeline.config.sample_rate,
padding,
pipeline.config.step,
filepath, pipeline.config.sample_rate, padding, pipeline.config.step,
)
pipeline.set_timestamp_shift(-padding[0])
inference = StreamingInference(
Expand All @@ -357,9 +348,7 @@ def run_single(
return pred

def evaluate(
self,
predictions: List[Annotation],
metric: BaseMetric,
self, predictions: List[Annotation], metric: BaseMetric,
) -> Union[pd.DataFrame, List[Annotation]]:
"""If a reference path was provided,
compute the diarization error rate of a list of predictions.
Expand Down Expand Up @@ -446,9 +435,7 @@ class Parallelize:
"""

def __init__(
self,
benchmark: Benchmark,
num_workers: int = 4,
self, benchmark: Benchmark, num_workers: int = 4,
):
self.benchmark = benchmark
self.num_workers = num_workers
Expand Down
18 changes: 7 additions & 11 deletions src/diart/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ def hard_map(

@staticmethod
def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap:
score_matrix_per_frame = (
np.stack( # (local_speakers, num_frames, global_speakers)
[
scores1[:, speaker : speaker + 1] * scores2
for speaker in range(scores1.shape[1])
],
axis=0,
)
score_matrix_per_frame = np.stack( # (local_speakers, num_frames, global_speakers)
[
scores1[:, speaker : speaker + 1] * scores2
for speaker in range(scores1.shape[1])
],
axis=0,
)
# Calculate total speech "activations" per local speaker
local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1)
Expand Down Expand Up @@ -215,9 +213,7 @@ def _loose_check_valid(self, src: int, tgt: int) -> bool:
return self.is_source_speaker_mapped(src)

def valid_assignments(
self,
strict: bool = False,
as_array: bool = False,
self, strict: bool = False, as_array: bool = False,
) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]:
source, target = [], []
val_type = "strict" if strict else "loose"
Expand Down
8 changes: 2 additions & 6 deletions src/diart/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

try:
from pyannote.audio import Model
from pyannote.audio.pipelines.speaker_verification import (
PretrainedSpeakerEmbedding,
)
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio.utils.powerset import Powerset

_has_pyannote = True
Expand Down Expand Up @@ -95,9 +93,7 @@ def recreate_session(self):
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
self.session = onnxruntime.InferenceSession(
self.path,
sess_options=options,
providers=[self.execution_provider],
self.path, sess_options=options, providers=[self.execution_provider],
)

def to(self, device: torch.device) -> ONNXModel:
Expand Down
7 changes: 2 additions & 5 deletions src/diart/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def initial() -> "OutputAccumulationState":
@property
def cropped_waveform(self) -> SlidingWindowFeature:
return SlidingWindowFeature(
self.waveform[: self.next_sample],
self.waveform.sliding_window,
self.waveform[: self.next_sample], self.waveform.sliding_window,
)

def to_tuple(
Expand All @@ -145,9 +144,7 @@ def to_tuple(


def accumulate_output(
duration: float,
step: float,
patch_collar: float = 0.05,
duration: float, step: float, patch_collar: float = 0.05,
) -> Operator:
"""Accumulate predictions and audio to infinity: O(N) space complexity.
Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
Expand Down
5 changes: 1 addition & 4 deletions src/diart/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,7 @@ def __init__(
self.server.set_fn_message_received(self._on_message_received)

def _on_message_received(
self,
client: Dict[Text, Any],
server: WebsocketServer,
message: AnyStr,
self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr,
):
# Only one client at a time is allowed
if self.client is None or self.client["id"] != client["id"]:
Expand Down

0 comments on commit d2f154d

Please sign in to comment.