# TrainGPT
**Repository Path**: asmots/train-gpt
## Basic Information
- **Project Name**: TrainGPT
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: Not specified
- **Default Branch**: master
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-12-06
- **Last Updated**: 2025-12-21
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# TrainGPT: 基于 MinGPT 的贪吃蛇 AI 训练项目
本项目旨在通过行为克隆(Behavior Cloning)技术,训练一个基于 Transformer (MinGPT) 的 AI 来玩贪吃蛇游戏。项目包含数据生成、模型训练、推理测试全流程。
## 📺 效果演示
| 🤪 低分模型 (Low Score) | 🧠 高分模型 (High Score) |
| :---: | :---: |
| | |
| *随机策略训练,经常撞墙* | *A* 策略训练,走位精准* |
---
## 🚀 快速开始
### 1. 环境准备
本项目基于 Python 和 PyTorch 开发,建议使用 GPU 进行训练。
**核心依赖:**
* Python 3.8+
* PyTorch (建议 2.0+, 需支持 CUDA)
* Pygame (用于游戏环境)
* Numpy
**安装命令:**
```bash
pip install torch torchvision numpy pygame
```
### 2. 直接试用(已有模型)
本项目提供了两个预训练模型,你可以通过可视化脚本直观感受它们的区别。脚本已内置 **HUD (抬头显示)**,实时展示 AI 的神经网络思考过程。
**模型列表:**
* 🧠 **高分模型 (推荐)**: `min_gpt/checkpoints/best_model.pt` (基于 A* 数据训练,走位风骚)
* 🤪 **低分模型**: `min_gpt/model_data/low_score/best_model.pt` (基于随机数据训练,经常撞墙)
**运行命令:**
```bash
cd min_gpt
# 1. 试用高分模型 (默认)
python play_game_with_gpt.py --model checkpoints/best_model.pt --speed 40
# 2. 试用低分模型 (形成鲜明对比)
python play_game_with_gpt.py --model model_data/low_score/best_model.pt --speed 100
```
**HUD 界面说明:**
* **State Perception**: 红灯亮起代表 AI 探测到了该方向的危险(墙或身体)。
* **Action Probabilities**: 柱状图显示 AI 对下一步动作的信心(直走/左转/右转)。
### 3. 生成训练数据
我们需要先生成高质量的专家数据。本项目提供了基于 **A* 算法** 的数据生成器,能产生高质量的路径规划数据。
```bash
cd snake_ai
# 生成 100 局高分数据(分数 >= 50)
python generate_data.py --astar --num_episodes 100
```
生成的数据会保存在 `snake_ai/snake_expert_data.jsonl`。
**⚠️ 重要提示:**
请将生成的 `snake_expert_data.jsonl` 复制到 `min_gpt` 目录下,以便训练脚本读取:
```bash
copy snake_expert_data.jsonl ..\min_gpt\
```
### 4. 训练 MinGPT 模型
使用生成的专家数据训练 Transformer 模型。
```bash
cd ../min_gpt
# 开始训练
python train.py
```
训练过程中会自动保存验证集表现最好的模型到 `checkpoints/best_model.pt`。
### 5. 测试与可视化
加载训练好的模型,观看 AI 玩游戏。
```bash
# 在 min_gpt 目录下运行
python play_game_with_gpt.py
```
---
## 🧠 技术原理详解
### 1. 数据流与状态表示 (Data Representation)
为了让 Transformer 理解贪吃蛇的游戏局面,我们将每一帧的游戏状态抽象为一个 **11维的布尔向量**。

**11维状态向量详解:**
1. **危险探测 (3维)**:
* `Danger Straight`: 前方是否有障碍(墙或身体)
* `Danger Right`: 右侧是否有障碍
* `Danger Left`: 左侧是否有障碍
2. **当前移动方向 (4维)**:
* `Dir Left`, `Dir Right`, `Dir Up`, `Dir Down` (One-hot 编码)
3. **食物相对位置 (4维)**:
* `Food Left`, `Food Right`, `Food Up`, `Food Down`
**滑动窗口 (Sliding Window):**
Transformer 不仅仅看当前这一帧,而是看**过去 20 帧**的序列。
* 输入张量形状: `(Batch_Size, 20, 11)`
* 这让模型能够感知“速度”、“趋势”以及短期的历史决策路径。
**训练数据格式 (JSONL):**
```json
{
"episode_id": 0,
"final_score": 52,
"steps": [
{
"state": [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1], // 11维状态
"action": 1, // 0:直走, 1:右转, 2:左转
"reward": 0,
"next_state": [...]
},
...
]
}
```
### 2. MinGPT 模型架构 (Model Architecture)
本项目采用的是 **Decoder-only Transformer** 架构,类似于 GPT-2 的微缩版。

