1 Star 0 Fork 1

赵子健/强化学习期中作业_对DQN进行改进

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils_memory.py 2.04 KB
一键复制 编辑 原始数据 按行查看 历史
赵子健 提交于 3年前 . code -version1
from typing import (
Tuple,
)
import torch
from utils_types import (
BatchAction,
BatchDone,
BatchNext,
BatchReward,
BatchState,
TensorStack5,
TorchDevice,
)
class ReplayMemory(object):
def __init__(
self,
channels: int,
capacity: int,
device: TorchDevice,
full_sink: bool = True,
) -> None:
self.__device = device
self.__capacity = capacity
self.__size = 0
self.__pos = 0
sink = lambda x: x.to(device) if full_sink else x
self.__m_states = sink(torch.zeros(
(capacity, channels, 84, 84), dtype=torch.uint8))
self.__m_actions = sink(torch.zeros((capacity, 1), dtype=torch.long))
self.__m_rewards = sink(torch.zeros((capacity, 1), dtype=torch.int8))
self.__m_dones = sink(torch.zeros((capacity, 1), dtype=torch.bool))
def push(
self,
folded_state: TensorStack5,
action: int,
reward: int,
done: bool,
) -> None:
self.__m_states[self.__pos] = folded_state
self.__m_actions[self.__pos, 0] = action
self.__m_rewards[self.__pos, 0] = reward
self.__m_dones[self.__pos, 0] = done
self.__pos += 1
self.__size = max(self.__size, self.__pos)
self.__pos %= self.__capacity
def sample(self, batch_size: int) -> Tuple[
BatchState,
BatchAction,
BatchReward,
BatchNext,
BatchDone,
]:
indices = torch.randint(0, high=self.__size, size=(batch_size,))
b_state = self.__m_states[indices, :4].to(self.__device).float()
b_next = self.__m_states[indices, 1:].to(self.__device).float()
b_action = self.__m_actions[indices].to(self.__device)
b_reward = self.__m_rewards[indices].to(self.__device).float()
b_done = self.__m_dones[indices].to(self.__device).float()
return b_state, b_action, b_reward, b_next, b_done
def __len__(self) -> int:
return self.__size
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zzj_rs/RL_homework_DQN.git
git@gitee.com:zzj_rs/RL_homework_DQN.git
zzj_rs
RL_homework_DQN
强化学习期中作业_对DQN进行改进
master

搜索帮助