Ai
1 Star 0 Fork 1

edata/DeepLearning-MachineLearning-Note

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
focal_loss_pytorch.py 1.34 KB
一键复制 编辑 原始数据 按行查看 历史
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
"""
Focal Loss, ref papre:
http://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf
:param alpha: float, weight, detail ref paper above
:param gamma: float, detail ref paper above
:param reduction: str, 'mean'|'sum'|'none', reduction type
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, output, target):
"""
:param output: Tensor, shape (batch_size, class_nums, )
:param target: Tensor, shape (batch_size, )
:return: Tensor
"""
if output.dim != target.dim:
# convert target to onehot
target = torch.zeros_like(output).scatter_(1, target.unsqueeze(1), 1)
# convert output to presudo probability
p = output.sigmoid()
a = target * self.alpha + (1 - target) * (1 - self.alpha)
pt = 1 - (target * p + (1 - target) * (1 - p))
focal_weight = (a * pt.pow(self.gamma)).detach()
focal_loss = F.binary_cross_entropy(p, target, weight=focal_weight, reduction=self.reduction)
return focal_loss
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/edata-code/DeepLearning-MachineLearning-Note.git
git@gitee.com:edata-code/DeepLearning-MachineLearning-Note.git
edata-code
DeepLearning-MachineLearning-Note
DeepLearning-MachineLearning-Note
master

搜索帮助