代码拉取完成,页面将自动刷新
import random
import gym
import numpy as np
import torch
import rl_utils
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seedseed = 0
random.seed(seedseed)
np.random.seed(seedseed)
torch.manual_seed(seedseed)
'''parameters'''
hidden_dim = 128
discount_factor = 0.9
num_episodes = 2000
'''env'''
env_name = 'Pendulum-v1'
env = gym.make(env_name)
env.reset(seed=seedseed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_high = env.action_space.high[0] # 动作最大值
action_low = env.action_space.low[0] # 动作最小值
print(f'state_dim = {state_dim}')
print(f'action_dim = {action_dim}')
print(f'action_low = {action_low}')
print(f'action_high = {action_high}')
alg_name = 'PPO'
if alg_name == 'TRPO':
lmbda = 0.9
critic_lr = 1e-2
kl_constraint = 0.00005
alpha = 0.5
from on_policy.alg_TRPO_Continuous import TRPOContinuous
agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,lmbda, kl_constraint, alpha, critic_lr, discount_factor, device)
elif alg_name == 'PPO':
actor_lr = 1e-4
critic_lr = 5e-3
para_GAE_lmbda = 0.9
epochs = 10
para_PPO_clip = 0.2
from on_policy.alg_PPO_Continuous import PPOContinuous
agent = PPOContinuous(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,para_GAE_lmbda, epochs, para_PPO_clip, discount_factor, device)
print('Training!!!!')
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
rl_utils.plot_results(return_list, env_name, alg_name, string_train_test = 'Training', moving_average_weight = 9)
print('Testing!!!!')
return_list_test = rl_utils.test_agent(env, agent, num_episodes = 50)
rl_utils.plot_results(return_list_test, env_name, alg_name, string_train_test = 'Testing', moving_average_weight = 3)
print('Rendering!!!!')
rl_utils.test_agent_render(env, agent)
# time_start = time.perf_counter() # 记录开始时间
# time_end = time.perf_counter() # 记录结束时间
# time_sum = time_end - time_start # 计算的时间差为程序的执行时间,单位为秒/s
# print('time = %f' %time_sum)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。