From 55e6f944826669ea2a7b111457a81c21b81f2813 Mon Sep 17 00:00:00 2001 From: zwl <15878005+zwl90188229@user.noreply.gitee.com> Date: Tue, 15 Jul 2025 12:41:34 +0000 Subject: [PATCH] add StreamLearn/Algorithm/PAA/PAA.py. Signed-off-by: zwl <15878005+zwl90188229@user.noreply.gitee.com> --- StreamLearn/Algorithm/PAA/PAA.py | 79 ++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 StreamLearn/Algorithm/PAA/PAA.py diff --git a/StreamLearn/Algorithm/PAA/PAA.py b/StreamLearn/Algorithm/PAA/PAA.py new file mode 100644 index 0000000..7a07eac --- /dev/null +++ b/StreamLearn/Algorithm/PAA/PAA.py @@ -0,0 +1,79 @@ +from StreamLearn.Base.SemiEstimator import StreamAlgorithm +from .paa_model import PAAModel +from .paa_buffer import PAABuffer +from .paa_proto import OnlinePrototypeLearning +from .loss import prototype_contrast_loss, consistency_loss +import torch + +class PAA(StreamAlgorithm): + def __init__(self, args): + super().__init__() + self.args = args + self.name = 'PAA' + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 初始化模型、缓冲区、原型学习 + self.model = PAAModel(args.num_classes) + self.model.to(self.device) + self.buffer = PAABuffer(args.buffer_size) + self.proto_learner = OnlinePrototypeLearning(args.ema_alpha) + + # 优化器和损失 + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr_paa) + self.criterion_ce = torch.nn.CrossEntropyLoss() + + def stream_fit(self, data_batch): + """单次流数据训练(核心逻辑)""" + images = data_batch['images'].to(self.device) + labels = data_batch['labels'].to(self.device) + + # 前向传播:特征 + 双分类预测 + feats, logits1, logits2 = self.model(images) + + # 1. 缓冲区去噪 + clean_data = self.buffer.denoise( + feats.detach().cpu(), + labels.detach().cpu(), + self.args.lambda1, + self.args.lambda2, + self.args.knn_k + ) + clean_feats = torch.tensor(clean_data['feats']).to(self.device) + clean_labels = torch.tensor(clean_data['labels']).to(self.device) + + # 2. 计算损失(CE + 一致性 + 原型对比) + loss_ce = self.criterion_ce(logits1, labels) + loss_cons = consistency_loss(logits1, logits2) + loss_proto = prototype_contrast_loss(clean_feats, self.proto_learner.prototypes, self.args.temp1) + total_loss = loss_ce + loss_cons + loss_proto + + # 3. 反向传播更新模型 + self.optimizer.zero_grad() + total_loss.backward() + self.optimizer.step() + + # 4. 更新原型和缓冲区 + self.proto_learner.update_prototypes(clean_feats, clean_labels) + self.buffer.update(clean_data) + + def stream_evaluate(self, data_batch): + """评估准确率""" + with torch.no_grad(): + images = data_batch['images'].to(self.device) + labels = data_batch['labels'].to(self.device) + _, logits1, _ = self.model(images) + preds = torch.argmax(logits1, dim=1) + acc = (preds == labels).float().mean() + return acc.item() + + def fit(self, stream_dataset): + """完整流训练(循环run_time次)""" + self.metrics = [] + for _ in range(self.args.run_time): + data_batch = stream_dataset.sample_m() # 采样batch(需匹配数据集接口) + acc = self.stream_evaluate(data_batch) + self.stream_fit(data_batch) + self.metrics.append(acc) + + def test(self, data): + print(f'PAA: {self.metrics}') \ No newline at end of file -- Gitee