How to fine-tune segmentation model with pyannote 2.0 #736
-
Hi, first of all thanks for that great repository! I'm trying pyannote develop branch, but facing a trouble. from pyannote.audio import Model
model = Model.from_pretrained('pyannote/segmentation')
model.task = Segmentation(ami)
model.freeze_up_to('sincnet')
trainer.fit(model) My code is following. In [1]: from pyannote.database import get_protocol
...: from pyannote.audio import Model
...: from pyannote.audio.tasks import Segmentation
...: from pytorch_lightning import Trainer
In [2]: protocol_vad = get_protocol("protocol.SpeakerDiarization.vad")
In [3]: model = Model.from_pretrained("pyannote/segmentation")
...: model.task = Segmentation(protocol=protocol_vad)
...: model.freeze_up_to("sincnet")
...: trainer = Trainer()
...: trainer.fit(model) And got this error.
Whereas, this one worked. In [1]: from pyannote.database import get_protocol
...: from pyannote.audio import Model
...: from pyannote.audio.tasks import VoiceActivityDetection
...: from pytorch_lightning import Trainer
In [2]: protocol_vad = get_protocol("protocol.SpeakerDiarization.vad")
In [3]: model = Model.from_pretrained("pyannote/segmentation")
...: model.task = VoiceActivityDetection(protocol_vad)
...: model.freeze_up_to("sincnet")
...: trainer = Trainer()
...: trainer.fit(model, model.task)
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 6 replies
-
Thanks for your contribution.
|
Beta Was this translation helpful? Give feedback.
-
Thank you for your reply and dealing with the bug. I have a more question. In a notebook tutorial, # to make things faster, we run the inference once and for all...
validation_files = list(protocol.development())
for file in validation_files:
file['vad_scores'] = inference(file)
# ... and tell the pipeline to load VAD scores directly from files
pipeline = VoiceActivityDetectionPipeline(scores="vad_scores") ref: https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/voice_activity_detection.ipynb But "pyannote.audio.pipelines.VoiceActivityDetection" does not have arg of "scores" at call method.(superclass also) |
Beta Was this translation helpful? Give feedback.
-
Sorry with so many questions. An error occurred while fine-tuning Embedded. In [1]: from pyannote.database import get_protocol
...: from pyannote.audio import Model
...: from pyannote.audio.tasks import SpeakerEmbedding
...: from pytorch_lightning import Trainer
In [2]: protocol = get_protocol("protocol.SpeakerDiarization.emb")
In [3]: model = Model.from_pretrained("pyannote/embedding")
...: model.task = SpeakerEmbedding(protocol=protocol)
...: model.freeze_up_to("sincnet")
...: trainer = Trainer()
...: trainer.fit(model)
/home/vscode/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pyannote/audio/core/model.py:826: UserWarning: Model has been trained with a task-dependent loss function. Either use the 'task' argument to force setting up the loss function or set 'strict' to False to load the model without its loss function and prevent this warning from appearing.
warnings.warn(msg)
/home/vscode/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/core/saving.py:209: UserWarning: Found keys that are not in the model state dict but in the checkpoint: ['loss_func.W']
rank_zero_warn(
/home/vscode/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/core/memory.py:202: LightningDeprecationWarning: Argument `mode` in `ModelSummary` is deprecated in v1.4 and will be removed in v1.6. Use `max_depth=-1` to replicate `mode=full` behaviour.
rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Loading protocol.SpeakerDiarization.emb training labels: 2file [00:00, 31.75file/s]
/home/vscode/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:37: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------------------------------
0 | sincnet | SincNet | 42.6 K | [32, 1, 32000] | [32, 60, 115]
1 | tdnns | ModuleList | 2.8 M | ? | ?
2 | stats_pool | StatsPool | 0 | [32, 1500, 101] | [32, 3000]
3 | embedding | Linear | 1.5 M | [32, 3000] | [32, 512]
4 | loss_func | ArcFaceLoss | 1.5 K | ? | ?
5 | validation_metric | AUROC | 0 | ? | ?
------------------------------------------------------------------------------------
4.3 M Trainable params
42.6 K Non-trainable params
4.3 M Total params
17.392 Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-211d28b237a5> in <module>
3 model.freeze_up_to("sincnet")
4 trainer = Trainer()
----> 5 trainer.fit(model)
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
551 self.checkpoint_connector.resume_start()
552
--> 553 self._run(model)
554
555 assert self.state.stopped
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
916
917 # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 918 self._dispatch()
919
920 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
984 self.accelerator.start_predicting(self)
985 else:
--> 986 self.accelerator.start_training(self)
987
988 def run_stage(self):
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
90
91 def start_training(self, trainer: "pl.Trainer") -> None:
---> 92 self.training_type_plugin.start_training(trainer)
93
94 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
159 def start_training(self, trainer: "pl.Trainer") -> None:
160 # double dispatch to initiate the training loop
--> 161 self._results = trainer.run_stage()
162
163 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
994 if self.predicting:
995 return self._run_predict()
--> 996 return self._run_train()
997
998 def _pre_training_routine(self):
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1029 self.progress_bar_callback.disable()
1030
-> 1031 self._run_sanity_check(self.lightning_module)
1032
1033 # enable train mode
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
1109
1110 # reload dataloaders
-> 1111 self._evaluation_loop.reload_evaluation_dataloaders()
1112
1113 # run eval step
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in reload_evaluation_dataloaders(self)
171 self.trainer.reset_test_dataloader(model)
172 elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
--> 173 self.trainer.reset_val_dataloader(model)
174
175 def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py in reset_val_dataloader(self, model)
435 has_step = is_overridden("validation_step", model)
436 if has_loader and has_step:
--> 437 self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, "val")
438
439 def reset_test_dataloader(self, model) -> None:
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py in _reset_eval_dataloader(self, model, mode)
396 if len(dataloaders) != 0:
397 for i, dataloader in enumerate(dataloaders):
--> 398 num_batches = len(dataloader) if has_len(dataloader) else float("inf")
399 self._worker_check(dataloader, f"{mode} dataloader {i}")
400
~/.cache/pypoetry/virtualenvs/diarization-pyannote2-jfLww0d6-py3.8/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py in has_len(dataloader)
61 # try getting the length
62 if len(dataloader) == 0:
---> 63 raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
64 has_len = True
65 except TypeError:
ValueError: `Dataloader` returned 0 length. Please make sure that it returns at least 1 batch
In [4]: len(list(protocol.development()))
Out[4]: 1
In [5]: len(list(protocol.train()))
Out[5]: 2 |
Beta Was this translation helpful? Give feedback.
Thanks for your contribution.