1 Star 0 Fork 0

zhouzz / 强化学习

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
alg_SAC_Discrete.py 7.40 KB
一键复制 编辑 原始数据 按行查看 历史
zhouzz 提交于 2024-05-26 17:00 . 2024年5月26日17:00:30
# https://hrl.boyuai.com/chapter/2/sac%E7%AE%97%E6%B3%95/
# 最大熵强化学习:通过控制策略所采取动作的熵来调整探索与利用的平衡
# https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/discrete_sac.py
import random
import gym
import numpy as np
import torch
import torch.nn.functional as F
import rl_utils
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)
class QValueNet(torch.nn.Module):
''' 只有一层隐藏层的Q网络 '''
def __init__(self, state_dim, hidden_dim, action_dim):
super(QValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
class SAC:
''' 处理离散动作的SAC算法 '''
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, para_soft_update, discount_factor, device):
self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device) # 策略网络
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)
self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device) # 第一个Q网络
self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device) # 第一个目标Q网络
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)
self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device) # 第二个Q网络
self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device) # 第二个目标Q网络
self.target_critic_2.load_state_dict(self.critic_2.state_dict()) # 令目标Q网络的初始参数和Q网络一样
self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)
self.discount_factor = discount_factor
self.para_soft_update = para_soft_update
self.device = device
# 使用alpha的log值,可以使训练结果比较稳定
self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
self.log_alpha.requires_grad = True # 可以对alpha求梯度
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)
self.target_entropy = target_entropy # 目标熵的大小
def take_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.actor(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
# 计算目标Q值,直接用策略网络的输出概率进行期望计算
def calc_target(self, rewards, next_states, dones):
next_probs = self.actor(next_states)
next_log_probs = torch.log(next_probs + 1e-8)
entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)
# 挑选一个Q值小的网络
q1_value = self.target_critic_1(next_states)
q2_value = self.target_critic_2(next_states)
min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)
next_value = min_qvalue + self.log_alpha.exp() * entropy
td_target = rewards + self.discount_factor * next_value * (1 - dones)
return td_target
def soft_update(self, net, target_net):
for param_target, param in zip(target_net.parameters(),net.parameters()):
param_target.data.copy_(param_target.data * (1.0 - self.para_soft_update) +param.data * self.para_soft_update)
def update(self, transition_dict):
states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device) # 动作不再是float类型
rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
actions = actions.type(torch.long)
'''更新两个Q网络'''
td_target = self.calc_target(rewards, next_states, dones)
critic_1_q_values = self.critic_1(states).gather(1, actions)
critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_1_optimizer.step()
critic_2_q_values = self.critic_2(states).gather(1, actions)
critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
'''更新策略网络'''
probs = self.actor(states)
log_probs = torch.log(probs + 1e-8)
# 直接根据概率计算熵
entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True) #
q1_value = self.critic_1(states)
q2_value = self.critic_2(states)
min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True) # 直接根据概率计算期望
actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
'''更新熵正则项的系数 alpha'''
alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())
self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
self.soft_update(self.critic_1, self.target_critic_1)
self.soft_update(self.critic_2, self.target_critic_2)
actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
discount_factor = 0.98
para_soft_update = 0.005 # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env_name = 'CartPole-v1'
env = gym.make(env_name)
seed2024=0
random.seed(seed2024)
np.random.seed(seed2024)
env.reset(seed=seed2024)
torch.manual_seed(seed2024)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, para_soft_update, discount_factor, device)
alg_name = 'SAC'
print('Training!!!!')
return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)
rl_utils.plot_results(return_list, env_name, alg_name, string_train_test = 'Training', moving_average_weight = 9)
print('Testing!!!!')
return_list_test = rl_utils.test_agent(env, agent, num_episodes = 50)
rl_utils.plot_results(return_list_test, env_name, alg_name, string_train_test = 'Testing', moving_average_weight = 3)
# print('Rendering!!!!')
# rl_utils.test_agent_render(env, agent)
Python
1
https://gitee.com/zhouzizhen/learnDRL.git
git@gitee.com:zhouzizhen/learnDRL.git
zhouzizhen
learnDRL
强化学习
master

搜索帮助