# GNN_learning **Repository Path**: haobingwen/gnn_learning ## Basic Information - **Project Name**: GNN_learning - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-11-23 - **Last Updated**: 2025-12-11 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Graph Neural Network (GNN) Learning Codes 这是一个关于图神经网络(GNN)的学习与实验仓库,包含多种经典 GNN 模型及图相关任务示例代码。本项目已经按「库代码 (`src/`)+ 任务入口 (`tasks/`) + 实验脚本 (`experiments/`)」的结构进行了重构,方便阅读、复用和扩展。 --- ## 目录结构总览 项目根目录下的主要结构如下: - `data/`:数据集根目录(Cora、Planetoid、PubMed、PCQM4M 等)。 - `src/`:**库代码**,包含可复用的数据集封装、模型、层与工具函数。 - `src/datasets/` - `planetoid_pubmed.py`:PubMed 引文网络数据集封装(`PlanetoidPubMed`)。 - `src/models/` - `node_classification/` - `gat.py`:节点分类用 GAT 模型定义。 - `graph_regression/` - `gin_conv.py`:GIN 卷积层。 - `gin_node.py`:基于 GIN 的节点表征网络。 - `gin_graph.py`:图级 GIN 池化模型 `GINGraphPooling`。 - `mol_encoder.py`:分子图的原子/键特征编码器 `AtomEncoder`/`BondEncoder`。 - `pcqm4m_data.py`:PCQM4M 图回归数据集封装 `MyPCQM4MDataset`。 - `src/layers/` - `ckgconv.py`(后续可迁入):自定义 GNN 层。 - `src/utils/`:可放置通用工具(目前为空,未来可从 `experiments/` 中抽取)。 - `tasks/`:**任务入口脚本**,负责解析参数、组织数据与模型、启动训练或可视化。 - `node_classification/` - `node_classification.py`:基于 PubMed 的 GAT 节点分类任务入口脚本,内部调用 `src.datasets.PlanetoidPubMed` 与 `src.models.node_classification.GAT`。 - `cluster_gcn.py`:Cluster-GCN 节点分类脚本(暂依赖旧结构,可逐步迁移)。 - `edge_classification/` - `edge_classification.py`:边分类任务入口。 - `graph_embedding/` - `deepwalk_2vec.py`:DeepWalk + word2vec 节点嵌入与可视化。 - `graph_embedding.py`:基于邻接矩阵幂 + UMAP 的图嵌入可视化(Walklets 风格)。 - `lle_embedding.py`:LLE 等流形学习相关可视化。 - `graph_regression/` - `main.py`:基于 GIN 的图回归主训练脚本,内部使用 `src.models.graph_regression` 中的模型和数据集封装。 - `run.sh`:运行脚本示例(Linux/macOS 环境下)。 - `saves/`:保存 GIN 回归任务训练输出(日志、模型权重等)。 - `experiments/`:**实验与演示脚本**,用于探索、画图、快速试验。 - `dataset_analysis.py`:数据集分析与可视化。 - `dataset_test.py`:数据读取/预处理测试。 - `graph_test.py`:图构造与操作测试。 - `randwalk_test.py`:随机游走/DeepWalk 等相关试验。 - `model_compare.py`:模型对比实验。 - `gbrx.py`:图可视化实验脚本。 - `notebooks/` - `learn_node_representation.ipynb`:节点表示学习的 Jupyter Notebook 教程。 - `outputs/` - 存放生成的图像与 HTML 文件,如 DeepWalk 可视化结果等。 --- ## 环境依赖 项目基于 Python 3.x,核心依赖包括: - `torch`(PyTorch) - `torch-geometric`(PyG)以及其依赖(`torch-scatter` 等) - `ogb`(Open Graph Benchmark,用于 PCQM4M 等) - `networkx` - `gensim` - `matplotlib` - `scikit-learn` - `umap-learn` - `python-louvain`(社区划分,用于图嵌入可视化) - 其它常规科学计算库:`numpy`, `pandas`, `tqdm` 等 建议在虚拟环境中安装依赖,例如: ```bash pip install torch torchvision torchaudio # 按照官网选择与你 CUDA 对应的版本 pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-$(python -c "import torch;print(torch.__version__.split('+')[0])")+cpu.html pip install ogb networkx gensim matplotlib scikit-learn umap-learn python-louvain pandas tqdm ``` > 安装 PyTorch 与 PyG 时请参考各自官网,根据你的 CUDA 版本选择合适的安装命令;上面仅为示意。 --- ## 如何运行各类任务 以下示例都假定你在项目根目录(包含本 `README.md` 的目录)下运行命令。 ### 1. 节点分类(PubMed + GAT) 入口脚本:`tasks/node_classification/node_classification.py` 运行方式(推荐包形式): ```bash python -m tasks.node_classification.node_classification ``` 或直接运行脚本: ```bash python tasks/node_classification/node_classification.py ``` 该脚本会: - 使用 `src.datasets.PlanetoidPubMed` 从 `data/PlanetoidPubMed` 目录加载 PubMed 数据; - 构建 `src.models.node_classification.GAT` 模型; - 在训练集上训练,在测试集上评估准确率并打印日志。 > 如需调整隐藏维度、层数、dropout 等,可在 `src/models/node_classification/gat.py` 内修改 GAT 定义,或在入口脚本中增加参数解析逻辑。 #### 1.1 Cluster-GCN 节点分类 Cluster-GCN 脚本仍位于:`tasks/node_classification/cluster_gcn.py`,暂时依赖原始结构,可这样运行: ```bash python tasks/node_classification/cluster_gcn.py ``` > 该脚本通常需要 Reddit 等大规模数据集,运行前请确保 `data/` 中数据路径正确,或根据脚本中的说明下载数据。 --- ### 2. 图嵌入与可视化(Graph Embedding) 目录:`tasks/graph_embedding/` #### 2.1 DeepWalk + word2vec 脚本:`tasks/graph_embedding/deepwalk_2vec.py` 运行: ```bash python tasks/graph_embedding/deepwalk_2vec.py ``` 主要流程: - 使用 `networkx` 构造图(如空手道俱乐部图); - 运行随机游走生成序列; - 使用 `gensim` 进行 word2vec 训练得到节点嵌入; - 使用 `matplotlib` / 降维方法将嵌入可视化,结果保存到 `outputs/` 目录下(如 `deepwalk_word2vec.html`)。 #### 2.2 基于邻接矩阵幂 + UMAP 的图嵌入 脚本:`tasks/graph_embedding/graph_embedding.py` 运行: ```bash python tasks/graph_embedding/graph_embedding.py ``` 主要流程: - 基于 `networkx` 获取图的邻接矩阵 $A$; - 计算 $A^k$ 作为高阶邻接特征; - 使用 `umap-learn` 将高维特征降到 2D; - 使用 `python-louvain` 做社区划分,用社区着色; - 生成多子图可视化并保存图片。 #### 2.3 其它嵌入(LLE 等) 脚本:`tasks/graph_embedding/lle_embedding.py` 等,可按需运行: ```bash python tasks/graph_embedding/lle_embedding.py ``` --- ### 3. 图回归(Graph Regression,PCQM4M + GIN) 目录:`tasks/graph_regression/` 核心库代码:`src/models/graph_regression/` 入口脚本:`tasks/graph_regression/main.py` #### 3.1 准备数据集 `MyPCQM4MDataset` 会自动从官方链接下载 PCQM4M 数据,并解压到你指定的 `--dataset_root` 目录下。建议新建一个目录,例如 `dataset/`: ```bash mkdir dataset ``` 首次运行会自动下载并处理数据,时间较长且占用硬盘空间较多,请确保网络与磁盘空间充足。 #### 3.2 运行训练脚本 从项目根目录运行(推荐): ```bash python -m tasks.graph_regression.main --task_name GINGraphPooling --dataset_root dataset ``` 或进入子目录运行: ```bash cd tasks/graph_regression python main.py --task_name GINGraphPooling --dataset_root ../../dataset ``` 常用参数说明(在 `main.py` 中由 `argparse` 解析): - `--task_name`:实验名,用于区分不同实验并创建 `saves/*` 目录。 - `--device`:GPU 编号(如 `0`),若无 GPU 则会自动使用 CPU。 - `--num_layers`:GINConv 层数,默认 5。 - `--graph_pooling`:图池化方式,`sum` / `mean` / `max` / `attention` / `set2set`。 - `--emb_dim`:节点嵌入维度,默认 256。 - `--drop_ratio`:dropout 比例。 - `--batch_size`:训练批大小,默认 512。 - `--epochs`:最大训练轮数,默认 100。 - `--early_stop`:若验证集 MAE 连续若干 epoch 未提升则提前停止。 - `--dataset_root`:PCQM4M 数据集根目录(必选)。 训练过程中: - 使用 `SummaryWriter` 将指标记录到 TensorBoard; - 在验证集上监控 MAE 并保存在 `tasks/graph_regression/saves/` 下; - 可选地在最优 checkpoint 上对测试集推理并保存提交文件。 使用 TensorBoard 查看训练曲线示例: ```bash tensorboard --logdir tasks/graph_regression/saves ``` --- ## 实验脚本与 Notebook `experiments/` 目录下脚本适合作为「沙箱」用来: - 探索不同图数据的性质(`dataset_analysis.py`); - 尝试不同可视化(`graph_test.py`, `gbrx.py`); - 对比不同模型/参数组合(`model_compare.py`); - 验证随机游走/嵌入等基础组件(`randwalk_test.py`)。 `notebooks/learn_node_representation.ipynb` 提供了交互式的节点表示学习示例,推荐用 Jupyter / VS Code 打开阅读和修改。 --- ## 小结与建议 - **库代码统一在 `src/` 下维护**: - 新增模型或数据集时,建议优先放到 `src/models/` 或 `src/datasets/`,再在 `tasks/` 中创建对应的入口脚本。 - **任务入口统一放在 `tasks/` 下**: - 每个子目录对应一个任务类型(节点分类 / 图回归 / 嵌入 / 边分类等)。 - **实验脚本放在 `experiments/`**,避免和正式任务入口混在一起。 如果你想继续扩展这个仓库(比如增加 GraphSAGE、GIN 的节点分类版本、GCN 的图分类 demo 等),可以直接在 `src/models/` 中新增子模块,然后仿照现有 `tasks/*` 目录写入口脚本即可。 --- ## 许可证 MIT License