Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 authored Nov 9, 2023
1 parent bd2313f commit a1aabf7
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/diart/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from requests import HTTPError

try:
import pyannote.audio.pipelines.utils as pyannote_loader
from pyannote.audio import Inference, Model
from pyannote.audio.pipelines.speaker_verification import (
PretrainedSpeakerEmbedding,
)
from pyannote.audio.utils.powerset import Powerset

_has_pyannote = True
Expand Down Expand Up @@ -42,12 +45,17 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
self.model_info = model_info
self.hf_token = hf_token

def __call__(self) -> nn.Module:
model = pyannote_loader.get_model(self.model_info, self.hf_token)
specs = getattr(model, "specifications", None)
if specs is not None and specs.powerset:
model = PowersetAdapter(model)
return model
def __call__(self) -> Callable:
try:
model = Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
specs = getattr(model, "specifications", None)
if specs is not None and specs.powerset:
model = PowersetAdapter(model)
return model
except HTTPError:
return PretrainedSpeakerEmbedding(
self.model_info, use_auth_token=self.hf_token
)


class LazyModel(ABC):
Expand Down Expand Up @@ -145,9 +153,6 @@ def duration(self) -> float:
self.load()
return self.model.specifications.duration

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
return self.model(waveform)


class EmbeddingModel(LazyModel):
"""Minimal interface for an embedding model."""
Expand Down

0 comments on commit a1aabf7

Please sign in to comment.