PyTorch implementation of focal loss for multi-class semantic segmentation.
If you want to use the alpha form focal loss, you need to do two things:
- Please prepare a set of alpha for each class.
- Change the comment in the code as below:
focal_loss = self.alpha[targets] * (1 - pt)**self.gamma * ce_loss
#focal_loss = (1 - pt) ** self.gamma * ce_loss
...
fn_loss = FocalLoss()
pred = model(x)
loss = fn_loss(pred, target)
...
...
class_weights = [a set of alpha for each class]
fn_loss = FocalLoss(alpha = class_weights)
pred = model(x)
loss = fn_loss(pred, target)
...
Please do visit my colleague's github as well!
https://github.com/jinsoo9595