Skip to content

Commit

Permalink
added method for a save wavform as wav format
Browse files Browse the repository at this point in the history
  • Loading branch information
m15kh committed Nov 8, 2024
1 parent e9dae1a commit 2f4e8b5
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/diart/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pyannote.metrics.base import BaseMetric
from rx.core import Observer
from tqdm import tqdm
import soundfile as sf


from . import blocks
from . import operators as dops
Expand Down Expand Up @@ -159,7 +161,7 @@ def _close_chronometer(self):
def attach_hooks(
self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]
):
"""Attach hooks to the pipeline.
"""Attach hooks to the pipeline.
Parameters
----------
Expand All @@ -168,6 +170,31 @@ def attach_hooks(
"""
self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks])


def save_waveform_hook(self, save_path: str = "output"):
"""Create a hook function to save waveform data.
Parameters
----------
save_path: str
The directory path where waveforms will be saved.
"""
count = 0

def save_waveform(results):
nonlocal count
prediction, waveform = results

if prediction:
filename = f"{save_path}/waveform{count}.wav"
sf.write(filename, waveform.data, samplerate=16000)
print(f"Waveform saved to {filename}")
count += 1

return save_waveform



def attach_observers(self, *observers: Observer):
"""Attach rx observers to the pipeline.
Expand Down

0 comments on commit 2f4e8b5

Please sign in to comment.