代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
"""
Focal Loss, ref papre:
http://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf
:param alpha: float, weight, detail ref paper above
:param gamma: float, detail ref paper above
:param reduction: str, 'mean'|'sum'|'none', reduction type
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, output, target):
"""
:param output: Tensor, shape (batch_size, class_nums, )
:param target: Tensor, shape (batch_size, )
:return: Tensor
"""
if output.dim != target.dim:
# convert target to onehot
target = torch.zeros_like(output).scatter_(1, target.unsqueeze(1), 1)
# convert output to presudo probability
p = output.sigmoid()
a = target * self.alpha + (1 - target) * (1 - self.alpha)
pt = 1 - (target * p + (1 - target) * (1 - p))
focal_weight = (a * pt.pow(self.gamma)).detach()
focal_loss = F.binary_cross_entropy(p, target, weight=focal_weight, reduction=self.reduction)
return focal_loss
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。