# TrainGPT **Repository Path**: asmots/train-gpt ## Basic Information - **Project Name**: TrainGPT - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-12-06 - **Last Updated**: 2025-12-21 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # TrainGPT: 基于 MinGPT 的贪吃蛇 AI 训练项目 本项目旨在通过行为克隆(Behavior Cloning)技术,训练一个基于 Transformer (MinGPT) 的 AI 来玩贪吃蛇游戏。项目包含数据生成、模型训练、推理测试全流程。 ## 📺 效果演示 | 🤪 低分模型 (Low Score) | 🧠 高分模型 (High Score) | | :---: | :---: | | | | | *随机策略训练,经常撞墙* | *A* 策略训练,走位精准* | --- ## 🚀 快速开始 ### 1. 环境准备 本项目基于 Python 和 PyTorch 开发,建议使用 GPU 进行训练。 **核心依赖:** * Python 3.8+ * PyTorch (建议 2.0+, 需支持 CUDA) * Pygame (用于游戏环境) * Numpy **安装命令:** ```bash pip install torch torchvision numpy pygame ``` ### 2. 直接试用(已有模型) 本项目提供了两个预训练模型,你可以通过可视化脚本直观感受它们的区别。脚本已内置 **HUD (抬头显示)**,实时展示 AI 的神经网络思考过程。 **模型列表:** * 🧠 **高分模型 (推荐)**: `min_gpt/checkpoints/best_model.pt` (基于 A* 数据训练,走位风骚) * 🤪 **低分模型**: `min_gpt/model_data/low_score/best_model.pt` (基于随机数据训练,经常撞墙) **运行命令:** ```bash cd min_gpt # 1. 试用高分模型 (默认) python play_game_with_gpt.py --model checkpoints/best_model.pt --speed 40 # 2. 试用低分模型 (形成鲜明对比) python play_game_with_gpt.py --model model_data/low_score/best_model.pt --speed 100 ``` **HUD 界面说明:** * **State Perception**: 红灯亮起代表 AI 探测到了该方向的危险(墙或身体)。 * **Action Probabilities**: 柱状图显示 AI 对下一步动作的信心(直走/左转/右转)。 ### 3. 生成训练数据 我们需要先生成高质量的专家数据。本项目提供了基于 **A* 算法** 的数据生成器,能产生高质量的路径规划数据。 ```bash cd snake_ai # 生成 100 局高分数据(分数 >= 50) python generate_data.py --astar --num_episodes 100 ``` 生成的数据会保存在 `snake_ai/snake_expert_data.jsonl`。 **⚠️ 重要提示:** 请将生成的 `snake_expert_data.jsonl` 复制到 `min_gpt` 目录下,以便训练脚本读取: ```bash copy snake_expert_data.jsonl ..\min_gpt\ ``` ### 4. 训练 MinGPT 模型 使用生成的专家数据训练 Transformer 模型。 ```bash cd ../min_gpt # 开始训练 python train.py ``` 训练过程中会自动保存验证集表现最好的模型到 `checkpoints/best_model.pt`。 ### 5. 测试与可视化 加载训练好的模型,观看 AI 玩游戏。 ```bash # 在 min_gpt 目录下运行 python play_game_with_gpt.py ``` --- ## 🧠 技术原理详解 ### 1. 数据流与状态表示 (Data Representation) 为了让 Transformer 理解贪吃蛇的游戏局面,我们将每一帧的游戏状态抽象为一个 **11维的布尔向量**。 ![State Vector Diagram](min_gpt/snake_state_vector_diagram.png) **11维状态向量详解:** 1. **危险探测 (3维)**: * `Danger Straight`: 前方是否有障碍(墙或身体) * `Danger Right`: 右侧是否有障碍 * `Danger Left`: 左侧是否有障碍 2. **当前移动方向 (4维)**: * `Dir Left`, `Dir Right`, `Dir Up`, `Dir Down` (One-hot 编码) 3. **食物相对位置 (4维)**: * `Food Left`, `Food Right`, `Food Up`, `Food Down` **滑动窗口 (Sliding Window):** Transformer 不仅仅看当前这一帧,而是看**过去 20 帧**的序列。 * 输入张量形状: `(Batch_Size, 20, 11)` * 这让模型能够感知“速度”、“趋势”以及短期的历史决策路径。 **训练数据格式 (JSONL):** ```json { "episode_id": 0, "final_score": 52, "steps": [ { "state": [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1], // 11维状态 "action": 1, // 0:直走, 1:右转, 2:左转 "reward": 0, "next_state": [...] }, ... ] } ``` ### 2. MinGPT 模型架构 (Model Architecture) 本项目采用的是 **Decoder-only Transformer** 架构,类似于 GPT-2 的微缩版。 ![MinGPT Architecture](min_gpt/mingpt_architecture_chinese.png) **核心组件解析:** 1. **线性投影 (Linear Projection)**: * 将 11维的离散状态映射到 128维的连续向量空间 (Embedding)。 * 作用:提取特征,将稀疏的布尔信息转化为稠密的语义信息。 2. **位置编码 (Positional Encoding)**: * Transformer 本质上是并行处理的,不具备时间概念。 * 我们需要加上可学习的位置向量,告诉模型“这是第 T-1 步,那是第 T 步”。 3. **多头因果自注意力 (Masked Multi-Head Self-Attention)**: * **Self-Attention**: 模型会计算序列中每一帧与其他帧的相关性。例如,当前的“危险”状态可能需要参考几步前的“转向”操作。 * **Causal Mask (因果掩码)**: 这是一个下三角矩阵掩码,确保模型在预测第 T 步动作时,**只能看到 T 之前的信息**,绝对看不到未来。这是“自回归”生成的关键。 ![Attention Mechanism](min_gpt/attention_mechanism_chinese.png) 4. **前馈网络 (Feed Forward Network)**: * 对每个时间步的特征进行非线性变换,增强模型的表达能力。 ### 3. 为什么选择 Transformer? 相比于传统的 DQN (Deep Q-Network): * **序列建模能力**: DQN 通常只看当前一帧(或堆叠几帧),而 Transformer 天生擅长处理长序列。它能更好地理解“局面是如何演变到这一步的”。 * **行为克隆 (Behavior Cloning)**: 我们将强化学习问题转化为了**监督学习**问题(Sequence Modeling)。即:给定过去的历史状态,预测专家会做的下一个动作。 * **泛化潜力**: 这种架构也就是目前大语言模型(LLM)的基础,证明了其强大的模式识别和泛化能力。 --- ## 💡 数据质量的重要性 **"Garbage In, Garbage Out"** 在训练过程中,我们对比了两组数据: | 数据来源 | 策略 | 平均分 | 训练后模型表现 | | :--- | :--- | :--- | :--- | | **随机/弱DQN** | 随机游走 + 基础避障 | < 10 | **差**。经常在空地转圈,容易把自己围死。 | | **A* 算法** | 全局最优路径规划 | **> 50** | **优**。走位风骚,能通过极其狭窄的通道,有明显的规划感。 | **结论**:对于模仿学习,**少量的高质量专家数据**(如 100 局 A* 数据)远胜于**大量低质量数据**(如 10000 局随机数据)。 --- ## 💻 训练环境与依赖 * **硬件**: * 推荐使用 NVIDIA GPU (支持 CUDA)。 * 实测 RTX 4060 训练 100 局数据仅需 1-2 分钟。 * CPU 也可以训练,但速度较慢。 * **软件**: * Python 3.8+ * PyTorch 1.10+ * Pygame (用于渲染游戏画面) ```bash pip install torch torchvision numpy pygame ``` --- ## 🛠️ 目录结构说明 ``` TrainGPT/ ├── min_gpt/ # 🧠 核心算法目录 │ ├── model.py # Transformer 模型定义 (Block, CausalSelfAttention) │ ├── train.py # 训练主循环 │ ├── dataset.py # 自定义 Dataset 和 DataLoader │ ├── config.py # 配置文件 (维度, 层数, 学习率等) │ ├── play_game_with_gpt.py # ✨ 推理与可视化脚本 │ └── ... │ ├── snake_ai/ # 🐍 游戏环境与数据生成 │ ├── snake_game.py # 贪吃蛇游戏引擎 (Pygame) │ ├── generate_data.py # 数据生成入口 (支持 --astar) │ ├── a_star_solver.py # A* 寻路算法实现 │ └── agent.py # 基础 DQN Agent (用于对比) │ └── README.md # 项目文档 ```