Skip to content

Commit

Permalink
add manual_optimization method in Task
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages committed Nov 26, 2024
1 parent 0e5b60e commit 4d97bca
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 29 deletions.
51 changes: 51 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from torchmetrics import Metric, MetricCollection

from pyannote.audio.utils.loss import binary_cross_entropy, nll_loss
from pyannote.audio.utils.params import merge_dict
from pyannote.audio.utils.protocol import check_protocol

Subsets = list(Subset.__args__)
Expand Down Expand Up @@ -231,6 +232,9 @@ class Task(pl.LightningDataModule):
If True, data loaders will copy tensors into CUDA pinned
memory before returning them. See pytorch documentation
for more details. Defaults to False.
gradient: dict, optional
Keywords arguments for gradient calculation.
Defaults to {"clip_val": 5.0, "clip_algorithm": "norm", "accumulate_batches": 1}
augmentation : BaseWaveformTransform, optional
torch_audiomentations waveform transform, used by dataloader
during training.
Expand All @@ -245,6 +249,12 @@ class Task(pl.LightningDataModule):
"""

GRADIENT_DEFAULTS = {
"clip_val": 5.0,
"clip_algorithm": "norm",
"accumulate_batches": 1,
}

def __init__(
self,
protocol: Protocol,
Expand All @@ -255,6 +265,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand Down Expand Up @@ -302,6 +313,7 @@ def __init__(

self.num_workers = num_workers
self.pin_memory = pin_memory
self.gradient = merge_dict(self.GRADIENT_DEFAULTS, gradient)
self.augmentation = augmentation or Identity(output_type="dict")
self._metric = metric

Expand Down Expand Up @@ -810,6 +822,45 @@ def common_step(self, batch, batch_idx: int, stage: Literal["train", "val"]):
# can obviously be overriden for each task
def training_step(self, batch, batch_idx: int):
return self.common_step(batch, batch_idx, "train")

def manual_optimization(self, loss: torch.Tensor, batch_idx: int) -> torch.Tensor:
"""Process manual optimization for each optimizer
Parameters
----------
loss: torch.Tensor
Computed loss for current training step.
batch_idx: int
Batch index.
Returns
-------
scaled_loss: torch.Tensor
Loss scaled by `1 / Task.gradient["accumulate_batches"]`.
"""
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]

num_accumulate_batches = self.gradient["accumulate_batches"]
if batch_idx % num_accumulate_batches == 0:
for optimizer in optimizers:
optimizer.zero_grad()

# scale loss to keep the gradient magnitude as it would be using batches
# with size = batch_size * num_accumulate_batches
scaled_loss = loss / num_accumulate_batches
self.model.manual_backward(scaled_loss)

if (batch_idx + 1) % num_accumulate_batches == 0:
for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.gradient["clip_val"],
gradient_clip_algorithm=self.gradient["clip_algorithm"],
)
optimizer.step()

return scaled_loss

def val__getitem__(self, idx):
# will become val_dataset.__getitem__ method
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/embedding/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
scale: float = 64.0,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -106,6 +107,7 @@ def __init__(
min_duration=min_duration,
batch_size=self.batch_size,
num_workers=num_workers,
gradient=gradient,
pin_memory=pin_memory,
augmentation=augmentation,
metric=metric,
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -116,6 +117,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
cache: Optional[Union[str, None]] = None,
Expand All @@ -122,6 +123,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
17 changes: 3 additions & 14 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
max_num_speakers: Optional[
Expand All @@ -132,6 +133,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down Expand Up @@ -466,20 +468,7 @@ def training_step(self, batch, batch_idx: int):
)

if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=5.0,
gradient_clip_algorithm="norm",
)
optimizer.step()
loss = self.manual_optimization(loss, batch_idx)

return {"loss": loss}

Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/tasks/segmentation/voice_activity_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
):
Expand All @@ -104,6 +105,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down
18 changes: 3 additions & 15 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
batch_size: int = 32,
num_workers: Optional[int] = None,
pin_memory: bool = False,
gradient: Optional[Dict] = None,
augmentation: Optional[BaseWaveformTransform] = None,
metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None,
max_num_speakers: Optional[
Expand All @@ -185,6 +186,7 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
gradient=gradient,
augmentation=augmentation,
metric=metric,
cache=cache,
Expand Down Expand Up @@ -1009,22 +1011,8 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

# using multiple optimizers requires manual optimization
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
optimizer.step()
loss = self.manual_optimization(loss, batch_idx)

return {"loss": loss}

Expand Down

0 comments on commit 4d97bca

Please sign in to comment.