# pt2ms_convert **Repository Path**: jungheil/pt2ms_convert ## Basic Information - **Project Name**: pt2ms_convert - **Description**: pt2ms_convert - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 1 - **Created**: 2022-12-19 - **Last Updated**: 2023-06-17 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README ## Pytorch 2 Mindspore Converter ### 简介 这是一个权重转换脚本,能够将训练好的pytorch权重文件转换为mindspore的权重文件。只需要pytorch的 `pth`权重文件和mindspore的网络源代码或者其生成的 `ckpt`权重文件即可完成转换。 mindspore和pytorch的部分算子参数名字不同,需要进行转换。例如,pytorch 中的 `LayerNorm`参数为 `weight`和 `bias`,但是在mindspore中的参数名分别为 `gamma`和 `beta`,这会导致该算子无法匹配。但是在权重文件中并没有算子类型的信息,而且我们不能简单的根据参数名字进行转换,例如 pytorch中的 `Linear`和mindspore中的 `Dense`参数都为 `weight`和 `bias`,这会出现混淆的情况。幸运的是目前参数名称不相同的情况只出现在 `BatchNorm`、`GroupNorm`和 `LayerNorm`当中,若**网络包含以上算子,请务必注意**。本脚本通过节点的前缀进行区分,映射规则在 `opsmap.yml`中。 脚本根据网络结构对参数进行匹配,算子的名称并不会影响匹配,但是对结构的顺序非常敏感。虽然在不存在歧义的情况下,脚本会对顺序不一致的参数进行匹配,但是仍然**非常建议根据pytorch的结构编写mindspore的代码**! 对于脚本无法匹配的参数,将会在最后进行输出。可通过 `prematch.yml`文件进行手动匹配。 ### opsmap.yml 若网络中包含 `GroupNorm`和 `LayerNorm`必须修改本文件!以 `LayerNorm`为例子: ```yaml - name: LayerNorm prefix: - "*.layernorm*" - "*.ln*" map: weight: gamma bias: beta ``` 其中,`map`中的key为pytorch中的算子参数名,value是mindspore中对应的算子参数名。`prefix`是参数的前缀,若算子的前缀在 `prefix`中,则根据 `map`的规则对参数名称进行转换。`prefix`支持 使用 通配符 `*`,如 `*.ln*`能够匹配: * `aaa.bbb.ln` * `aaa.ln2` ### prematch.yml 手动对pytorch和mindspore的权重进行匹配,同时手动匹配的算子不参与之后的最优匹配中。例: ```yaml pt_first_conv.0: ms_first_conv.0 pt.features.0.branch_main.bn: features.0.branch_main.bn ``` ### 开始 1. 根据网络情况修改 `opsmap.yml`文件 2. 准备mindspore的网络源代码或权重文件 3. 准备pytorch权重文件 4. 运行脚本 | 参数 | 说明 | | ------------------- | ------------------------------------------- | | `ms_net`, `m` | mindspore网络源代码 | | `pt_net` , `p` | pytorch网络源代码 | | `pt_weight` , `w` | pytorch权重 | | `ms_weight` , `n` | mindspore权重文件,在不提供网络源代码时使用 | | `strict_type` , `t` | 严格参数数据类型 | | `dst_net` , `d` | 生成网络路径 | | `order_match` , `o` | 严格按顺序匹配 | * 提供网络源代码能会在转换过后,输出网络迁移后pytorch和mindspore网络对于随机输入其计算结果的差异: ```bash python weight_conv_pt2ms.py -m models.shufflenetv2_ms.ShuffleNetV2 -p models.shufflenetv2_pt.ShuffleNetV2 -w weights/ShuffleNetV2.1.5x.pth.tar ``` * 不输入网络源代码,并使用严格类型,输出路径为 `weights/shufflenetv2_new.ckpt`: ```bash python weight_conv_pt2ms.py -n weights/shufflenetv2.ckpt -w weights/ShuffleNetV2.1.5x.pth.tar -d weights/shufflenetv2_new.ckpt -t ``` 5. 若没有未匹配的参数,完成;若存在,则根据实际情况调整mindspore网络参数结构(特别是顺序)或者通过 `prematch.yml`文件手动匹配。