代码拉取完成,页面将自动刷新
Algorithms | Action Space | Policy | Update | Envs |
---|---|---|---|---|
DQN (double, dueling, PER) | Discrete Only | -- | Off-policy | Atari, Classic Control |
AC | Discrete/Continuous | Stochastic | On-policy | All |
PG | Discrete/Continuous | Stochastic | On-policy | All |
DDPG | Continuous | Deterministic | Off-policy | Classic Control, Box2D, Mujoco, Robotics, DeepMind Control, RLBench |
TD3 | Continuous | Deterministic | Off-policy | Classic Control, Box2D, Mujoco, Robotics, DeepMind Control, RLBench |
SAC | Continuous | Stochastic | Off-policy | Classic Control, Box2D, Mujoco, Robotics, DeepMind Control, RLBench |
A3C | Discrete/Continuous | Stochastic | On-policy | Atari, Classic Control, Box2D, Mujoco, Robotics, DeepMind Control |
PPO | Discrete/Continuous | Stochastic | On-policy | All |
DPPO | Discrete/Continuous | Stochastic | On-policy | Atari, Classic Control, Box2D, Mujoco, Robotics, DeepMind Control |
TRPO | Discrete/Continuous | Stochastic | On-policy | All |
AlgName = 'DQN'
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'CartPole-v1'
# EnvType = 'classic_control' # the name of env needs to match the type of env
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'AC'
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'Pendulum-v0'
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'PG'
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'CartPole-v0'
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'DDPG'
EnvName = 'Pendulum-v0' # only continuous action
EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'TD3'
EnvName = 'Pendulum-v0' # only continuous action
EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'SAC'
EnvName = 'Pendulum-v0' # only continuous action
EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'A3C'
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'Pendulum-v0' # only continuous action
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
number_workers = 2 # need to specify number of parallel workers
env = build_env(EnvName, EnvType, nenv=number_workers)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'Pendulum-v0'
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, 'PPO')
alg = PPO(method='clip', **alg_params) # specify 'clip' or 'penalty' method for PPO
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=False, **learn_params)
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'Pendulum-v0'
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
number_workers = 2 # need to specify number of parallel workers
env = build_env(EnvName, EnvType, nenv=number_workers)
alg_params, learn_params = call_default_params(env, EnvType, 'DPPO')
alg = DPPO(method='penalty', **alg_params) # specify 'clip' or 'penalty' method for PPO
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
AlgName = 'TRPO'
EnvName = 'PongNoFrameskip-v4'
EnvType = 'atari'
# EnvName = 'CartPole-v0'
# EnvType = 'classic_control'
# EnvName = 'BipedalWalker-v2'
# EnvType = 'box2d'
# EnvName = 'Ant-v2'
# EnvType = 'mujoco'
# EnvName = 'FetchPush-v1'
# EnvType = 'robotics'
# EnvName = 'FishSwim-v0'
# EnvType = 'dm_control'
# EnvName = 'ReachTarget'
# EnvType = 'rlbench'
env = build_env(EnvName, EnvType)
alg_params, learn_params = call_default_params(env, EnvType, AlgName)
alg = eval(AlgName+'(**alg_params)')
alg.learn(env=env, mode='train', render=False, **learn_params)
alg.learn(env=env, mode='test', render=True, **learn_params)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。