From 4254401c5aac7ff028daa7d7b6fe548616e0ad7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 26 Nov 2024 13:50:41 +0100 Subject: [PATCH] feat: add support for models stored in pipeline subfolders (#1794) --- CHANGELOG.md | 15 ++++++++++ pyannote/audio/core/model.py | 14 ++++----- pyannote/audio/core/pipeline.py | 50 +++++++++++++++++++++++++++++---- 3 files changed, 65 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6b0a4233..74091f22b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ ## develop +### TL;DR + +#### Quality of life improvements + +Models can now be stored alongside their pipelines in the same repository, streamlining gating mechanism: +- accept `pyannote/speaker-diarization-x.x` pipeline user agreement +- ~~accept `pyannote/segmentation-3.0` model user agreement~~ +- ~~accept `pyannote/wespeaker-voxceleb-resnet34-LM` model user agreement~~ +- load pipeline with `Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=True)` + +#### Improve speech separation quality + +Clipping and speaker/source alignment issues in speech separation pipeline have been fixed. + ### Breaking changes - BREAKING(task): drop support for `multilabel` training in `SpeakerDiarization` task @@ -11,6 +25,7 @@ ### New features +- improve(hub): add support for pipeline repos that also include underlying models - feat(clustering): add support for `k-means` clustering - feat(model): add `wav2vec_frozen` option to freeze/unfreeze `wav2vec` in `SSeRiouSS` architecture - feat(task): add support for manual optimization in `SpeakerDiarization` task diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 8af802293..514a38ad0 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -532,6 +532,7 @@ def from_pretrained( map_location=None, hparams_file: Union[Path, Text] = None, strict: bool = True, + subfolder: Optional[str] = None, use_auth_token: Union[Text, None] = None, cache_dir: Union[Path, Text] = CACHE_DIR, **kwargs, @@ -542,7 +543,7 @@ def from_pretrained( ---------- checkpoint : Path or str Path to checkpoint, or a remote URL, or a model identifier from - the huggingface.co model hub. + the hf.co model hub. map_location: optional Same role as in torch.load(). Defaults to `lambda storage, loc: storage`. @@ -559,8 +560,10 @@ def from_pretrained( strict : bool, optional Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults to True. + subfolder : str, optional + Folder inside the hf.co model repo. use_auth_token : str, optional - When loading a private huggingface.co model, set `use_auth_token` + When loading a private hf.co model, set `use_auth_token` to True or to a string containing your hugginface.co authentication token that can be obtained by running `huggingface-cli login` cache_dir: Path or str, optional @@ -606,18 +609,13 @@ def from_pretrained( path_for_pl = hf_hub_download( model_id, HF_PYTORCH_WEIGHTS_NAME, + subfolder=subfolder, repo_type="model", revision=revision, library_name="pyannote", library_version=__version__, cache_dir=cache_dir, - # force_download=False, - # proxies=None, - # etag_timeout=10, - # resume_download=False, use_auth_token=use_auth_token, - # local_files_only=False, - # legacy_cache_layout=False, ) except RepositoryNotFoundError: print( diff --git a/pyannote/audio/core/pipeline.py b/pyannote/audio/core/pipeline.py index 8842ab1d4..7b005610b 100644 --- a/pyannote/audio/core/pipeline.py +++ b/pyannote/audio/core/pipeline.py @@ -47,6 +47,49 @@ PIPELINE_PARAMS_NAME = "config.yaml" +def expand_subfolders( + config, hf_model_id, use_auth_token: Optional[Text] = None +) -> None: + """Expand $model subfolders in config + + Processes `config` dictionary recursively and replaces "$model/{subfolder}" + values with {"checkpoint": hf_model_id, + "subfolder": {subfolder}, + "use_auth_token": use_auth_token} + + Parameters + ---------- + config : dict + hf_model_id : str + Parent Huggingface model identifier + use_auth_token : str, optional + Token used for loading from the root folder. + """ + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, str) and value.startswith("$model/"): + subfolder = "/".join(value.split("/")[1:]) + config[key] = { + "checkpoint": hf_model_id, + "subfolder": subfolder, + "use_auth_token": use_auth_token, + } + else: + expand_subfolders(value, hf_model_id, use_auth_token=use_auth_token) + + elif isinstance(config, list): + for idx, value in enumerate(config): + if isinstance(value, str) and value.startswith("$model/"): + subfolder = "/".join(value.split("/")[1:]) + config[idx] = { + "checkpoint": hf_model_id, + "subfolder": subfolder, + "use_auth_token": use_auth_token, + } + else: + expand_subfolders(value, hf_model_id, use_auth_token=use_auth_token) + + class Pipeline(_Pipeline): @classmethod def from_pretrained( @@ -95,13 +138,7 @@ def from_pretrained( library_name="pyannote", library_version=__version__, cache_dir=cache_dir, - # force_download=False, - # proxies=None, - # etag_timeout=10, - # resume_download=False, use_auth_token=use_auth_token, - # local_files_only=False, - # legacy_cache_layout=False, ) except RepositoryNotFoundError: @@ -122,6 +159,7 @@ def from_pretrained( with open(config_yml, "r") as fp: config = yaml.load(fp, Loader=yaml.SafeLoader) + expand_subfolders(config, model_id, use_auth_token=use_auth_token) if "version" in config: check_version(