From b86326fdf6ae09b07a34f6b318ecf97eaa3110c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=AF=E5=AE=87?= Date: Fri, 10 Oct 2025 17:15:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9PPO=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9B=B8=E5=85=B3=E9=80=BB=E8=BE=91=E4=BD=BF?= =?UTF-8?q?=E5=85=B6=E6=AD=A3=E5=B8=B8=E8=AE=AD=E7=BB=83&=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=80=A7=E8=83=BD=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py | 5 +++-- PyTorch/built-in/rl/PPO_for_Pytorch/README.md | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py b/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py index fee85271ef..901f397d1d 100644 --- a/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py +++ b/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py @@ -269,8 +269,9 @@ class PPO: torch.save(self.policy_old.state_dict(), checkpoint_path) def load(self, checkpoint_path): - self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) - self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) + device_npu = torch.device('npu') + self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=device_npu, weights_only=False)) + self.policy.load_state_dict(torch.load(checkpoint_path, map_location=device_npu, weights_only=False)) diff --git a/PyTorch/built-in/rl/PPO_for_Pytorch/README.md b/PyTorch/built-in/rl/PPO_for_Pytorch/README.md index 361f17c526..4cf87788c9 100644 --- a/PyTorch/built-in/rl/PPO_for_Pytorch/README.md +++ b/PyTorch/built-in/rl/PPO_for_Pytorch/README.md @@ -131,13 +131,14 @@ | NAME | FPS | MAX Training TimeSteps | Average Reward | |--------------| ------ | ---------------------- | -------------- | | 1p-竞品V | 585.37 | 3000000 | 197.75 | -| 1p-NPU-Atlas 800T A2 | 284.02 | 3000000 | 256.06 | +| 1p-NPU-Atlas 800T A2 | 284.02 | 3000000 | 240 | 说明:上表为历史数据,仅供参考。2025年5月10日更新的性能数据如下: | NAME | 精度类型 | FPS | | :------ |:-------:|:------:| | 1p-竞品 | FP16 | 585.37 | | 1p-Atlas 900 A2 PoDc | FP16 | 413.79 | +| 1p-Atlas 800T A2 | FP16 | 336.84 | # 公网地址说明 无。 -- Gitee