1 Star 0 Fork 0

zhouzz / 强化学习

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main_on_policy.py 2.09 KB
一键复制 编辑 原始数据 按行查看 历史
zhouzz 提交于 2024-05-03 11:53 . 2024年5月3日11:53:20
import random
import gym
import numpy as np
import torch
import rl_utils
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seedseed = 0
random.seed(seedseed)
np.random.seed(seedseed)
torch.manual_seed(seedseed)
'''parameters'''
hidden_dim = 128
discount_factor = 0.9
num_episodes = 2000
'''env'''
env_name = 'Pendulum-v1'
env = gym.make(env_name)
env.reset(seed=seedseed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_high = env.action_space.high[0] # 动作最大值
action_low = env.action_space.low[0] # 动作最小值
print(f'state_dim = {state_dim}')
print(f'action_dim = {action_dim}')
print(f'action_low = {action_low}')
print(f'action_high = {action_high}')
alg_name = 'PPO'
if alg_name == 'TRPO':
lmbda = 0.9
critic_lr = 1e-2
kl_constraint = 0.00005
alpha = 0.5
from on_policy.alg_TRPO_Continuous import TRPOContinuous
agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,lmbda, kl_constraint, alpha, critic_lr, discount_factor, device)
elif alg_name == 'PPO':
actor_lr = 1e-4
critic_lr = 5e-3
para_GAE_lmbda = 0.9
epochs = 10
para_PPO_clip = 0.2
from on_policy.alg_PPO_Continuous import PPOContinuous
agent = PPOContinuous(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,para_GAE_lmbda, epochs, para_PPO_clip, discount_factor, device)
print('Training!!!!')
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
rl_utils.plot_results(return_list, env_name, alg_name, string_train_test = 'Training', moving_average_weight = 9)
print('Testing!!!!')
return_list_test = rl_utils.test_agent(env, agent, num_episodes = 50)
rl_utils.plot_results(return_list_test, env_name, alg_name, string_train_test = 'Testing', moving_average_weight = 3)
print('Rendering!!!!')
rl_utils.test_agent_render(env, agent)
# time_start = time.perf_counter() # 记录开始时间
# time_end = time.perf_counter() # 记录结束时间
# time_sum = time_end - time_start # 计算的时间差为程序的执行时间,单位为秒/s
# print('time = %f' %time_sum)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zhouzizhen/learnDRL.git
git@gitee.com:zhouzizhen/learnDRL.git
zhouzizhen
learnDRL
强化学习
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891