diff --git a/StreamLearn/Algorithm/PAA/loss.py b/StreamLearn/Algorithm/PAA/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8937e45952337de7be33365cb207640f670ab1 --- /dev/null +++ b/StreamLearn/Algorithm/PAA/loss.py @@ -0,0 +1,16 @@ +import torch +import torch.nn.functional as F + +def prototype_contrast_loss(feats, prototypes, temp=0.1): + """原型对比损失(区分新旧类)""" + feats = F.normalize(feats, dim=1) + prototypes = F.normalize(torch.tensor(list(prototypes.values())), dim=1) + logits = torch.mm(feats, prototypes.t()) / temp + labels = torch.arange(len(feats), device=feats.device) # 假设每个样本对应自身原型(需根据实际调整) + return F.cross_entropy(logits, labels) + +def consistency_loss(logits1, logits2): + """双分类器预测一致性损失""" + prob1 = F.softmax(logits1, dim=1) + prob2 = F.softmax(logits2, dim=1) + return torch.mean(torch.sum(torch.abs(prob1 - prob2), dim=1)) \ No newline at end of file