**核心组件解析:**
1. **线性投影 (Linear Projection)**:
* 将 11维的离散状态映射到 128维的连续向量空间 (Embedding)。
* 作用:提取特征,将稀疏的布尔信息转化为稠密的语义信息。
2. **位置编码 (Positional Encoding)**:
* Transformer 本质上是并行处理的,不具备时间概念。
* 我们需要加上可学习的位置向量,告诉模型“这是第 T-1 步,那是第 T 步”。
3. **多头因果自注意力 (Masked Multi-Head Self-Attention)**:
* **Self-Attention**: 模型会计算序列中每一帧与其他帧的相关性。例如,当前的“危险”状态可能需要参考几步前的“转向”操作。
* **Causal Mask (因果掩码)**: 这是一个下三角矩阵掩码,确保模型在预测第 T 步动作时,**只能看到 T 之前的信息**,绝对看不到未来。这是“自回归”生成的关键。

4. **前馈网络 (Feed Forward Network)**:
* 对每个时间步的特征进行非线性变换,增强模型的表达能力。
### 3. 为什么选择 Transformer?
相比于传统的 DQN (Deep Q-Network):
* **序列建模能力**: DQN 通常只看当前一帧(或堆叠几帧),而 Transformer 天生擅长处理长序列。它能更好地理解“局面是如何演变到这一步的”。
* **行为克隆 (Behavior Cloning)**: 我们将强化学习问题转化为了**监督学习**问题(Sequence Modeling)。即:给定过去的历史状态,预测专家会做的下一个动作。
* **泛化潜力**: 这种架构也就是目前大语言模型(LLM)的基础,证明了其强大的模式识别和泛化能力。
---
## 💡 数据质量的重要性
**"Garbage In, Garbage Out"**
在训练过程中,我们对比了两组数据:
| 数据来源 | 策略 | 平均分 | 训练后模型表现 |
| :--- | :--- | :--- | :--- |
| **随机/弱DQN** | 随机游走 + 基础避障 | < 10 | **差**。经常在空地转圈,容易把自己围死。 |
| **A* 算法** | 全局最优路径规划 | **> 50** | **优**。走位风骚,能通过极其狭窄的通道,有明显的规划感。 |
**结论**:对于模仿学习,**少量的高质量专家数据**(如 100 局 A* 数据)远胜于**大量低质量数据**(如 10000 局随机数据)。
---
## 💻 训练环境与依赖
* **硬件**:
* 推荐使用 NVIDIA GPU (支持 CUDA)。
* 实测 RTX 4060 训练 100 局数据仅需 1-2 分钟。
* CPU 也可以训练,但速度较慢。
* **软件**:
* Python 3.8+
* PyTorch 1.10+
* Pygame (用于渲染游戏画面)
```bash
pip install torch torchvision numpy pygame
```
---
## 🛠️ 目录结构说明
```
TrainGPT/
├── min_gpt/ # 🧠 核心算法目录
│ ├── model.py # Transformer 模型定义 (Block, CausalSelfAttention)
│ ├── train.py # 训练主循环
│ ├── dataset.py # 自定义 Dataset 和 DataLoader
│ ├── config.py # 配置文件 (维度, 层数, 学习率等)
│ ├── play_game_with_gpt.py # ✨ 推理与可视化脚本
│ └── ...
│
├── snake_ai/ # 🐍 游戏环境与数据生成
│ ├── snake_game.py # 贪吃蛇游戏引擎 (Pygame)
│ ├── generate_data.py # 数据生成入口 (支持 --astar)
│ ├── a_star_solver.py # A* 寻路算法实现
│ └── agent.py # 基础 DQN Agent (用于对比)
│
└── README.md # 项目文档
```