From 2a59bc7b0dcce819edd97986e8000435f8e4ed23 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Thu, 9 Jan 2025 09:51:22 +0100 Subject: [PATCH] add support for compose transform --- pyannote/audio/core/model.py | 32 ++++++++++++++++++++++---------- pyannote/audio/core/task.py | 25 ++++++++++++++++--------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pyannote/audio/core/model.py b/pyannote/audio/core/model.py index 6dfe1a8bf..fbc45c249 100644 --- a/pyannote/audio/core/model.py +++ b/pyannote/audio/core/model.py @@ -40,6 +40,7 @@ from pyannote.core import SlidingWindow from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary from torch.utils.data import DataLoader +from torch_audiomentations.core.composition import Compose from pyannote.audio import __version__ from pyannote.audio.core.io import Audio @@ -679,16 +680,27 @@ def default_map_location(storage, loc): TaskClass = getattr(task_module, task_class_name) - # instanciate task augmentation - augmentation = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] - if augmentation: - augmentation_module = import_module(augmentation["module"]) - augmentation_class = augmentation["class"] - augmentation_kwargs = augmentation["kwargs"] - AugmentationClass = getattr(augmentation_module, augmentation_class) - augmentation = AugmentationClass(**augmentation_kwargs) - - task_hparams["augmentation"] = augmentation + # instantiate task augmentation + def instantiate_transform(transform_data): + transform_module = import_module(transform_data["module"]) + transform_class = transform_data["class"] + transform_kwargs = transform_data["kwargs"] + TransformClass = getattr(transform_module, transform_class) + return TransformClass(**transform_kwargs) + + augmentation_data = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"] + # BaseWaveformTransform case + if isinstance(augmentation_data, Dict): + task_hparams["augmentation"] = instantiate_transform(augmentation_data) + + # Compose transform case + elif isinstance(augmentation_data , List): + transforms = [] + for transform_data in augmentation_data: + transform = instantiate_transform(transform_data) + transforms.append(transform) + + task_hparams["augmentation"] = Compose(transforms=transforms, output_type="dict") # instanciate task metrics metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"] diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 605e3a975..89706b0b3 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -46,6 +46,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torch_audiomentations import Identity from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torch_audiomentations.core.composition import BaseCompose from torchmetrics import Metric, MetricCollection from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss @@ -915,18 +916,24 @@ def on_save_checkpoint(self, checkpoint): } # save augmentation: - # TODO: add support for compose augmentation + def serialize_augmentation(augmentation) -> Dict: + return { + "module": augmentation.__class__.__module__, + "class": augmentation.__class__.__name__, + "kwargs": { + param: getattr(augmentation, param, None) + for param in inspect.signature(augmentation.__init__).parameters + } + } + if not self.augmentation: checkpoint["pyannote.audio"]["task"]["augmentation"] = None elif isinstance(self.augmentation, BaseWaveformTransform): - checkpoint["pyannote.audio"]["task"]["augmentation"] = [{ - "module": self.augmentation.__class__.__module__, - "class": self.augmentation.__class__.__name__, - "kwargs": { - param: getattr(self.augmentation, param, None) - for param in inspect.signature(self.augmentation.__init__).parameters - } - }] + checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_augmentation(self.augmentation) + elif isinstance(self.augmentation, BaseCompose): + checkpoint["pyannote.audio"]["task"]["augmentation"] = [] + for augmentation in self.augmentation.transforms: + checkpoint["pyannote.audio"]["task"]["augmentation"].append(serialize_augmentation(augmentation)) # save metrics: if isinstance(self.metric, Metric):