diff --git a/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py b/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py index fee85271ef11675f01082b97cb10b3bdffdaaf7f..901f397d1d32a7c2a088733b5142f3d66ad3304b 100644 --- a/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py +++ b/PyTorch/built-in/rl/PPO_for_Pytorch/PPO.py @@ -269,8 +269,9 @@ class PPO: torch.save(self.policy_old.state_dict(), checkpoint_path) def load(self, checkpoint_path): - self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) - self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) + device_npu = torch.device('npu') + self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=device_npu, weights_only=False)) + self.policy.load_state_dict(torch.load(checkpoint_path, map_location=device_npu, weights_only=False)) diff --git a/PyTorch/built-in/rl/PPO_for_Pytorch/README.md b/PyTorch/built-in/rl/PPO_for_Pytorch/README.md index 361f17c526158461f6d7d3e15317a829d5ecbb76..4cf87788c903d5438a36288573b164f3f2ffc350 100644 --- a/PyTorch/built-in/rl/PPO_for_Pytorch/README.md +++ b/PyTorch/built-in/rl/PPO_for_Pytorch/README.md @@ -131,13 +131,14 @@ | NAME | FPS | MAX Training TimeSteps | Average Reward | |--------------| ------ | ---------------------- | -------------- | | 1p-竞品V | 585.37 | 3000000 | 197.75 | -| 1p-NPU-Atlas 800T A2 | 284.02 | 3000000 | 256.06 | +| 1p-NPU-Atlas 800T A2 | 284.02 | 3000000 | 240 | 说明:上表为历史数据,仅供参考。2025年5月10日更新的性能数据如下: | NAME | 精度类型 | FPS | | :------ |:-------:|:------:| | 1p-竞品 | FP16 | 585.37 | | 1p-Atlas 900 A2 PoDc | FP16 | 413.79 | +| 1p-Atlas 800T A2 | FP16 | 336.84 | # 公网地址说明 无。