1 Star 0 Fork 1

atari/BiSeNet

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ohem_ce_loss.py 1.41 KB
一键复制 编辑 原始数据 按行查看 历史
CoinCheung 提交于 2022-06-25 22:28 +08:00 . Dev (#240)
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
# import ohem_cpp
# class OhemCELoss(nn.Module):
#
# def __init__(self, thresh, lb_ignore=255):
# super(OhemCELoss, self).__init__()
# self.score_thresh = thresh
# self.lb_ignore = lb_ignore
# self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='mean')
#
# def forward(self, logits, labels):
# n_min = labels[labels != self.lb_ignore].numel() // 16
# labels = ohem_cpp.score_ohem_label(
# logits, labels, self.lb_ignore, self.score_thresh, n_min).detach()
# loss = self.criteria(logits, labels)
# return loss
class OhemCELoss(nn.Module):
def __init__(self, thresh, lb_ignore=255):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda()
self.lb_ignore = lb_ignore
self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
def forward(self, logits, labels):
n_min = labels[labels != self.lb_ignore].numel() // 16
loss = self.criteria(logits, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
loss_hard, _ = loss.topk(n_min)
return torch.mean(loss_hard)
if __name__ == '__main__':
pass
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/atari/BiSeNet.git
git@gitee.com:atari/BiSeNet.git
atari
BiSeNet
BiSeNet
master

搜索帮助