Skip to content

Commit

Permalink
add support for compose transform
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages committed Jan 9, 2025
1 parent da3356b commit 2a59bc7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
32 changes: 22 additions & 10 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.model_summary.model_summary import ModelSummary
from torch.utils.data import DataLoader
from torch_audiomentations.core.composition import Compose

from pyannote.audio import __version__
from pyannote.audio.core.io import Audio
Expand Down Expand Up @@ -679,16 +680,27 @@ def default_map_location(storage, loc):

TaskClass = getattr(task_module, task_class_name)

# instanciate task augmentation
augmentation = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"]
if augmentation:
augmentation_module = import_module(augmentation["module"])
augmentation_class = augmentation["class"]
augmentation_kwargs = augmentation["kwargs"]
AugmentationClass = getattr(augmentation_module, augmentation_class)
augmentation = AugmentationClass(**augmentation_kwargs)

task_hparams["augmentation"] = augmentation
# instantiate task augmentation
def instantiate_transform(transform_data):
transform_module = import_module(transform_data["module"])
transform_class = transform_data["class"]
transform_kwargs = transform_data["kwargs"]
TransformClass = getattr(transform_module, transform_class)
return TransformClass(**transform_kwargs)

augmentation_data = loaded_checkpoint["pyannote.audio"]["task"]["augmentation"]
# BaseWaveformTransform case
if isinstance(augmentation_data, Dict):
task_hparams["augmentation"] = instantiate_transform(augmentation_data)

# Compose transform case
elif isinstance(augmentation_data , List):
transforms = []
for transform_data in augmentation_data:
transform = instantiate_transform(transform_data)
transforms.append(transform)

task_hparams["augmentation"] = Compose(transforms=transforms, output_type="dict")

# instanciate task metrics
metrics = loaded_checkpoint["pyannote.audio"]["task"]["metrics"]
Expand Down
25 changes: 16 additions & 9 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch_audiomentations import Identity
from torch_audiomentations.core.transforms_interface import BaseWaveformTransform
from torch_audiomentations.core.composition import BaseCompose
from torchmetrics import Metric, MetricCollection

from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss
Expand Down Expand Up @@ -915,18 +916,24 @@ def on_save_checkpoint(self, checkpoint):
}

# save augmentation:
# TODO: add support for compose augmentation
def serialize_augmentation(augmentation) -> Dict:
return {
"module": augmentation.__class__.__module__,
"class": augmentation.__class__.__name__,
"kwargs": {
param: getattr(augmentation, param, None)
for param in inspect.signature(augmentation.__init__).parameters
}
}

if not self.augmentation:
checkpoint["pyannote.audio"]["task"]["augmentation"] = None
elif isinstance(self.augmentation, BaseWaveformTransform):
checkpoint["pyannote.audio"]["task"]["augmentation"] = [{
"module": self.augmentation.__class__.__module__,
"class": self.augmentation.__class__.__name__,
"kwargs": {
param: getattr(self.augmentation, param, None)
for param in inspect.signature(self.augmentation.__init__).parameters
}
}]
checkpoint["pyannote.audio"]["task"]["augmentation"] = serialize_augmentation(self.augmentation)
elif isinstance(self.augmentation, BaseCompose):
checkpoint["pyannote.audio"]["task"]["augmentation"] = []
for augmentation in self.augmentation.transforms:
checkpoint["pyannote.audio"]["task"]["augmentation"].append(serialize_augmentation(augmentation))

# save metrics:
if isinstance(self.metric, Metric):
Expand Down

0 comments on commit 2a59bc7

Please sign in to comment.