1 Star 0 Fork 0

zhouzz / 强化学习

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
alg_PPO_Discrete_big.py 3.76 KB
一键复制 编辑 原始数据 按行查看 历史
zhouzz 提交于 2024-05-02 20:31 . 2024年5月2日20:31:21
import torch
import torch.nn.functional as F
import rl_utils
# from PPO_Discrete import PPO
class ValueNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim):
super(ValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc3 = torch.nn.Linear(hidden_dim, 1)
def forward(self, x):
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return self.fc3(x)
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc3 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return F.softmax(self.fc3(x), dim=1)
class PPO:
''' PPO算法,采用截断方式 '''
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, epochs, para_GAE_lmbda, para_PPO_clip, discount_fac, device):
self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.critic = ValueNet(state_dim, hidden_dim).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)
self.epochs = epochs # 一条序列的数据用来训练轮数
self.discount_fac = discount_fac
self.para_GAE_lmbda = para_GAE_lmbda
self.para_PPO_clip = para_PPO_clip # PPO中截断范围的参数
self.device = device
def take_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.actor(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def update(self, transition_dict):
states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
td_target = rewards + self.discount_fac * self.critic(next_states) * (1 - dones)
# 时序差分误差
td_delta = td_target - self.critic(states)
# 根据策略\theta '的优势
advantage = rl_utils.compute_advantage(self.discount_fac, self.para_GAE_lmbda,td_delta.cpu()).to(self.device)
# states 下,根据策略\theta ',采取各个动作的概率的对数
old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()
for _ in range(self.epochs):
# states 下,根据策略\theta,采取各个动作的概率的对数
log_probs = torch.log(self.actor(states).gather(1, actions))
# 采取各个动作的概率的比
ratio = torch.exp(log_probs - old_log_probs)
surr2 = torch.clamp(ratio, 1 - self.para_PPO_clip, 1 + self.para_PPO_clip) * advantage # 截断
actor_loss = torch.mean(-torch.min(ratio * advantage, surr2)) # PPO损失函数
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zhouzizhen/learnDRL.git
git@gitee.com:zhouzizhen/learnDRL.git
zhouzizhen
learnDRL
强化学习
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891