# snn inference **Repository Path**: kzkx/snn-inference ## Basic Information - **Project Name**: snn inference - **Description**: 2025年秋季国科大《GPU架构与编程》 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 2 - **Created**: 2026-01-12 - **Last Updated**: 2026-01-25 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 2025年秋季国科大《GPU架构与编程》作业一 - SNN FashionMNIST 推理(CUDA) ## 1. 目录结构与关键文件 - `inference.cu`:推理主程序(含 CUDA kernel 与 `scnn_inference`)。 - `train.py` :训练脚本(用于生成权重文件)。 ## 2. 训练与权重 训练脚本会生成推理所需的 10 个 txt 权重文件(固定文件名)。推理程序只依赖这些 txt 文件与 `data/` 下测试集。 ## 3. 网络拓扑(固定) 推理程序实现的网络拓扑满足: ``` Conv -> IF -> Pool -> Conv -> IF -> Pool -> Flatten -> FC -> IF -> FC -> IF -> FC ``` 其中: - 时间步长 `T = 5`(在 `scnn_inference` 内部设置;满足要求 `T >= 2`)。 - `IF` 为 integrate-and-fire:每步累积膜电位,超过阈值 `v_threshold=1.0` 触发脉冲并重置膜电位。 - 各层间特征以二值脉冲(0/1)传播(最终输出层为累积计分)。 ## 4. 实现要点(inference.cu) 为降低推理耗时,`inference.cu` 内做了若干优化(不改变拓扑/层类型的前提下): - Conv1+Conv2 融合 mega kernel:一个 block 处理一个样本,减少中间特征落地与 kernel launch 开销。 - Conv 权重/偏置使用 constant memory:小参数常量缓存,降低访存开销。 - IF 更新使用 PTX 级分支消除:将“触发/重置”写成 predicated 指令,减少分支代价。 - FC 层使用 WMMA(Tensor Core):将 FC 权重 padding 到 16 的倍数,使用 half + wmma 进行矩阵乘。 - pinned host memory + 双 stream:H2D、计算、D2H 尝试重叠。 - 单次推理全 batch:本实现将测试集一次性作为 batch 送入,尽量减少调度开销。 ## 5. 编译(Linux / nvcc) ### 5.1 环境要求 - Ubuntu18.04 - 已安装 CUDA Toolkit(含 `nvcc`),CUDA版本为11.8 - 单块Tesla V100S-PCIE-32GB ### 5.2 编译命令 在仓库根目录打开终端: ```sh nvcc inference.cu -o inference_prog -Xcompiler "-O3 -std=c++14" -gencode arch=compute_70,code=sm_70 -rdc=true ``` ## 6. 运行(重要:参数目录与数据路径) `main()` 需要一个参数:``,该目录用于读取 10 个权重文件。 同时,测试集数据路径是通过下面的相对路径拼出来的: - 权重读取:`/conv1.weight.txt` 等 - 数据读取:`/../../.. /data/FashionMNIST/raw/...` 因此,**推荐的运行方式**是: 1) 把 10 个权重/偏置文件复制到 `inference_prog` 同目录下 2) 在 `inference_prog` 目录下运行 ```sh ./inference_prog . ``` ## 7. 推理程序输出 程序运行后输出一行: ``` : ``` - `seconds`:推理耗时(秒),保留 4 位小数 - `accuracy`:在 FashionMNIST 测试集(t10k)上的分类准确率,保留 4 位小数 示例: ``` 0.1234:0.8765 ```