代码拉取完成,页面将自动刷新
FL对处理样本不平衡问题效果不错,这里对它做一个详细的解释,并且用pytorch实现,详细的解释请查看我的博客。
参考两个代码,目前来看,这里的实现应该是最简单的,效果与mmdetection一样,最新实现请参考源代码。
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, output, target):
if output.dim != target.dim:
# convert target to onehot
target = torch.zeros_like(output).scatter_(1, target.unsqueeze(1), 1)
# convert output to presudo probability
p = output.sigmoid()
a = target * self.alpha + (1 - target) * (1 - self.alpha)
pt = 1 - (target * p + (1 - target) * (1 - p))
focal_weight = a * pt.pow(self.gamma)
focal_loss = F.binary_cross_entropy(p, target, weight=focal_weight, reduction=self.reduction)
return focal_loss
def lgb_focal_loss(y_pred, dtrain, alpha, gamma):
targets = dtrain.label.reshape(-1)
n_classes = len(np.unique(targets))
# encode targets to one hot
t = np.eye(n_classes)[targets]
def focal_loss(x):
# sigmoid => pseudo probability
p = 1 / (1 + np.exp(-x))
a = t * alpha + (1 - t) * (1 - alpha)
g_coef = (1 - (t * p + (1 - t) * (1 - p))) ** gamma
bce = t * np.log(p) + (1 - t) * np.log(1 - p)
return - (a * g_coef * bce).sum(axis=1)
grad = derivative(focal_loss, y_pred, n=1, dx=1e-6)
hess = derivative(focal_loss, y_pred, n=2, dx=1e-6)
return grad, hess
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。