diff --git a/StreamLearn/legacy/README.md b/StreamLearn/legacy/README.md index 8484b7adaaeeb42ec5426b2fcf51ef04584f19d5..58398827adc928d0b96eb73e563af0d4134dc0f8 100644 --- a/StreamLearn/legacy/README.md +++ b/StreamLearn/legacy/README.md @@ -11,6 +11,7 @@ - **PAA**:课题一的PAA算法,实现在线类增量持续学习。 - **无监督分类**:课题二中的矩阵略图近似优化算法,针对带噪声的CIFAR-10图像分类任务。 - **SAFC**:课题三中的增量学习算法,分为准备阶段(P阶段)和适应阶段(A阶段)。 +- **Adaprompt**:课程三中的面向数据分布变化的流数据增量学习,针对于cifar10图像分类的分布偏移场景。 - **流数据调度**:课题四中的流数据学习算法调度模拟环境,优化流数据任务的资源利用率和时效性。 - **类别增量学习**:课题五中的增量学习算法,支持动态添加类别。 @@ -1004,6 +1005,47 @@ def train_and_evaluate_stream_HamOS(args): 测试主文件为test_HamOS.py +### 3.5 Adaprompt: 面向分布外泛化的测试时适应。 +针对零样本提示调优中因数据偏差和模型偏差导致的性能下降问题,我们提出了 ADAPROMPT 算法。该算法通过三大模块解决这些问题:一是提示集成,融合多个手工提示的预测结果以规避单一提示的最差情况,缓解数据偏差;二是测试时提示调优,利用无标签测试数据微调所有提示,使其适应数据偏差;三是置信感知缓冲区,存储高置信度、类别平衡的样本用于更新提示,减少模型偏差带来的误差累积。最终在多个基准数据集上实现了优于现有方法的性能,且时间成本较低。 + +该算法主要包括`分布偏移数据集构造`,`ODS算法实现`,`性能测试`三个部分,相关代码参见目录: +- StreamLearn/Dataset/TTADataset.py +- StreamLearn/Algorithm/Adaprompt/Adaprompt.py + +首先,测试代码的配置文件为:StreamLearn/Config/Adaprompt.py,用户可修改测试文件进行不同复合数据分布变化的测试。 +数据集下载地址:[CIFAR10-C](https://zenodo.org/records/2535967),训练参数下载地址:[百度网盘](https://pan.baidu.com/s/1mxADnKpv73X-Tu1uR8fkXg),提取码: qaj2。 + +其次,按照以下方式构造流式包含复合数据分布变化(包含协变量分布和标记分布偏移)的 CIFAR10 数据集。 +```python +import StreamLearn.Dataset.TTADataset as datasets +dataset = datasets.CIFAR10C( + root=args.stream.dataset_dir, + batch_size=args.stream.batch_size, + severities=args.stream.severities, + corruptions=args.stream.corruptions, + seed=args.seed, +) +``` +其中,`root`为数据根目录、`batch_size`控制数据流批大小、`severities`与`corruptions`控制协变量分布偏移的程度与类型、`seed`为数据集生成随机种子。 + +继而,调用 Adaprompt 算法,复用已训练完毕的模型在复合分布偏移的数据流中进行自适应学习。 +```python +model=get_coop(clip_arch, test_set, device, n_ctx, ctx_init, learned_cls=False,multi=False): +``` +通过get_coop函数获得clip模型的image_encoder和text_encoder. + +最后,将数据流中的样本输入算法中进行预测。 +```python +# Adaprompt算法测试 +pred = model(X).detach().cpu() +``` + +测试的使用方法为: +```bash +cd stream-learn +python -m StreamLearn.Simulator.test.main --tasks adaprompt +``` + ## 课题四 ### 4.1 分布式交互学习流数据存储管理系统 GEAR