Skip to content

Commit

Permalink
Fix import errors and make errors more informative.
Browse files Browse the repository at this point in the history
  • Loading branch information
Miles B Silva committed Nov 22, 2024
1 parent d460b6c commit 3d0a27c
Show file tree
Hide file tree
Showing 19 changed files with 130 additions and 35 deletions.
7 changes: 5 additions & 2 deletions src/senselab/audio/tasks/speaker_diarization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.speaker_diarization.pyannote import diarize_audios_with_pyannote
from senselab.utils.data_structures import DeviceType, PyannoteAudioModel, ScriptLine, SenselabModel
from senselab.utils.data_structures import DeviceType, PyannoteAudioModel, ScriptLine


def diarize_audios(
audios: List[Audio],
model: SenselabModel = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main"),
model: Optional[PyannoteAudioModel] = None,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
Expand All @@ -29,6 +29,9 @@ def diarize_audios(
Returns:
List[List[ScriptLine]]: The list of script lines with speaker labels.
"""
if model is None:
model = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main")

if isinstance(model, PyannoteAudioModel):
return diarize_audios_with_pyannote(
audios=audios,
Expand Down
4 changes: 3 additions & 1 deletion src/senselab/audio/tasks/speaker_diarization/pyannote.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _get_pyannote_diarization_pipeline(

def diarize_audios_with_pyannote(
audios: List[Audio],
model: PyannoteAudioModel = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main"),
model: Optional[PyannoteAudioModel] = None,
device: Optional[DeviceType] = None,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
Expand All @@ -65,6 +65,8 @@ def diarize_audios_with_pyannote(
Returns:
List[ScriptLine]: A list of ScriptLine objects containing the diarization results.
"""
if model is None:
model = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main")

def _annotation_to_script_lines(annotation: Annotation) -> List[ScriptLine]:
"""Convert a Pyannote annotation to a list of script lines.
Expand Down
7 changes: 5 additions & 2 deletions src/senselab/audio/tasks/speaker_embeddings/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.speaker_embeddings.speechbrain import SpeechBrainEmbeddings
from senselab.utils.data_structures import DeviceType, SenselabModel, SpeechBrainModel
from senselab.utils.data_structures import DeviceType, SpeechBrainModel


def extract_speaker_embeddings_from_audios(
audios: List[Audio],
model: SenselabModel = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main"),
model: Optional[SpeechBrainModel] = None,
device: Optional[DeviceType] = None,
) -> List[torch.Tensor]:
"""Compute the speaker embedding of audio signals.
Expand All @@ -35,6 +35,9 @@ def extract_speaker_embeddings_from_audios(
>>> print(embeddings[0].shape)
torch.Size([192])
"""
if model is None:
model = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main")

if isinstance(model, SpeechBrainModel):
return SpeechBrainEmbeddings.extract_speechbrain_speaker_embeddings_from_audios(
audios=audios, model=model, device=device
Expand Down
5 changes: 4 additions & 1 deletion src/senselab/audio/tasks/speaker_embeddings/speechbrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _get_speechbrain_model(
def extract_speechbrain_speaker_embeddings_from_audios(
cls,
audios: List[Audio],
model: SpeechBrainModel = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main"),
model: Optional[SpeechBrainModel] = None,
device: Optional[DeviceType] = None,
) -> List[torch.Tensor]:
"""Compute the speaker embeddings of audio signals.
Expand All @@ -67,6 +67,9 @@ def extract_speechbrain_speaker_embeddings_from_audios(
- Optimizing the computation by working in batches
- Double-checking the input size of classifier.encode_batch
"""
if model is None:
model = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main")

classifier = cls._get_speechbrain_model(model=model, device=device)
# 16khz comes from the model cards of ecapa-tdnn, resnet, and xvector
expected_sample_rate = 16000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def verify_speaker(
audios: List[Tuple[Audio, Audio]],
model: SpeechBrainModel = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main"),
model: Optional[SpeechBrainModel] = None,
device: Optional[DeviceType] = None,
threshold: float = 0.25,
) -> List[Tuple[float, bool]]:
Expand All @@ -38,6 +38,8 @@ def verify_speaker(
between the two samples, and the prediction is a boolean
indicating if the two samples are from the same speaker.
"""
if model is None:
model = SpeechBrainModel(path_or_uri="speechbrain/spkrec-ecapa-voxceleb", revision="main")
device = _select_device_and_dtype(compatible_devices=[DeviceType.CPU, DeviceType.CUDA])[0]

scores_and_predictions = []
Expand Down
7 changes: 5 additions & 2 deletions src/senselab/audio/tasks/speech_enhancement/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.speech_enhancement.speechbrain import SpeechBrainEnhancer
from senselab.utils.data_structures import DeviceType, SenselabModel, SpeechBrainModel
from senselab.utils.data_structures import DeviceType, SpeechBrainModel


def enhance_audios(
audios: List[Audio],
model: SenselabModel = SpeechBrainModel(path_or_uri="speechbrain/sepformer-wham16k-enhancement", revision="main"),
model: Optional[SpeechBrainModel] = None,
device: Optional[DeviceType] = None,
) -> List[Audio]:
"""Enhances all audios using the given model.
Expand All @@ -23,6 +23,9 @@ def enhance_audios(
Returns:
List[Audio]: The list of enhanced audio objects.
"""
if model is None:
model = SpeechBrainModel(path_or_uri="speechbrain/sepformer-wham16k-enhancement", revision="main")

if isinstance(model, SpeechBrainModel):
return SpeechBrainEnhancer.enhance_audios_with_speechbrain(audios=audios, model=model, device=device)
else:
Expand Down
7 changes: 4 additions & 3 deletions src/senselab/audio/tasks/speech_enhancement/speechbrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def _get_speechbrain_model(
def enhance_audios_with_speechbrain(
cls,
audios: List[Audio],
model: SpeechBrainModel = SpeechBrainModel(
path_or_uri="speechbrain/sepformer-wham16k-enhancement", revision="main"
),
model: Optional[SpeechBrainModel] = None,
device: Optional[DeviceType] = None,
batch_size: int = 16,
) -> List[Audio]:
Expand All @@ -63,6 +61,9 @@ def enhance_audios_with_speechbrain(
Returns:
List[Audio]: The list of enhanced audio objects.
"""
if model is None:
model = SpeechBrainModel(path_or_uri="speechbrain/sepformer-wham16k-enhancement", revision="main")

# Take the start time of the model initialization
start_time_model = time.time()
enhancer, device, _ = cls._get_speechbrain_model(model=model, device=device)
Expand Down
6 changes: 4 additions & 2 deletions src/senselab/audio/tasks/speech_to_text/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_hf_asr_pipeline(
def transcribe_audios_with_transformers(
cls,
audios: List[Audio],
model: HFModel = HFModel(path_or_uri="openai/whisper-tiny"),
model: Optional[HFModel] = None,
language: Optional[Language] = None,
return_timestamps: Optional[str] = "word",
max_new_tokens: int = 128,
Expand All @@ -80,7 +80,7 @@ def transcribe_audios_with_transformers(
Args:
audios (List[Audio]): The list of audio objects to be transcribed.
model (HFModel): The Hugging Face model used for transcription.
model (HFModel): The Hugging Face model used for transcription. (default is `openai/whisper-tiny`).
language (Optional[Language]): The language of the audio (default is None).
return_timestamps (Optional[str]): The level of timestamp details (default is "word").
max_new_tokens (int): The maximum number of new tokens (default is 128).
Expand All @@ -91,6 +91,8 @@ def transcribe_audios_with_transformers(
Returns:
List[ScritpLine]: The list of script lines.
"""
if model is None:
model = HFModel(path_or_uri="openai/whisper-tiny")

def _audio_to_huggingface_dict(audio: Audio) -> Dict:
"""Convert an Audio object to a dictionary that can be used by the transformers pipeline.
Expand Down
5 changes: 4 additions & 1 deletion src/senselab/audio/tasks/text_to_speech/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def synthesize_texts(
texts: List[str],
model: SenselabModel = HFModel(path_or_uri="suno/bark", revision="main"),
model: Optional[SenselabModel] = None,
language: Optional[Language] = None,
device: Optional[DeviceType] = None,
targets: Optional[List[Audio | Tuple[Audio, str]]] = None,
Expand Down Expand Up @@ -43,6 +43,9 @@ def synthesize_texts(
Returns:
List[Audio]: The list of synthesized audio objects.
"""
if model is None:
model = HFModel(path_or_uri="suno/bark", revision="main")

if targets is not None:
assert len(targets) == len(texts), ValueError("Provided targets should be same length as texts")

Expand Down
6 changes: 4 additions & 2 deletions src/senselab/audio/tasks/text_to_speech/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _get_hf_tts_pipeline(
def synthesize_texts_with_transformers(
cls,
texts: List[str],
model: HFModel = HFModel(path_or_uri="suno/bark", revision="main"),
model: Optional[HFModel] = None,
device: Optional[DeviceType] = None,
forward_params: Optional[Dict[str, Any]] = None,
) -> List[Audio]:
Expand All @@ -57,7 +57,7 @@ def synthesize_texts_with_transformers(
Args:
texts (List[str]): The list of text strings to be synthesized.
model (HFModel): The Hugging Face model used for synthesis.
model (HFModel): The Hugging Face model used for synthesis (default is `suno/bark`).
device (Optional[DeviceType]): The device to run the model on (default is None).
forward_params (Optional[Dict[str, Any]]): Additional parameters to pass to the forward function.
Expand All @@ -68,6 +68,8 @@ def synthesize_texts_with_transformers(
- Add speaker embeddings as they do in here:
https://huggingface.co/docs/transformers/tasks/text-to-speech
"""
if model is None:
model = HFModel(path_or_uri="suno/bark", revision="main")
pipe = HuggingFaceTTS._get_hf_tts_pipeline(model=model, device=device)

synthesized_audios = pipe(texts, forward_params=forward_params)
Expand Down
12 changes: 9 additions & 3 deletions src/senselab/audio/tasks/text_to_speech/marstts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ class Mars5TTS:
@classmethod
def _get_torch_tts_model(
cls,
model: TorchModel = TorchModel(path_or_uri="Camb-ai/mars5-tts", revision="master"),
model: Optional[TorchModel] = None,
language: Optional[Language] = Language(language_code="en"),
device: Optional[DeviceType] = None,
) -> Tuple[torch.nn.Module, type]:
"""Get or create a Torch-based Mars5TTS model.
Args:
model (TorchModel): The Torch model (default is "Camb-ai/mars5-tts").
model (TorchModel): The Torch model (currently only supports "Camb-ai/mars5-tts").
language (Optional[Language]): The language of the text (default is Language(language_code="en")).
The only supported language is "en" for now.
device (DeviceType): The device to run the model on (default is None). Supported devices are CPU and CUDA.
Expand All @@ -34,6 +34,9 @@ def _get_torch_tts_model(
model: The Torch-based Mars5TTS model.
config_class: The configuration class used by the model.
"""
if model is None:
model = TorchModel(path_or_uri="Camb-ai/mars5-tts", revision="master")

if model.path_or_uri != "Camb-ai/mars5-tts" or model.revision != "master":
raise NotImplementedError("Only the 'Camb-ai/mars5-tts' model is supported for now.")
if language == Language(language_code="en"):
Expand All @@ -57,7 +60,7 @@ def synthesize_texts_with_mars5tts(
cls,
texts: List[str],
targets: List[Tuple[Audio, str]],
model: TorchModel = TorchModel(path_or_uri="Camb-ai/mars5-tts", revision="master"),
model: Optional[TorchModel] = None,
language: Optional[Language] = None,
device: Optional[DeviceType] = None,
deep_clone: bool = True,
Expand Down Expand Up @@ -95,6 +98,9 @@ def synthesize_texts_with_mars5tts(
The original repo of the model is: https://github.com/Camb-ai/MARS5-TTS.
"""
if model is None:
model = TorchModel(path_or_uri="Camb-ai/mars5-tts", revision="master")

# Take the start time of the model initialization
start_time_model = time.time()
my_model, config_class = cls._get_torch_tts_model(model, language, device)
Expand Down
8 changes: 6 additions & 2 deletions src/senselab/audio/tasks/text_to_speech/styletts2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StyleTTS2:
@classmethod
def _get_style_tts_2_model(
cls,
model: TorchModel = TorchModel(path_or_uri="wilke0818/StyleTTS2-TorchHub", revision="main"),
model: Optional[TorchModel] = None,
language: Optional[Language] = None,
device: Optional[DeviceType] = None,
pretrain_data: Optional[Literal["LibriTTS", "LJSpeech"]] = "LibriTTS",
Expand All @@ -54,6 +54,8 @@ def _get_style_tts_2_model(
Returns:
model: The Torch-based StyleTTS2 model.
"""
if model is None:
model = TorchModel(path_or_uri="wilke0818/StyleTTS2-TorchHub", revision="main")
if language == Language(language_code="en"):
model_name: str = "styletts2" # This is the default model they have for English.
else:
Expand All @@ -80,7 +82,7 @@ def synthesize_texts_with_style_tts_2(
texts: List[str],
target_audios: List[Audio],
target_transcripts: List[Optional[str]],
model: TorchModel = TorchModel(path_or_uri="wilke0818/StyleTTS2-TorchHub", revision="main"),
model: Optional[TorchModel] = None,
language: Optional[Language] = None,
device: Optional[DeviceType] = None,
pretrain_data: Optional[Literal["LibriTTS", "LJSpeech"]] = "LibriTTS",
Expand Down Expand Up @@ -132,6 +134,8 @@ def synthesize_texts_with_style_tts_2(
The original repo of the model is: https://github.com/yl4579/StyleTTS2.
"""
if model is None:
model = TorchModel(path_or_uri="wilke0818/StyleTTS2-TorchHub", revision="main")
nltk.download("punkt")
nltk.download("punkt_tab")
# Take the start time of the model initialization
Expand Down
14 changes: 10 additions & 4 deletions src/senselab/audio/tasks/voice_activity_detection/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,35 @@

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.speaker_diarization.pyannote import diarize_audios_with_pyannote
from senselab.utils.data_structures import DeviceType, PyannoteAudioModel, ScriptLine, SenselabModel
from senselab.utils.data_structures import DeviceType, PyannoteAudioModel, ScriptLine


def detect_human_voice_activity_in_audios(
audios: List[Audio],
model: SenselabModel = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main"),
model: Optional[PyannoteAudioModel] = None,
device: Optional[DeviceType] = None,
) -> List[List[ScriptLine]]:
"""Diarizes all audios using the given model.
Args:
audios (List[Audio]): The list of audio objects to be processed.
model (SenselabModel): The model used for voice activity detection.
model (Optional[PyannoteAudioModel]): The model used for voice activity detection
(default is `pyannote/speaker-diarization-3.1`).
device (Optional[DeviceType]): The device to run the model on (default is None).
Returns:
List[List[ScriptLine]]: The list of script lines with voice label.
"""
if model is None:
model = PyannoteAudioModel(path_or_uri="pyannote/speaker-diarization-3.1", revision="main")

if isinstance(model, PyannoteAudioModel):
results = diarize_audios_with_pyannote(audios=audios, model=model, device=device)
for sample in results:
for chunk in sample:
chunk.speaker = "VOICE"
return results
else:
raise NotImplementedError("Only Pyannote models are supported for now.")
raise NotImplementedError(
"Only Pyannote models are supported for now. We aim to support more models in the future."
)
8 changes: 5 additions & 3 deletions src/senselab/audio/tasks/voice_cloning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from senselab.audio.data_structures import Audio
from senselab.audio.tasks.voice_cloning.knnvc import KNNVC
from senselab.utils.data_structures import DeviceType, SenselabModel, TorchModel
from senselab.utils.data_structures import DeviceType, TorchModel


def clone_voices(
source_audios: List[Audio],
target_audios: List[Audio],
model: SenselabModel = TorchModel(path_or_uri="bshall/knn-vc", revision="master"),
model: Optional[TorchModel] = None,
device: Optional[DeviceType] = None,
**kwargs: Any, # noqa:ANN401
) -> List[Audio]:
Expand All @@ -28,7 +28,7 @@ def clone_voices(
(e.g., words) will remain the same, but the voice sounds like the target.
target_audios (List[Audio]): A list of audio samples whose voices will be extracted
and used to replace the voices in the corresponding source audio samples.
model (SenselabModel, optional): The model to use for voice cloning. Currently,
model (TorchModel, optional): The model to use for voice cloning. Currently,
only KNNVC (K-Nearest Neighbors Voice Conversion) is supported, encapsulated
by the `TorchModel` class. `TorchModel` is a child class of `SenselabModel`
and specifies the model and revision for cloning. Defaults to
Expand All @@ -53,6 +53,8 @@ def clone_voices(
Todo:
Add logging with timestamps.
"""
if model is None:
model = TorchModel(path_or_uri="bshall/knn-vc", revision="master")
if len(source_audios) != len(target_audios):
raise ValueError("The list of source and target audios must have the same length.")

Expand Down
4 changes: 3 additions & 1 deletion src/senselab/audio/tasks/voice_cloning/knnvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def clone_voices_with_knn_vc(
cls,
source_audios: List[Audio],
target_audios: List[Audio],
model: TorchModel = TorchModel(path_or_uri="bshall/knn-vc", revision="master"),
model: Optional[TorchModel] = None,
prematched_vocoder: bool = True,
topk: int = 4,
device: Optional[DeviceType] = None,
Expand All @@ -75,6 +75,8 @@ def clone_voices_with_knn_vc(
Raises:
ValueError: If the audio files are not mono or if the sampling rates are not supported.
"""
if model is None:
model = TorchModel(path_or_uri="bshall/knn-vc", revision="master")
if not isinstance(prematched_vocoder, bool):
raise TypeError("prematched_vocoder must be a boolean.")

Expand Down
Loading

0 comments on commit 3d0a27c

Please sign in to comment.