Skip to content

Commit

Permalink
Merge branch 'develop3' into develop2
Browse files Browse the repository at this point in the history
  • Loading branch information
FrenchKrab committed Mar 18, 2022
2 parents 372f66c + 66a873f commit 6602b33
Show file tree
Hide file tree
Showing 23 changed files with 575 additions and 484 deletions.
15 changes: 7 additions & 8 deletions pyannote/audio/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,19 @@
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

# from pyannote.audio.core.callback import GraduallyUnfreeze
from pyannote.database import FileFinder, get_protocol
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
RichProgressBar,
)

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from torch_audiomentations.utils.config import from_dict as get_augmentation

# from pyannote.audio.core.callback import GraduallyUnfreeze
from pyannote.database import FileFinder, get_protocol


@hydra.main(config_path="train_config", config_name="config")
def train(cfg: DictConfig) -> Optional[float]:
Expand All @@ -67,15 +66,15 @@ def train(cfg: DictConfig) -> Optional[float]:
# instantiate task and validation metric
task = instantiate(cfg.task, protocol, augmentation=augmentation)

# validation metric to monitor (and its direction: min or max)
monitor, direction = task.val_monitor

# instantiate model
fine_tuning = cfg.model["_target_"] == "pyannote.audio.cli.pretrained"
model = instantiate(cfg.model)
model.task = task
model.setup(stage="fit")

# validation metric to monitor (and its direction: min or max)
monitor, direction = task.val_monitor

# number of batches in one epoch
num_batches_per_epoch = model.task.train__len__() // model.task.batch_size

Expand All @@ -90,6 +89,7 @@ def configure_optimizers(self):
num_batches_per_epoch=num_batches_per_epoch,
)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

model.configure_optimizers = MethodType(configure_optimizers, model)

callbacks = [RichProgressBar(), LearningRateMonitor()]
Expand Down Expand Up @@ -155,4 +155,3 @@ def configure_optimizers(self):

if __name__ == "__main__":
train()

4 changes: 2 additions & 2 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import torch.nn as nn
import torch.optim
from huggingface_hub import cached_download, hf_hub_url
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.model_summary import ModelSummary
from semver import VersionInfo
Expand All @@ -42,7 +43,6 @@
from pyannote.audio import __version__
from pyannote.audio.core.io import Audio
from pyannote.audio.core.task import Problem, Resolution, Specifications, Task
from pyannote.core import SlidingWindow

