Skip to content

Commit

Permalink
Clean up useless embedding model subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Nov 10, 2023
1 parent 143bcc1 commit 472c6cc
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions src/diart/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,17 @@ def from_pyannote(
return PyannoteSegmentationModel(model, use_hf_token)

@staticmethod
def from_onnx(model_path: Union[str, Path]) -> "SegmentationModel":
return ONNXSegmentationModel(model_path)
def from_onnx(
model_path: Union[str, Path],
input_name: str = "waveform",
output_name: str = "segmentation",
) -> "SegmentationModel":
return ONNXSegmentationModel(model_path, input_name, output_name)

@staticmethod
def from_pretrained(model, use_hf_token: Union[Text, bool, None] = True) -> "SegmentationModel":
def from_pretrained(
model, use_hf_token: Union[Text, bool, None] = True
) -> "SegmentationModel":
if isinstance(model, str) or isinstance(model, Path):
if Path(model).name.endswith(".onnx"):
return SegmentationModel.from_onnx(model)
Expand Down Expand Up @@ -223,9 +229,14 @@ def duration(self) -> float:


class ONNXSegmentationModel(SegmentationModel):
def __init__(self, model_path: Union[str, Path]):
def __init__(
self,
model_path: Union[str, Path],
input_name: str = "waveform",
output_name: str = "segmentation",
):
model_path = Path(model_path)
loader = ONNXLoader(model_path, input_names=["waveform"], output_name="segmentation")
loader = ONNXLoader(model_path, [input_name], output_name)
super().__init__(loader)
with open(model_path.parent / f"{model_path.stem}.yml", "r") as metadata_file:
metadata = yaml.load(metadata_file, yaml.SafeLoader)
Expand Down Expand Up @@ -264,14 +275,23 @@ def from_pyannote(
wrapper: EmbeddingModel
"""
assert _has_pyannote, "No pyannote.audio installation found"
return PyannoteEmbeddingModel(model, use_hf_token)
loader = PyannoteLoader(model, use_hf_token)
return EmbeddingModel(loader)

@staticmethod
def from_onnx(model_path: Union[str, Path]) -> "EmbeddingModel":
return ONNXEmbeddingModel(model_path)
def from_onnx(
model_path: Union[str, Path],
input_names: List[str] | None = None,
output_name: str = "embedding",
) -> "EmbeddingModel":
input_names = input_names or ["waveform", "weights"]
loader = ONNXLoader(model_path, input_names, output_name)
return EmbeddingModel(loader)

@staticmethod
def from_pretrained(model, use_hf_token: Union[Text, bool, None] = True) -> "EmbeddingModel":
def from_pretrained(
model, use_hf_token: Union[Text, bool, None] = True
) -> "EmbeddingModel":
if isinstance(model, str) or isinstance(model, Path):
if Path(model).name.endswith(".onnx"):
return EmbeddingModel.from_onnx(model)
Expand All @@ -295,14 +315,3 @@ def __call__(
if isinstance(embeddings, np.ndarray):
embeddings = torch.from_numpy(embeddings)
return embeddings


class PyannoteEmbeddingModel(EmbeddingModel):
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
super().__init__(PyannoteLoader(model_info, hf_token))


class ONNXEmbeddingModel(EmbeddingModel):
def __init__(self, model_path: Union[str, Path]):
loader = ONNXLoader(Path(model_path), input_names=["waveform", "weights"], output_name="embedding")
super().__init__(loader)

0 comments on commit 472c6cc

Please sign in to comment.