diff --git a/StreamLearn/Algorithm/PAA/paa_proto.py b/StreamLearn/Algorithm/PAA/paa_proto.py new file mode 100644 index 0000000000000000000000000000000000000000..af9af978e9856baca1ba5b10e37cda175b2b724e --- /dev/null +++ b/StreamLearn/Algorithm/PAA/paa_proto.py @@ -0,0 +1,20 @@ +import torch +import numpy as np + +class OnlinePrototypeLearning: + def __init__(self, ema_alpha=0.9): + self.ema_alpha = ema_alpha + self.prototypes = None # 存储类原型:shape=[num_classes, feat_dim] + + def update_prototypes(self, feats, labels): + """根据当前batch的特征和标签更新原型""" + feats = feats.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + for c in np.unique(labels): + mask = (labels == c) + class_feats = feats[mask] + if self.prototypes is None: + self.prototypes = {c: np.mean(class_feats, axis=0)} + else: + # EMA更新:原型 = alpha*旧原型 + (1-alpha)*新均值 + self.prototypes[c] = self.ema_alpha * self.prototypes[c] + (1 - self.ema_alpha) * np.mean(class_feats, axis=0) \ No newline at end of file