1 Star 0 Fork 83

朱宗鑫 / rlcard

forked from Daochen Zha / rlcard 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
run_rl.py 3.40 KB
一键复制 编辑 原始数据 按行查看 历史
Daochen Zha 提交于 2021-06-21 15:18 . Fix bugs in plot
''' An example of training a reinforcement learning agent on the environments in RLCard
'''
import os
import argparse
import torch
import rlcard
from rlcard.agents import RandomAgent
from rlcard.utils import get_device, set_seed, tournament, reorganize, Logger, plot_curve
def train(args):
# Check whether gpu is available
device = get_device()
# Seed numpy, torch, random
set_seed(args.seed)
# Make the environment with seed
env = rlcard.make(args.env, config={'seed': args.seed})
# Initialize the agent and use random agents as opponents
if args.algorithm == 'dqn':
from rlcard.agents import DQNAgent
agent = DQNAgent(num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
device=device)
elif args.algorithm == 'nfsp':
from rlcard.agents import NFSPAgent
agent = NFSPAgent(num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
device=device)
agents = [agent]
for _ in range(env.num_players):
agents.append(RandomAgent(num_actions=env.num_actions))
env.set_agents(agents)
# Start training
with Logger(args.log_dir) as logger:
for episode in range(args.num_episodes):
if args.algorithm == 'nfsp':
agents[0].sample_episode_policy()
# Generate data from the environment
trajectories, payoffs = env.run(is_training=True)
# Reorganaize the data to be state, action, reward, next_state, done
trajectories = reorganize(trajectories, payoffs)
# Feed transitions into agent memory, and train the agent
# Here, we assume that DQN always plays the first position
# and the other players play randomly (if any)
for ts in trajectories[0]:
agent.feed(ts)
# Evaluate the performance. Play with random agents.
if episode % args.evaluate_every == 0:
logger.log_performance(env.timestep, tournament(env, args.num_eval_games)[0])
# Get the paths
csv_path, fig_path = logger.csv_path, logger.fig_path
# Plot the learning curve
plot_curve(csv_path, fig_path, args.algorithm)
# Save model
save_path = os.path.join(args.log_dir, 'model.pth')
torch.save(agent, save_path)
print('Model saved in', save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
parser.add_argument('--env', type=str, default='leduc-holdem',
choices=['blackjack', 'leduc-holdem', 'limit-holdem', 'doudizhu', 'mahjong', 'no-limit-holdem', 'uno', 'gin-rummy'])
parser.add_argument('--algorithm', type=str, default='dqn', choices=['dqn', 'nfsp'])
parser.add_argument('--cuda', type=str, default='')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_episodes', type=int, default=5000)
parser.add_argument('--num_eval_games', type=int, default=2000)
parser.add_argument('--evaluate_every', type=int, default=100)
parser.add_argument('--log_dir', type=str, default='experiments/leduc_holdem_dqn_result/')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
train(args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dreamszhu/rlcard.git
git@gitee.com:dreamszhu/rlcard.git
dreamszhu
rlcard
rlcard
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891