From b338cfd40d4cbfef8f98c787ac664323ff8c8d96 Mon Sep 17 00:00:00 2001 From: zhu-zehan <2356966412@qq.com> Date: Thu, 14 Aug 2025 22:00:50 +0800 Subject: [PATCH] add compressed_and_private_decentralized_learning --- .../README.md | 45 +++++ .../data_load.py | 161 ++++++++++++++++++ .../dataset/.keep | 0 .../model_load.py | 148 ++++++++++++++++ ...antified_private_decentralized_learning.py | 149 ++++++++++++++++ ...arsified_private_decentralized_learning.py | 153 +++++++++++++++++ ...arsified_private_decentralized_learning.py | 153 +++++++++++++++++ 7 files changed, 809 insertions(+) create mode 100644 examples/community/compressed_and_private_decentralized_learning/README.md create mode 100644 examples/community/compressed_and_private_decentralized_learning/data_load.py create mode 100644 examples/community/compressed_and_private_decentralized_learning/dataset/.keep create mode 100644 examples/community/compressed_and_private_decentralized_learning/model_load.py create mode 100644 examples/community/compressed_and_private_decentralized_learning/quantified_private_decentralized_learning.py create mode 100644 examples/community/compressed_and_private_decentralized_learning/random_sparsified_private_decentralized_learning.py create mode 100644 examples/community/compressed_and_private_decentralized_learning/top_k_sparsified_private_decentralized_learning.py diff --git a/examples/community/compressed_and_private_decentralized_learning/README.md b/examples/community/compressed_and_private_decentralized_learning/README.md new file mode 100644 index 0000000..22f2bae --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/README.md @@ -0,0 +1,45 @@ +# 通信压缩的差分隐私分布式学习 +本项目基于MindSpore框架对分布式学习实现差分隐私保证, 并利用量化、随机稀疏化、Top-K稀疏化等方式对通信进行压缩, 从而提高分布式隐私训练算法的通信高效性。 + + +## 环境要求 +- mindspore >= 1.9: 本算法 mindspore 的集合通信库 +- openmpi >= 5.0.1: 本算法需要执行多进程并行训, 开启多进程的命令 mpirun 依赖于 openmpi 库 + + +## 脚本说明 +``` +├── README.md +├── dataset //存放数据集的路径 +├── model_load.py //模型定义加载函数 +├── data_load.py //数据集加载函数 +├── quantified_private_decentralized_learning.py //基于通信量化的差分隐私分布式学习的主函数,对应quant_ddl算法 +├── random_sparsified_private_decentralized_learning.py //基于通信随机稀疏化的差分隐私分布式学习的主函数,对应randspar_ddl算法 +└── top_k_sparsified_private_decentralized_learning.py //基于通信Tok-K稀疏化的差分隐私分布式学习的主函数,对应topkspar_ddl算法 +``` + + +## 引入相关包 + +```python +from model_load import resnet18 +from data_load import create_dataset +``` + + +## 启动脚本 +1. 开启8个进程/节点执行分布式训练, 运行quant_ddl算法: +```shell +mpirun --allow-run-as-root -n 8 python ./quantified_private_decentralized_learning.py +``` + +2. 开启8个进程/节点执行分布式训练, 运行randspar_ddl算法: +```shell +mpirun --allow-run-as-root -n 8 python ./random_sparsified_private_decentralized_learning.py +``` + +2. 开启8个进程/节点执行分布式训练, 运行topkspar_ddl算法: +```shell +mpirun --allow-run-as-root -n 8 python ./top_k_sparsified_private_decentralized_learning.py +``` + diff --git a/examples/community/compressed_and_private_decentralized_learning/data_load.py b/examples/community/compressed_and_private_decentralized_learning/data_load.py new file mode 100644 index 0000000..47c6485 --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/data_load.py @@ -0,0 +1,161 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import cv2 +import numpy as np +import mindspore as ms +import mindspore.dataset as ds +import mindspore.communication as comm + + +def create_dataset(dataset_type, dataset_path, batch_size, train): + + rank_id = comm.get_rank() + rank_size = comm.get_group_size() + + # ============================================================================ + # CIFAR10 + # ============================================================================ + if dataset_type == 'CIFAR10': + if train: + dataset = ds.Cifar10Dataset(dataset_path, num_shards=rank_size, shard_id=rank_id, usage='train') + else: + dataset = ds.Cifar10Dataset(dataset_path, num_shards=rank_size, shard_id=rank_id, usage='test') + + image_transforms = [ + ds.vision.Rescale(1.0 / 255.0, 0), + ds.vision.HWC2CHW() + ] + label_transform = ds.transforms.TypeCast(ms.int32) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + + # ============================================================================ + # SVHN + # ============================================================================ + elif dataset_type == 'SVHN': + if train: + dataset = ds.SVHNDataset(dataset_path, num_shards=rank_size, shard_id=rank_id, usage='train') + else: + dataset = ds.SVHNDataset(dataset_path, num_shards=rank_size, shard_id=rank_id, usage='test') + + image_transforms = [ + ds.vision.Rescale(1.0 / 255.0, 0), + # ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), + ds.vision.HWC2CHW() + ] + label_transform = ds.transforms.TypeCast(ms.int32) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + + # ============================================================================ + # MedMNIST + # ============================================================================ + elif dataset_type == 'MedMNIST': + dataset_name = "pathmnist" # num_classes = 9 + data_file = os.path.join(dataset_path, f"{dataset_name}.npz") + data = np.load(data_file) + + if train: + train_images = data["train_images"] + train_labels = data["train_labels"].reshape(-1) + my_accessible = MyAccessible(train_images, train_labels) + dataset = ds.GeneratorDataset(source=my_accessible, column_names=["image", "label"], + num_shards=rank_size, shard_id=rank_id) + else: + test_images = data["test_images"] + test_labels = data["test_labels"].reshape(-1) + my_accessible = MyAccessible(test_images, test_labels) + dataset = ds.GeneratorDataset(source=my_accessible, column_names=["image", "label"], + num_shards=rank_size, shard_id=rank_id) + + image_transforms = [ + ds.vision.Rescale(1.0 / 255.0, 0), + # ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)), + ds.vision.HWC2CHW() + ] + label_transform = ds.transforms.TypeCast(ms.int32) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + + # ============================================================================ + # Flower102 + # ============================================================================ + elif dataset_type == 'Flower102': + if train: + dataset = ds.Flowers102Dataset(dataset_dir=dataset_path, + task="Classification", + usage="train", + decode=False) + + image_transforms = [ + ds.vision.Decode(), + ds.vision.Resize((224, 224)), + ds.vision.Rescale(1.0 / 255.0, 0), + ds.vision.HWC2CHW() + ] + label_transform = ds.transforms.TypeCast(ms.int32) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + + else: + dataset = ds.Flowers102Dataset(dataset_dir=dataset_path, + task="Classification", + usage="test", + decode=True) + + dataset = dataset.batch(batch_size) + return dataset + + +class MyAccessible(): + def __init__(self, data, label): + self._data = data + self._label = label + + def __getitem__(self, index): + return self._data[index], self._label[index] + + def __len__(self): + return len(self._data) + + +def generator(indices, image_folder, labels): + + for idx in indices: + image_path = os.path.join(image_folder, "image_{}.jpg".format(str(idx).zfill(5))) + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + label = labels[idx] + + print(type(image), type(label)) + + yield image, label + + +if __name__ == '__main__': + ms.set_context(mode=ms.GRAPH_MODE) + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) + comm.init() + ms.set_seed(1) + + dataset_path = "./dataset/CIFAR10" + dataset = create_dataset('CIFAR10', dataset_path, batch_size=32, train=True) + print(dataset.output_shapes()) + print(dataset.get_dataset_size()) + print(dataset.num_classes()) + + diff --git a/examples/community/compressed_and_private_decentralized_learning/dataset/.keep b/examples/community/compressed_and_private_decentralized_learning/dataset/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/community/compressed_and_private_decentralized_learning/model_load.py b/examples/community/compressed_and_private_decentralized_learning/model_load.py new file mode 100644 index 0000000..9368e02 --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/model_load.py @@ -0,0 +1,148 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn + + +class BasicBlock(nn.Cell): + """Basic Block for resnet 18 and resnet 34 + """ + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock, self).__init__() + + self.residual_branch = nn.SequentialCell( + nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=stride, + pad_mode='pad', + padding=1, + bias_init=False), nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, + out_channels * BasicBlock.expansion, + kernel_size=3, + pad_mode='pad', + padding=1, + bias_init=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion)) + + self.shortcut = nn.SequentialCell() + + if stride != 1 or in_channels != BasicBlock.expansion * out_channels: + self.shortcut = nn.SequentialCell( + nn.Conv2d(in_channels, + out_channels * BasicBlock.expansion, + kernel_size=1, + stride=stride, + bias_init=False), + nn.BatchNorm2d(out_channels * BasicBlock.expansion)) + + def construct(self, x): + return nn.ReLU()(self.residual_branch(x) + + self.shortcut(x)) + + +class ResNet(nn.Cell): + def __init__(self, block, layers, num_classes=100, inter_layer=False): + super(ResNet, self).__init__() + self.inter_layer = inter_layer + self.in_channels = 64 + + self.conv1 = nn.SequentialCell( + nn.Conv2d(3, 64, kernel_size=3, pad_mode='pad', padding=1, bias_init=False), + nn.BatchNorm2d(64), nn.ReLU()) + + self.stage2 = self._make_layer(block, 64, layers[0], 1) + self.stage3 = self._make_layer(block, 128, layers[1], 2) + self.stage4 = self._make_layer(block, 256, layers[2], 2) + self.stage5 = self._make_layer(block, 512, layers[3], 2) + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Dense(512 * block.expansion, num_classes) + + def _make_layer(self, block, out_channels, num_blocks, stride): + """make resnet layers(by layer i didnt mean this 'layer' was the + same as a neuron netowork layer, ex. conv layer), one layer may + contain more than one residual block + Args: + block: block type, basic block or bottle neck block + out_channels: output depth channel number of this layer + num_blocks: how many blocks per layer + stride: the stride of the first block of this layer + + Return: + return a resnet layer + """ + + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels * block.expansion + + return nn.SequentialCell(*layers) + + def construct(self, x): + x = self.conv1(x) + + if self.inter_layer: + x1 = self.stage2(x) + x2 = self.stage3(x1) + x3 = self.stage4(x2) + x4 = self.stage5(x3) + x = self.avg_pool(x4) + x = x.view(x.shape[0], -1) + x = self.fc(x) + + return [x1, x2, x3, x4, x] + else: + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.stage5(x) + x = self.avg_pool(x) + x = x.view(x.shape[0], -1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + # only load state_dict() + # if pretrained: + # model.load_state_dict( + # torch.load(model_urls[arch], map_location=torch.device('cpu'))) + + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +if __name__ == '__main__': + model = resnet18(num_classes=10) + print(model) + + + \ No newline at end of file diff --git a/examples/community/compressed_and_private_decentralized_learning/quantified_private_decentralized_learning.py b/examples/community/compressed_and_private_decentralized_learning/quantified_private_decentralized_learning.py new file mode 100644 index 0000000..f692881 --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/quantified_private_decentralized_learning.py @@ -0,0 +1,149 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.ops as ops +from mindspore import nn +import mindspore.communication as comm + +from model_load import resnet18 +from data_load import create_dataset + + +# ============================================================================ +# Define some important functions +# ============================================================================ + +# Define Flatten function +def flatten(tensors): + if len(tensors) == 1: + flat = tensors[0].view(-1) + else: + flat = ops.cat([t.view(-1) for t in tensors], axis=0) + return flat + + +# Define Unflatten function +def unflatten(flat, tensors): + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return tuple(outputs) + + +# Define Gaussian Noise Generating function +def generate_gaussian_noise(length, mean, std): + gaussian_noise = ms.Tensor(np.random.normal(mean, std, length), ms.float32) + return gaussian_noise + + +# Define Forward function +def forward_fn(data, label): + output = model(data) + loss = criterion(output, label) + return loss, output + + +# ============================================================================ +# Distributed and Communication Initialization +# ============================================================================ +ms.set_context(mode=ms.GRAPH_MODE) +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) +comm.init() +ms.set_seed(1) +np.random.seed(1) +gpu = comm.get_rank() +num_workers = comm.get_group_size() + + +if __name__ == "__main__": + """ + The users only need to modify the following region according to their training tasks. + """ + # ================================================================================================== + epochs = 10 + lr = 0.01 + sigma = 0.008 + train_batch_size = 64 + test_batch_size = 100 + + # Model and Optimizer Initialization + model = resnet18(num_classes=10) + criterion = nn.CrossEntropyLoss() + optimizer = nn.SGD(model.trainable_params(), learning_rate=lr) + + # Estimate maximum gradient norm + gradient_norm_bound_estimate = 5 + + # Data Load + dataset_path = "./dataset/CIFAR10" + train_dataset = create_dataset('CIFAR10', dataset_path, batch_size=train_batch_size, train=True) + test_dataset = create_dataset('CIFAR10', dataset_path, batch_size=test_batch_size, train=False) + # ================================================================================================== + """ + The users only need to modify the above region according to their training tasks. + """ + + # Define Gradient Computation and Communication + grad_fn = ms.value_and_grad(forward_fn, None, model.trainable_params(), has_aux=True) + grad_reducer = nn.DistributedGradReducer(optimizer.parameters) + all_sum = ops.AllReduce(op="sum") + + # Training and Testing Pipeline + for epoch in range(epochs): + # Train pipeline + model.set_train() + for iteration, (inputs, targets) in enumerate(train_dataset): + output = model(inputs) + (loss, _), grads = grad_fn(inputs, targets) # loss and gradients + flatten_grads = flatten(grads) + + if epoch == 0 and iteration == 0: + length = flatten_grads.numel() + local_s = ms.Tensor(np.zeros(length), ms.float32) + local_v = ms.Tensor(np.zeros(length), ms.float32) + global_s = ms.Tensor(np.zeros(length), ms.float32) + global_v = ms.Tensor(np.zeros(length), ms.float32) + + standard_vr = sigma * gradient_norm_bound_estimate + noise = generate_gaussian_noise(flatten_grads.numel(), 0, standard_vr) + flatten_grads += noise + before_quantify = flatten_grads.copy() - local_s.copy() + local_v = ops.cast(ops.cast(before_quantify, ms.float16), ms.float32) + local_s += 0.5 * local_v.copy() + unflattened_local_v = unflatten(local_v, grads) + all_reduced_local_v = grad_reducer(unflattened_local_v) + global_v = global_s.copy() + flatten(all_reduced_local_v) + global_s += 0.5 * flatten(all_reduced_local_v) + unflattened_global_v = unflatten(global_v, grads) + optimizer(unflattened_global_v) # update + + print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, loss.item())) + + # Test pipeline + model.set_train(False) + total = 0 + correct = 0 + for step, (inputs, targets) in enumerate(test_dataset): + outputs = model(inputs) + _, predicted = ops.max(outputs, 1) + total += len(targets) + correct += (predicted == targets).sum().item() + print('Epoch: {}, Accuracy: {:.4f}'.format(epoch, 100 * correct / total)) + + diff --git a/examples/community/compressed_and_private_decentralized_learning/random_sparsified_private_decentralized_learning.py b/examples/community/compressed_and_private_decentralized_learning/random_sparsified_private_decentralized_learning.py new file mode 100644 index 0000000..8a207ed --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/random_sparsified_private_decentralized_learning.py @@ -0,0 +1,153 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.ops as ops +from mindspore import nn +import mindspore.communication as comm + +from model_load import resnet18 +from data_load import create_dataset + + +# ============================================================================ +# Define some important functions +# ============================================================================ + +# Define Flatten function +def flatten(tensors): + if len(tensors) == 1: + flat = tensors[0].view(-1) + else: + flat = ops.cat([t.view(-1) for t in tensors], axis=0) + return flat + + +# Define Unflatten function +def unflatten(flat, tensors): + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return tuple(outputs) + + +# Define Gaussian Noise Generating function +def generate_gaussian_noise(length, mean, std): + gaussian_noise = ms.Tensor(np.random.normal(mean, std, length), ms.float32) + return gaussian_noise + + +# Define Forward function +def forward_fn(data, label): + output = model(data) + loss = criterion(output, label) + return loss, output + + +# ============================================================================ +# Distributed and Communication Initialization +# ============================================================================ +ms.set_context(mode=ms.GRAPH_MODE) +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) +comm.init() +ms.set_seed(1) +np.random.seed(1) +gpu = comm.get_rank() +num_workers = comm.get_group_size() + + +if __name__ == "__main__": + """ + The users only need to modify the following region according to their training tasks. + """ + # ================================================================================================== + epochs = 10 + lr = 0.01 + sigma = 0.008 + train_batch_size = 64 + test_batch_size = 100 + + # Model and Optimizer Initialization + model = resnet18(num_classes=10) + criterion = nn.CrossEntropyLoss() + optimizer = nn.SGD(model.trainable_params(), learning_rate=lr) + + # Estimate maximum gradient norm + gradient_norm_bound_estimate = 5 + + # Data Load + dataset_path = "./dataset/CIFAR10" + train_dataset = create_dataset('CIFAR10', dataset_path, batch_size=train_batch_size, train=True) + test_dataset = create_dataset('CIFAR10', dataset_path, batch_size=test_batch_size, train=False) + # ================================================================================================== + """ + The users only need to modify the above region according to their training tasks. + """ + + # Define Gradient Computation and Communication + grad_fn = ms.value_and_grad(forward_fn, None, model.trainable_params(), has_aux=True) + grad_reducer = nn.DistributedGradReducer(optimizer.parameters) + all_sum = ops.AllReduce(op="sum") + + # Training and Testing Pipeline + for epoch in range(epochs): + # Train pipeline + model.set_train() + for iteration, (inputs, targets) in enumerate(train_dataset): + output = model(inputs) + (loss, _), grads = grad_fn(inputs, targets) # loss and gradients + flatten_grads = flatten(grads) + + if epoch == 0 and iteration == 0: + length = flatten_grads.numel() + local_s = ms.Tensor(np.zeros(length), ms.float32) + local_v = ms.Tensor(np.zeros(length), ms.float32) + global_s = ms.Tensor(np.zeros(length), ms.float32) + global_v = ms.Tensor(np.zeros(length), ms.float32) + + standard_vr = sigma * gradient_norm_bound_estimate + noise = generate_gaussian_noise(flatten_grads.numel(), 0, standard_vr) + flatten_grads += noise + before_sparsify = flatten_grads.copy() - local_s.copy() + num_zeros = int(length * 0.1) + zero_indices_np = np.random.choice(length, num_zeros, replace=False) + zero_indices_ms = ms.Tensor(zero_indices_np, ms.int32) + updates = ms.Tensor(np.zeros(num_zeros, dtype=np.float32), ms.float32) + local_v = ops.tensor_scatter_elements(before_sparsify, zero_indices_ms, updates, axis=0) + local_s += 0.5 * local_v.copy() + unflattened_local_v = unflatten(local_v, grads) + all_reduced_local_v = grad_reducer(unflattened_local_v) + global_v = global_s.copy() + flatten(all_reduced_local_v) + global_s += 0.5 * flatten(all_reduced_local_v) + unflattened_global_v = unflatten(global_v, grads) + optimizer(unflattened_global_v) # update + + print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, loss.item())) + + # Test pipeline + model.set_train(False) + total = 0 + correct = 0 + for step, (inputs, targets) in enumerate(test_dataset): + outputs = model(inputs) + _, predicted = ops.max(outputs, 1) + total += len(targets) + correct += (predicted == targets).sum().item() + print('Epoch: {}, Accuracy: {:.4f}'.format(epoch, 100 * correct / total)) + + diff --git a/examples/community/compressed_and_private_decentralized_learning/top_k_sparsified_private_decentralized_learning.py b/examples/community/compressed_and_private_decentralized_learning/top_k_sparsified_private_decentralized_learning.py new file mode 100644 index 0000000..521a108 --- /dev/null +++ b/examples/community/compressed_and_private_decentralized_learning/top_k_sparsified_private_decentralized_learning.py @@ -0,0 +1,153 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.ops as ops +from mindspore import nn +import mindspore.communication as comm + +from model_load import resnet18 +from data_load import create_dataset + + +# ============================================================================ +# Define some important functions +# ============================================================================ + +# Define Flatten function +def flatten(tensors): + if len(tensors) == 1: + flat = tensors[0].view(-1) + else: + flat = ops.cat([t.view(-1) for t in tensors], axis=0) + return flat + + +# Define Unflatten function +def unflatten(flat, tensors): + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return tuple(outputs) + + +# Define Gaussian Noise Generating function +def generate_gaussian_noise(length, mean, std): + gaussian_noise = ms.Tensor(np.random.normal(mean, std, length), ms.float32) + return gaussian_noise + + +# Define Forward function +def forward_fn(data, label): + output = model(data) + loss = criterion(output, label) + return loss, output + + +# ============================================================================ +# Distributed and Communication Initialization +# ============================================================================ +ms.set_context(mode=ms.GRAPH_MODE) +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) +comm.init() +ms.set_seed(1) +np.random.seed(1) +gpu = comm.get_rank() +num_workers = comm.get_group_size() + + +if __name__ == "__main__": + """ + The users only need to modify the following region according to their training tasks. + """ + # ================================================================================================== + epochs = 10 + lr = 0.01 + sigma = 0.008 + train_batch_size = 64 + test_batch_size = 100 + + # Model and Optimizer Initialization + model = resnet18(num_classes=10) + criterion = nn.CrossEntropyLoss() + optimizer = nn.SGD(model.trainable_params(), learning_rate=lr) + + # Estimate maximum gradient norm + gradient_norm_bound_estimate = 5 + + # Data Load + dataset_path = "./dataset/CIFAR10" + train_dataset = create_dataset('CIFAR10', dataset_path, batch_size=train_batch_size, train=True) + test_dataset = create_dataset('CIFAR10', dataset_path, batch_size=test_batch_size, train=False) + # ================================================================================================== + """ + The users only need to modify the above region according to their training tasks. + """ + + # Define Gradient Computation and Communication + grad_fn = ms.value_and_grad(forward_fn, None, model.trainable_params(), has_aux=True) + grad_reducer = nn.DistributedGradReducer(optimizer.parameters) + all_sum = ops.AllReduce(op="sum") + + # Training and Testing Pipeline + for epoch in range(epochs): + # Train pipeline + model.set_train() + for iteration, (inputs, targets) in enumerate(train_dataset): + output = model(inputs) + (loss, _), grads = grad_fn(inputs, targets) # loss and gradients + flatten_grads = flatten(grads) + + if epoch == 0 and iteration == 0: + length = flatten_grads.numel() + local_s = ms.Tensor(np.zeros(length), ms.float32) + local_v = ms.Tensor(np.zeros(length), ms.float32) + global_s = ms.Tensor(np.zeros(length), ms.float32) + global_v = ms.Tensor(np.zeros(length), ms.float32) + + standard_vr = sigma * gradient_norm_bound_estimate + noise = generate_gaussian_noise(flatten_grads.numel(), 0, standard_vr) + flatten_grads += noise + before_sparsify = flatten_grads.copy() - local_s.copy() + num_zeros = int(length * 0.1) + _, indices = ops.top_k(-ops.abs(before_sparsify), num_zeros) + zero_indices_ms = ms.Tensor(indices.asnumpy(), ms.int32) + updates = ms.Tensor(np.zeros(num_zeros, dtype=np.float32), ms.float32) + local_v = ops.tensor_scatter_elements(before_sparsify, zero_indices_ms, updates, axis=0) + local_s += 0.5 * local_v.copy() + unflattened_local_v = unflatten(local_v, grads) + all_reduced_local_v = grad_reducer(unflattened_local_v) + global_v = global_s.copy() + flatten(all_reduced_local_v) + global_s += 0.5 * flatten(all_reduced_local_v) + unflattened_global_v = unflatten(global_v, grads) + optimizer(unflattened_global_v) # update + + print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, loss.item())) + + # Test pipeline + model.set_train(False) + total = 0 + correct = 0 + for step, (inputs, targets) in enumerate(test_dataset): + outputs = model(inputs) + _, predicted = ops.max(outputs, 1) + total += len(targets) + correct += (predicted == targets).sum().item() + print('Epoch: {}, Accuracy: {:.4f}'.format(epoch, 100 * correct / total)) + + -- Gitee