From e001d34c0c6ebd73bd452c7fd2152e45c1a21db5 Mon Sep 17 00:00:00 2001 From: zwl <15878005+zwl90188229@user.noreply.gitee.com> Date: Tue, 15 Jul 2025 12:49:33 +0000 Subject: [PATCH] add StreamLearn/Algorithm/PAA/paa_proto.py. Signed-off-by: zwl <15878005+zwl90188229@user.noreply.gitee.com> --- StreamLearn/Algorithm/PAA/paa_proto.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 StreamLearn/Algorithm/PAA/paa_proto.py diff --git a/StreamLearn/Algorithm/PAA/paa_proto.py b/StreamLearn/Algorithm/PAA/paa_proto.py new file mode 100644 index 0000000..af9af97 --- /dev/null +++ b/StreamLearn/Algorithm/PAA/paa_proto.py @@ -0,0 +1,20 @@ +import torch +import numpy as np + +class OnlinePrototypeLearning: + def __init__(self, ema_alpha=0.9): + self.ema_alpha = ema_alpha + self.prototypes = None # 存储类原型:shape=[num_classes, feat_dim] + + def update_prototypes(self, feats, labels): + """根据当前batch的特征和标签更新原型""" + feats = feats.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + for c in np.unique(labels): + mask = (labels == c) + class_feats = feats[mask] + if self.prototypes is None: + self.prototypes = {c: np.mean(class_feats, axis=0)} + else: + # EMA更新:原型 = alpha*旧原型 + (1-alpha)*新均值 + self.prototypes[c] = self.ema_alpha * self.prototypes[c] + (1 - self.ema_alpha) * np.mean(class_feats, axis=0) \ No newline at end of file -- Gitee