Skip to content

Commit

Permalink
Blacken code with good version
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 13, 2023
1 parent d2f154d commit 8a6edc7
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HyperParameter:
@staticmethod
def from_name(name: Text) -> "HyperParameter":
"""Create a HyperParameter object given its name.
Parameters
----------
name: str
Expand Down
4 changes: 3 additions & 1 deletion src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def __call__(
for wav, seg, emb in zip(waveforms, segmentations, embeddings):
# Add timestamps to segmentation
sw = SlidingWindow(
start=wav.extent.start, duration=seg_resolution, step=seg_resolution,
start=wav.extent.start,
duration=seg_resolution,
step=seg_resolution,
)
seg = SlidingWindowFeature(seg.cpu().numpy(), sw)

Expand Down
7 changes: 5 additions & 2 deletions src/diart/blocks/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def set_timestamp_shift(self, shift: float):
self.timestamp_shift = shift

def __call__(
self, waveforms: Sequence[SlidingWindowFeature],
self,
waveforms: Sequence[SlidingWindowFeature],
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
Expand All @@ -152,7 +153,9 @@ def __call__(
for wav, vad in zip(waveforms, voice_detection):
# Add timestamps to segmentation
sw = SlidingWindow(
start=wav.extent.start, duration=seg_resolution, step=seg_resolution,
start=wav.extent.start,
duration=seg_resolution,
step=seg_resolution,
)
vad = SlidingWindowFeature(vad.cpu().numpy(), sw)

Expand Down
25 changes: 19 additions & 6 deletions src/diart/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __init__(
)

# Form batches
self.stream = self.stream.pipe(ops.buffer_with_count(count=self.batch_size),)
self.stream = self.stream.pipe(
ops.buffer_with_count(count=self.batch_size),
)

if self.do_profile:
self.stream = self.stream.pipe(
Expand Down Expand Up @@ -221,7 +223,8 @@ def __call__(self) -> Annotation:
ops.do(StreamingPlot(config.duration, config.latency)),
)
observable.subscribe(
on_error=self._handle_error, on_completed=self._handle_completion,
on_error=self._handle_error,
on_completed=self._handle_completion,
)
# FIXME if read() isn't blocking, the prediction returned is empty
self.source.read()
Expand Down Expand Up @@ -303,7 +306,10 @@ def get_file_paths(self) -> List[Path]:
return list(self.speech_path.iterdir())

def run_single(
self, pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar,
self,
pipeline: blocks.Pipeline,
filepath: Path,
progress_bar: ProgressBar,
) -> Annotation:
"""Run a given pipeline on a given file.
Note that this method does NOT reset the
Expand All @@ -325,7 +331,10 @@ def run_single(
"""
padding = pipeline.config.get_file_padding(filepath)
source = src.FileAudioSource(
filepath, pipeline.config.sample_rate, padding, pipeline.config.step,
filepath,
pipeline.config.sample_rate,
padding,
pipeline.config.step,
)
pipeline.set_timestamp_shift(-padding[0])
inference = StreamingInference(
Expand All @@ -348,7 +357,9 @@ def run_single(
return pred

def evaluate(
self, predictions: List[Annotation], metric: BaseMetric,
self,
predictions: List[Annotation],
metric: BaseMetric,
) -> Union[pd.DataFrame, List[Annotation]]:
"""If a reference path was provided,
compute the diarization error rate of a list of predictions.
Expand Down Expand Up @@ -435,7 +446,9 @@ class Parallelize:
"""

def __init__(
self, benchmark: Benchmark, num_workers: int = 4,
self,
benchmark: Benchmark,
num_workers: int = 4,
):
self.benchmark = benchmark
self.num_workers = num_workers
Expand Down
18 changes: 11 additions & 7 deletions src/diart/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def hard_map(

@staticmethod
def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap:
score_matrix_per_frame = np.stack( # (local_speakers, num_frames, global_speakers)
[
scores1[:, speaker : speaker + 1] * scores2
for speaker in range(scores1.shape[1])
],
axis=0,
score_matrix_per_frame = (
np.stack( # (local_speakers, num_frames, global_speakers)
[
scores1[:, speaker : speaker + 1] * scores2
for speaker in range(scores1.shape[1])
],
axis=0,
)
)
# Calculate total speech "activations" per local speaker
local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1)
Expand Down Expand Up @@ -213,7 +215,9 @@ def _loose_check_valid(self, src: int, tgt: int) -> bool:
return self.is_source_speaker_mapped(src)

def valid_assignments(
self, strict: bool = False, as_array: bool = False,
self,
strict: bool = False,
as_array: bool = False,
) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]:
source, target = [], []
val_type = "strict" if strict else "loose"
Expand Down
4 changes: 3 additions & 1 deletion src/diart/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def recreate_session(self):
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
self.session = onnxruntime.InferenceSession(
self.path, sess_options=options, providers=[self.execution_provider],
self.path,
sess_options=options,
providers=[self.execution_provider],
)

def to(self, device: torch.device) -> ONNXModel:
Expand Down
7 changes: 5 additions & 2 deletions src/diart/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def initial() -> "OutputAccumulationState":
@property
def cropped_waveform(self) -> SlidingWindowFeature:
return SlidingWindowFeature(
self.waveform[: self.next_sample], self.waveform.sliding_window,
self.waveform[: self.next_sample],
self.waveform.sliding_window,
)

def to_tuple(
Expand All @@ -144,7 +145,9 @@ def to_tuple(


def accumulate_output(
duration: float, step: float, patch_collar: float = 0.05,
duration: float,
step: float,
patch_collar: float = 0.05,
) -> Operator:
"""Accumulate predictions and audio to infinity: O(N) space complexity.
Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
Expand Down
5 changes: 4 additions & 1 deletion src/diart/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ def __init__(
self.server.set_fn_message_received(self._on_message_received)

def _on_message_received(
self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr,
self,
client: Dict[Text, Any],
server: WebsocketServer,
message: AnyStr,
):
# Only one client at a time is allowed
if self.client is None or self.client["id"] != client["id"]:
Expand Down

0 comments on commit 8a6edc7

Please sign in to comment.