1 Star 0 Fork 83

朱宗鑫 / rlcard

forked from Daochen Zha / rlcard 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
run_cfr.py 1.84 KB
一键复制 编辑 原始数据 按行查看 历史
Daochen Zha 提交于 2021-06-21 15:18 . Fix bugs in plot
''' An example of solve Leduc Hold'em with CFR (chance sampling)
'''
import os
import argparse
import rlcard
from rlcard.agents import CFRAgent, RandomAgent
from rlcard.utils import set_seed, tournament, Logger, plot_curve
def train(args):
# Make environments, CFR only supports Leduc Holdem
env = rlcard.make('leduc-holdem', config={'seed': 0, 'allow_step_back':True})
eval_env = rlcard.make('leduc-holdem', config={'seed': 0})
# Seed numpy, torch, random
set_seed(args.seed)
# Initilize CFR Agent
agent = CFRAgent(env, os.path.join(args.log_dir, 'cfr_model'))
agent.load() # If we have saved model, we first load the model
# Evaluate CFR against random
eval_env.set_agents([agent, RandomAgent(num_actions=env.num_actions)])
# Start training
with Logger(args.log_dir) as logger:
for episode in range(args.num_episodes):
agent.train()
print('\rIteration {}'.format(episode), end='')
# Evaluate the performance. Play with Random agents.
if episode % args.evaluate_every == 0:
agent.save() # Save model
logger.log_performance(env.timestep, tournament(eval_env, args.num_eval_games)[0])
# Get the paths
csv_path, fig_path = logger.csv_path, logger.fig_path
# Plot the learning curve
plot_curve(csv_path, fig_path, 'cfr')
if __name__ == '__main__':
parser = argparse.ArgumentParser("CFR example in RLCard")
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_episodes', type=int, default=5000)
parser.add_argument('--num_eval_games', type=int, default=2000)
parser.add_argument('--evaluate_every', type=int, default=100)
parser.add_argument('--log_dir', type=str, default='experiments/leduc_holdem_cfr_result/')
args = parser.parse_args()
train(args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dreamszhu/rlcard.git
git@gitee.com:dreamszhu/rlcard.git
dreamszhu
rlcard
rlcard
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891