Skip to content

Commit

Permalink
🧹 Remove kornia dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
jejon committed Oct 15, 2024
1 parent 3dae140 commit 3d0380a
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 208 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies = [
"opendatasets>=0.1.22",
"rarfile>=4.1",
"scipy>=1.9.3",
"kornia>=0.7.0",
"pydicom>=2.4.3",
"matplotlib>=3.7.4",
"seaborn>=0.13.0",
Expand Down
2 changes: 1 addition & 1 deletion src/landmarker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Landmarker
"""

__version__ = "0.1.2-alpha"
__version__ = "0.1.2-alpha-v2"

__all__ = [
"data",
Expand Down
2 changes: 0 additions & 2 deletions src/landmarker/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from .losses import (
AdaptiveWingLoss,
EuclideanDistanceJSDivergenceReg,
EuclideanDistanceVarianceReg,
GaussianHeatmapL2Loss,
GeneralizedNormalHeatmapLoss,
Expand All @@ -17,7 +16,6 @@

__all__ = [
"GeneralizedNormalHeatmapLoss",
"EuclideanDistanceJSDivergenceReg",
"EuclideanDistanceVarianceReg",
"MultivariateGaussianNLLLoss",
"WingLoss",
Expand Down
107 changes: 0 additions & 107 deletions src/landmarker/losses/losses.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
"""Heatmap loss functions."""

from typing import Optional

import numpy as np
import torch
from kornia.losses import js_div_loss_2d
from torch import nn
from torch.nn import functional as F

from landmarker.heatmap.generator import (
GaussianHeatmapGenerator,
HeatmapGenerator,
LaplacianHeatmapGenerator,
)
from landmarker.models.utils import LogSoftmaxND


Expand Down Expand Up @@ -174,105 +166,6 @@ def forward(self, pred: torch.Tensor, cov: torch.Tensor, target: torch.Tensor) -
return loss


class EuclideanDistanceJSDivergenceReg(nn.Module):
r"""
Euclidean distance loss with Jensen-Shannon divergence regularization. The regularization term
imposes a Gaussian distribution on the predicted heatmaps and calculates the Jensen-Shannon
divergence between the predicted and the target heatmap. The Jensen-Shannon divergence is a
method to measure the similarity between two probability distributions. It is a symmetrized
and smoothed version of the Kullback-Leibler divergence, and is defined as the average of the
Kullback-Leibler divergence between the two distributions and a mixture M of the two
distributions:
:math:`JSD(P||Q) = 0.5 * KL(P||M) + 0.5 * KL(Q||M)` where :math:`M = 0.5 * (P + Q)`
Generalization of regularization term proposed by Nibali et al. (2018), which only considers
Multivariate Gaussian distributions, to a generalized Gaussian distribution. (However, now
we only consider the Gaussian and the Laplace distribution.)
source: Numerical Coordinate Regression with Convolutional Neural Networks - Nibali et al.
(2018)
Args:
alpha (float, optional): Weight of the regularization term. Defaults to 1.0.
sigma_t (float, optional): Target sigma value. Defaults to 1.0.
rotation_t (float, optional): Target rotation value. Defaults to 0.0.
nb_landmarks (int, optional): Number of landmarks. Defaults to 1.
heatmap_fun (str, optional): Specifies the type of heatmap function to use. Defaults to
'gaussian'. Possible values are 'gaussian' and 'laplacian'.
heatmap_size (tuple[int, int], optional): Size of the heatmap. Defaults to (512, 512).
gamma (Optional[float], optional): Gamma value for the Laplace distribution. Defaults to
1.0.
reduction (str, optional): Specifies the reduction to apply to the output. Defaults to
'mean'.
eps (float, optional): Epsilon value to avoid division by zero. Defaults to 1e-6.
"""

# TODO: Implement generalized Gaussian distribution. (Currently only Gaussian and Laplace)

def __init__(
self,
alpha: float = 1.0,
sigma_t: float | torch.Tensor = 1.0,
rotation_t: float | torch.Tensor = 0.0,
nb_landmarks: int = 1,
heatmap_fun: str = "gaussian",
heatmap_size: tuple[int, ...] = (512, 512),
gamma: Optional[float] = 1.0,
reduction: str = "mean",
eps: float = 1e-6,
) -> None:
super().__init__()
self.alpha = alpha
self.reduction = reduction
if reduction not in ["mean", "sum", "none"]:
raise ValueError(f"Invalid reduction: {reduction}")
self.eps = eps
self.heatmap_fun: HeatmapGenerator
self.spatial_dims = len(heatmap_size)
if self.spatial_dims != 2:
raise ValueError("Only 2D heatmaps are supported.")
if heatmap_fun == "gaussian":
self.heatmap_fun = GaussianHeatmapGenerator(
nb_landmarks=nb_landmarks,
sigmas=sigma_t,
rotation=rotation_t,
heatmap_size=heatmap_size,
gamma=gamma,
)
elif heatmap_fun == "laplacian":
self.heatmap_fun = LaplacianHeatmapGenerator(
nb_landmarks=nb_landmarks,
sigmas=sigma_t,
rotation=rotation_t,
heatmap_size=heatmap_size,
gamma=gamma,
)
else:
raise ValueError(f"Invalid heatmap function: {heatmap_fun}")

def forward(
self, pred: torch.Tensor, pred_heatmap: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Args:
pred (torch.Tensor): Predicted coordinates.
pred_heatmap (torch.Tensor): Predicted heatmap.
target (torch.Tensor): Target coordinates.
"""
heatmap_t = self.heatmap_fun(target)
# Normalize heatmaps to sum to 1
heatmap_t = (heatmap_t + self.eps) / (heatmap_t + self.eps).sum(dim=(-2, -1), keepdim=True)
pred_heatmap = (pred_heatmap + self.eps) / (
pred_heatmap.sum(dim=(-2, -1), keepdim=True) + self.eps
)
reg = js_div_loss_2d(pred_heatmap, heatmap_t, reduction="none")
loss = _euclidean_distance(pred, target) + self.alpha * reg
if self.reduction == "mean":
return torch.mean(loss)
if self.reduction == "sum":
return torch.sum(loss)
return loss


class MultivariateGaussianNLLLoss(nn.Module):
"""
Negative log-likelihood loss for multivariate Gaussian distributions. The loss function is
Expand Down
98 changes: 1 addition & 97 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
import torch
import torch.nn as nn

from src.landmarker.heatmap.generator import (
GaussianHeatmapGenerator,
LaplacianHeatmapGenerator,
)
from src.landmarker.heatmap.generator import GaussianHeatmapGenerator
from src.landmarker.losses.losses import (
EuclideanDistanceJSDivergenceReg,
EuclideanDistanceVarianceReg,
GeneralizedNormalHeatmapLoss,
MultivariateGaussianNLLLoss,
Expand Down Expand Up @@ -349,98 +345,6 @@ def test_euclidean_distance_variance_reg_3d():
assert True


def test_euclidean_distance_js_divergence_reg():
"""Test the EuclideanDistanceJSDivergenceReg class."""
reduction = "mean"
# pred = torch.rand(1, 3, 2) * 64
# target = torch.rand(1, 3, 2) * 64
pred = torch.ones((1, 3, 2)) * (64 // 2)
target = torch.ones((1, 3, 2)) * (64 // 2 - 5)

heatmap_generator = GaussianHeatmapGenerator(3, sigmas=3, heatmap_size=(64, 64), gamma=1.0)
heatmap = heatmap_generator(pred)
heatmap_target = heatmap_generator(target)

loss_fn = EuclideanDistanceJSDivergenceReg(
reduction=reduction, heatmap_size=(64, 64), gamma=1.0, sigma_t=3, rotation_t=0.0
)
expected_output_shape = torch.Size([])

loss = loss_fn(target, heatmap_target, target)

assert loss.shape == expected_output_shape
assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-5)

loss = loss_fn(pred, heatmap, target)
assert torch.all(loss > 0)

reduction = "mean"
# pred = torch.rand(1, 3, 2) * 64
# target = torch.rand(1, 3, 2) * 64
pred = torch.ones((1, 3, 2)) * (64 // 2)
target = torch.ones((1, 3, 2)) * (64 // 2 - 5)

heatmap_generator = GaussianHeatmapGenerator(3, sigmas=3, heatmap_size=(64, 64), gamma=1.0)
heatmap = heatmap_generator(pred)
heatmap_target = heatmap_generator(target)

loss_fn = EuclideanDistanceJSDivergenceReg(
reduction=reduction, heatmap_size=(64, 64), gamma=1.0, sigma_t=3, rotation_t=0.0
)
expected_output_shape = torch.Size([])

loss = loss_fn(target, heatmap_target, target)

assert loss.shape == expected_output_shape
assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-5)

loss = loss_fn(pred, heatmap, target)
assert torch.all(loss > 0)

reduction = "none"
# pred = torch.rand(1, 3, 2) * 64
# target = torch.rand(1, 3, 2) * 64
pred = torch.ones((1, 3, 2)) * (64 // 2)
target = torch.ones((1, 3, 2)) * (64 // 2 - 5)

heatmap_generator = LaplacianHeatmapGenerator(3, sigmas=3, heatmap_size=(64, 64), gamma=1.0)
heatmap = heatmap_generator(pred)
heatmap_target = heatmap_generator(target)

loss_fn = EuclideanDistanceJSDivergenceReg(
reduction=reduction,
heatmap_size=(64, 64),
gamma=1.0,
sigma_t=3,
rotation_t=0.0,
heatmap_fun="laplacian",
)
expected_output_shape = pred.shape[:-1]

loss = loss_fn(target, heatmap_target, target)

assert loss.shape == expected_output_shape
assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-5)

loss = loss_fn(pred, heatmap, target)
assert torch.all(loss > 0)


def test_euclidean_distance_js_divergence_reg_3d():
"""Test the EuclideanDistanceJSDivergenceReg class. For 3D inputs."""
try:
EuclideanDistanceJSDivergenceReg(
reduction="mean",
heatmap_size=(64, 64, 64),
gamma=1.0,
sigma_t=3,
rotation_t=0.0,
)
assert False
except ValueError:
assert True


def test_star_loss():
"""Test the StarLoss class.""" ""
reduction = "mean"
Expand Down

0 comments on commit 3d0380a

Please sign in to comment.