forked from ternaus/robot-surgery-segmentation
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloss.py
61 lines (48 loc) · 2.04 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from torch.nn import functional as F
import utils
import numpy as np
class LossBinary:
"""
Loss defined as BCE - log(soft_jaccard)
Vladimir Iglovikov, Sergey Mushinskiy, Vladimir Osin,
Satellite Imagery Feature Detection using Deep Convolutional Neural Network: A Kaggle Competition
arXiv:1706.06169
"""
def __init__(self, jaccard_weight=0):
self.nll_loss = nn.BCEWithLogitsLoss()
self.jaccard_weight = jaccard_weight
def __call__(self, outputs, targets):
loss = self.nll_loss(outputs, targets)
if self.jaccard_weight:
eps = 1e-15
jaccard_target = (targets == 1).float()
jaccard_output = F.sigmoid(outputs)
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
loss -= self.jaccard_weight * torch.log((intersection + eps) / (union - intersection + eps))
return loss
class LossMulti:
def __init__(self, jaccard_weight=0, class_weights=None, num_classes=1):
if class_weights is not None:
nll_weight = utils.cuda(
torch.from_numpy(class_weights.astype(np.float32)))
else:
nll_weight = None
self.nll_loss = nn.NLLLoss2d(weight=nll_weight)
self.jaccard_weight = jaccard_weight
self.num_classes=num_classes
def __call__(self, outputs, targets):
loss = self.nll_loss(outputs, targets)
if self.jaccard_weight:
cls_weight = self.jaccard_weight / self.num_classes
eps = 1e-15
for cls in range(self.num_classes):
jaccard_target = (targets == cls).float()
jaccard_output = outputs[:, cls].exp()
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum() + eps
loss += (1 - intersection / (union - intersection)) * cls_weight
loss /= (1 + self.jaccard_weight)
return loss