1 Star 0 Fork 1

edata/DeepLearning-MachineLearning-Note

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

Focal Loss(FL)理解和实现

FL对处理样本不平衡问题效果不错,这里对它做一个详细的解释,并且用pytorch实现,详细的解释请查看我的博客

  • Pytorch Focal Loss
  • Light Focal Loss
  • Catboost Focal Loss

代码解释

参考两个代码,目前来看,这里的实现应该是最简单的,效果与mmdetection一样,最新实现请参考源代码。

  • pytorch 实现
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, output, target):
        if output.dim != target.dim:
            # convert target to onehot
            target = torch.zeros_like(output).scatter_(1, target.unsqueeze(1), 1)
        # convert output to presudo probability
        p = output.sigmoid()
        a = target * self.alpha + (1 - target) * (1 - self.alpha)
        pt = 1 - (target * p + (1 - target) * (1 - p))
        focal_weight = a * pt.pow(self.gamma)

        focal_loss = F.binary_cross_entropy(p, target, weight=focal_weight, reduction=self.reduction)
        return focal_loss
  • lightgbm实现
def lgb_focal_loss(y_pred, dtrain, alpha, gamma):
    targets = dtrain.label.reshape(-1)
    n_classes = len(np.unique(targets))
    # encode targets to one hot
    t = np.eye(n_classes)[targets]

    def focal_loss(x):
        # sigmoid => pseudo probability
        p = 1 / (1 + np.exp(-x))
        a = t * alpha + (1 - t) * (1 - alpha)
        g_coef = (1 - (t * p + (1 - t) * (1 - p))) ** gamma
        bce = t * np.log(p) + (1 - t) * np.log(1 - p)
        return - (a * g_coef * bce).sum(axis=1)

    grad = derivative(focal_loss, y_pred, n=1, dx=1e-6)
    hess = derivative(focal_loss, y_pred, n=2, dx=1e-6)
    return grad, hess
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/edata-code/DeepLearning-MachineLearning-Note.git
git@gitee.com:edata-code/DeepLearning-MachineLearning-Note.git
edata-code
DeepLearning-MachineLearning-Note
DeepLearning-MachineLearning-Note
master

搜索帮助