diff --git a/README.md b/README.md index 625ef47535334f66a3310522626210cb72a4299e..d7421aae0d23d36acf556df4c32c1dd1c9a50375 100644 --- a/README.md +++ b/README.md @@ -125,4 +125,48 @@ model.fit(dataset) # MEMO算法测试 y = model.predict(X=dataset._test_data) print(model.evaluate(y, dataset._test_targets)) +``` + +## 流数据分布自适应学习算法 + +该算法主要包含`分布偏移数据集构造`,`ODS算法实现`,`性能测试`三部分,相关代码参见目录: +- StreamLearn/Datast/TTADataset.py +- StreamLearn/Algorithm/TTA/ODS.py +- StreamLearn/tests/test_ODS.py + +首先,测试代码的配置文件为:StreamLearn/Config/ODS.py,用户可修改测试文件进行不同复合数据分布变化的测试。 +数据集下载地址:[CIFAR10-C](https://zenodo.org/records/2535967),训练参数下载地址:[百度网盘](https://pan.baidu.com/s/1mxADnKpv73X-Tu1uR8fkXg),提取码: qaj2。 + +其次,按照以下方式构造流式包含复合数据分布变化(包含协变量分布和标记分布偏移)的 CIFAR10 数据集。 +```python +import StreamLearn.Dataset.TTADataset as datasets +dataset = datasets.CIFAR10CB( + root=args.stream.dataset_dir, + batch_size=args.stream.batch_size, + severities=args.stream.severities, + corruptions=args.stream.corruptions, + bind_class=args.stream.bind_class, + bind_ratio=args.stream.bind_ratio, + seed=args.seed, +) +``` +其中,`root`为数据根目录、`batch_size`控制数据流批大小、`severities`与`corruptions`控制协变量分布偏移的程度与类型、`bind_class`与`bind_ratio`控制标记分布变化的类别与比例、`seed`为数据集生成随机种子。 + +继而,调用 ODS 算法,复用已训练完毕的模型在复合分布偏移的数据流中进行自适应学习。 +```python +args.method.model = net +estimator = ODS.ODS(args.method) +``` +其中,`net`保存了深度学习模型,`args.method`中保存了 ODS 算法所需的超参数。 + +最后,将数据流中的样本输入算法中进行预测。 +```python +# ODS算法测试 +pred = estimator.predict(X).detach().cpu() +``` + +测试具体代码详见 StreamLearn/tests/test_ODS.py 文件,使用方法为: +```bash +cd stream-learn +python StreamLearn/tests/test_ODS.py --data PATH_TO_DATA --checkpoint PATH_TO_CHECKPOINT ``` \ No newline at end of file diff --git a/StreamLearn/Algorithm/TTA/ODS.py b/StreamLearn/Algorithm/TTA/ODS.py new file mode 100644 index 0000000000000000000000000000000000000000..024c5816e07c28528e980436d571d726a26f318e --- /dev/null +++ b/StreamLearn/Algorithm/TTA/ODS.py @@ -0,0 +1,289 @@ +from copy import deepcopy +import os +import StreamLearn.Algorithm.TTA.iabn as iabn +import StreamLearn.Algorithm.TTA.memory as memory +import torch +import torch.nn as nn +import torch.jit +import torch.optim as optim +import torch.nn.functional as F +import numpy as np +from scipy.optimize import fsolve +from StreamLearn.Base.SemiEstimator import StreamEstimator + +__all__ = ["ODS"] + +class LabelDistributionQueue: + def __init__(self, num_class, capacity=None): + if capacity is None: capacity = num_class * 20 + self.queue_length = capacity + self.queue = torch.zeros(self.queue_length) + self.pointer = 0 + self.num_class = num_class + self.size = 0 + + def update(self, tgt_preds): + tgt_preds = tgt_preds.detach().cpu() + batch_sz = tgt_preds.shape[0] + self.size += batch_sz + if self.pointer + batch_sz > self.queue_length: # Deal with wrap around when ql % batchsize != 0 + rem_space = self.queue_length - self.pointer + self.queue[self.pointer: self.queue_length] = (tgt_preds[:rem_space] + 1) + self.queue[0: batch_sz - rem_space] = (tgt_preds[rem_space:]+1) + else: + self.queue[self.pointer: self.pointer + batch_sz] = (tgt_preds+1) + self.pointer = (self.pointer + batch_sz) % self.queue_length + + def get(self,): + bincounts = torch.bincount(self.queue.long(), minlength=self.num_class+1).float() / self.queue_length + bincounts = bincounts[1: ] + if bincounts.sum() == 0: bincounts[:] = 1 + # log_q = torch.log(bincounts + 1e-12).detach().cuda() + return bincounts + + + def full(self): + return self.size >= self.queue_length + +class AffinityMatrix: + def __init__(self, **kwargs): + pass + def __call__(X, **kwargs): + raise NotImplementedError + def is_psd(self, mat): + eigenvalues = torch.eig(mat)[0][:, 0].sort(descending=True)[0] + return eigenvalues, float((mat == mat.t()).all() and (eigenvalues >= 0).all()) + def symmetrize(self, mat): + return 1 / 2 * (mat + mat.t()) + +class rbf_affinity(AffinityMatrix): + def __init__(self, sigma: float, **kwargs): + self.sigma = sigma + self.k = kwargs['knn'] + def __call__(self, X): + N = X.size(0) + dist = torch.norm(X.unsqueeze(0) - X.unsqueeze(1), dim=-1, p=2) # [N, N] + n_neighbors = min(self.k, N) + kth_dist = dist.topk(k=n_neighbors, dim=-1, largest=False).values[:, -1] # compute k^th distance for each point, [N, knn + 1] + sigma = kth_dist.mean() + rbf = torch.exp(- dist ** 2 / (2 * sigma ** 2)) + # mask = torch.eye(X.size(0)).to(X.device) + # rbf = rbf * (1 - mask) + return rbf + +def entropy_energy(Y, unary, pairwise, bound_lambda): + E = (unary * Y - bound_lambda * pairwise * Y + Y * torch.log(Y.clip(1e-20))).sum() + return E + +def laplacian_optimization(unary, kernel, bound_lambda=1, max_steps=100): + E_list = [] + oldE = float('inf') + Y = (-unary).softmax(-1) # [N, K] + for i in range(max_steps): + pairwise = bound_lambda * kernel.matmul(Y) # [N, K] + exponent = -unary + pairwise + Y = exponent.softmax(-1) + E = entropy_energy(Y, unary, pairwise, bound_lambda).item() + E_list.append(E) + if (i > 1 and (abs(E - oldE) <= 1e-8 * abs(oldE))): + break + else: + oldE = E + return Y + +def LAME_optimize(logits, feats, knn, force_symmetry=True): + affinity = eval(f'rbf_affinity')(sigma=1.0, knn=knn) + probas = F.softmax(logits, 1) + # --- Get unary and terms and kernel --- + unary = - torch.log(probas + 1e-10) # [N, K] + feats = F.normalize(feats, p=2, dim=-1) # [N, d] + kernel = affinity(feats) # [N, N] + if force_symmetry: kernel = 1/2 * (kernel + kernel.t()) + # --- Perform optim --- + Y = laplacian_optimization(unary, kernel) + return Y + +class ODS(StreamEstimator): + def __init__(self, args): + super().__init__() + self.args = args + + model = args.model + self.base = deepcopy(args.model) + self.model = configure_model(model, args.optim.bn_momentum, args.use_learnt_stats) + self.optimizer = optim.Adam( + model.parameters(), + lr=float(args.optim.lr), + weight_decay=float(args.optim.wd) + ) + self.num_class = args.num_class + self.capacity = args.capacity + self.reset() + + self.memory = None + self.memory_labels = None + self.mem = memory.PBRS(capacity=args.capacity, num_class=args.num_class) + self.queue_dist = LabelDistributionQueue(num_class=args.num_class, capacity=args.capacity) + self.memory_position = [[] for i in range(self.num_class)] + self.memory_counts = [0 for i in range(self.num_class)] + self.random = np.random.default_rng(seed=0) + + self.exp_info = [] + + def storage_sample(self, x, y): + if self.memory is None: + self.memory = x.unsqueeze(0) + self.memory_labels = [y.item()] + return True + + if self.memory.shape[0] < self.capacity: + self.memory = torch.cat([self.memory, x.unsqueeze(0)]) + self.memory_labels.append(y.item()) + self.memory_position[y].append(self.memory.shape[0] - 1) + return True + + siz = [len(arr) for arr in self.memory_position] + mx_siz = np.max(siz) + mx_cls = [i for i in range(self.num_class) if mx_siz == len(self.memory_position[i])] + if y not in mx_cls: + rep_cls = self.random.choice(mx_cls) + rep_pos = self.random.choice(self.memory_position[rep_cls]) + self.memory_position[rep_cls].remove(rep_pos) + self.memory[rep_pos, ...] = x + self.memory_labels[rep_pos] = y.item() + self.memory_position[y].append(rep_pos) + return True + else: + rep_cls = y + index = self.random.integers(low=0, high=self.memory_counts[rep_cls], endpoint=False) + if index >= siz[rep_cls]: return False + rep_pos = self.memory_position[rep_cls][index] + self.memory[rep_pos, ...] = x + self.memory_labels[rep_pos] = y.item() + return True + + def storage(self, x, outputs, domain): + assert(x.shape[0] == outputs.shape[0]) + x = x.detach().cpu() + outputs = outputs.detach().cpu() + entropys = softmax_entropy(outputs) + predicts = outputs.max(1)[1] + for i in range(x.shape[0]): + y = predicts[i] + self.memory_counts[y] += 1 + self.storage_sample(x[i, ...], y) + + def sample(self, batch_sample): + if self.memory is None: return None + return self.memory.cuda(), self.memory_labels + + + def forward(self, x): + outputs = self.forward_and_adapt(x, self.model, self.optimizer) + return outputs + + def reset(self): pass + + @torch.enable_grad() # ensure grads in possible no grad context for testing + def forward_and_adapt(self, x, model, optimizer): + # forward + with torch.no_grad(): + # Get model output & features + model.eval() + outputs = model(x) + model_feats = extract_feats(model) + base_logits = self.base(x) + optim_dist = LAME_optimize(base_logits, model_feats, self.args.optim.knn) + probas = torch.sqrt(F.softmax(outputs, 1) * optim_dist) + unary = - torch.log(probas + 1e-10) + refine_outputs = (-unary).softmax(-1) + outputs = refine_outputs + self.storage(x, outputs, None) + self.save() + + model.train() + feats, labels = self.sample(x.cuda()) + self.queue_dist.update(torch.tensor(labels)) + weight = 1.0 - self.queue_dist.get() + 0.1 + weight = weight / weight.sum() + if feats is not None: + weight = weight[labels].cuda() + loss = softmax_entropy(model(feats)) * weight + loss = loss.sum(0) / weight.sum(0) + loss.backward() + optimizer.step() + optimizer.zero_grad() + return outputs + + def save(self, ): pass + def print(self, ): pass + + def fit(self, stream_dataset): + raise NotImplementedError("Test-time adaptation methods do not support fit method.") + + def predict(self, X): + """ + Predict y for input X. + + :param X: input. + """ + return self.forward(X) + + def evaluate(self, y_pred, y_true): + """ + Evaluate stream algorithm on a stream dataset. + + :param y_pred: predict y. + :param y_true: ground-truth y. + """ + accs = torch.sum(y_pred == y_true).item() / y_true.size(0) + return accs + +def extract_feats(model): + return model.feats.mean((-2, -1)) if hasattr(model, "feats") else model.module.feats.mean((-2, -1)) + +@torch.jit.script +def softmax_entropy(x: torch.Tensor) -> torch.Tensor: + """Entropy of softmax distribution from logits.""" + return -(x.softmax(1) * x.log_softmax(1)).sum(1) + +def configure_model(model, bn_momentum, use_learnt_stats): + """Configure model for use with note.""" + # iabn.convert_iabn(model) + for param in model.parameters(): # initially turn off requires_grad for all + param.requires_grad = False + for module in model.modules(): + if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d): + if use_learnt_stats: + module.track_running_stats = True + module.momentum = bn_momentum + else: + # With below, this module always uses the test batch statistics (no momentum) + module.track_running_stats = False + module.running_mean = None + module.running_var = None + module.weight.requires_grad_(True) + module.bias.requires_grad_(True) + if isinstance(module, iabn.InstanceAwareBatchNorm2d) or isinstance(module, iabn.InstanceAwareBatchNorm1d): + for param in module.parameters(): + param.requires_grad = True + return model + +def setup(model, args): + base = deepcopy(model) + model = configure_model(model, args.optim.bn_momentum, args.use_learnt_stats) + optimizer = optim.Adam( + model.parameters(), + lr=float(args.optim.lr), + weight_decay=float(args.optim.wd) + ) + ods_model = ODS( + args, + base, + model, + optimizer, + num_class=args.num_class, + capacity=args.queue_size, + ) + ods_model.reset() + return ods_model \ No newline at end of file diff --git a/StreamLearn/Algorithm/TTA/iabn.py b/StreamLearn/Algorithm/TTA/iabn.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c63f760e760e5c844f4a659a1017bcaab88d90 --- /dev/null +++ b/StreamLearn/Algorithm/TTA/iabn.py @@ -0,0 +1,133 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy + +IABN_K = 4.0 +SKIP_THRESHOLD = 1 + +def convert_iabn(module, **kwargs): + module_output = module + if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d): + IABN = InstanceAwareBatchNorm2d if isinstance(module, nn.BatchNorm2d) else InstanceAwareBatchNorm1d + module_output = IABN( + num_channels=module.num_features, + k=IABN_K, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + ) + + module_output._bn = copy.deepcopy(module) + + for name, child in module.named_children(): + module_output.add_module( + name, convert_iabn(child, **kwargs) + ) + del module + return module_output + + +class InstanceAwareBatchNorm2d(nn.Module): + def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True): + super(InstanceAwareBatchNorm2d, self).__init__() + self.num_channels = num_channels + self.eps = eps + self.k=k + self.affine = affine + self._bn = nn.BatchNorm2d(num_channels, eps=eps, + momentum=momentum, affine=affine) + + def _softshrink(self, x, lbd): + x_p = F.relu(x - lbd, inplace=True) + x_n = F.relu(-(x + lbd), inplace=True) + y = x_p - x_n + return y + + def forward(self, x): + b, c, h, w = x.size() + sigma2, mu = torch.var_mean(x, dim=[2, 3], keepdim=True, unbiased=True) #IN + + if self.training: + _ = self._bn(x) + sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True) + else: + if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None: # use batch stats + sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True) + else: + mu_b = self._bn.running_mean.view(1, c, 1, 1) + sigma2_b = self._bn.running_var.view(1, c, 1, 1) + + + if h*w <= SKIP_THRESHOLD: + mu_adj = mu_b + sigma2_adj = sigma2_b + else: + s_mu = torch.sqrt((sigma2_b + self.eps) / (h * w)) + s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (h * w - 1)) + + mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu) + + sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2) + + sigma2_adj = F.relu(sigma2_adj) #non negative + + x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps) + if self.affine: + weight = self._bn.weight.view(c, 1, 1) + bias = self._bn.bias.view(c, 1, 1) + x_n = x_n * weight + bias + return x_n + + +class InstanceAwareBatchNorm1d(nn.Module): + def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True): + super(InstanceAwareBatchNorm1d, self).__init__() + self.num_channels = num_channels + self.k = k + self.eps = eps + self.affine = affine + self._bn = nn.BatchNorm1d(num_channels, eps=eps, + momentum=momentum, affine=affine) + + def _softshrink(self, x, lbd): + x_p = F.relu(x - lbd, inplace=True) + x_n = F.relu(-(x + lbd), inplace=True) + y = x_p - x_n + return y + + def forward(self, x): + b, c, l = x.size() + sigma2, mu = torch.var_mean(x, dim=[2], keepdim=True, unbiased=True) + if self.training: + _ = self._bn(x) + sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True) + else: + if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None: # use batch stats + sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True) + else: + mu_b = self._bn.running_mean.view(1, c, 1) + sigma2_b = self._bn.running_var.view(1, c, 1) + + if l <= SKIP_THRESHOLD: + mu_adj = mu_b + sigma2_adj = sigma2_b + + else: + s_mu = torch.sqrt((sigma2_b + self.eps) / l) ## + s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (l - 1)) + + mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu) + sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2) + sigma2_adj = F.relu(sigma2_adj) + + + x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps) + + if self.affine: + weight = self._bn.weight.view(c, 1) + bias = self._bn.bias.view(c, 1) + x_n = x_n * weight + bias + + return x_n \ No newline at end of file diff --git a/StreamLearn/Algorithm/TTA/memory.py b/StreamLearn/Algorithm/TTA/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..e342a93396068c27cb7d6615bb0d3e3e3a861766 --- /dev/null +++ b/StreamLearn/Algorithm/TTA/memory.py @@ -0,0 +1,147 @@ +import random +import copy +import torch +import torch.nn.functional as F +import numpy as np + +class FIFO(): + def __init__(self, capacity): + self.data = [[], []] + self.capacity = capacity + pass + + def get_memory(self): + return self.data + + def get_occupancy(self): + return len(self.data[0]) + + def add_instance(self, instance): + assert (len(instance) == 2) + + if self.get_occupancy() >= self.capacity: + self.remove_instance() + + for i, dim in enumerate(self.data): + dim.append(instance[i]) + + def remove_instance(self): + for dim in self.data: + dim.pop(0) + pass + +class Reservoir(): # Time uniform + + def __init__(self, capacity): + super(Reservoir, self).__init__(capacity) + self.data = [[], [], []] + self.capacity = capacity + self.counter = 0 + + def get_memory(self): + return self.data + + def get_occupancy(self): + return len(self.data[0]) + + + def add_instance(self, instance): + assert (len(instance) == 3) + is_add = True + self.counter+=1 + + if self.get_occupancy() >= self.capacity: + is_add = self.remove_instance() + + if is_add: + for i, dim in enumerate(self.data): + dim.append(instance[i]) + + + def remove_instance(self): + + + m = self.get_occupancy() + n = self.counter + u = random.uniform(0, 1) + if u <= m / n: + tgt_idx = random.randrange(0, m) # target index to remove + for dim in self.data: + dim.pop(tgt_idx) + else: + return False + return True + +class PBRS(): + + def __init__(self, capacity, num_class): + self.data = [[[], []] for _ in range(num_class)] #feat, pseudo_cls, domain, cls, loss + self.counter = [0] * num_class + self.marker = [''] * num_class + self.capacity = capacity + self.num_class = num_class + pass + + def get_memory(self): + data = self.data + tmp_data = [[], []] + for data_per_cls in data: + feats, cls = data_per_cls + tmp_data[0].extend(feats) + tmp_data[1].extend(cls) + + return tmp_data + + def get_occupancy(self): + occupancy = 0 + for data_per_cls in self.data: + occupancy += len(data_per_cls[0]) + return occupancy + + def get_occupancy_per_class(self): + occupancy_per_class = [0] * self.num_class + for i, data_per_cls in enumerate(self.data): + occupancy_per_class[i] = len(data_per_cls[0]) + return occupancy_per_class + + def add_instance(self, instance): + assert (len(instance) == 2) + cls = instance[1] + self.counter[cls] += 1 + is_add = True + + if self.get_occupancy() >= self.capacity: + is_add = self.remove_instance(cls) + + if is_add: + for i, dim in enumerate(self.data[cls]): + dim.append(instance[i]) + + def get_largest_indices(self): + + occupancy_per_class = self.get_occupancy_per_class() + max_value = max(occupancy_per_class) + largest_indices = [] + for i, oc in enumerate(occupancy_per_class): + if oc == max_value: + largest_indices.append(i) + return largest_indices + + def remove_instance(self, cls): + largest_indices = self.get_largest_indices() + if cls not in largest_indices: # instance is stored in the place of another instance that belongs to the largest class + largest = random.choice(largest_indices) # select only one largest class + tgt_idx = random.randrange(0, len(self.data[largest][0])) # target index to remove + for dim in self.data[largest]: + dim.pop(tgt_idx) + else:# replaces a randomly selected stored instance of the same class + m_c = self.get_occupancy_per_class()[cls] + n_c = self.counter[cls] + u = random.uniform(0, 1) + if u <= m_c / n_c: + tgt_idx = random.randrange(0, len(self.data[cls][0])) # target index to remove + for dim in self.data[cls]: + dim.pop(tgt_idx) + else: + return False + return True \ No newline at end of file diff --git a/StreamLearn/Config/ODS.py b/StreamLearn/Config/ODS.py new file mode 100644 index 0000000000000000000000000000000000000000..232cfeb94a9ebfda5f2a478a44c9d7390278faad --- /dev/null +++ b/StreamLearn/Config/ODS.py @@ -0,0 +1,39 @@ +import argparse +import StreamLearn.Dataset.TTADataset as datasets +import StreamLearn.Network.ResNetTTA as models + +__all__ = ["configs"] + +configs = argparse.Namespace() +configs.method = argparse.Namespace() +configs.stream = argparse.Namespace() +configs.model = argparse.Namespace() + +configs.method.name = "ODS" +configs.method.method = "ODS" +configs.method.optim = argparse.Namespace() +configs.method.optim.lr = 1e-4 +configs.method.optim.wd = 0.0 +configs.method.optim.iabn_k = 4 +configs.method.optim.bn_momentum = 0.01 +configs.method.optim.knn = 5 +configs.method.queue_size = 64 +configs.method.use_learnt_stats = True +configs.method.capacity = 1024 + +configs.stream.name = "CIFAR10-10" +configs.stream.dataset_dir = "" +configs.stream.batch_size = 64 +configs.stream.dataset = datasets.CIFAR10CB +configs.stream.num_class = 10 +configs.stream.severities = [5] +configs.stream.corruptions = [ + "gaussian_noise", "shot_noise", "impulse_noise", + "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", + "snow", "frost", "fog", "brightness", "contrast", + "elastic_transform", "pixelate", "jpeg_compression"] +configs.stream.bind_class = [0, 0, 0, 1, 1, 1, 1, 8, 8, 8, 8, 9, 9, 9, 9] +configs.stream.bind_ratio = 10 + +configs.model.archtecture = models.ResNet18 +configs.model.checkpoint_dir = "" \ No newline at end of file diff --git a/StreamLearn/Dataset/TTADataset.py b/StreamLearn/Dataset/TTADataset.py new file mode 100644 index 0000000000000000000000000000000000000000..94756f5f8080d17e5ea771f8af94187ec86c1c6f --- /dev/null +++ b/StreamLearn/Dataset/TTADataset.py @@ -0,0 +1,294 @@ +import numpy as np +import logging, os +import torch +import math +import torchvision +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from StreamLearn.Dataset.StreamDataset import StreamDataset + +__all__ = ['CIFAR10C', 'CIFAR100C', "get_cifar_loader", "CIFAR10CB", "CIFAR100CB"] + +def sparse2coarse(targets): + """Convert Pytorch CIFAR100 sparse targets to coarse targets. + Usage: + trainset = torchvision.datasets.CIFAR100(path) + trainset.targets = sparse2coarse(trainset.targets) + """ + coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, + 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, + 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, + 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, + 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, + 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, + 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, + 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, + 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, + 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) + return coarse_labels[targets] + +class CIFAR10C(StreamDataset): + NAME = "CIFAR-10-C" + CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate", + "gaussian_noise", "defocus_blur", "brightness", "fog", + "zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast", + "jpeg_compression", "elastic_transform") + + def __init__(self, + root="/data", + batch_size=64, + seed=998244353, + corruptions=CORRUPTIONS, + severities=1, + ): + super().__init__() + self.root, self.seed = root, seed + self.batch_size = batch_size + self.rand = np.random.RandomState(seed=seed) + self._epoch, self._epochs = 0, 0 + self._data, self._label = None, None + self.severities = severities if isinstance(severities, list) else [severities] + self.corruptions = corruptions + + def run(self, ): + for i_c, corruption in enumerate(self.corruptions): + for i_s, severity in enumerate(self.severities): + yield self._generate(i_c, i_s) + + def _generate(self, i_c, i_s): + corruption = self.corruptions[i_c] + severity = self.severities[i_s] + # Read & shuffle data + data, label = self.load_corruptions_cifar(severity=severity, corruption=corruption) + rand_index = self.rand.permutation(data.shape[0]) + data, label = data[rand_index, ...], label[rand_index, ...] + # print(label.numpy().tolist()) + # Calculate class ratio + ratio = np.array([(torch.sum(label == i) / len(label)).item() for i in range(self.num_classes())]) + logging.info(" > Class Ratio: ") + logging.info(ratio) + # Yield data + n_batches = math.ceil(data.shape[0] / self.batch_size) + for counter in range(n_batches): + data_curr = data[counter * self.batch_size:(counter + 1) * self.batch_size] + label_curr = label[counter * self.batch_size:(counter + 1) * self.batch_size] + tag_cur = f"{corruption}-{severity}-{counter}" + yield data_curr, label_curr, tag_cur, ratio + + + def num_classes(self,): return 10 + def download(self,): pass + + def load_cifar(self,): + dataset = datasets.CIFAR10(root=self.root, train=False, transform=transforms.Compose([transforms.ToTensor()]), download=True) + return self.load_dataset(dataset) + + def load_corruptions_cifar(self, severity, corruption): + assert 0 <= severity <= 5 + if severity > 0: assert corruption in self.CORRUPTIONS + if severity == 0: return self.load_cifar() + n_total_cifar = 10000 + if not os.path.exists(self.root): raise FileNotFoundError("The root of datasets is not found.") + data_dir = os.path.join(self.root, self.NAME) + # Load labels + label_path = os.path.join(data_dir, "labels.npy") + if not os.path.isfile(label_path): raise FileNotFoundError("Labels are missing.") + labels = np.load(label_path) + labels = labels[: n_total_cifar] + # if len(set(labels)) == 100: labels = sparse2coarse(labels) + # Load images + data_path = os.path.join(data_dir, f"{corruption}.npy") + if not os.path.isfile(data_path): raise FileNotFoundError("Data {corruption} is missing.") + images_all = np.load(data_path) + images = images_all[(severity - 1) * n_total_cifar: severity * n_total_cifar] + # To torch tensors + images = np.transpose(images, (0, 3, 1, 2)) + images = images.astype(np.float32) / 255 + images = torch.tensor(images) + labels = torch.tensor(labels) + return images, labels + +class CIFAR10CB(CIFAR10C): + NAME = "CIFAR-10-C" + CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate", + "gaussian_noise", "defocus_blur", "brightness", "fog", + "zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast", + "jpeg_compression", "elastic_transform") + + def __init__(self, + root="/data", + batch_size=64, + seed=998244353, + corruptions=CORRUPTIONS, + severities=1, + bind_class=[], + bind_ratio=[], + ): + super().__init__(root=root, batch_size=batch_size, seed=seed, corruptions=corruptions, severities=severities) + self.bind_class = bind_class + self.bind_ratio = bind_ratio + + def _generate(self, i_c, i_s): + corruption = self.corruptions[i_c] + severity = self.severities[i_s] + # Read + data, label = self.load_corruptions_cifar(severity=severity, corruption=corruption) + # Shuffle data + rand_index = self.rand.permutation(data.shape[0]) + data, label = data[rand_index, ...], label[rand_index, ...] + if i_c < len(self.bind_class): + index = self.bind_class[i_c] + ratio = self.bind_ratio + ratio = ratio / (ratio + self.num_classes() - 1) + n_samples = data.shape[0] + proba = torch.where(label == index, 1.0, 1.0 / n_samples) + proba = proba / proba.sum() + n_sampling = int(data.shape[0] / self.num_classes() / ratio) + index = self.rand.choice(n_samples, n_sampling, p=proba.cpu().numpy(), replace=False) + select = np.array([False] * n_samples) + select[index] = True + data = data[select, ...] + label = label[select, ...] + # data = torch.cat([data[select, ...],data[~select, ...]], 0) + # label = torch.cat([label[select, ...], label[~select, ...]], 0) + + # Calculate class ratio + ratio = [torch.sum(label == i) / len(label) for i in range(self.num_classes())] + logging.info(" > Class Ratio: ") + logging.info(ratio) + + # Yield data + n_batches = math.ceil(data.shape[0] / self.batch_size) + for counter in range(n_batches): + data_curr = data[counter * self.batch_size:(counter + 1) * self.batch_size] + label_curr = label[counter * self.batch_size:(counter + 1) * self.batch_size] + tag_cur = f"{corruption}-{severity}-{counter}" + yield data_curr, label_curr, tag_cur, ratio + +class CIFAR100C(CIFAR10C): + NAME = "CIFAR-100-C" + CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate", + "gaussian_noise", "defocus_blur", "brightness", "fog", + "zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast", + "jpeg_compression", "elastic_transform") + + def __init__(self, + root="/data", + batch_size=64, + seed=998244353, + corruptions=CORRUPTIONS, + severities=1, + ): + super().__init__(root, batch_size, seed) + self._epoch, self._epochs = 0, 0 + self._data, self._label = None, None + self.severities = severities if isinstance(severities, list) else [severities] + self.corruptions = corruptions + + def num_classes(self,): return 100 + + def load_cifar(self,): + dataset = datasets.CIFAR100(root=self.root, train=False, transform=transforms.Compose([transforms.ToTensor()]), download=True) + return self.load_dataset(dataset) + +class CIFAR100CB(CIFAR100C): + NAME = "CIFAR-100-C" + CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate", + "gaussian_noise", "defocus_blur", "brightness", "fog", + "zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast", + "jpeg_compression", "elastic_transform") + + def __init__(self, + root="/data", + batch_size=64, + seed=998244353, + corruptions=CORRUPTIONS, + severities=1, + bind_class=[], + bind_ratio=[], + ): + super().__init__(root=root, batch_size=batch_size, seed=seed, corruptions=corruptions, severities=severities) + self.bind_class = bind_class + self.bind_ratio = bind_ratio + + def _generate(self, i_c, i_s): + corruption = self.corruptions[i_c] + severity = self.severities[i_s] + # Read + data, label = self.load_corruptions_cifar(severity=severity, corruption=corruption) + clabel = torch.tensor(sparse2coarse(label)) + num_cclasses = 20 + # Shuffle data + rand_index = self.rand.permutation(data.shape[0]) + data, label = data[rand_index, ...], label[rand_index, ...] + # Random Data + if i_c < len(self.bind_class): + index = self.bind_class[i_c] + index = list(map(int, index.split(","))) + m = len(index) + ratio = self.bind_ratio + n_samples = data.shape[0] + + large_proba = 1.0 + small_proba = 1.0 / n_samples + proba = torch.ones_like(clabel) * small_proba + for idx in index: + proba_large = torch.where(clabel == idx, large_proba, 0.0) + proba += proba_large + proba = proba / proba.sum() + + n_sampling = int(n_samples / num_cclasses / ratio * (m * ratio + num_cclasses - m)) + index = self.rand.choice(n_samples, n_sampling, p=proba.cpu().numpy(), replace=False) + select = np.array([False] * n_samples) + select[index] = True + data = data[select, ...] + label = label[select, ...] + # data = torch.cat([data[select, ...],data[~select, ...]], 0) + # label = torch.cat([label[select, ...], label[~select, ...]], 0) + # data, label = data[index, ...], label[index, ...] + + # Calculate class ratio + ratio = [torch.sum(label == i) / len(label) for i in range(self.num_classes())] + logging.info(" > Class Ratio: ") + logging.info(ratio) + + # Yield data + n_batches = math.ceil(data.shape[0] / self.batch_size) + for counter in range(n_batches): + data_curr = data[counter * self.batch_size:(counter + 1) * self.batch_size] + label_curr = label[counter * self.batch_size:(counter + 1) * self.batch_size] + tag_cur = f"{corruption}-{severity}-{counter}" + yield data_curr, label_curr, tag_cur, ratio + +def get_cifar_loader(dataset, root="/data", batch_size=256): + assert dataset in ["CIFAR10", "CIFAR100"] + NORM_VAL = { + "CIFAR10": ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + "CIFAR100": ((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)) + } + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + # transforms.Normalize(NORM_VAL[args.dataset]), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize(NORM_VAL[args.dataset]), + ]) + + trainset = getattr(datasets, dataset)(root=root, train=True, download=True, transform=transform_train) + testset = getattr(datasets, dataset)(root=root, train=False, download=True, transform=transform_test) + # if dataset == "CIFAR100": + # trainset.targets = sparse2coarse(trainset.targets) + # testset.targets = sparse2coarse(testset.targets) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + + CLAS_VAL = { + "CIFAR10": 10, + "CIFAR100": 100, + } + n_classes = CLAS_VAL[dataset] + + return trainloader, testloader, n_classes \ No newline at end of file diff --git a/StreamLearn/Network/ResNetTTA.py b/StreamLearn/Network/ResNetTTA.py new file mode 100644 index 0000000000000000000000000000000000000000..3521fe4d76f7bd06437541d5e59138136a6dd95c --- /dev/null +++ b/StreamLearn/Network/ResNetTTA.py @@ -0,0 +1,133 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * + planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512*block.expansion, num_classes) + self.feats = None + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + self.feats = out.detach() + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(num_classes=10): + return ResNet(BasicBlock, [2, 2, 2, 2], num_classes) + + +def ResNet34(num_classes=10): + return ResNet(BasicBlock, [3, 4, 6, 3], num_classes) + + +def ResNet50(num_classes=10): + return ResNet(Bottleneck, [3, 4, 6, 3], num_classes) + + +def ResNet101(num_classes=10): + return ResNet(Bottleneck, [3, 4, 23, 3], num_classes) + + +def ResNet152(num_classes=10): + return ResNet(Bottleneck, [3, 8, 36, 3], num_classes) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +# test() diff --git a/StreamLearn/tests/test_ODS.py b/StreamLearn/tests/test_ODS.py new file mode 100644 index 0000000000000000000000000000000000000000..52d8049e2c86476235ec507eaf26c0f6d9b760e9 --- /dev/null +++ b/StreamLearn/tests/test_ODS.py @@ -0,0 +1,81 @@ +import sys +sys.path.append('./') +import time +import torch +import numpy as np +import argparse +import random +import torch.backends.cudnn as cudnn +import StreamLearn.Algorithm.TTA.iabn as iabn +import StreamLearn.Algorithm.TTA.ODS as ODS +import StreamLearn.Network.ResNetTTA as models +import StreamLearn.Dataset.TTADataset as datasets +from StreamLearn.Config.ODS import configs + +parser = argparse.ArgumentParser(description='test PyTorch TTA methods') +parser.add_argument('--data', type=str, default="/home/anony/storage/data/", help="path of the datasets") +parser.add_argument('--checkpoint', type=str, default="/home/anony/CIFAR10.pth", help="path of the checkpoint") +parser.add_argument('--seed', default=0, type=int, help='random seed') +args = parser.parse_args() +args.method = configs.method +args.stream = configs.stream +args.model = configs.model +args.model.checkpoint_dir = args.checkpoint +args.stream.dataset_dir = args.data + +# Set cuda & device +seed = args.seed +device = 'cuda' if torch.cuda.is_available() else 'cpu' +cudnn.benchmark = True +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True +torch.cuda.manual_seed_all(seed) + +# Build stream dataset +dataset = args.stream.dataset( + root=args.stream.dataset_dir, + batch_size=args.stream.batch_size, + severities=args.stream.severities, + corruptions=args.stream.corruptions, + bind_class=args.stream.bind_class, + bind_ratio=args.stream.bind_ratio, + seed=args.seed, +) +n_classes = dataset.num_classes() +args.method.num_class = args.stream.num_class + +# Build model +net = args.model.archtecture(num_classes=n_classes) +net = net.to(device) +if device == 'cuda': net = torch.nn.DataParallel(net) +iabn.convert_iabn(net) +checkpoint = torch.load(args.model.checkpoint_dir) +net.load_state_dict(checkpoint['net']) + +# Wrap TTA models +args.method.model = net +estimator = ODS.ODS(args.method) + +# Evaluate on stream datasets +for i_c, corruption in enumerate(args.stream.corruptions): + for i_s, severity in enumerate(args.stream.severities): + preds, gts = [], [] + wall_time = 0 + for X, y, T, ratio in dataset._generate(i_c, i_s): + torch.cuda.synchronize() + start_time = time.time() + pred = estimator.predict(X.to(device)).detach().cpu() + torch.cuda.synchronize() + wall_time = time.time() - start_time + wall_time + preds.append(pred.detach().cpu()) + gts.append(y.cpu()) + preds, gts = torch.cat(preds), torch.cat(gts) + preds = preds.max(1)[1] + accs = estimator.evaluate(preds, gts) + print(f" > [{corruption}-{severity}]\tAccuracy = {accs:.2%}\tTime = {wall_time:.2f}") + +