1 Star 0 Fork 0

TangXiangjie/FedSTSS

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
fedstss_server.py 5.25 KB
一键复制 编辑 原始数据 按行查看 历史
import pickle
import sys
import threading
import numpy as np
import requests
from flask import Flask, request
from requests_toolbelt.adapters import source
import time_logger
from config import ServerConfig
import tensorflow as tf
# 设置可见的设备为 CPU
tf.config.set_visible_devices([], 'GPU')
config = ServerConfig(int(sys.argv[1]))
api = Flask(__name__)
training_round = 0
clients_secret = []
clients_duration = []
total_download_cost = 0
total_upload_cost = 0
def recv_thread(clients_secret, data, remote_addr):
time_logger.server_received() # 记录接受客户端发来的份额时间戳
global total_download_cost
total_download_cost += len(data)
print(f"[DOWNLOAD] Secret of {remote_addr} received. size: {len(data)}")
secret = pickle.loads(data) # 反序列化接受到来自各个客户端的份额
# print(secret) # 这是一个字典
secret_list = []
# 字典转列表
for key in secret:
array = secret[key]
# 将数组转换为列表并添加到 all_servers 列表中
secret_list.append(array.tolist())
clients_secret.append(secret_list)
print(len(clients_secret)) # 5个客户端
print(len(clients_secret[0])) # 6层模型参数
# clients_secret.append(secret) # 添加到客户端秘密份额列表
# 这是一个全局变量 进程级变量 len(clients_secret)=number_of_clients
print(f"[SECRET] Secret opened successfully.")
if len(clients_secret) != config.number_of_clients:
return
time_logger.server_start() # 收集够了 开始部分加和 记录时间戳
# ****************************
# 初始化结果列表
model = []
# 遍历所有 clients_secret 并累加每个元组的值
for i in range(len(clients_secret[0])):
if isinstance(clients_secret[0][i][0], list): # 检查是否是二维列表
result_row = []
for j in range(len(clients_secret[0][i])):
result_sub_row = []
for k in range(len(clients_secret[0][i][j])):
key = clients_secret[0][i][j][k][0]
value_sum = sum(clients_secret_list[i][j][k][1] for clients_secret_list in clients_secret)
result_sub_row.append((key, value_sum))
result_row.append(result_sub_row)
model.append(result_row)
else: # 一维列表
result_row = []
for j in range(len(clients_secret[0][i])):
key = clients_secret[0][i][j][0]
value_sum = sum(clients_secret_list[i][j][1] for clients_secret_list in clients_secret)
result_row.append((key, value_sum))
model.append(result_row)
# *************************
pickle_model = pickle.dumps(model)
len_dumped_model = len(pickle_model)
time_logger.server_start_upload() # 子服务器向聚合服务器发送初步加和的秘密份额
global total_upload_cost
total_upload_cost += len(pickle_model)
url = f'http://{config.master_server_address}:{config.master_server_port}/recv'
# 构造用于发送数据的 URL,
# 其中 config.master_server_address 和 config.master_server_port 分别是聚合服务器的地址和端口
# /recv 是聚合服务器上用于接收数据的路由。
s = requests.Session()
# 创建一个新的请求会话。
new_source = source.SourceAddressAdapter(config.server_address)
# 创建一个地址适配器,这个适配器指定了发送请求时使用的源地址(子服务器的地址)。config.server_address 是子服务器的IP地址。
s.mount('http://', new_source)
# 将创建的源地址适配器挂载到会话的 HTTP 通信上。这意味着所有通过这个会话发出的 HTTP 请求都将使用指定的源地址。
print(s.post(url, pickle_model).json())
# 通过会话发送一个 POST 请求到聚合服务器,请求的内容是序列化的模型数据 pickle_model。
# 调用 .json() 方法来解析返回的 JSON 数据,并打印出来。这通常用于获取聚合服务器处理请求的结果。
clients_secret.clear()
# 清空 clients_secret 列表。这是因为在发送完数据后,当前的秘密数据已经不再需要,清空列表为接收下一轮的数据做准备。
global training_round
training_round += 1
print(f"[UPLOAD] Sent aggregated weights to the master, size: {len_dumped_model}")
print(f"[DOWNLOAD] Total download cost so far: {total_download_cost}")
print(f"[UPLOAD] Total upload cost so far: {total_upload_cost}")
print(f"********************** [ROUND] Round {training_round} completed **********************")
time_logger.server_idle()
# 客户端请求/recv 传来模型
# recv_thread线程里边收集
# 达到足够数量继续加和
# 然后新建会话 绑定源地址 发给聚合服务器
@api.route('/recv', methods=['POST'])
def recv():
# recv到post里边的内容,并且作为参数传给线程
my_thread = threading.Thread(target=recv_thread, args=(clients_secret,
request.data, request.remote_addr))
my_thread.start()
return {"response": "ok"}
api.run(host=config.server_address, port=int(config.server_base_port) + int(sys.argv[1]), debug=True, threaded=True)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/tangxiangjie/fedSTSS.git
git@gitee.com:tangxiangjie/fedSTSS.git
tangxiangjie
fedSTSS
FedSTSS
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385