From b286f5f791afc66622adcfd78f1e9fd4f6e9d8ee Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Thu, 9 Nov 2023 23:43:26 +0100 Subject: [PATCH] Clean up useless embedding model subclasses --- src/diart/models.py | 51 ++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 52c65360..5c360721 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -12,7 +12,7 @@ from requests import HTTPError try: - from pyannote.audio import Inference, Model + from pyannote.audio import Model from pyannote.audio.pipelines.speaker_verification import ( PretrainedSpeakerEmbedding, ) @@ -148,11 +148,17 @@ def from_pyannote( return PyannoteSegmentationModel(model, use_hf_token) @staticmethod - def from_onnx(model_path: Union[str, Path]) -> "SegmentationModel": - return ONNXSegmentationModel(model_path) + def from_onnx( + model_path: Union[str, Path], + input_name: str = "waveform", + output_name: str = "segmentation", + ) -> "SegmentationModel": + return ONNXSegmentationModel(model_path, input_name, output_name) @staticmethod - def from_pretrained(model, use_hf_token: Union[Text, bool, None] = True) -> "SegmentationModel": + def from_pretrained( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "SegmentationModel": if isinstance(model, str) or isinstance(model, Path): if Path(model).name.endswith(".onnx"): return SegmentationModel.from_onnx(model) @@ -197,9 +203,14 @@ def duration(self) -> float: class ONNXSegmentationModel(SegmentationModel): - def __init__(self, model_path: Union[str, Path]): + def __init__( + self, + model_path: Union[str, Path], + input_name: str = "waveform", + output_name: str = "segmentation", + ): model_path = Path(model_path) - loader = ONNXLoader(model_path, input_names=["waveform"], output_name="segmentation") + loader = ONNXLoader(model_path, [input_name], output_name) super().__init__(loader) with open(model_path.parent / f"{model_path.stem}.yml", "r") as metadata_file: metadata = yaml.load(metadata_file, yaml.SafeLoader) @@ -238,14 +249,23 @@ def from_pyannote( wrapper: EmbeddingModel """ assert _has_pyannote, "No pyannote.audio installation found" - return PyannoteEmbeddingModel(model, use_hf_token) + loader = PyannoteLoader(model, use_hf_token) + return EmbeddingModel(loader) @staticmethod - def from_onnx(model_path: Union[str, Path]) -> "EmbeddingModel": - return ONNXEmbeddingModel(model_path) + def from_onnx( + model_path: Union[str, Path], + input_names: List[str] | None = None, + output_name: str = "embedding", + ) -> "EmbeddingModel": + input_names = input_names or ["waveform", "weights"] + loader = ONNXLoader(model_path, input_names, output_name) + return EmbeddingModel(loader) @staticmethod - def from_pretrained(model, use_hf_token: Union[Text, bool, None] = True) -> "EmbeddingModel": + def from_pretrained( + model, use_hf_token: Union[Text, bool, None] = True + ) -> "EmbeddingModel": if isinstance(model, str) or isinstance(model, Path): if Path(model).name.endswith(".onnx"): return EmbeddingModel.from_onnx(model) @@ -269,14 +289,3 @@ def __call__( if isinstance(embeddings, np.ndarray): embeddings = torch.from_numpy(embeddings) return embeddings - - -class PyannoteEmbeddingModel(EmbeddingModel): - def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): - super().__init__(PyannoteLoader(model_info, hf_token)) - - -class ONNXEmbeddingModel(EmbeddingModel): - def __init__(self, model_path: Union[str, Path]): - loader = ONNXLoader(Path(model_path), input_names=["waveform", "weights"], output_name="embedding") - super().__init__(loader)