CACHE_DIR = os.getenv(
"PYANNOTE_CACHE",
Expand Down Expand Up @@ -385,7 +385,7 @@ def setup(self, stage=None):
# setup custom loss function
self.task.setup_loss_func()
# setup custom validation metrics
validation_metric = self.task.setup_validation_metric()
validation_metric = self.task.metric
if validation_metric is not None:
self.validation_metric = validation_metric
self.validation_metric.to(self.device)
Expand Down
95 changes: 40 additions & 55 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2020-2021 CNRS
# Copyright (c) 2020- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand All @@ -23,13 +23,18 @@

from __future__ import annotations

try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property

import multiprocessing
import sys
import warnings
from dataclasses import dataclass
from enum import Enum
from numbers import Number
from typing import Dict, List, Optional, Sequence, Text, Tuple, Type, Union
from typing import Dict, List, Optional, Sequence, Text, Tuple, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -152,6 +157,9 @@ class Task(pl.LightningDataModule):
augmentation : BaseWaveformTransform, optional
torch_audiomentations waveform transform, used by dataloader
during training.
metric : optional
Validation metric(s). Can be anything supported by torchmetrics.MetricCollection.
Defaults to value returned by `default_metric` method.
Attributes
----------
Expand All @@ -169,7 +177,7 @@ def __init__(
num_workers: int = None,
pin_memory: bool = False,
augmentation: BaseWaveformTransform = None,
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
super().__init__()

Expand Down Expand Up @@ -205,20 +213,7 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.augmentation = augmentation
self.metrics = metrics

@property
def default_validation_metric(
self,
) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]:
"""Used to get validation metrics when none is provided by the user
Returns
-------
Union[Metric, Sequence[Metric], Dict[str, Metric]]
The default validation metric(s) of this task (something that can be plugged in a torchmetrics.MetricCollection).
"""
pass
self._metric = metric

def prepare_data(self):
"""Use this to download and prepare data
Expand Down Expand Up @@ -249,39 +244,6 @@ def setup(self, stage: Optional[str] = None):
def setup_loss_func(self):
pass

@property
def val_metric_prefix(self) -> str:
return f"{self.ACRONYM}@val_"

def get_default_val_metric_name(self, metric: Union[Metric, Type]) -> str:
prefix = self.val_metric_prefix
mn = Task.get_metric_name(metric)
return f"{prefix}{mn}"

@staticmethod
def get_metric_name(metric: Union[Metric, Type]) -> str:
if isinstance(metric, Metric):
return type(metric).__name__.lower()
elif isinstance(metric, Type):
return metric.__name__.lower()
else:
msg = "get_metric_name only accepts Metric and Type arguments."
raise ValueError(msg)

def setup_validation_metric(self) -> Metric:
metricsarg = self.metrics
if self.metrics is None:
metricsarg = self.default_validation_metric

# Convert metricargs to a list, if it is a single Metric
if isinstance(metricsarg, Metric):
metricsarg = [metricsarg]
# Convert metricsarg to a dict, now that it is a list
# If the metrics' names are not given, generate them automatically
if not isinstance(metricsarg, dict):
metricsarg = {Task.get_metric_name(m): m for m in metricsarg}
return MetricCollection(metricsarg, prefix=self.val_metric_prefix)

def train__iter__(self):
# will become train_dataset.__iter__ method
msg = f"Missing '{self.__class__.__name__}.train__iter__' method."
Expand Down Expand Up @@ -310,6 +272,18 @@ def train_dataloader(self) -> DataLoader:
collate_fn=self.collate_fn,
)

@cached_property
def logging_prefix(self):

prefix = f"{self.__class__.__name__}-"
if hasattr(self.protocol, "name"):
# "." has a special meaning for pytorch-lightning checkpointing
# so we remove dots from protocol names
name_without_dots = "".join(self.protocol.name.split("."))
prefix += f"{name_without_dots}-"

return prefix

def default_loss(
self, specifications: Specifications, target, prediction, weight=None
) -> torch.Tensor:
Expand Down Expand Up @@ -399,7 +373,7 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]):
# compute loss
loss = self.default_loss(self.specifications, y, y_pred, weight=weight)
self.model.log(
f"{self.ACRONYM}@{stage}_loss",
f"{self.logging_prefix}{stage.capitalize()}Loss",
loss,
on_step=False,
on_epoch=True,
Expand Down Expand Up @@ -443,6 +417,18 @@ def validation_step(self, batch, batch_idx: int):
def validation_epoch_end(self, outputs):
pass

def default_metric(self) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]:
"""Default validation metric"""
msg = f"Missing '{self.__class__.__name__}.default_metric' method."
raise NotImplementedError(msg)

@cached_property
def metric(self) -> MetricCollection:
if self._metric is None:
self._metric = self.default_metric()

return MetricCollection(self._metric, prefix=self.logging_prefix)

@property
def val_monitor(self):
"""Quantity (and direction) to monitor
Expand All @@ -461,7 +447,6 @@ def val_monitor(self):
pytorch_lightning.callbacks.ModelCheckpoint
pytorch_lightning.callbacks.EarlyStopping
"""
if self.has_validation:
return f"{self.ACRONYM}@val_loss", "min"
else:
return None, "min"

name, metric = next(iter(self.metric.items()))
return name, "max" if metric.higher_is_better else "min"
13 changes: 7 additions & 6 deletions pyannote/audio/tasks/embedding/arcface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2020-2021 CNRS
# Copyright (c) 2020- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,11 +26,11 @@
from typing import Dict, Sequence, Union

import pytorch_metric_learning.losses
from pyannote.database import Protocol
from torch_audiomentations.core.transforms_interface import BaseWaveformTransform
from torchmetrics import Metric

from pyannote.audio.core.task import Task
from pyannote.database import Protocol

from .mixins import SupervisedRepresentationLearningTaskMixin

Expand Down Expand Up @@ -70,10 +70,11 @@ class SupervisedRepresentationLearningWithArcFace(
augmentation : BaseWaveformTransform, optional
torch_audiomentations waveform transform, used by dataloader
during training.
metric : optional
Validation metric(s). Can be anything supported by torchmetrics.MetricCollection.
Defaults to AUROC (area under the ROC curve).
"""

ACRONYM = "arcface"

#  TODO: add a ".metric" property that tells how speaker embedding trained with this approach
#  should be compared. could be a string like "cosine" or "euclidean" or a pdist/cdist-like
#  callable. this ".metric" property should be propagated all the way to Inference (via the model).
Expand All @@ -90,7 +91,7 @@ def __init__(
num_workers: int = None,
pin_memory: bool = False,
augmentation: BaseWaveformTransform = None,
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):

self.num_chunks_per_class = num_chunks_per_class
Expand All @@ -107,7 +108,7 @@ def __init__(
num_workers=num_workers,
pin_memory=pin_memory,
augmentation=augmentation,
metrics=metrics,
metric=metric,
)

def setup_loss_func(self):
Expand Down
23 changes: 4 additions & 19 deletions pyannote/audio/tasks/embedding/mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2020-2021 CNRS
# Copyright (c) 2020- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -33,7 +33,7 @@
from torchmetrics import AUROC, Metric
from tqdm import tqdm

from pyannote.audio.core.task import Problem, Resolution, Specifications, Task
from pyannote.audio.core.task import Problem, Resolution, Specifications
from pyannote.audio.utils.random import create_rng_for_worker


Expand Down Expand Up @@ -124,8 +124,7 @@ def setup(self, stage: Optional[str] = None):
if isinstance(self.protocol, SpeakerVerificationProtocol):
self._validation = list(self.protocol.development_trial())

@property
def default_validation_metric(
def default_metric(
self,
) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]:
return AUROC(compute_on_step=False)
Expand Down Expand Up @@ -222,7 +221,7 @@ def training_step(self, batch, batch_idx: int):
loss = self.model.loss_func(self.model(X), y)

self.model.log(
f"{self.ACRONYM}@train_loss",
f"{self.logging_prefix}TrainLoss",
loss,
on_step=False,
on_epoch=True,
Expand Down Expand Up @@ -290,17 +289,3 @@ def validation_step(self, batch, batch_idx: int):

elif isinstance(self.protocol, SpeakerDiarizationProtocol):
pass

@property
def val_monitor(self):

if self.has_validation and self.metrics is None:

if isinstance(self.protocol, SpeakerVerificationProtocol):
return Task.get_default_val_metric_name(AUROC), "max"

elif isinstance(self.protocol, SpeakerDiarizationProtocol):
return None, "min"

else:
return None, "min"
Loading

0 comments on commit 6602b33

Please sign in to comment.