1 Star 0 Fork 0

TangXiangjie/FedSTSS

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
fedavgclient.py 4.81 KB
一键复制 编辑 原始数据 按行查看 历史
import pickle # pickle 用于序列化和反序列化Python对象。
import sys # 用于访问与Python解释器交互的功能。
import threading # 用于并行执行。
import numpy as np # 用于数学和数组操作。
from flask import Flask, request # 用于创建Web服务器和处理HTTP请求
import tensorflow as tf
# 设置可见的设备为 CPU
tf.config.set_visible_devices([], 'GPU')
import mnistcommon
import flcommon
import time_logger
from config import ClientConfig
# 是自定义模块,包含联邦学习的通用函数、MNIST数据处理、时间记录和配置设置。
config = ClientConfig(int(sys.argv[1])) # 通过命令行参数获取客户端索引,并加载相应的配置。
client_datasets = mnistcommon.load_train_dataset(config.number_of_clients, permute=True)
# 调用 mnistcommon.load_train_dataset 加载训练数据,根据客户端数进行分割并可能进行随机排序。
api = Flask(__name__)
# 创建一个Flask应用实例,用于接收和发送HTTP请求
total_upload_cost = 0
total_download_cost = 0
# 跟踪上传和下载的数据成本。
training_round = 0
# 训练和通信函数:
# start_next_round 函数处理接收到的模型权重,更新本地模型,执行一轮训练,然后发送更新回服务器。
# 使用 pickle.loads(data) 加载接收的模型权重。
# 使用 model.set_weights 和 model.fit 更新和训练模型。
# 将训练后的模型权重序列化并发送。
# 更新上传和下载成本统计。
def start_next_round(data):
time_logger.client_start()
x_train, y_train = client_datasets[config.client_index][0], client_datasets[config.client_index][1]
model = mnistcommon.get_model()
global training_round
if training_round != 0:
round_weight = pickle.loads(data)
model.set_weights(round_weight)
print(
f"Model: FedAvg, "
f"Round: {training_round + 1}/{config.training_rounds}, "
f"Client {config.client_index + 1}/{config.number_of_clients}, "
f"Dataset Size: {len(x_train)}")
model.fit(x_train, y_train, epochs=config.epochs, batch_size=config.batch_size, verbose=config.verbose,
validation_split=config.validation_split)
round_weight = np.array(model.get_weights())
layers = []
for index, value in enumerate(round_weight):
layers.append(value.astype('float64'))
pickle_model = pickle.dumps(np.array(layers))
flcommon.send_to_fedavg_server(pickle_model, config)
len_serialized_model = len(pickle_model)
global total_upload_cost
total_upload_cost += len_serialized_model
print(f"[Upload] Size of the object to send to server is {len_serialized_model}")
print(f"Sent {training_round} to server")
global total_download_cost
print(f"[DOWNLOAD] Total download cost so far: {total_download_cost}")
print(f"[UPLOAD] Total upload cost so far: {total_upload_cost}")
training_round += 1
print(f"********************** Round {training_round} completed **********************")
print("Waiting to receive response from server...")
time_logger.client_idle()
# Flask 路由处理:
# /recv 用于接收来自服务器的数据。
# /start 用于开始训练流程,触发第一轮训练。
# 为每个网络请求分配一个新线程,以非阻塞方式处理。
@api.route('/recv', methods=['POST'])
def recv():
my_thread = threading.Thread(target=recv_thread, args=(request.data, ))
my_thread.start()
return {"response": "ok"}
@api.route('/start', methods=['GET'])
def start():
time_logger.start_training()
my_thread = threading.Thread(target=start_next_round, args=(0, ))
my_thread.start()
return {"response": "ok"}
# 接收线程函数:
# recv_thread 处理从服务器接收的数据,更新全局下载成本,判断是否完成所有训练轮次,如果未完成则继续训练。
def recv_thread(data):
global total_download_cost
total_download_cost += len(data)
global training_round
if config.training_rounds == training_round:
time_logger.finish_training()
time_logger.print_result()
print(f"[DOWNLOAD] Total download cost so far: {total_download_cost}")
print(f"[UPLOAD] Total upload cost so far: {total_upload_cost}")
print("Training finished.")
return
start_next_round(data,)
# api.run 启动Flask应用,监听配置中指定的IP地址和端口。端口通过命令行参数动态确定,这允许多个客户端实例在同一台机器上运行而不冲突。
api.run(host=flcommon.get_ip(config), port=config.client_base_port + int(sys.argv[1]), debug=True, threaded=True)
# 错误处理和日志记录
# time_logger 用于记录训练开始、结束和空闲时的时间,帮助分析性能和调试。
# 所有重要的训练和网络活动都打印在控制台上,这有助于跟踪进度和诊断问题。
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/tangxiangjie/fedSTSS.git
git@gitee.com:tangxiangjie/fedSTSS.git
tangxiangjie
fedSTSS
FedSTSS
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385