diff --git a/README.md b/README.md index 305271d34b256f5288d95f0dbaeb21d5a5e3eb3e..cd59c958ef2590979111b77d31c45b173a6f18ed 100644 --- a/README.md +++ b/README.md @@ -236,9 +236,8 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz | torch | 2.1.0 | | torch_npu | release v5.0.0 | -当前仓库使用的megatron commitId为[bcce6f54](https://github.com/NVIDIA/Megatron-LM/tree/bcce6f54e075e3c3374ea67adefe54f3f2da2b07) -【基于最新版本我们实测的性能情况统计如下】 +【基于现版本megatron我们实测的性能情况统计如下】 diff --git a/README_en.md b/README_en.md index a14128071bf352403b4423370f11a7840b9a8d5a..102afd4658fbf4f72c2c546a766e7782d3116f54 100644 --- a/README_en.md +++ b/README_en.md @@ -237,9 +237,8 @@ For the supported models listed above, we provide training scripts and readme in | torch | 2.1.0 | | torch_npu | release v5.0.0 | -The current repository uses Megatron commitId [bcce6f54](https://github.com/NVIDIA/Megatron-LM/tree/bcce6f54e075e3c3374ea67adefe54f3f2da2b07) -【Based on the latest version, the performance statistics from our testing are as follows】 +【Based on the current version of megatron, the performance statistics from our testing are as follows】
diff --git a/examples/aquila/README.md b/examples/aquila/README.md index d26df0a6a442a2c2c1b11f16261ea785bed3aae3..8575529334854a4262fc54d0c386b1000afd5c5f 100644 --- a/examples/aquila/README.md +++ b/examples/aquila/README.md @@ -28,6 +28,11 @@ Aquila-7B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -48,15 +53,17 @@ pip install torch_npu-2.1.0.postxxxx-cp38-cp38-xxxx_aarch64.whl # 通过PTA上提供的安装包,以whl文件方式安装apex pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl -# 拉取AscendSpeed源代码,进入AscendSpeed目录,source相应的CANN环境文件,然后源码方式安装ascendspeed加速包 -git clone https://gitee.com/ascend/AscendSpeed.git -cd AscendSpeed/ # 请按照您的真实环境修改 set_env.sh 路径 source /usr/local/Ascend/ascend-toolkit/set_env.sh -pip install -e ./ +# 拉取AscendSpeed源代码,进入AscendSpeed目录,然后源码方式安装ascendspeed加速包 +git clone https://gitee.com/ascend/AscendSpeed.git +cd AscendSpeed/ +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 +pip install -r requirements.txt +pip install -e . +cd .. -# 进入ModelLink主目录,安装其余依赖包 -cd ModelLink/ +# 安装其余依赖包 pip install -r requirements.txt ``` diff --git a/examples/aquila/README_en.md b/examples/aquila/README_en.md index 26bba5d923c0352bc919bf0c9a464535a0c87ed2..3fd5c58b9b685272b13b9f7fd83a52f7c1d559e1 100644 --- a/examples/aquila/README_en.md +++ b/examples/aquila/README_en.md @@ -28,6 +28,11 @@ Here's a hardware summary of pre-training Aquila-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -45,14 +50,18 @@ conda activate test pip install torch==2.1.0 pip install torch_npu-2.1.0.postxxxx-cp38-cp38-xxxx_aarch64.whl pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl -# use git to clone the AscendSpeed source code, enter the directory, source the set_env.sh file based on your host settings(you may need to change the path), then install ascendspeed package by source code + +# source the set_env.sh file based on your host settings(you may need to change the path) +source /usr/local/Ascend/ascend-toolkit/set_env.sh +# use git to clone the AscendSpeed source code, enter the directory, then install ascendspeed package by source code git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed/ -source /usr/local/Ascend/ascend-toolkit/set_env.sh -pip install -e ./ +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 +pip install -r requirements.txt +pip install -e . +cd .. -# enter the ModelLink/ directory and install other packages -cd ModelLink/ +# install other packages pip install -r requirements.txt ``` diff --git a/examples/baichuan/README.md b/examples/baichuan/README.md index 3533bed6ee98f6d1d98f072a49b86b6510832aa1..535a9b8c87a1c903aca367476d0eb1d54784f12d 100644 --- a/examples/baichuan/README.md +++ b/examples/baichuan/README.md @@ -41,6 +41,11 @@ Baichuan-7B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -66,6 +71,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -262,6 +268,11 @@ Baichuan-13B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -287,6 +298,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/baichuan/README_en.md b/examples/baichuan/README_en.md index fcb128f86ac9d2eca3363c642c2700583ce76ecf..b3d493c328022bb5f66f728e21c8c448f5938c1d 100644 --- a/examples/baichuan/README_en.md +++ b/examples/baichuan/README_en.md @@ -42,6 +42,11 @@ Here's a hardware summary of pre-training Baichuan-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -67,6 +72,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -264,6 +270,11 @@ Here's a hardware summary of pre-training Baichuan-13B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -289,6 +300,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh #install Ascendspeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/baichuan2/README.md b/examples/baichuan2/README.md index 3cef67bd90bbf525dba3d3a49ba982a33d33b940..0636a1857bb501d1921da63b3cfb29bd1d43971d 100644 --- a/examples/baichuan2/README.md +++ b/examples/baichuan2/README.md @@ -39,6 +39,11 @@ Baichuan2-7B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -64,6 +69,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -263,6 +269,11 @@ Baichuan2-13B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -288,6 +299,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/baichuan2/README_en.md b/examples/baichuan2/README_en.md index a47668e13f80046dcf3ea7a3b56e4331e25723a7..c433bf770b89b3d6fd777a52cd1f9a473c0c52a4 100644 --- a/examples/baichuan2/README_en.md +++ b/examples/baichuan2/README_en.md @@ -39,6 +39,11 @@ Here's a hardware summary of pre-training Baichuan2-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -64,6 +69,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -263,6 +269,11 @@ Here's a hardware summary of pre-training Baichuan2-13B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -288,6 +299,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/bloom/README.md b/examples/bloom/README.md index 38fb639c5ead7a43614e806fed6b9c4db01f1c02..27d49a897c6d258a0b44aa6de05aa2b30647f840 100644 --- a/examples/bloom/README.md +++ b/examples/bloom/README.md @@ -24,6 +24,11 @@ Bloom-7B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -49,6 +54,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -231,6 +237,11 @@ Bloom-176B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -256,6 +267,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/bloom/README_en.md b/examples/bloom/README_en.md index e6789e6ca484b6c539052aeb097d473ec1bbd853..64e10251f3d3d6d4a739e7b543d7afe1d5598de1 100644 --- a/examples/bloom/README_en.md +++ b/examples/bloom/README_en.md @@ -24,6 +24,11 @@ Here's a hardware summary of pre-training Bloom-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -49,6 +54,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -233,6 +239,11 @@ Here's a hardware summary of pre-training Bloom-176B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -258,6 +269,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/intern/README.md b/examples/intern/README.md index 91cc5af84dcea33a762e56cd0edac9489df2693d..9b99ac6b6d0a761d16613aacfd09384964ee0a05 100644 --- a/examples/intern/README.md +++ b/examples/intern/README.md @@ -37,6 +37,11 @@ InternLM-7B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -62,6 +67,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -222,6 +228,11 @@ InternLM-65B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -247,6 +258,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/intern/README_en.md b/examples/intern/README_en.md index 085797dea72ddb58f11f7ad8ae2b58886166ba6c..56d77e91c0f2d2dda31da560e269d1c7f2f5c0da 100644 --- a/examples/intern/README_en.md +++ b/examples/intern/README_en.md @@ -38,6 +38,11 @@ Here's a hardware summary of pre-training InternLM-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -63,6 +68,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -222,6 +228,11 @@ Here's a hardware summary of pre-training InternLM-65B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -247,6 +258,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/llama/README.md b/examples/llama/README.md index 6d65e3a5db350c8788e40aeed50e00526b555531..fa0b323a6ea333867238ebde306b4140f8414b58 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -40,6 +40,11 @@ LLaMA-7B/13B 训练的硬件配置如下: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -60,6 +65,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -470,6 +476,11 @@ LLaMA-33B/65B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -494,6 +505,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/llama/README_en.md b/examples/llama/README_en.md index 7cdefac288a98f966ce2a17b8df87542f547c800..2cd333875b12d5e8af6d468e58ef1f4edd4b1ac7 100644 --- a/examples/llama/README_en.md +++ b/examples/llama/README_en.md @@ -39,6 +39,11 @@ Here's a hardware summary of pre-training LLaMA-7B/13B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -60,6 +65,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install ascendspeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -452,6 +458,11 @@ The model was trained using alpaca datasets. ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -475,6 +486,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/llama2/README.md b/examples/llama2/README.md index ace911a0968a393576cdfd31daf7acdafda4b0a1..991b5eddb72145061e30241912083c37c2f0f43f 100755 --- a/examples/llama2/README.md +++ b/examples/llama2/README.md @@ -47,6 +47,11 @@ LLAMA2-7B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -71,6 +76,7 @@ LLAMA2-7B 训练的硬件配置: # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -394,6 +400,11 @@ LLaMA2-13B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -418,6 +429,7 @@ LLaMA2-13B 训练的硬件配置: # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -681,6 +693,11 @@ LLaMA2-34B/70B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -705,6 +722,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/llama2/README_en.md b/examples/llama2/README_en.md index 1626bb4ec73df93c360896db1a14fc5e4a416772..e2bdd42597185fa689b97a6f0411cfcbb922fcb8 100644 --- a/examples/llama2/README_en.md +++ b/examples/llama2/README_en.md @@ -46,6 +46,11 @@ Here's a hardware summary of pre-training LLAMA2-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -70,6 +75,7 @@ Here's a hardware summary of pre-training LLAMA2-7B: # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -413,6 +419,11 @@ Here's a hardware summary of pre-training LLaMA2-13B: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -437,6 +448,7 @@ Here's a hardware summary of pre-training LLaMA2-13B: # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. @@ -696,6 +708,11 @@ Here's a hardware summary of pre-training LLaMA2-34B/70B: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -721,6 +738,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md index 84e626163f5ba979cde41c010b773809d6af295b..e7b218aee08cbd16db122f17974807c8e3eb5a10 100644 --- a/examples/mixtral/README.md +++ b/examples/mixtral/README.md @@ -40,6 +40,11 @@ ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -65,6 +70,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/mixtral/README_en.md b/examples/mixtral/README_en.md index dc25da17d6787c6aa8682b1603bf145827fe6421..f462c6abcc674d3aac206bf1e8116456979192be 100644 --- a/examples/mixtral/README_en.md +++ b/examples/mixtral/README_en.md @@ -40,6 +40,11 @@ Recommended hardware configuration for inference: ```shell git clone https://gitee.com/ascend/ModelLink.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout -f bcce6f +cp -r megatron ../ModelLink/ +cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -65,6 +70,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed +git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/qwen/README.md b/examples/qwen/README.md index 64013256aeacf902163dfcc99260c9c275a016ea..d1004d3916ee8a564d4e84d2826d6b63ca2978b7 100644 --- a/examples/qwen/README.md +++ b/examples/qwen/README.md @@ -47,6 +47,11 @@ Qwen-7B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -68,6 +73,7 @@ Qwen-7B 训练的硬件配置: # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. @@ -273,6 +279,11 @@ Qwen-14B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -294,6 +305,7 @@ Qwen-14B 训练的硬件配置: # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. @@ -505,6 +517,11 @@ Qwen-72B 训练的硬件配置: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -526,6 +543,7 @@ Qwen-72B 训练的硬件配置: # 安装加速库 git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. diff --git a/examples/qwen/README_en.md b/examples/qwen/README_en.md index 6f96b1a0131db4a5d812a39973ecf66f121a3a62..c98c2912b333ed4a417e3f4c6fa167e8d5422242 100644 --- a/examples/qwen/README_en.md +++ b/examples/qwen/README_en.md @@ -46,6 +46,11 @@ Here's a hardware summary of pre-training Qwen-7B: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -67,6 +72,7 @@ Here's a hardware summary of pre-training Qwen-7B: # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. @@ -275,6 +281,11 @@ Here's a hardware summary of pre-training Qwen-14B: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -296,6 +307,7 @@ Here's a hardware summary of pre-training Qwen-14B: # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. @@ -510,6 +522,11 @@ Here's a hardware summary of pre-training Qwen-72B: ```shell git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. cd ModelLink mkdir logs mkdir model_from_hf @@ -531,6 +548,7 @@ Here's a hardware summary of pre-training Qwen-72B: # install AscendSpeed git clone https://gitee.com/ascend/AscendSpeed.git cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 pip install -r requirements.txt pip install -e . cd .. diff --git a/megatron/__init__.py b/megatron/__init__.py deleted file mode 100644 index c35de282a27e754544b1961d42bb5be27b83af71..0000000000000000000000000000000000000000 --- a/megatron/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -from .global_vars import get_args, get_retro_args -from .global_vars import get_current_global_batch_size -from .global_vars import get_num_microbatches -from .global_vars import get_signal_handler -from .global_vars import update_num_microbatches -from .global_vars import get_tokenizer -from .global_vars import get_tensorboard_writer -from .global_vars import get_wandb_writer -from .global_vars import get_adlr_autoresume -from .global_vars import get_timers -from .initialize import initialize_megatron - -from .utils import (print_rank_0, - is_last_rank, - print_rank_last) diff --git a/megatron/arguments.py b/megatron/arguments.py deleted file mode 100644 index d4f1cd5a324b5d23928a690ea7cf7bda30f4432f..0000000000000000000000000000000000000000 --- a/megatron/arguments.py +++ /dev/null @@ -1,1403 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Megatron arguments.""" - -import argparse -import dataclasses -import json -import os -import torch -import types - -import torch.nn.functional as F -from megatron.global_vars import set_retro_args, get_retro_args -from tools.retro.utils import get_args_path as get_retro_args_path - -from megatron.core.models.retro import RetroConfig -from megatron.core.transformer import TransformerConfig - - -def parse_args(extra_args_provider=None, ignore_unknown_args=False): - """Parse all arguments.""" - parser = argparse.ArgumentParser(description='Megatron-LM Arguments', - allow_abbrev=False) - - # Standard arguments. - parser = _add_network_size_args(parser) - parser = _add_regularization_args(parser) - parser = _add_training_args(parser) - parser = _add_initialization_args(parser) - parser = _add_learning_rate_args(parser) - parser = _add_checkpointing_args(parser) - parser = _add_mixed_precision_args(parser) - parser = _add_distributed_args(parser) - parser = _add_validation_args(parser) - parser = _add_data_args(parser) - parser = _add_autoresume_args(parser) - parser = _add_biencoder_args(parser) - parser = _add_vision_args(parser) - parser = _add_logging_args(parser) - parser = _add_inference_args(parser) - parser = _add_transformer_engine_args(parser) - parser = _add_retro_args(parser) - parser = _add_experimental_args(parser) - - # Custom arguments. - if extra_args_provider is not None: - parser = extra_args_provider(parser) - - # Parse. - if ignore_unknown_args: - args, _ = parser.parse_known_args() - else: - args = parser.parse_args() - - # Args from environment - args.rank = int(os.getenv('RANK', '0')) - args.world_size = int(os.getenv("WORLD_SIZE", '1')) - - return args - -def validate_args(args, defaults={}): - # Tensor model parallel size. - args.tensor_model_parallel_size = min( - args.tensor_model_parallel_size, args.world_size) - assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ - ' ({}) is not divisible by tensor model parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size) - # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size)) - args.transformer_pipeline_model_parallel_size = ( - args.pipeline_model_parallel_size - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_size - ) - # Checks. - model_parallel_size = args.pipeline_model_parallel_size * \ - args.tensor_model_parallel_size - assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \ - 'world size ({}) is not divisible by tensor parallel size ({}) times ' \ - 'pipeline parallel size ({}) times context parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, args.context_parallel_size) - args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size) - if args.rank == 0: - print('using world size: {}, data-parallel size: {}, ' - 'context-parallel size: {} ' - 'tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {} '.format( - args.world_size, args.data_parallel_size, - args.context_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size), flush=True) - if args.pipeline_model_parallel_size > 1: - if args.pipeline_model_parallel_split_rank is not None: - assert args.pipeline_model_parallel_split_rank < \ - args.pipeline_model_parallel_size, 'split rank needs'\ - ' to be less than pipeline model parallel size ({})'.format( - args.pipeline_model_parallel_size) - - if args.tp_comm_overlap: - assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' - - - # Deprecated arguments - assert args.batch_size is None, '--batch-size argument is no longer ' \ - 'valid, use --micro-batch-size instead' - del args.batch_size - assert args.warmup is None, '--warmup argument is no longer valid, use ' \ - '--lr-warmup-fraction instead' - del args.warmup - assert args.model_parallel_size is None, '--model-parallel-size is no ' \ - 'longer valid, use --tensor-model-parallel-size instead' - del args.model_parallel_size - - if args.checkpoint_activations: - if args.rank == 0: - print('--checkpoint-activations is no longer valid, use --recompute-activations, ' - 'or, for more control, --recompute-granularity and --recompute-method.') - exit() - del args.checkpoint_activations - - if args.recompute_activations: - args.recompute_granularity = 'selective' - del args.recompute_activations - - # Set input defaults. - for key in defaults: - # For default to be valid, it should not be provided in the - # arguments that are passed to the program. We check this by - # ensuring the arg is set to None. - if getattr(args, key, None) is not None: - if args.rank == 0: - print('WARNING: overriding default arguments for {key}:{v} \ - with {key}:{v2}'.format(key=key, v=defaults[key], - v2=getattr(args, key)), - flush=True) - else: - setattr(args, key, defaults[key]) - - # Batch size. - assert args.micro_batch_size is not None - assert args.micro_batch_size > 0 - if args.global_batch_size is None: - args.global_batch_size = args.micro_batch_size * args.data_parallel_size - if args.rank == 0: - print('setting global batch size to {}'.format( - args.global_batch_size), flush=True) - assert args.global_batch_size > 0 - if args.num_layers_per_virtual_pipeline_stage is not None: - assert args.pipeline_model_parallel_size > 2, \ - 'pipeline-model-parallel size should be greater than 2 with ' \ - 'interleaved schedule' - assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ - 'number of layers should be divisible by the pipeline parallel size' - num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size - assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \ - 'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage' - args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \ - args.num_layers_per_virtual_pipeline_stage - else: - args.virtual_pipeline_model_parallel_size = None - # Overlap P2P communication is disabled if not using the interleaved schedule. - args.overlap_p2p_comm = False - if args.rank == 0: - print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved ' - 'schedule does not support overlapping p2p communication') - - if args.overlap_param_gather: - assert args.use_distributed_optimizer, \ - '--overlap-param-gather only supported with distributed optimizer' - - # Parameters dtype. - args.params_dtype = torch.float - if args.fp16: - assert not args.bf16 - args.params_dtype = torch.half - if args.bf16: - assert not args.fp16 - args.params_dtype = torch.bfloat16 - # bfloat16 requires gradient accumulation and all-reduce to - # be done in fp32. - if not args.accumulate_allreduce_grads_in_fp32: - args.accumulate_allreduce_grads_in_fp32 = True - if args.rank == 0: - print('accumulate and all-reduce gradients in fp32 for ' - 'bfloat16 data type.', flush=True) - - if args.rank == 0: - print('using {} for parameters ...'.format(args.params_dtype), - flush=True) - - if args.dataloader_type is None: - args.dataloader_type = 'single' - - # Consumed tokens. - args.consumed_train_samples = 0 - args.consumed_valid_samples = 0 - - # Support for variable sequence lengths across batches/microbatches. - # set it if the dataloader supports generation of variable sequence lengths - # across batches/microbatches. Due to additional communication overhead - # during pipeline parallelism, it should not be set if sequence length - # is constant during training. - args.variable_seq_lengths = False - - # Iteration-based training. - if args.train_iters: - # If we use iteration-based training, make sure the - # sample-based options are off. - assert args.train_samples is None, \ - 'expected iteration-based training' - assert args.lr_decay_samples is None, \ - 'expected iteration-based learning rate decay' - assert args.lr_warmup_samples == 0, \ - 'expected iteration-based learning rate warmup' - assert args.rampup_batch_size is None, \ - 'expected no batch-size rampup for iteration-based training' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_iters == 0, \ - 'can only specify one of lr-warmup-fraction and lr-warmup-iters' - - # Sample-based training. - if args.train_samples: - # If we use sample-based training, make sure the - # iteration-based options are off. - assert args.train_iters is None, \ - 'expected sample-based training' - assert args.lr_decay_iters is None, \ - 'expected sample-based learning rate decay' - assert args.lr_warmup_iters == 0, \ - 'expected sample-based learnig rate warmup' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_samples == 0, \ - 'can only specify one of lr-warmup-fraction ' \ - 'and lr-warmup-samples' - - if args.num_layers is not None: - assert args.encoder_num_layers is None, \ - 'cannot have both num-layers and encoder-num-layers specified' - args.encoder_num_layers = args.num_layers - else: - assert args.encoder_num_layers is not None, \ - 'either num-layers or encoder-num-layers should be specified' - args.num_layers = args.encoder_num_layers - - # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', - 'max_position_embeddings'] - for req_arg in required_args: - _check_arg_is_not_none(args, req_arg) - - # Checks. - if args.ffn_hidden_size is None: - if args.swiglu: - # reduce the dimnesion for MLP since projections happens on - # two linear layers. this keeps the number of paramters in - # the same ballpark as the counterpart with 4*h size - # we keep it a multiple of 64, which means the actual tensor size - # will be a multiple of 64 / tp_size - args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64 - else: - args.ffn_hidden_size = 4 * args.hidden_size - - if args.kv_channels is None: - assert args.hidden_size % args.num_attention_heads == 0 - args.kv_channels = args.hidden_size // args.num_attention_heads - - if args.seq_length is not None: - assert args.encoder_seq_length is None - args.encoder_seq_length = args.seq_length - else: - assert args.encoder_seq_length is not None - args.seq_length = args.encoder_seq_length - - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length - if args.lr is not None: - assert args.min_lr <= args.lr - if args.save is not None: - assert args.save_interval is not None - # Mixed precision checks. - if args.fp16_lm_cross_entropy: - assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' - if args.fp32_residual_connection: - assert args.fp16 or args.bf16, \ - 'residual connection in fp32 only supported when using fp16 or bf16.' - - if args.weight_decay_incr_style == 'constant': - assert args.start_weight_decay is None - assert args.end_weight_decay is None - args.start_weight_decay = args.weight_decay - args.end_weight_decay = args.weight_decay - else: - assert args.start_weight_decay is not None - assert args.end_weight_decay is not None - - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - # Persistent fused layer norm. - if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): - args.no_persist_layer_norm = True - if args.rank == 0: - print('Persistent fused layer norm kernel is supported from ' - 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' - 'Defaulting to no_persist_layer_norm=True') - - # Activation recomputing. - if args.distribute_saved_activations: - assert args.tensor_model_parallel_size > 1, 'can distribute ' \ - 'recomputed activations only across tensor model ' \ - 'parallel groups' - assert args.recompute_granularity == 'full', \ - 'distributed recompute activations is only '\ - 'application to full recompute granularity' - assert args.recompute_method is not None, \ - 'for distributed recompute activations to work you '\ - 'need to use a recompute method ' - assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \ - 'distributed recompute activations are supported for pytorch ' \ - 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ - 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) - - if args.recompute_granularity == 'selective': - assert args.recompute_method is None, \ - 'recompute method is not yet supported for ' \ - 'selective recomputing granularity' - - # disable sequence parallelism when tp=1 - # to avoid change in numerics when - # sequence_parallelism is enabled. - if args.tensor_model_parallel_size == 1: - args.sequence_parallel = False - - # disable async_tensor_model_parallel_allreduce when - # model parallel memory optimization is enabled - if args.sequence_parallel: - args.async_tensor_model_parallel_allreduce = False - - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": - if args.sequence_parallel: - raise RuntimeError( - "Using sequence parallelism requires setting the environment variable " - "CUDA_DEVICE_MAX_CONNECTIONS to 1") - if args.async_tensor_model_parallel_allreduce: - raise RuntimeError( - "Using async gradient all reduce requires setting the environment " - "variable CUDA_DEVICE_MAX_CONNECTIONS to 1") - - # Disable bias gelu fusion if we are disabling bias altogether - if not args.add_bias_linear: - args.bias_gelu_fusion = False - - # Retro checks. - if args.retro_add_retriever: - - # Sequence parallelism unsupported. - assert not args.sequence_parallel, \ - "retro currently does not support sequence parallelism." - - # Pipeline parallelism unsupported. - assert args.pipeline_model_parallel_size == 1, \ - "retro currently does not support pipeline parallelism." - - # Load retro args. - retro_args_path = get_retro_args_path(args.retro_workdir) - assert os.path.exists(retro_args_path), "retro workdir missing args.json" - with open(retro_args_path) as f: - retro_args = types.SimpleNamespace(**json.load(f)) - retro_args.retro_return_doc_ids = args.retro_return_doc_ids - retro_args.retro_gpt_retrieved_length = \ - args.retro_num_retrieved_chunks * \ - retro_args.retro_gpt_chunk_length - set_retro_args(retro_args) - - # Legacy RoPE arguments - if args.use_rotary_position_embeddings: - args.position_embedding_type = 'rope' - - # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now - # don't allow it to keep things simple - if not args.add_position_embedding and args.position_embedding_type != 'rope': - raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') - - # MoE Spec check - if args.num_experts is not None: - assert args.spec is None, "Model Spec must be None when using MoEs" - - # Expert parallelism check - if args.expert_model_parallel_size > 1: - assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism" - assert args.num_experts % args.expert_model_parallel_size == 0, \ - "Number of experts should be a multiple of expert model parallel_size." - assert not args.use_distributed_optimizer, \ - "Expert parallelism is not suppored with distributed optimizer." - assert not args.fp16, \ - "Expert parallelism is not supported with fp16 training." - if args.tensor_model_parallel_size > 1: - assert args.sequence_parallel, \ - "When using expert parallelism and tensor parallelism, sequence parallelism must be used." - - # Print arguments. - _print_args("arguments", args) - retro_args = get_retro_args() - if retro_args and args != retro_args: - _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank)) - - return args - - -def _print_args(title, args): - """Print arguments.""" - if args.rank == 0: - print(f'------------------------ {title} ------------------------', - flush=True) - str_list = [] - for arg in vars(args): - dots = '.' * (48 - len(arg)) - str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print(f'-------------------- end of {title} ---------------------', - flush=True) - - -def _check_arg_is_not_none(args, arg): - assert getattr(args, arg) is not None, '{} argument is None'.format(arg) - -def core_transformer_config_from_args(args): - - # Translate args to core transformer configuration - kw_args = {} - for f in dataclasses.fields(TransformerConfig): - if hasattr(args, f.name): - kw_args[f.name] = getattr(args, f.name) - kw_args['persist_layer_norm'] = not args.no_persist_layer_norm - kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p - kw_args['layernorm_epsilon'] = args.norm_epsilon - kw_args['deallocate_pipeline_outputs'] = True - kw_args['pipeline_dtype'] = args.params_dtype - kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm - kw_args['num_moe_experts'] = args.num_experts - if args.swiglu: - kw_args['activation_func'] = F.silu - kw_args['gated_linear_unit'] = True - kw_args['bias_gelu_fusion'] = False - if args.squared_relu: - assert not args.swiglu - def squared_relu(x): - return torch.pow(F.relu(x), 2) - kw_args['activation_func'] = squared_relu - if args.init_method_xavier_uniform: - kw_args['init_method'] = torch.nn.init.xavier_uniform_ - kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_ - if args.group_query_attention: - kw_args['num_query_groups'] = args.num_query_groups - else: - kw_args['num_query_groups'] = None - - # If using Retro, return Retro config. - retro_args = get_retro_args() - if retro_args: - kw_args['retro_preprocess'] = retro_args - return RetroConfig(**kw_args) - - # Return Transformer config. - return TransformerConfig(**kw_args) - - -def _add_transformer_engine_args(parser): - group = parser.add_argument_group(title='Transformer-Engine') - - group.add_argument('--fp8-format', default=None, - choices=['e4m3', 'hybrid'], - help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass', - dest='fp8') - group.add_argument('--fp8-margin', type=int, default=0, - help='Scaling margin for fp8', - dest='fp8_margin') - group.add_argument('--fp8-interval', type=int, default=1, - help='Scaling update interval for fp8', - dest='fp8_interval') - group.add_argument('--fp8-amax-history-len', type=int, default=1, - help='Number of steps for which amax history is recorded per tensor', - dest='fp8_amax_history_len') - group.add_argument('--fp8-amax-compute-algo', default='most_recent', - choices=['most_recent', 'max'], - help='Algorithm for computing amax from history', - dest='fp8_amax_compute_algo') - group.add_argument('--no-fp8-wgrad', action='store_false', - help='Execute wgrad in higher precision even for FP8 runs', - dest='fp8_wgrad') - group.add_argument('--transformer-impl', default='local', - choices=['local', 'transformer_engine'], - help='Which Transformer implementation to use.') - - return parser - -def _add_inference_args(parser): - group = parser.add_argument_group(title='inference') - - group.add_argument('--inference-batch-times-seqlen-threshold', - type=int, default=512, - help='During inference, if batch-size times ' - 'sequence-length is smaller than this threshold ' - 'then we will not use pipelining, otherwise we will.') - group.add_argument('--max-tokens-to-oom', - type=int, default=12000, - help='Maximum number of tokens during inference' - 'tokens here is # in prompt + # to generate' - 'Allows us to throw an error before OOM crashes server') - group.add_argument('--output-bert-embeddings', action='store_true', - help='Output Bert embeddings (via mean pooling) from ' - 'model, rather than its binary head output or entire ' - 'hidden batch.') - group.add_argument('--bert-embedder-type', default="megatron", - choices=["megatron", "huggingface"], - help='Select either Megatron or Huggingface as the ' - 'Bert embedder.') - - return parser - - -def _add_retro_args(parser): - group = parser.add_argument_group(title='retro') - - group.add_argument('--retro-workdir', default=None, - help='Retro working directory, which contains the ' - 'preprocessed data for for pretraining. This directory ' - 'is built during preprocessing (see ' - 'tools/retro/README.md), and contains subdirectories ' - 'for the chunk database and pretraining neighbors.') - group.add_argument('--retro-add-retriever', - action='store_true', default=False, - help='Add a retriever to the transformer, for use in ' - 'pretraining a Retro model.') - group.add_argument('--retro-cyclic-train-iters', type=int, default=None, - help='Set number of training iterations for cyclic ' - 'Retro training.') - group.add_argument('--retro-encoder-layers', type=int, default=2, - help='Number of layers to use for the retrieval ' - 'encoder.') - group.add_argument('--retro-encoder-hidden-dropout', - type=float, default=0.1, help='Hidden dropout for ' - 'retrieval encoder.') - group.add_argument('--retro-encoder-attention-dropout', - type=float, default=0.1, help='Attention dropout for ' - 'retrieval encoder.') - group.add_argument("--retro-num-neighbors", type=int, default=2, - help='Number of neighbors to retrieve during ' - 'pretraining.') - group.add_argument("--retro-num-retrieved-chunks", type=int, default=2, - help='Number of chunks to retrieve from the retrieval ' - 'database.') - group.add_argument("--retro-return-doc-ids", action="store_true", - help="Turn this on when preprocessing retro data.") - group.add_argument("--retro-no-verify-neighbor-count", action="store_false", - dest="retro_verify_neighbor_count", - help="Skip verifying that len(GPT dataset) == len(saved " - "neighbors).") - - # Enforce argument naming convention. - for action in group._group_actions: - prefix = action.dest.split("_")[0] - assert prefix == "retro", \ - "Retro args must be prefixed with '--retro-*', for consistent " \ - "styling. Please fix '%s'." % ", ".join(action.option_strings) - - return parser - - -def _add_network_size_args(parser): - group = parser.add_argument_group(title='network size') - - group.add_argument('--num-layers', type=int, default=None, - help='Number of transformer layers.') - group.add_argument('--encoder-num-layers', type=int, default=None, - help='Number of encoder transformer layers.') - group.add_argument('--decoder-num-layers', type=int, default=None, - help='Number of decoder transformer layers.') - group.add_argument('--hidden-size', type=int, default=None, - help='Tansformer hidden size.') - group.add_argument('--ffn-hidden-size', type=int, default=None, - help='Transformer Feed-Forward Network hidden size. ' - 'This is set to 4*hidden-size if not provided') - group.add_argument('--num-attention-heads', type=int, default=None, - help='Number of transformer attention heads.') - group.add_argument('--kv-channels', type=int, default=None, - help='Projection weights dimension in multi-head ' - 'attention. This is set to ' - ' args.hidden_size // args.num_attention_heads ' - 'if not provided.') - group.add_argument('--group-query-attention', action='store_true', - help='Use group-query attention.') - group.add_argument('--num-query-groups', type=int, default=1) - - group.add_argument('--max-position-embeddings', type=int, default=None, - help='Maximum number of position embeddings to use. ' - 'This is the size of position embedding.') - group.add_argument('--position-embedding-type', type=str, default='learned_absolute', - choices=['learned_absolute', 'rope'], - help='Position embedding type.') - group.add_argument('--use-rotary-position-embeddings', action='store_true', - help='Use rotary positional embeddings or not. ' - 'Deprecated: use --position-embedding-type') - group.add_argument('--rotary-percent', type=float, default=1.0, - help='Percent of rotary dimension to use, default 100%%') - group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, - help='Sequence length interpolation factor for rotary embeddings.') - group.add_argument('--no-position-embedding', - action='store_false', - help='Disable position embedding. Deprecated: use --position-embedding-type', - dest='add_position_embedding') - group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, - help='Pad the vocab size to be divisible by this value.' - 'This is added for computational efficieny reasons.') - group.add_argument('--normalization', default='LayerNorm', - choices=['LayerNorm', 'RMSNorm'], - help='Which normalization technique to use.') - group.add_argument('--norm-epsilon', type=float, default=1e-5, - help='Epsilon for layer norm and RMS norm.') - group.add_argument('--apply-layernorm-1p', action='store_true', - help='Adjust LayerNorm weights such that they are centered ' - 'around zero. This improves numerical stability.') - group.add_argument('--apply-residual-connection-post-layernorm', - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' - 'reasons.') - group.add_argument('--squared-relu', action='store_true', - help='Use squared relu activation instead of default gelu') - group.add_argument('--swiglu', action='store_true', - help='Use gated linear units and SiLU activation instead of default gelu') - group.add_argument('--onnx-safe', type=bool, required=False, - help='Use workarounds for known problems with ' - 'Torch ONNX exporter') - group.add_argument('--bert-no-binary-head', action='store_false', - help='Disable BERT binary head.', - dest='bert_binary_head') - group.add_argument('--num-experts', type=int, default=None, - help='Number of Experts in Switch Transformer (None means no Switch)') - group.add_argument('--untie-embeddings-and-output-weights', action='store_true', - help='Untie embeddings and output weights.'), - return parser - - -def _add_logging_args(parser): - group = parser.add_argument_group(title='logging') - - group.add_argument('--log-params-norm', action='store_true', - help='If set, calculate and log parameters norm.') - group.add_argument('--log-num-zeros-in-grad', action='store_true', - help='If set, calculate and log the number of zeros in gradient.') - group.add_argument('--log-throughput', action='store_true', - help='If set, calculate and log throughput per GPU.') - group.add_argument('--timing-log-level', type=int, - default=0, choices=range(0,3), - help='Granularity level to measure and report timing. ' - ' 0: report only iteration time and make sure timing ' - ' does not introduce extra overhead.' - ' 1: report timing for operations that are executed ' - ' very limited times (basically once) during ' - ' each iteration (such as gradient all-reduce) ' - ' 2: report timing for operations that migh be ' - ' executed numerous times during each iteration. ' - 'Note that setting the level to 1 or 2 might ' - 'cause increase in iteration time.') - group.add_argument('--no-barrier-with-level-1-timing', action='store_false', - help='If not set, use barrier with level 1 time ' - 'measurements. Note that this is up to the user ' - 'to make sure calling barrier with their timers ' - 'will not result in hangs. This can happen if for ' - 'example the user adds a level 1 timer that is not ' - 'called by all ranks.', - dest='barrier_with_L1_time') - group.add_argument('--timing-log-option', type=str, default='minmax', - choices=['max', 'minmax', 'all'], - help='Options for logging timing:' - ' max: report the max timing across all ranks' - ' minmax: report min and max timings across all ranks' - ' all: report timings of all ranks.') - group.add_argument('--tensorboard-log-interval', type=int, default=1, - help='Report to tensorboard interval.') - group.add_argument('--tensorboard-queue-size', type=int, default=1000, - help='Size of the tensorboard queue for pending events ' - 'and summaries before one of the ‘add’ calls forces a ' - 'flush to disk.') - group.add_argument('--log-timers-to-tensorboard', action='store_true', - help='If set, write timers to tensorboard.') - group.add_argument('--log-batch-size-to-tensorboard', action='store_true', - help='If set, write batch-size to tensorboard.') - group.add_argument('--no-log-learnig-rate-to-tensorboard', - action='store_false', - help='Disable learning rate logging to tensorboard.', - dest='log_learning_rate_to_tensorboard') - group.add_argument('--no-log-loss-scale-to-tensorboard', - action='store_false', - help='Disable loss-scale logging to tensorboard.', - dest='log_loss_scale_to_tensorboard') - group.add_argument('--log-validation-ppl-to-tensorboard', - action='store_true', - help='If set, write validation perplexity to ' - 'tensorboard.') - group.add_argument('--log-memory-to-tensorboard', - action='store_true', - help='Enable memory logging to tensorboard.') - group.add_argument('--log-world-size-to-tensorboard', - action='store_true', - help='Enable world size logging to tensorboard.') - group.add_argument('--wandb-project', type=str, default='', - help='The wandb project name. Ignore wandb by default.') - group.add_argument('--wandb-exp-name', type=str, default='', - help='The wandb experiment name.') - group.add_argument('--wandb-save-dir', type=str, default='', - help='Path to save the wandb results locally.') - return parser - - -def _add_regularization_args(parser): - group = parser.add_argument_group(title='regularization') - - group.add_argument('--attention-dropout', type=float, default=0.1, - help='Post attention dropout probability.') - group.add_argument('--hidden-dropout', type=float, default=0.1, - help='Dropout probability for hidden state transformer.') - group.add_argument('--weight-decay', type=float, default=0.01, - help='Weight decay coefficient for L2 regularization.') - group.add_argument('--start-weight-decay', type=float, - help='Initial weight decay coefficient for L2 regularization.') - group.add_argument('--end-weight-decay', type=float, - help='End of run weight decay coefficient for L2 regularization.') - group.add_argument('--weight-decay-incr-style', type=str, default='constant', - choices=['constant', 'linear', 'cosine'], - help='Weight decay increment function.') - group.add_argument('--clip-grad', type=float, default=1.0, - help='Gradient clipping based on global L2 norm.') - group.add_argument('--adam-beta1', type=float, default=0.9, - help='First coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-beta2', type=float, default=0.999, - help='Second coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-eps', type=float, default=1e-08, - help='Term added to the denominator to improve' - 'numerical stability') - group.add_argument('--sgd-momentum', type=float, default=0.9, - help='Momentum factor for sgd') - return parser - - -def _add_training_args(parser): - group = parser.add_argument_group(title='training') - - group.add_argument('--micro-batch-size', type=int, default=None, - help='Batch size per model instance (local batch size). ' - 'Global batch size is local batch size times data ' - 'parallel size times number of micro batches.') - group.add_argument('--batch-size', type=int, default=None, - help='Old batch size parameter, do not use. ' - 'Use --micro-batch-size instead') - group.add_argument('--global-batch-size', type=int, default=None, - help='Training batch size. If set, it should be a ' - 'multiple of micro-batch-size times data-parallel-size. ' - 'If this value is None, then ' - 'use micro-batch-size * data-parallel-size as the ' - 'global batch size. This choice will result in 1 for ' - 'number of micro-batches.') - group.add_argument('--rampup-batch-size', nargs='*', default=None, - help='Batch size ramp up with the following values:' - ' --rampup-batch-size ' - ' ' - ' ' - 'For example:' - ' --rampup-batch-size 16 8 300000 \ ' - ' --global-batch-size 1024' - 'will start with global batch size 16 and over ' - ' (1024 - 16) / 8 = 126 intervals will increase' - 'the batch size linearly to 1024. In each interval' - 'we will use approximately 300000 / 126 = 2380 samples.') - group.add_argument('--recompute-activations', action='store_true', - help='recompute activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--recompute-granularity', type=str, default=None, - choices=['full', 'selective'], - help='Checkpoint activations to allow for training ' - 'with larger models, sequences, and batch sizes. ' - 'It is supported at two granularities 1) full: ' - 'whole transformer layer is recomputed, ' - '2) selective: core attention part of the transformer ' - 'layer is recomputed.') - group.add_argument('--no-check-for-nan-in-loss-and-grad', action='store_false', - help='Check for NaNs in loss and grad', - dest='check_for_nan_in_loss_and_grad') - group.add_argument('--distribute-saved-activations', - action='store_true', - help='If set, distribute recomputed activations ' - 'across model parallel group.') - group.add_argument('--recompute-method', type=str, default=None, - choices=['uniform', 'block'], - help='1) uniform: uniformly divide the total number of ' - 'Transformer layers and recompute the input activation of ' - 'each divided chunk at specified granularity, ' - '2) recompute the input activations of only a set number of ' - 'individual Transformer layers per pipeline stage and do the ' - 'rest without any recomputing at specified granularity' - 'default) do not apply activations recompute to any layers') - group.add_argument('--recompute-num-layers', type=int, default=None, - help='1) uniform: the number of Transformer layers in each ' - 'uniformly divided recompute unit, ' - '2) block: the number of individual Transformer layers ' - 'to recompute within each pipeline stage.') - group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', - help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', - dest='clone_scatter_output_in_embedding') - group.add_argument('--profile', action='store_true', - help='Enable nsys profiling. When using this option, nsys ' - 'options should be specified in commandline. An example ' - 'nsys commandline is `nsys profile -s none -t nvtx,cuda ' - '-o --force-overwrite true ' - '--capture-range=cudaProfilerApi ' - '--capture-range-end=stop`.') - group.add_argument('--profile-step-start', type=int, default=10, - help='Global step to start profiling.') - group.add_argument('--profile-step-end', type=int, default=12, - help='Global step to stop profiling.') - group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], - help='Global ranks to profile.') - group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the ' - ' overlap of Tensor parallel communication and GEMM kernels.') - group.add_argument('--tp-comm-overlap-cfg', type=str, default=None, - help = 'Config file when tp_comm_overlap is enabled.') - group.add_argument('--disable-tp-comm-split-ag', action='store_false', - help = 'Disables the All-Gather overlap with fprop GEMM.', - dest='tp_comm_split_ag') - group.add_argument('--disable-tp-comm-split-rs', action='store_false', - help = 'Disables the Reduce-Scatter overlap with fprop GEMM.', - dest='tp_comm_split_rs') - group.add_argument('--disable-tp-comm-bulk-dgrad', action='store_false', - help = 'Disables the All-Gather overlap with bprop activation gradient GEMM.', - dest='tp_comm_bulk_dgrad') - group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false', - help = 'Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.', - dest='tp_comm_bulk_wgrad') - - - # deprecated - group.add_argument('--checkpoint-activations', action='store_true', - help='Checkpoint activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--train-iters', type=int, default=None, - help='Total number of iterations to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--train-samples', type=int, default=None, - help='Total number of samples to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--log-interval', type=int, default=100, - help='Report loss and timing interval.') - group.add_argument('--exit-interval', type=int, default=None, - help='Exit the program after the iteration is divisible ' - 'by this value.') - group.add_argument('--exit-duration-in-mins', type=int, default=None, - help='Exit the program after this many minutes.') - group.add_argument('--exit-signal-handler', action='store_true', - help='Dynamically save the checkpoint and shutdown the ' - 'training if SIGTERM is received') - group.add_argument('--tensorboard-dir', type=str, default=None, - help='Write TensorBoard logs to this directory.') - group.add_argument('--no-masked-softmax-fusion', - action='store_false', - help='Disable fusion of query_key_value scaling, ' - 'masking, and softmax.', - dest='masked_softmax_fusion') - group.add_argument('--no-bias-gelu-fusion', action='store_false', - help='Disable bias and gelu fusion.', - dest='bias_gelu_fusion') - group.add_argument('--no-bias-dropout-fusion', action='store_false', - help='Disable bias and dropout fusion.', - dest='bias_dropout_fusion') - group.add_argument('--use-flash-attn', action='store_true', - help='use FlashAttention implementation of attention. ' - 'https://arxiv.org/abs/2205.14135') - group.add_argument('--disable-bias-linear', action='store_false', - help='Disable bias in the linear layers', - dest='add_bias_linear') - group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd'], - help='Optimizer function') - group.add_argument('--dataloader-type', type=str, default=None, - choices=['single', 'cyclic'], - help='Single pass vs multiple pass data loader') - group.add_argument('--no-async-tensor-model-parallel-allreduce', - action='store_false', - help='Disable asynchronous execution of ' - 'tensor-model-parallel all-reduce with weight ' - 'gradient compuation of a column-linear layer.', - dest='async_tensor_model_parallel_allreduce') - group.add_argument('--no-persist-layer-norm', action='store_true', - help='Disable using persistent fused layer norm kernel. ' - 'This kernel supports only a set of hidden sizes. Please ' - 'check persist_ln_hidden_sizes if your hidden ' - 'size is supported.') - group.add_argument('--sequence-parallel', action='store_true', - help='Enable sequence parallel optimization.') - group.add_argument('--no-gradient-accumulation-fusion', - action='store_false', - help='Disable fusing gradient accumulation to weight ' - 'gradient computation of linear layers', - dest='gradient_accumulation_fusion') - group.add_argument('--use-mcore-models', action='store_true', - help='Use the implementation from megatron core') - group.add_argument('--manual-gc', action='store_true', - help='Disable the threshold-based default garbage ' - 'collector and trigger the garbage collection manually. ' - 'Manual garbage collection helps to align the timing of ' - 'the collection across ranks which mitigates the impact ' - 'of CPU-associated jitters. When the manual gc is enabled, ' - 'garbage collection is performed only at the start and the ' - 'end of the validation routine by default.') - group.add_argument('--manual-gc-interval', type=int, default=0, - help='Training step interval to trigger manual garbage ' - 'collection. When the value is set to 0, garbage ' - 'collection is not triggered between training steps.') - group.add_argument('--no-manual-gc-eval', action='store_false', - help='When using manual garbage collection, disable ' - 'garbage collection at the start and the end of each ' - 'evaluation run.', dest='manual_gc_eval') - - return parser - - -def _add_initialization_args(parser): - group = parser.add_argument_group(title='initialization') - - group.add_argument('--seed', type=int, default=1234, - help='Random seed used for python, numpy, ' - 'pytorch, and cuda.') - group.add_argument('--data-parallel-random-init', action='store_true', - help='Enable random initialization of params ' - 'across data parallel ranks') - group.add_argument('--init-method-std', type=float, default=0.02, - help='Standard deviation of the zero mean normal ' - 'distribution used for weight initialization.') - group.add_argument('--init-method-xavier-uniform', action='store_true', - help='Enable Xavier uniform parameter initialization') - - return parser - - -def _add_learning_rate_args(parser): - group = parser.add_argument_group(title='learning rate') - - group.add_argument('--lr', type=float, default=None, - help='Initial learning rate. Depending on decay style ' - 'and initial warmup, the learing rate at each ' - 'iteration would be different.') - group.add_argument('--lr-decay-style', type=str, default='linear', - choices=['constant', 'linear', 'cosine', 'inverse-square-root'], - help='Learning rate decay function.') - group.add_argument('--lr-decay-iters', type=int, default=None, - help='number of iterations to decay learning rate over,' - ' If None defaults to `--train-iters`') - group.add_argument('--lr-decay-samples', type=int, default=None, - help='number of samples to decay learning rate over,' - ' If None defaults to `--train-samples`') - group.add_argument('--lr-warmup-fraction', type=float, default=None, - help='fraction of lr-warmup-(iters/samples) to use ' - 'for warmup (as a float)') - group.add_argument('--lr-warmup-iters', type=int, default=0, - help='number of iterations to linearly warmup ' - 'learning rate over.') - group.add_argument('--lr-warmup-samples', type=int, default=0, - help='number of samples to linearly warmup ' - 'learning rate over.') - group.add_argument('--lr-warmup-init', type=float, default=0.0, - help='Initial value for learning rate warmup. The ' - 'scheduler starts warmup from this value.') - group.add_argument('--warmup', type=int, default=None, - help='Old lr warmup argument, do not use. Use one of the' - '--lr-warmup-* arguments above') - group.add_argument('--min-lr', type=float, default=0.0, - help='Minumum value for learning rate. The scheduler' - 'clip values below this threshold.') - group.add_argument('--override-opt_param-scheduler', action='store_true', - help='Reset the values of the scheduler (learning rate,' - 'warmup iterations, minimum learning rate, maximum ' - 'number of iterations, and decay style from input ' - 'arguments and ignore values from checkpoints. Note' - 'that all the above values will be reset.') - group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true', - help='Use checkpoint to set the values of the scheduler ' - '(learning rate, warmup iterations, minimum learning ' - 'rate, maximum number of iterations, and decay style ' - 'from checkpoint and ignore input arguments.') - - return parser - - -def _add_checkpointing_args(parser): - group = parser.add_argument_group(title='checkpointing') - - group.add_argument('--save', type=str, default=None, - help='Output directory to save checkpoints to.') - group.add_argument('--save-interval', type=int, default=None, - help='Number of iterations between checkpoint saves.') - group.add_argument('--no-save-optim', action='store_true', default=None, - help='Do not save current optimizer.') - group.add_argument('--no-save-rng', action='store_true', default=None, - help='Do not save current rng state.') - group.add_argument('--load', type=str, default=None, - help='Directory containing a model checkpoint.') - group.add_argument('--no-load-optim', action='store_true', default=None, - help='Do not load optimizer when loading checkpoint.') - group.add_argument('--no-load-rng', action='store_true', default=None, - help='Do not load rng state when loading checkpoint.') - group.add_argument('--finetune', action='store_true', - help='Load model for finetuning. Do not load optimizer ' - 'or rng state from checkpoint and set iteration to 0. ' - 'Assumed when loading a release checkpoint.') - group.add_argument('--no-initialization', action='store_false', - help='Do not perform initialization when building model, ' - 'can reduce startup time when definitely loading from a ' - 'checkpoint', - dest='perform_initialization') - group.add_argument('--use-checkpoint-args', action='store_true', - help='Override any command line arguments with arguments ' - 'from the checkpoint') - group.add_argument('--exit-on-missing-checkpoint', action='store_true', - help="If '--load' is set, but checkpoint is not found " - "(e.g., path typo), then exit instead of random " - "initialization.") - - return parser - - -def _add_mixed_precision_args(parser): - group = parser.add_argument_group(title='mixed precision') - - group.add_argument('--fp16', action='store_true', - help='Run model in fp16 mode.') - group.add_argument('--bf16', action='store_true', - help='Run model in bfloat16 mode.') - group.add_argument('--loss-scale', type=float, default=None, - help='Static loss scaling, positive power of 2 ' - 'values can improve fp16 convergence. If None, dynamic' - 'loss scaling is used.') - group.add_argument('--initial-loss-scale', type=float, default=2**32, - help='Initial loss-scale for dynamic loss scaling.') - group.add_argument('--min-loss-scale', type=float, default=1.0, - help='Minimum loss scale for dynamic loss scale.') - group.add_argument('--loss-scale-window', type=float, default=1000, - help='Window over which to raise/lower dynamic scale.') - group.add_argument('--hysteresis', type=int, default=2, - help='hysteresis for dynamic loss scaling') - group.add_argument('--fp32-residual-connection', action='store_true', - help='Move residual connections to fp32.') - group.add_argument('--apply-query-key-layer-scaling', action='store_true', - help='Scale Q * K^T by 1 / layer-number. ' - 'Useful for fp16 training.') - group.add_argument('--attention-softmax-in-fp32', action='store_true', - help='Run attention masking and softmax in fp32. ' - 'This flag is ignored unless ' - '--no-query-key-layer-scaling is specified.') - group.add_argument('--accumulate-allreduce-grads-in-fp32', - action='store_true', - help='Gradient accumulation and all-reduce in fp32.') - group.add_argument('--fp16-lm-cross-entropy', action='store_true', - help='Move the cross entropy unreduced loss calculation' - 'for lm head to fp16.') - - return parser - - -def _add_distributed_args(parser): - group = parser.add_argument_group(title='distributed') - - group.add_argument('--tensor-model-parallel-size', type=int, default=1, - help='Degree of tensor model parallelism.') - group.add_argument('--pipeline-model-parallel-size', type=int, default=1, - help='Degree of pipeline model parallelism.') - group.add_argument('--pipeline-model-parallel-split-rank', - type=int, default=None, - help='Rank where encoder and decoder should be split.') - group.add_argument('--model-parallel-size', type=int, default=None, - help='Old model parallel argument, do not use. Use ' - '--tensor-model-parallel-size instead.') - group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, - help='Number of layers per virtual pipeline stage') - group.add_argument('--no-overlap-p2p-communication', action='store_false', - help='overlap pipeline parallel communication with forward and backward chunks', - dest='overlap_p2p_comm') - group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo'], - help='Which backend to use for distributed training.') - group.add_argument('--distributed-timeout-minutes', type=int, default=10, - help='Timeout minutes for torch.distributed.') - group.add_argument('--overlap-grad-reduce', action='store_true', - default=False, help='If set, overlap DDP grad reduce.') - group.add_argument('--no-delay-grad-reduce', action='store_false', - help='If not set, delay / synchronize grad reductions in all but first PP stage.', - dest='delay_grad_reduce') - group.add_argument('--overlap-param-gather', action='store_true', - default=False, help='If set, overlap param all-gather in distributed optimizer.') - group.add_argument('--delay-param-gather', action='store_true', - default=False, help='If set, delay / synchronize param all-gathers in all but first PP stage.') - group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', - help='If not set, use scatter/gather to optimize communication of tensors in pipeline.', - dest='scatter_gather_tensors_in_pipeline') - group.add_argument('--use-ring-exchange-p2p', action='store_true', - default=False, help='If set, use custom-built ring exchange ' - 'for p2p communications. Note that this option will require ' - 'a custom built image that support ring-exchange p2p.') - group.add_argument('--local_rank', type=int, default=None, - help='local rank passed from distributed launcher.') - group.add_argument('--lazy-mpu-init', type=bool, required=False, - help='If set to True, initialize_megatron() ' - 'skips DDP initialization and returns function to ' - 'complete it instead.Also turns on ' - '--use-cpu-initialization flag. This is for ' - 'external DDP manager.' ) - group.add_argument('--use-cpu-initialization', action='store_true', - default=None, help='If set, affine parallel weights ' - 'initialization uses CPU' ) - group.add_argument('--empty-unused-memory-level', default=0, type=int, - choices=[0, 1, 2], - help='Call torch.cuda.empty_cache() each iteration ' - '(training and eval), to reduce fragmentation.' - '0=off, 1=moderate, 2=aggressive.') - group.add_argument('--standalone-embedding-stage', action='store_true', - default=False, help='If set, *input* embedding layer ' - 'is placed on its own pipeline stage, without any ' - 'transformer layers. (For T5, this flag currently only ' - 'affects the encoder embedding.)') - group.add_argument('--use-distributed-optimizer', action='store_true', - help='Use distributed optimizer.') - group.add_argument('--expert-model-parallel-size', type=int, default=1, - help='Degree of expert model parallelism.') - group.add_argument('--context-parallel-size', type=int, default=1, - help='Degree of context parallelism.') - group.add_argument('--nccl-communicator-config-path', type=str, default=None, - help='Path to the yaml file with NCCL communicator ' - 'configurations. The number of min/max thread groups and thread ' - 'group cluster size of each communicator can be configured by ' - 'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.') - return parser - - -def _add_validation_args(parser): - group = parser.add_argument_group(title='validation') - - group.add_argument('--eval-iters', type=int, default=100, - help='Number of iterations to run for evaluation' - 'validation/test for.') - group.add_argument('--eval-interval', type=int, default=1000, - help='Interval between running evaluation on ' - 'validation set.') - group.add_argument('--skip-train', action='store_true', - default=False, help='If set, bypass the training loop, ' - 'optionally do evaluation for validation/test, and exit.') - - return parser - - -def _add_data_args(parser): - group = parser.add_argument_group(title='data and dataloader') - - group.add_argument('--data-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ... It is used with --split when a ' - 'single dataset used for all three: train, valid ' - 'and test. It is exclusive to the other ' - '--*-data-path args') - group.add_argument('--split', type=str, default='969, 30, 1', - help='Comma-separated list of proportions for training,' - ' validation, and test split. For example the split ' - '`90,5,5` will use 90%% of data for training, 5%% for ' - 'validation and 5%% for test.') - group.add_argument('--train-data-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--valid-data-path', nargs='*', default=None, - help='Path to the validation dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--test-data-path', nargs='*', default=None, - help='Path to the test dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--data-cache-path', default=None, - help='Path to a directory to hold cached index files.') - - group.add_argument('--vocab-size', type=int, default=None, - help='Size of vocab before EOD or padding.') - group.add_argument('--vocab-file', type=str, default=None, - help='Path to the vocab file.') - group.add_argument('--merge-file', type=str, default=None, - help='Path to the BPE merge file.') - group.add_argument('--vocab-extra-ids', type=int, default=0, - help='Number of additional vocabulary tokens. ' - 'They are used for span masking in the T5 model') - group.add_argument('--seq-length', type=int, default=None, - help='Maximum sequence length to process.') - group.add_argument('--encoder-seq-length', type=int, default=None, - help='Maximum encoder sequence length to process.' - 'This should be exclusive of --seq-length') - group.add_argument('--decoder-seq-length', type=int, default=None, - help="Maximum decoder sequence length to process.") - group.add_argument('--retriever-seq-length', type=int, default=256, - help='Maximum sequence length for the biencoder model ' - 'for retriever') - group.add_argument('--sample-rate', type=float, default=1.0, - help='sample rate for training data. Supposed to be 0 ' - ' < sample_rate < 1') - group.add_argument('--mask-prob', type=float, default=0.15, - help='Probability of replacing a token with mask.') - group.add_argument('--short-seq-prob', type=float, default=0.1, - help='Probability of producing a short sequence.') - group.add_argument('--num-workers', type=int, default=2, - help="Dataloader number of workers.") - group.add_argument('--tokenizer-type', type=str, - default=None, - choices=['BertWordPieceLowerCase', - 'BertWordPieceCase', - 'GPT2BPETokenizer', - 'SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', - 'Llama2Tokenizer', - 'NullTokenizer'], - help='What type of tokenizer to use.') - group.add_argument('--tokenizer-model', type=str, default=None, - help='Sentencepiece tokenizer model.') - group.add_argument('--reset-position-ids', action='store_true', - help='Reset posistion ids after end-of-document token.') - group.add_argument('--reset-attention-mask', action='store_true', - help='Reset self attention maske after ' - 'end-of-document token.') - group.add_argument('--eod-mask-loss', action='store_true', - help='Mask loss for the end of document tokens.') - - return parser - - -def _add_autoresume_args(parser): - group = parser.add_argument_group(title='autoresume') - - group.add_argument('--adlr-autoresume', action='store_true', - help='Enable autoresume on adlr cluster.') - group.add_argument('--adlr-autoresume-interval', type=int, default=1000, - help='Intervals over which check for autoresume' - 'termination signal') - - return parser - - -def _add_biencoder_args(parser): - group = parser.add_argument_group(title='biencoder') - - # network size - group.add_argument('--ict-head-size', type=int, default=None, - help='Size of block embeddings to be used in ICT and ' - 'REALM (paper default: 128)') - group.add_argument('--biencoder-projection-dim', type=int, default=0, - help='Size of projection head used in biencoder (paper' - ' default: 128)') - group.add_argument('--biencoder-shared-query-context-model', action='store_true', - help='Whether to share the parameters of the query ' - 'and context models or not') - - # checkpointing - group.add_argument('--ict-load', type=str, default=None, - help='Directory containing an ICTBertModel checkpoint') - group.add_argument('--bert-load', type=str, default=None, - help='Directory containing an BertModel checkpoint ' - '(needed to start ICT and REALM)') - - # data - group.add_argument('--titles-data-path', type=str, default=None, - help='Path to titles dataset used for ICT') - group.add_argument('--query-in-block-prob', type=float, default=0.1, - help='Probability of keeping query in block for ' - 'ICT dataset') - group.add_argument('--use-one-sent-docs', action='store_true', - help='Whether to use one sentence documents in ICT') - group.add_argument('--evidence-data-path', type=str, default=None, - help='Path to Wikipedia Evidence frm DPR paper') - - # training - group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, - default=[], help="Which top-k accuracies to report " - "(e.g. '1 5 20')") - group.add_argument('--retriever-score-scaling', action='store_true', - help='Whether to scale retriever scores by inverse ' - 'square root of hidden size') - - # faiss index - group.add_argument('--block-data-path', type=str, default=None, - help='Where to save/load BlockData to/from') - group.add_argument('--embedding-path', type=str, default=None, - help='Where to save/load Open-Retrieval Embedding' - ' data to/from') - - # indexer - group.add_argument('--indexer-batch-size', type=int, default=128, - help='How large of batches to use when doing indexing ' - 'jobs') - group.add_argument('--indexer-log-interval', type=int, default=1000, - help='After how many batches should the indexer ' - 'report progress') - return parser - - -def _add_vision_args(parser): - group = parser.add_argument_group(title="vision") - - # general vision arguements - group.add_argument('--num-classes', type=int, default=1000, - help='num of classes in vision classificaiton task') - group.add_argument('--img-h', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--img-w', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--num-channels', type=int, default=3, - help='Number of channels in input image data') - group.add_argument('--patch-dim', type=int, default=16, - help='patch dimension') - group.add_argument('--classes-fraction', type=float, default=1.0, - help='training with fraction of classes.') - group.add_argument('--data-per-class-fraction', type=float, default=1.0, - help='training with fraction of data per class.') - group.add_argument('--no-data-sharding', action='store_false', - help='Disable data sharding.', - dest='data_sharding') - group.add_argument('--head-lr-mult', type=float, default=1.0, - help='learning rate multiplier for head during finetuning') - - # pretraining type and backbone selection` - group.add_argument('--vision-pretraining', action='store_true', - help='flag to indicate vision pretraining') - group.add_argument('--vision-pretraining-type', type=str, default='classify', - choices=['classify', 'inpaint', 'dino'], - help='pretraining objectives') - group.add_argument('--vision-backbone-type', type=str, default='vit', - choices=['vit', 'mit', 'swin'], - help='backbone types types') - group.add_argument('--swin-backbone-type', type=str, default='tiny', - choices=['tiny', 'base', 'h3'], - help='pretraining objectives') - - # inpainting arguments - group.add_argument('--mask-type', type=str, default='random', - choices=['random', 'row'], - help='mask types') - group.add_argument('--mask-factor', type=float, default=1.0, - help='mask size scaling parameter') - - # dino arguments - group.add_argument('--iter-per-epoch', type=int, default=1250, - help='iterations per epoch') - group.add_argument('--dino-local-img-size', type=int, default=96, - help='Image size for vision classification task') - group.add_argument('--dino-local-crops-number', type=int, default=10, - help='Number of local crops') - group.add_argument('--dino-head-hidden-size', type=int, default=2048, - help='Hidden dimension size in dino head') - group.add_argument('--dino-bottleneck-size', type=int, default=256, - help='Bottle neck dimension in dino head ') - group.add_argument('--dino-freeze-last-layer', type=float, default=1, - help='Freezing last layer weights') - group.add_argument('--dino-norm-last-layer', action='store_true', - help='Disable Norm in last layer.') - group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, - help='warump teacher temperature') - group.add_argument('--dino-teacher-temp', type=float, default=0.07, - help='teacher temperature') - group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, - help='warmup teacher temperaure epochs') - - return parser - -def _add_experimental_args(parser): - group = parser.add_argument_group(title='experimental') - - group.add_argument('--spec', type=str, default=None, nargs=2, - help='Specify the pair ' - 'that returns a spec to customize a model, transformer ' - 'block, or transformer layer, depending on the use case. ' - 'For more details, see the model class, ' - '`transformer_block.py`, or `transformer_layer.py`') - - return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py deleted file mode 100644 index 2be766e384578cbdff64d199571f0f753c013c6d..0000000000000000000000000000000000000000 --- a/megatron/checkpointing.py +++ /dev/null @@ -1,714 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Input/output checkpointing.""" - -import os -import random -import sys -import numpy as np - -import torch - -from megatron import update_num_microbatches -from megatron.core import mpu, tensor_parallel -from .global_vars import get_args -from .utils import (unwrap_model, - print_rank_0) - - -_CHECKPOINT_VERSION = None - - -def set_checkpoint_version(value): - global _CHECKPOINT_VERSION - if _CHECKPOINT_VERSION is not None: - assert _CHECKPOINT_VERSION == value, \ - "checkpoint versions do not match" - _CHECKPOINT_VERSION = value - - -def get_checkpoint_version(): - global _CHECKPOINT_VERSION - return _CHECKPOINT_VERSION - - -def check_checkpoint_args(checkpoint_args): - """Ensure fixed arguments for a model are the same for the input - arguments and the one retrieved from checkpoint.""" - args = get_args() - - def _compare(arg_name, old_arg_name=None, default=None): - if old_arg_name is not None: - ckpt_arg_name = old_arg_name - else: - ckpt_arg_name = arg_name - if default is not None: - checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default) - else: - checkpoint_value = getattr(checkpoint_args, ckpt_arg_name) - args_value = getattr(args, arg_name) - error_message = '{} value from checkpoint ({}) is not equal to the ' \ - 'input argument value ({}).'.format( - arg_name, checkpoint_value, args_value) - assert checkpoint_value == args_value, error_message - - _compare('num_layers') - _compare('hidden_size') - _compare('num_attention_heads') - _compare('add_position_embedding', default=True) - if args.vocab_file: - _compare('max_position_embeddings') - _compare('make_vocab_size_divisible_by') - _compare('padded_vocab_size') - _compare('tokenizer_type') - if args.data_parallel_random_init: - _compare('data_parallel_random_init') - if get_checkpoint_version() < 3.0: - _compare('tensor_model_parallel_size', - old_arg_name='model_parallel_size') - if get_checkpoint_version() >= 3.0: - _compare('tensor_model_parallel_size') - _compare('pipeline_model_parallel_size') - - -def ensure_directory_exists(filename): - """Build filename's path if it does not already exists.""" - dirname = os.path.dirname(filename) - os.makedirs(dirname, exist_ok = True) - - -def get_checkpoint_name(checkpoints_path, iteration, release=False, - pipeline_parallel=None, - tensor_rank=None, pipeline_rank=None, - expert_parallel=None, expert_rank=None): - """Determine the directory name for this rank's checkpoint.""" - if release: - directory = 'release' - else: - directory = 'iter_{:07d}'.format(iteration) - - # Use both the tensor and pipeline MP rank. - if pipeline_parallel is None: - pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) - if tensor_rank is None: - tensor_rank = mpu.get_tensor_model_parallel_rank() - if pipeline_rank is None: - pipeline_rank = mpu.get_pipeline_model_parallel_rank() - if expert_parallel is None: - expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1) - if expert_rank is None: - expert_rank = mpu.get_expert_model_parallel_rank() - - # Use both the tensor and pipeline MP rank. If using the distributed - # optimizer, then the optimizer's path must additionally include the - # data parallel rank. - if not pipeline_parallel: - common_path = os.path.join(checkpoints_path, directory, - f'mp_rank_{tensor_rank:02d}') - else: - common_path = os.path.join(checkpoints_path, directory, - f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') - - if expert_parallel: - common_path = common_path + f'_{expert_rank:03d}' - - return os.path.join(common_path, "model_optim_rng.pt") - - -def get_distributed_optimizer_checkpoint_name(model_checkpoint_name): - return os.path.join(os.path.dirname(model_checkpoint_name), - "distrib_optim.pt") - - -def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): - """Finds the checkpoint for rank 0 without knowing if we are using - pipeline parallelism/expert parallelism or not. - - Since the checkpoint naming scheme changes if pipeline or expert - parallelism is present, we need to look for both naming schemes if - we don't know if the checkpoint has pipeline or expert parallelism. - """ - - # Look for checkpoint with no pipelining and no expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=False, - tensor_rank=0, pipeline_rank=0, - expert_parallel=False, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with no pipelining and expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=False, - tensor_rank=0, pipeline_rank=0, - expert_parallel=True, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with pipelining and no expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=True, - tensor_rank=0, pipeline_rank=0, - expert_parallel=False, expert_rank=0) - if os.path.isfile(filename): - return filename - - # Look for checkpoint with pipelining and expert parallelism - filename = get_checkpoint_name(checkpoints_path, iteration, release, - pipeline_parallel=True, - tensor_rank=0, pipeline_rank=0, - expert_parallel=True, expert_rank=0) - if os.path.isfile(filename): - return filename - - return None, None - - -def get_checkpoint_tracker_filename(checkpoints_path): - - """Tracker file rescords the latest chckpoint during - training to restart from.""" - return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') - - -def read_metadata(tracker_filename): - # Read the tracker file and either set the iteration or - # mark it as a release checkpoint. - iteration = 0 - release = False - with open(tracker_filename, 'r') as f: - metastring = f.read().strip() - try: - iteration = int(metastring) - except ValueError: - release = metastring == 'release' - if not release: - print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( - tracker_filename)) - sys.exit() - assert iteration > 0 or release, 'error parsing metadata file {}'.format( - tracker_filename) - - # Get the max iteration retrieved across the ranks. - if torch.distributed.is_initialized(): - iters_cuda = torch.cuda.LongTensor([iteration]) - torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) - max_iter = iters_cuda[0].item() - - # We should now have all the same iteration. - # If not, print a warning and chose the maximum - # iteration across all ranks. - if iteration != max_iter: - rank = torch.distributed.get_rank() - print('WARNING: on rank {} found iteration {} in the ' - 'metadata while max iteration across the ranks ' - 'is {}, replacing it with max iteration.'.format( - rank, iteration, max_iter), flush=True) - else: - # When loading a checkpoint outside of training (for example, - # when editing it), we might not have torch distributed - # initialized, in this case, just assume we have the latest - max_iter = iteration - return max_iter, release - - -def get_rng_state(): - """ collect rng state across data parallel ranks """ - args = get_args() - rng_state = { - 'random_rng_state': random.getstate(), - 'np_rng_state': np.random.get_state(), - 'torch_rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state(), - 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} - - rng_state_list = None - if torch.distributed.is_initialized() and \ - mpu.get_data_parallel_world_size() > 1 and \ - args.data_parallel_random_init: - rng_state_list = \ - [None for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather_object( - rng_state_list, - rng_state, - group=mpu.get_data_parallel_group()) - else: - rng_state_list = [rng_state] - - return rng_state_list - - -def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): - """Save a model checkpoint.""" - args = get_args() - - # Only rank zero of the data parallel writes to the disk. - model = unwrap_model(model) - - print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( - iteration, args.save)) - - # Collect rng state across data parallel ranks. - rng_state = get_rng_state() - - # Checkpoint name. - checkpoint_name = get_checkpoint_name(args.save, iteration) - - # Save distributed optimizer's custom parameter state. - if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None: - optim_checkpoint_name = \ - get_distributed_optimizer_checkpoint_name(checkpoint_name) - ensure_directory_exists(optim_checkpoint_name) - optimizer.save_parameter_state(optim_checkpoint_name) - - # Collect args, model, RNG. - if not torch.distributed.is_initialized() \ - or mpu.get_data_modulo_expert_parallel_rank() == 0: - - # Arguments, iteration, and model. - state_dict = {} - state_dict['args'] = args - state_dict['checkpoint_version'] = 3.0 - state_dict['iteration'] = iteration - if len(model) == 1: - state_dict['model'] = model[0].state_dict_for_save_checkpoint() - else: - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - state_dict['model%d' % i] = \ - model[i].state_dict_for_save_checkpoint() - - # Optimizer stuff. - if not args.no_save_optim: - if optimizer is not None: - state_dict['optimizer'] = optimizer.state_dict() - if opt_param_scheduler is not None: - state_dict['opt_param_scheduler'] = \ - opt_param_scheduler.state_dict() - - # RNG states. - if not args.no_save_rng: - state_dict["rng_state"] = rng_state - - # Save. - ensure_directory_exists(checkpoint_name) - torch.save(state_dict, checkpoint_name) - - # Wait so everyone is done (necessary) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \ - .format(iteration, args.save)) - - # And update the latest iteration - if not torch.distributed.is_initialized() \ - or torch.distributed.get_rank() == 0: - tracker_filename = get_checkpoint_tracker_filename(args.save) - with open(tracker_filename, 'w') as f: - f.write(str(iteration)) - - # Wait so everyone is done (not necessary) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - -def _transpose_first_dim(t, num_splits, num_splits_first, model): - input_shape = t.size() - # We use a self_attention module but the values extracted aren't - # specific to self attention so should work for cross attention as well - while hasattr(model, 'module'): - model = model.module - attention_module = model.language_model.encoder.layers[0].self_attention - hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head - num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition - if num_splits_first: - """[num_splits * np * hn, h] - -->(view) [num_splits, np, hn, h] - -->(tranpose) [np, num_splits, hn, h] - -->(view) [np * num_splits * hn, h] """ - - intermediate_shape = \ - (num_splits, num_attention_heads_per_partition, - hidden_size_per_attention_head) + input_shape[1:] - - t = t.view(*intermediate_shape) - t = t.transpose(0, 1).contiguous() - else: - """[np * hn * num_splits, h] - -->(view) [np, hn, num_splits, h] - -->(tranpose) [np, num_splits, hn, h] - -->(view) [np * num_splits * hn, h] """ - - intermediate_shape = \ - (num_attention_heads_per_partition, - hidden_size_per_attention_head, num_splits) +\ - input_shape[1:] - - t = t.view(*intermediate_shape) - t = t.transpose(1, 2).contiguous() - t = t.view(*input_shape) - - return t - - -def fix_query_key_value_ordering(model, checkpoint_version): - """Fix up query/key/value matrix ordering if checkpoint - version is smaller than 2.0 - """ - if checkpoint_version < 2.0: - if isinstance(model, list): - assert len(model)==1 - model = model[0] - for name, param in model.named_parameters(): - if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): - if checkpoint_version == 0: - fixed_param = _transpose_first_dim(param.data, 3, True, model) - elif checkpoint_version == 1.0: - fixed_param = _transpose_first_dim(param.data, 3, False, model) - else: - print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") - sys.exit() - param.data.copy_(fixed_param) - if name.endswith(('.key_value.weight', '.key_value.bias')): - if checkpoint_version == 0: - fixed_param = _transpose_first_dim(param.data, 2, True, model) - elif checkpoint_version == 1.0: - fixed_param = _transpose_first_dim(param.data, 2, False, model) - else: - print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") - sys.exit() - param.data.copy_(fixed_param) - print_rank_0(" succesfully fixed query-key-values ordering for" - " checkpoint version {}".format(checkpoint_version)) - - -def _load_base_checkpoint(load_dir, rank0=False): - """ Load the base state_dict from the given directory - - If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. - """ - - # Read the tracker file and set the iteration. - tracker_filename = get_checkpoint_tracker_filename(load_dir) - - # If no tracker file, return nothing - if not os.path.isfile(tracker_filename): - if not rank0: - print_rank_0('WARNING: could not find the metadata file {} '.format( - tracker_filename)) - print_rank_0(' will not load any checkpoints and will start from ' - 'random') - return None, "", False - - # Otherwise, read the tracker file and either set the iteration or - # mark it as a release checkpoint. - iteration, release = read_metadata(tracker_filename) - - # Checkpoint. - if rank0: - checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release) - else: - checkpoint_name = get_checkpoint_name(load_dir, iteration, release) - if release: - print_rank_0(f' loading release checkpoint from {load_dir}') - else: - print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}') - - # Load the checkpoint. - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except ModuleNotFoundError: - from megatron.fp16_deprecated import loss_scaler - # For backward compatibility. - if not rank0: - print_rank_0(' > deserializing using the old code structure ...') - sys.modules['fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - state_dict = torch.load(checkpoint_name, map_location='cpu') - sys.modules.pop('fp16.loss_scaler', None) - sys.modules.pop('megatron.fp16.loss_scaler', None) - except BaseException as e: - print_rank_0('could not load the checkpoint') - print_rank_0(e) - sys.exit() - - return state_dict, checkpoint_name, release - - -def load_args_from_checkpoint(args, load_arg='load'): - """Set required arguments from the checkpoint specified in the - arguments. - - Will overwrite arguments that have a non-None default value, but - will leave any arguments that default to None as set. - - Returns the same args NameSpace with the new values added/updated. - - If no checkpoint is specified in args, or if the checkpoint is - there but invalid, the arguments will not be modified - - """ - load_dir = getattr(args, load_arg) - - if load_dir is None: - print_rank_0('No load directory specified, using provided arguments.') - return args - - state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True) - - # Args. - if not state_dict: - print_rank_0('Checkpoint not found to provide arguments, using provided arguments.') - return args - - if 'args' not in state_dict: - print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.') - return args - - checkpoint_args = state_dict['args'] - checkpoint_version = state_dict.get('checkpoint_version', 0) - args.iteration = state_dict['iteration'] - - # One-off conversion for foundation models - if hasattr(checkpoint_args, 'disable_bias_linear'): - setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear')) - - def _set_arg(arg_name, old_arg_name=None, force=False): - if not force and getattr(args, arg_name, None) is not None: - return - - if old_arg_name is not None: - checkpoint_value = getattr(checkpoint_args, old_arg_name, None) - else: - checkpoint_value = getattr(checkpoint_args, arg_name, None) - - if checkpoint_value is not None: - print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") - setattr(args, arg_name, checkpoint_value) - else: - print_rank_0(f"Checkpoint did not provide arguments {arg_name}") - - _set_arg('num_layers') - _set_arg('hidden_size') - _set_arg('ffn_hidden_size') - _set_arg('seq_length') - _set_arg('num_attention_heads') - _set_arg('num_query_groups', force=True) - _set_arg('group_query_attention', force=True) - _set_arg('kv_channels') - _set_arg('max_position_embeddings') - _set_arg('position_embedding_type', force=True) - _set_arg('add_position_embedding', force=True) - _set_arg('use_rotary_position_embeddings', force=True) - _set_arg('rotary_percent', force=True) - _set_arg('add_bias_linear', force=True) - _set_arg('swiglu', force=True) - _set_arg('untie_embeddings_and_output_weights', force=True) - _set_arg('apply_layernorm_1p', force=True) - _set_arg('normalization', force=True) - _set_arg('tokenizer_type') - _set_arg('padded_vocab_size') - if checkpoint_version < 3.0: - _set_arg('tensor_model_parallel_size', - 'model_parallel_size') - else: - _set_arg('tensor_model_parallel_size', force=True) - _set_arg('pipeline_model_parallel_size', force=True) - _set_arg('virtual_pipeline_model_parallel_size', force=True) - _set_arg('num_layers_per_virtual_pipeline_stage') - return args, checkpoint_args - - -def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): - """Load a model checkpoint and return the iteration. - strict (bool): whether to strictly enforce that the keys in - :attr:`state_dict` of the checkpoint match the names of - parameters and buffers in model. - """ - args = get_args() - load_dir = getattr(args, load_arg) - - model = unwrap_model(model) - - state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False) - - # Checkpoint not loaded. - if state_dict is None: - - # Conditionally exit at this point. - if args.exit_on_missing_checkpoint: - print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<") - torch.distributed.barrier() - sys.exit() - - # Iteration defaults to 0. - return 0 - - # Set checkpoint version. - set_checkpoint_version(state_dict.get('checkpoint_version', 0)) - - # Set iteration. - if args.finetune or release: - iteration = 0 - else: - try: - iteration = state_dict['iteration'] - except KeyError: - try: # Backward compatible with older checkpoints - iteration = state_dict['total_iters'] - except KeyError: - print_rank_0('A metadata file exists but unable to load ' - 'iteration from checkpoint {}, exiting'.format( - checkpoint_name)) - sys.exit() - - # Check arguments. - assert args.consumed_train_samples == 0 - assert args.consumed_valid_samples == 0 - if 'args' in state_dict and not args.finetune: - checkpoint_args = state_dict['args'] - check_checkpoint_args(checkpoint_args) - args.consumed_train_samples = getattr(checkpoint_args, - 'consumed_train_samples', 0) - update_num_microbatches(consumed_samples=args.consumed_train_samples) - args.consumed_valid_samples = getattr(checkpoint_args, - 'consumed_valid_samples', 0) - else: - print_rank_0('could not find arguments in the checkpoint ...') - - # Model. - if len(model) == 1: - model[0].load_state_dict(state_dict['model'], strict=strict) - else: - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - model[i].load_state_dict(state_dict['model%d' % i], strict=strict) - - # Fix up query/key/value matrix ordering if needed. - checkpoint_version = get_checkpoint_version() - print_rank_0(f' checkpoint version {checkpoint_version}') - fix_query_key_value_ordering(model, checkpoint_version) - - # Optimizer. - if not release and not args.finetune and not args.no_load_optim: - try: - # Load state dict. - if optimizer is not None: - optimizer.load_state_dict(state_dict['optimizer']) - - # Load distributed optimizer's custom parameter state. - if args.use_distributed_optimizer: - tracker_filename = get_checkpoint_tracker_filename(load_dir) - iteration, release = read_metadata(tracker_filename) - model_checkpoint_name = \ - get_checkpoint_name(load_dir, iteration, release) - optim_checkpoint_name = \ - get_distributed_optimizer_checkpoint_name( - model_checkpoint_name) - optimizer.load_parameter_state(optim_checkpoint_name) - - # Load scheduler. - if opt_param_scheduler is not None: - if 'lr_scheduler' in state_dict: # backward compatbility - opt_param_scheduler.load_state_dict(state_dict['lr_scheduler']) - else: - opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler']) - except KeyError: - print_rank_0('Unable to load optimizer from checkpoint {}. ' - 'Specify --no-load-optim or --finetune to prevent ' - 'attempting to load the optimizer state, ' - 'exiting ...'.format(checkpoint_name)) - sys.exit() - else: - if (args.fp16 or args.bf16) and optimizer is not None: - optimizer.reload_model_params() - - # rng states. - if not release and not args.finetune and not args.no_load_rng: - try: - if 'rng_state' in state_dict: - # access rng_state for data parallel rank - if args.data_parallel_random_init: - rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] - else: - rng_state = state_dict['rng_state'][0] - random.setstate(rng_state['random_rng_state']) - np.random.set_state(rng_state['np_rng_state']) - torch.set_rng_state(rng_state['torch_rng_state']) - torch.cuda.set_rng_state(rng_state['cuda_rng_state']) - # Check for empty states array - if not rng_state['rng_tracker_states']: - raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - rng_state['rng_tracker_states']) - else: # backward compatability - random.setstate(state_dict['random_rng_state']) - np.random.set_state(state_dict['np_rng_state']) - torch.set_rng_state(state_dict['torch_rng_state']) - torch.cuda.set_rng_state(state_dict['cuda_rng_state']) - # Check for empty states array - if not state_dict['rng_tracker_states']: - raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - state_dict['rng_tracker_states']) - except KeyError: - print_rank_0('Unable to load rng state from checkpoint {}. ' - 'Specify --no-load-rng or --finetune to prevent ' - 'attempting to load the rng state, ' - 'exiting ...'.format(checkpoint_name)) - sys.exit() - - # Some utilities want to load a checkpoint without distributed being initialized - if torch.distributed.is_initialized(): - torch.distributed.barrier() - - print_rank_0(f' successfully loaded checkpoint from {args.load} ' - f'at iteration {iteration}') - - return iteration - - -def load_biencoder_checkpoint(model, only_query_model=False, - only_context_model=False, custom_load_path=None): - """ - selectively load retrieval models for indexing/retrieving - from saved checkpoints - """ - - args = get_args() - - model = unwrap_model(model) - - load_path = custom_load_path if custom_load_path is not None else args.load - - tracker_filename = get_checkpoint_tracker_filename(load_path) - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - - checkpoint_name = get_checkpoint_name(load_path, iteration, - args.use_distributed_optimizer, - release=False) - - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - state_dict = torch.load(checkpoint_name, map_location='cpu') - ret_state_dict = state_dict['model'] - - if only_query_model: - ret_state_dict.pop('context_model') - if only_context_model: - ret_state_dict.pop('query_model') - - assert len(model) == 1 - model[0].load_state_dict(ret_state_dict) - torch.distributed.barrier() - - if mpu.get_data_parallel_rank() == 0: - print(' successfully loaded {}'.format(checkpoint_name)) - - return model diff --git a/megatron/core/README.md b/megatron/core/README.md deleted file mode 100644 index 0c8c61738da2f6526d065eb600e33d23187bcd0c..0000000000000000000000000000000000000000 --- a/megatron/core/README.md +++ /dev/null @@ -1 +0,0 @@ -Megatron Core is a library for efficient and scalable training of transformer based models. diff --git a/megatron/core/__init__.py b/megatron/core/__init__.py deleted file mode 100644 index 2858dc692dff393ba9e01dc34cf75e6c74834346..0000000000000000000000000000000000000000 --- a/megatron/core/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -import megatron.core.tensor_parallel -import megatron.core.utils -from megatron.core import parallel_state -from megatron.core.distributed import DistributedDataParallel -from megatron.core.inference_params import InferenceParams -from megatron.core.model_parallel_config import ModelParallelConfig - -# Alias parallel_state as mpu, its legacy name -mpu = parallel_state - -__all__ = [ - "parallel_state", - "tensor_parallel", - "utils", - "DistributedDataParallel", - "InferenceParams", - "ModelParallelConfig", -] diff --git a/megatron/core/datasets/Makefile b/megatron/core/datasets/Makefile deleted file mode 100644 index 8f9db7686696fbea6c94b998db4b40ef426c748d..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color -CPPFLAGS += $(shell python3 -m pybind11 --includes) -LIBNAME = helpers -LIBEXT = $(shell python3-config --extension-suffix) - -default: $(LIBNAME)$(LIBEXT) - -%$(LIBEXT): %.cpp - $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/megatron/core/datasets/__init__.py b/megatron/core/datasets/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py deleted file mode 100644 index 89f3bbc9e5e25b301ef9559f4ff1dc1898faf61c..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/blended_dataset.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import hashlib -import json -import logging -import os -import time -from collections import OrderedDict -from typing import Dict, List, Tuple, Union - -import numpy -import torch - -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import log_single_rank, normalize - -logger = logging.getLogger(__name__) - -_VERBOSE = False - - -class BlendedDataset(torch.utils.data.Dataset): - """Conjugating class for a set of MegatronDataset instances - - Args: - datasets (List[MegatronDataset]): The MegatronDataset instances to blend - - weights (List[float]): The weights which determines the dataset blend ratios - - size (int): The number of samples to draw from the blend - - config (BlendedMegatronDatasetConfig): The config object which informs dataset creation - - Raises: - RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization - """ - - def __init__( - self, - datasets: List[MegatronDataset], - weights: List[float], - size: int, - config: BlendedMegatronDatasetConfig, - ) -> None: - assert len(datasets) < 32767 - assert len(datasets) == len(weights) - assert numpy.isclose(sum(weights), 1.0) - assert all(map(lambda _: type(_) == type(datasets[0]), datasets)) - - # Alert user to unnecessary blending - if len(datasets) == 1: - log_single_rank( - logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset" - ) - - # Redundant normalization for bitwise identical comparison with Megatron-LM - weights = normalize(weights) - - self.datasets = datasets - self.weights = weights - self.size = size - self.config = config - - unique_identifiers = OrderedDict() - unique_identifiers["class"] = type(self).__name__ - unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets] - unique_identifiers["weights"] = self.weights - unique_identifiers["size"] = self.size - - self.unique_description = json.dumps(unique_identifiers, indent=4) - self.unique_description_hash = hashlib.md5( - self.unique_description.encode("utf-8") - ).hexdigest() - - self.dataset_index, self.dataset_sample_index = self._build_indices() - - # Check size - _ = self[self.size - 1] - try: - _ = self[self.size] - raise RuntimeError(f"{type(self).__name__} size is improperly bounded") - except IndexError: - log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}") - - def __len__(self) -> int: - return self.size - - def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: - dataset_id = self.dataset_index[idx] - dataset_sample_id = self.dataset_sample_index[idx] - return { - "dataset_id": dataset_id, - **self.datasets[dataset_id][dataset_sample_id], - } - - def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]: - """Build and optionally cache the dataset index and the dataset sample index - - The dataset index is a 1-D mapping which determines the dataset to query. The dataset - sample index is a 1-D mapping which determines the sample to request from the queried - dataset. - - Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index - """ - path_to_cache = getattr(self.config, "path_to_cache") - - if path_to_cache: - get_path_to = lambda suffix: os.path.join( - path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" - ) - path_to_description = get_path_to("description.txt") - path_to_dataset_index = get_path_to("dataset_index.npy") - path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy") - cache_hit = all( - map( - os.path.isfile, - [path_to_description, path_to_dataset_index, path_to_dataset_sample_index], - ) - ) - else: - cache_hit = False - - if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0): - log_single_rank( - logger, logging.INFO, f"Build and save the {type(self).__name__} indices", - ) - - # Build the dataset and dataset sample indexes - log_single_rank( - logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes" - ) - t_beg = time.time() - from megatron.core.datasets import helpers - - dataset_index = numpy.zeros(self.size, dtype=numpy.int16) - dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64) - helpers.build_blending_indices( - dataset_index, - dataset_sample_index, - self.weights, - len(self.datasets), - self.size, - _VERBOSE, - ) - - if path_to_cache: - os.makedirs(path_to_cache, exist_ok=True) - # Write the description - with open(path_to_description, "wt") as writer: - writer.write(self.unique_description) - # Save the indexes - numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True) - numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True) - else: - log_single_rank( - logger, - logging.WARNING, - "Unable to save the indexes because path_to_cache is None", - ) - - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - return dataset_index, dataset_sample_index - - log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices") - - log_single_rank( - logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}" - ) - t_beg = time.time() - dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r') - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, - logging.INFO, - f"\tLoad the dataset sample index from {path_to_dataset_sample_index}", - ) - t_beg = time.time() - dataset_sample_index = numpy.load( - path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r' - ) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - return dataset_index, dataset_sample_index diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py deleted file mode 100644 index 3dee4e469616efe867b72c3b95e8f4ef4cf18720..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/blended_megatron_dataset_builder.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import logging -import math -from typing import Any, List, Optional, Tuple, Type, Union - -import numpy -import torch - -from megatron.core.datasets.blended_dataset import BlendedDataset -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset -from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import Split, normalize - -logger = logging.getLogger(__name__) - -DistributedDataset = Union[BlendedDataset, MegatronDataset, MMapIndexedDataset] - - -class BlendedMegatronDatasetBuilder(object): - """Builder class for the BlendedDataset and MegatronDataset classes - - Args: - cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset - - sizes (List[int]): The minimum number of total samples to draw from each split, varies - with blend - - config (BlendedMegatronDatasetConfig): The config object which informs dataset creation - """ - - def __init__( - self, cls: Type[MegatronDataset], sizes: List[int], config: BlendedMegatronDatasetConfig, - ): - self.cls = cls - self.sizes = sizes - self.config = config - - def build(self) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]: - """Build all dataset splits according to the provided blend(s) - - This method is distributed-aware and must be called on all ranks. - - The dataset splits returned can vary according to the config. Supply config.blend and - config.split to build BlendedDataset and/or MegatronDataset splits from the same - distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset - splits from separate distributions. - - Returns: - List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either - MegatronDataset or BlendedDataset (or None) per split - """ - return self._build_blended_dataset_splits() - - def _build_blended_dataset_splits( - self, - ) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]: - """Build all dataset splits according to the provided blend(s) - - See the BlendedMegatronDatasetBuilder.build alias for more information. - - Returns: - List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either - MegatronDataset or BlendedDataset (or None) per split - """ - - if getattr(self.config, "blend"): - blend = getattr(self.config, "blend") - split = getattr(self.config, "split_vector") - - # Blend consists of a single prefix - if len(blend) == 1: - return self._build_megatron_dataset_splits(blend[0], split, self.sizes) - - # Blend consists of multiple weights and prefixes - ( - prefix_per_dataset, - weight_per_dataset, - sizes_per_dataset, - ) = _get_prefixes_weights_and_sizes_for_blend(blend, self.sizes) - - megatron_datasets = [[] for _ in range(len(Split))] - - for i in range(len(prefix_per_dataset)): - megatron_datasets_split = self._build_megatron_dataset_splits( - prefix_per_dataset[i], split, sizes_per_dataset[i] - ) - for j in range(len(megatron_datasets_split)): - megatron_datasets[j].append(megatron_datasets_split[j]) - - # Sum over all contributing datasets, per split - size_per_split = list(map(sum, zip(*sizes_per_dataset))) - - blended_datasets = [] - - for i in range(len(megatron_datasets)): - is_none = map(lambda _: _ is None, megatron_datasets[i]) - - if split[i] == 0.0: - assert all(is_none) - blended_datasets.append(None) - else: - assert all(is_none) or not any(is_none) - blended_datasets.append( - self._build_generic_dataset( - BlendedDataset, - megatron_datasets[i], - weight_per_dataset, - size_per_split[i], - self.config, - ) - ) - - return blended_datasets - - else: - blended_datasets = [] - for i in range(len(Split)): - blend = getattr(self.config, "blend_per_split")[i] - - # Blend is not provided - if not blend: - blended_datasets.append(None) - continue - - split_spoof = [0.0] * len(Split) - split_spoof[i] = 1.0 - sizes_spoof = [0] * len(Split) - sizes_spoof[i] = self.sizes[i] - - # Blend consists of a sigle prefix - if len(blend) == 1: - blended_datasets.append( - self._build_megatron_dataset_splits(blend[0], split_spoof, sizes_spoof)[i] - ) - - # Blend consists of multiple weights and prefixes - else: - ( - prefix_per_dataset, - weight_per_dataset, - sizes_per_dataset, - ) = _get_prefixes_weights_and_sizes_for_blend(blend, sizes_spoof) - - megatron_datasets = [] - for j in range(len(prefix_per_dataset)): - megatron_datasets.append( - self._build_megatron_dataset_splits( - prefix_per_dataset[j], split_spoof, sizes_per_dataset[j], - )[i] - ) - - size_per_split = list(map(sum, zip(*sizes_per_dataset))) - - blended_datasets.append( - self._build_generic_dataset( - BlendedDataset, - megatron_datasets, - weight_per_dataset, - size_per_split[i], - self.config, - ) - ) - - return blended_datasets - - def _build_megatron_dataset_splits( - self, path_prefix: str, split: List[float], sizes: List[int], - ) -> List[Optional[MegatronDataset]]: - """Build each MegatronDataset split from a single MMapIndexedDataset - - Args: - path_prefix (str): The MMapIndexedDataset .bin and .idx file prefix - - split (List[float]): The dataset split ratios (must sum to 1.00) - - sizes (List[int]): The number of total samples to draw from each split - - Returns: - List[Optional[MegatronDataset]]: The MegatronDatset (or None) per split - """ - indexed_dataset = self._build_generic_dataset( - MMapIndexedDataset, path_prefix, self.cls.is_multimodal() - ) - - if indexed_dataset is not None: - if self.cls.is_split_by_sequence(): - split_idx_bounds = _get_split_indices( - split, indexed_dataset.sequence_lengths.shape[0] - ) - else: - split_idx_bounds = _get_split_indices( - split, indexed_dataset.document_indices.shape[0] - 1 - ) - split_indices = [ - numpy.arange( - start=split_idx_bounds[i], - stop=split_idx_bounds[i + 1], - step=1, - dtype=numpy.int32, - ) - for i, _ in enumerate(Split) - ] - else: - split_indices = [None for _ in Split] - - megatron_datasets = [] - for i, _split in enumerate(Split): - if split[i] == 0.0: - megatron_datasets.append(None) - else: - megatron_datasets.append( - self._build_generic_dataset( - self.cls, indexed_dataset, split_indices[i], sizes[i], _split, self.config - ) - ) - - return megatron_datasets - - def _build_generic_dataset( - self, cls: Type[DistributedDataset], *args: Any, - ) -> Optional[DistributedDataset]: - """Build the DistributedDataset - - Return None if and only if the underlying MegatronDataset class is not built on the current - rank and torch.distributed is initialized. - - Args: - cls (Type[DistributedDataset]): The DistributedDataset class to be built - - args (Tuple[Any]): The positional arguments used to build the provided - DistributedDataset class - - Raises: - Exception: When the dataset constructor raises an OSError - - Returns: - Optional[DistributedDataset]: The DistributedDataset instantion or None - """ - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - - dataset = None - - # First, build on rank 0 - if rank == 0 and getattr(self.config, "is_built_on_rank")(): - try: - dataset = cls(*args) - except OSError as err: - log = ( - f"Failed to write dataset materials to the data cache directory. " - + f"Please supply a directory to which you have write access via " - + f"the path_to_cache attribute in BlendedMegatronDatasetConfig and " - + f"retry. Refer to the preserved traceback above for more information." - ) - raise Exception(log) from err - - torch.distributed.barrier() - - # After, build on other ranks - if rank != 0 and getattr(self.config, "is_built_on_rank")(): - dataset = cls(*args) - - return dataset - - return cls(*args) - - -def _get_split_indices(split: List[float], num_elements: int) -> List[int]: - """Determine the document index bounds per split - - Args: - split (List[float]): The dataset split ratios (must sum to 1.00) - - num_elements (int): The number of elements, e.g. sequences or documents, available for - the split - - Returns: - List[int]: The indices for all three splits e.g. [0, 900, 990, 1000] for a 1000-document - set and a [90.0, 9.0, 1.0] split - """ - split_indices = [0] - for split_pct in split: - split_indices.append(split_indices[-1] + int(round(split_pct * float(num_elements)))) - split_indices[1:] = list( - map(lambda _: _ - (split_indices[-1] - num_elements), split_indices[1:]) - ) - - assert len(split_indices) == len(split) + 1 - assert split_indices[-1] == num_elements - - return split_indices - - -def _get_prefixes_weights_and_sizes_for_blend( - blend: List[str], target_num_samples_per_split: List[int] -) -> Tuple[List[str], List[float], List[List[int]]]: - """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits - - Args: - blend (List[str]): e.g. ["30", "path/to/dataset_1_prefix", "70", - "path/to/dataset_2_prefix"] - - target_num_samples_per_split (List[int]): The number of samples to target for each - BlendedDataset split - - Returns: - Tuple[List[str], List[float], List[List[int]]]: The prefix strings e.g. - ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], the normalized weights e.g. - [0.3, 0.7], and the number of samples to request per MegatronDataset per split - """ - weights, prefixes = zip( - *[(float(blend[i]), blend[i + 1].strip()) for i in range(0, len(blend), 2)] - ) - - weights = normalize(weights) - - # Use 0.5% target margin to ensure we satiate the network - sizes_per_dataset = [ - [ - int(math.ceil(target_num_samples * weight * 1.005)) - for target_num_samples in target_num_samples_per_split - ] - for weight in weights - ] - - return prefixes, weights, sizes_per_dataset diff --git a/megatron/core/datasets/blended_megatron_dataset_config.py b/megatron/core/datasets/blended_megatron_dataset_config.py deleted file mode 100644 index b7e242a4be1ec7e62d0ec89cd00ea1b35d32042b..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/blended_megatron_dataset_config.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import logging -import re -from dataclasses import dataclass, field -from typing import Callable, List, Optional - -import torch - -from megatron.core.datasets.utils import Split, log_single_rank, normalize -from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank - -logger = logging.getLogger(__name__) - - -@dataclass -class BlendedMegatronDatasetConfig: - """Configuration object for megatron-core blended and megatron datasets - - Attributes: - is_built_on_rank (Callable): A callable which returns True if the dataset should be built - on the current rank. It should be Megatron Core parallelism aware i.e. global rank, group - rank, and virtual rank may inform its return value. - - random_seed (int): The seed for all RNG during dataset creation. - - sequence_length (int): The sequence length. - - blend (Optional[List[str]]): The blend string, consisting of either a single dataset or a - flattened sequential sequence of weight-dataset pairs. For exampe, ["dataset-path1"] and - ["50", "dataset-path1", "50", "dataset-path2"] are both valid. Not to be used with - 'blend_per_split'. Defaults to None. - - blend_per_split (blend_per_split: Optional[List[Optional[List[str]]]]): A set of blend - strings, as defined above, one for each split distribution. Not to be used with 'blend'. - Defauls to None. - - split (Optional[str]): The split string, a comma separated weighting for the dataset splits - when drawing samples from a single distribution. Not to be used with 'blend_per_split'. - Defaults to None. - - split_vector: (Optional[List[float]]): The split string, parsed and normalized post- - initialization. Not to be passed to the constructor. - - path_to_cache (str): Where all re-useable dataset indices are to be cached. - """ - - is_built_on_rank: Callable - - random_seed: int - - sequence_length: int - - blend: Optional[List[str]] = None - - blend_per_split: Optional[List[Optional[List[str]]]] = None - - split: Optional[str] = None - - split_vector: Optional[List[float]] = field(init=False, default=None) - - path_to_cache: str = None - - def __post_init__(self): - """Python dataclass method that is used to modify attributes after initialization. See - https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. - """ - if torch.distributed.is_initialized(): - gb_rank = torch.distributed.get_rank() - vp_rank = get_virtual_pipeline_model_parallel_rank() - if gb_rank == 0 and (vp_rank == 0 or vp_rank is None): - assert ( - self.is_built_on_rank() - ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0" - - if self.blend_per_split is not None and any(self.blend_per_split): - assert self.blend is None, "blend and blend_per_split are incompatible" - assert len(self.blend_per_split) == len( - Split - ), f"blend_per_split must contain {len(Split)} blends" - if self.split is not None: - self.split = None - log_single_rank(logger, logging.WARNING, f"Let split = {self.split}") - else: - assert self.blend is not None, "one of either blend or blend_per_split must be provided" - assert self.split is not None, "both blend and split must be provided" - self.split_vector = _parse_and_normalize_split(self.split) - log_single_rank(logger, logging.INFO, f"Let split_vector = {self.split_vector}") - - -@dataclass -class GPTDatasetConfig(BlendedMegatronDatasetConfig): - """Configuration object for megatron-core blended and megatron GPT datasets - - Attributes: - return_document_ids (bool): Whether to return the document ids when querying the dataset. - """ - - return_document_ids: bool = False - - -def _parse_and_normalize_split(split: str) -> List[float]: - """Parse the dataset split ratios from a string - - Args: - split (str): The train valid test split string e.g. "99,1,0" - - Returns: - List[float]: The trian valid test split ratios e.g. [99.0, 1.0, 0.0] - """ - split = list(map(float, re.findall(r"[.0-9]+", split))) - split = split + [0.0 for _ in range(len(Split) - len(split))] - - assert len(split) == len(Split) - assert all(map(lambda _: _ >= 0.0, split)) - - split = normalize(split) - - return split diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py deleted file mode 100644 index 1004e649a297853a4c799ea99ec0e6819f3c9fdf..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/gpt_dataset.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import logging -import os -import time -from typing import Dict, Tuple - -import numpy -import torch - -from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset -from megatron.core.datasets.megatron_dataset import MegatronDataset -from megatron.core.datasets.utils import Split, log_single_rank - -logger = logging.getLogger(__name__) - - -class GPTDataset(MegatronDataset): - """The base GPT dataset - - Args: - indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the - MegatronDataset - - indexed_indices (numpy.ndarray): The set of the documents indices to expose - - num_samples (int): The number of samples to draw from the indexed dataset - - index_split (Split): The indexed_indices Split - - config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters - """ - - def __init__( - self, - indexed_dataset: MMapIndexedDataset, - indexed_indices: numpy.ndarray, - num_samples: int, - index_split: Split, - config: GPTDatasetConfig, - ) -> None: - super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config) - - def _finalize(self) -> None: - """Abstract method implementation - - Load or build/cache the document, sample, and shuffle indices - """ - assert isinstance(self.config, GPTDatasetConfig) - - ( - self.document_index, - self.sample_index, - self.shuffle_index, - ) = self._build_document_sample_shuffle_indices() - - def __len__(self) -> int: - """Abstract method implementation - - Returns: - int: The length of the dataset - """ - return self.sample_index.shape[0] - 1 - - def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: - """Abstract method implementation - - Args: - idx (int): The index into the dataset - - Returns: - Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a - dictionary - """ - text, document_ids = self._query_document_sample_shuffle_indices(idx) - if getattr(self.config, "return_document_ids"): - return {"text": text, "document_ids": document_ids} - else: - return {"text": text} - - @staticmethod - def is_multimodal() -> bool: - """Abstract method implementation - - Returns: - bool: False - """ - return False - - @staticmethod - def is_split_by_sequence() -> bool: - """Abstract method implementation - - Returns: - bool: True - """ - return True - - def _query_document_sample_shuffle_indices( - self, idx: int - ) -> Tuple[numpy.ndarray, numpy.ndarray]: - """Get the text (token ids) and document ids for a given index - - Args: - idx (int): The index into the dataset - - Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids - """ - # Do the shuffle mapping - idx = self.shuffle_index[idx] - - # Get the beginning and end documents and offsets - doc_index_beg, doc_index_beg_offset = self.sample_index[idx] - doc_index_end, doc_index_end_offset = self.sample_index[idx + 1] - - document_ids = [] - sample_parts = [] - - # Sample spans a single document - if doc_index_beg == doc_index_end: - # Add the document id - document_ids.append(self.document_index[doc_index_beg]) - - # Add the entire sample - sample_parts.append( - self.indexed_dataset.get( - self.document_index[doc_index_beg], - offset=doc_index_beg_offset, - length=doc_index_end_offset - doc_index_beg_offset + 1, - ) - ) - - # Sample spans multiple documents - else: - for i in range(doc_index_beg, doc_index_end + 1): - # Add the document id - document_ids.append(self.document_index[i]) - - # Add the sample part - offset = 0 if i > doc_index_beg else doc_index_beg_offset - length = None if i < doc_index_end else doc_index_end_offset + 1 - sample_parts.append( - self.indexed_dataset.get(self.document_index[i], offset=offset, length=length) - ) - - return ( - numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64), - numpy.array(document_ids, dtype=numpy.int64), - ) - - def _build_document_sample_shuffle_indices( - self, - ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: - """Build the document index, the sample index, and the shuffle index - - The document index: - -- 1-D - -- An ordered array of document ids - - The sample index: - -- 2-D - -- The document indices and offsets which mark the start of every sample - - The shuffle index: - -- 1-D - -- A random permutation of index range of the sample index - - Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the - shuffle index - - TODO: Explain the 80% threshold - """ - path_to_cache = getattr(self.config, "path_to_cache") - if path_to_cache is None: - path_to_cache = os.path.join( - self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices" - ) - - get_path_to = lambda suffix: os.path.join( - path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" - ) - path_to_description = get_path_to("description.txt") - path_to_document_index = get_path_to("document_index.npy") - path_to_sample_index = get_path_to("sample_index.npy") - path_to_shuffle_index = get_path_to("shuffle_index.npy") - cache_hit = all( - map( - os.path.isfile, - [ - path_to_description, - path_to_document_index, - path_to_sample_index, - path_to_shuffle_index, - ], - ) - ) - - num_tokens_per_epoch = _get_num_tokens_per_epoch(self.indexed_dataset, self.indexed_indices) - - sequence_length = getattr(self.config, "sequence_length") - - num_epochs = _get_num_epochs(num_tokens_per_epoch, sequence_length, self.num_samples) - - if not cache_hit and torch.distributed.get_rank() == 0: - log_single_rank( - logger, - logging.INFO, - f"Build and save the {type(self).__name__} {self.index_split.name} indices", - ) - - if num_epochs == 1: - separate_final_epoch = False - else: - # Get the number of samples for the last epoch - num_samples_sans_final_epoch = ( - (num_epochs - 1) * num_tokens_per_epoch - 1 - ) // sequence_length - num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch - num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length - - # num_samples_from_final_epoch should be non-negative - assert num_samples_from_final_epoch >= 0 - - # num_samples_from_final_epoch should not exceed max value - assert num_samples_from_final_epoch <= num_samples_per_epoch + 1 - - # Separate the final epoch if it falls below the threshold - threshold = 0.80 - separate_final_epoch = num_samples_from_final_epoch < int( - threshold * num_samples_per_epoch - ) - - log_single_rank( - logger, - logging.DEBUG, - f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}", - ) - log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}") - log_single_rank( - logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}" - ) - - log_single_rank( - logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}" - ) - - numpy_random_state = numpy.random.RandomState(getattr(self.config, "random_seed")) - - os.makedirs(path_to_cache, exist_ok=True) - - # Write the description - with open(path_to_description, "wt") as writer: - writer.write(self.unique_description) - - # Build the document index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}", - ) - t_beg = time.time() - document_index = _build_document_index( - self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch - ) - numpy.save(path_to_document_index, document_index, allow_pickle=True) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - # Build the sample index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}", - ) - t_beg = time.time() - from megatron.core.datasets import helpers - - assert document_index.dtype == numpy.int32 - assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32 - sample_index = helpers.build_sample_idx( - self.indexed_dataset.sequence_lengths, - document_index, - sequence_length, - num_epochs, - num_tokens_per_epoch, - ) - numpy.save(path_to_sample_index, sample_index, allow_pickle=True) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - # Build the shuffle index - log_single_rank( - logger, - logging.INFO, - f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}", - ) - t_beg = time.time() - if separate_final_epoch: - shuffle_index = _build_shuffle_index( - num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state - ) - else: - shuffle_index = _build_shuffle_index( - sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state - ) - numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices" - ) - - log_single_rank( - logger, - logging.INFO, - f"\tLoad the document index from {os.path.basename(path_to_document_index)}", - ) - t_beg = time.time() - document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r') - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, - logging.INFO, - f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}", - ) - t_beg = time.time() - sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r') - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, - logging.INFO, - f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}", - ) - t_beg = time.time() - shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r') - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank( - logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}" - ) - log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") - - return document_index, sample_index, shuffle_index - - -def _get_num_tokens_per_epoch(indexed_dataset: MMapIndexedDataset, indices: numpy.ndarray) -> int: - """Calculate the number of tokens in a single epoch - - Args: - indexed_dataset (MMapIndexedDataset): The underlying MMapIndexedDataset - - indices (numpy.ndarray): The subset of indices into the underlying MMapIndexedDataset - - Returns: - int: The number of tokens in a single epoch - """ - return numpy.sum(indexed_dataset.sequence_lengths[indices]) - - -def _get_num_epochs(num_tokens_per_epoch: int, seq_length: int, num_samples: int) -> int: - """Calculate the number of epochs - - Args: - num_tokens_per_epoch (int): The number of tokens in a single epoch - - seq_length (int): The sequence length in tokens - - num_samples (int): The total number of samples - - Returns: - int: The number of epochs - """ - num_epochs = 0 - num_tokens = 0 - while True: - num_epochs += 1 - num_tokens += num_tokens_per_epoch - # -1 is because we need to retrieve seq_length + 1 token each time - # but the last token will overlap with the first token of the next - # sample except for the last sample. - if ((num_tokens - 1) // seq_length) >= num_samples: - return num_epochs - - -def _build_document_index( - documents: numpy.ndarray, - num_epochs: int, - numpy_random_state: numpy.random.RandomState, - separate_final_epoch: bool, -) -> numpy.ndarray: - """Build an array with length = num epochs * num documents - - Args: - documents (numpy.ndarray): the subset of exposed document indices - - num_epochs (int): The number of epochs - - numpy_random_state (numpy.random.RandomState): The NumPy random state - - separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle - - Returns: - numpy.ndarray: The document index - - TODO: Explain separate_final_epoch - """ - if not separate_final_epoch or num_epochs == 1: - document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1] - document_index[:] = documents - document_index = document_index.reshape(-1) - document_index = document_index.astype(numpy.int32) - numpy_random_state.shuffle(document_index) - return document_index - - doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False) - doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False) - return numpy.concatenate((doc_idx_first, doc_idx_last)) - - -def _build_shuffle_index( - num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState -) -> numpy.ndarray: - """Build the range [0, size) and shuffle - - Args: - num_samples (int): The size of the first shuffle range [0, num_samples) - - total_size (int): The size of the entire index. If larger than 'num_samples', it defines - - the second shuffle range [num_samples, total_size) - - numpy_random_state (numpy.random.RandomState): The NumPy random state - - Returns: - numpy.ndarray: The shuffle index - - TODO: Explain [0, num_samples) [num_samples, total_size) split - """ - dtype_ = numpy.uint32 - if total_size >= (numpy.iinfo(numpy.uint32).max - 1): - dtype_ = numpy.int64 - - shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_) - numpy_random_state.shuffle(shuffle_idx_first) - if num_samples == total_size: - return shuffle_idx_first - - shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) - numpy_random_state.shuffle(shuffle_idx_last) - - return numpy.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/megatron/core/datasets/helpers.cpp b/megatron/core/datasets/helpers.cpp deleted file mode 100644 index 4e1b3dbc931c4cd1fe23dc8c1e2fba029d104086..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/helpers.cpp +++ /dev/null @@ -1,765 +0,0 @@ -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/* Helper methods for fast index mapping builds */ - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace py = pybind11; -using namespace std; - -const int32_t LONG_SENTENCE_LEN = 512; - -void build_blending_indices(py::array_t &dataset_index, - py::array_t &dataset_sample_index, - const py::array_t &weights, - const int32_t num_datasets, - const int64_t size, const bool verbose) -{ - /* Given multiple datasets and a weighting array, build samples - such that it follows those wieghts.*/ - - if (verbose) - { - std::cout << "> building indices for blended datasets ..." << std::endl; - } - - // Get the pointer access without the checks. - auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); - auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); - auto weights_ptr = weights.unchecked<1>(); - - // Initialize buffer for number of samples used for each dataset. - int64_t current_samples[num_datasets]; - for (int64_t i = 0; i < num_datasets; ++i) - { - current_samples[i] = 0; - } - - // For each sample: - for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) - { - - // Determine where the max error in sampling is happening. - auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); - int64_t max_error_index = 0; - double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); - for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) - { - double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); - if (error > max_error) - { - max_error = error; - max_error_index = dataset_idx; - } - } - - // Populate the indices. - dataset_index_ptr[sample_idx] = static_cast(max_error_index); - dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; - - // Update the total samples. - current_samples[max_error_index] += 1; - } - - // print info - if (verbose) - { - std::cout << " > sample ratios:" << std::endl; - for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) - { - auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; - } - } -} - -py::array build_sample_idx(const py::array_t &sizes_, - const py::array_t &doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) -{ - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ - - // Consistency checks. - assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); - - // Remove bound checks. - auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); - - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t *sample_idx = new int32_t[2 * (num_samples + 1)]; - - // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; - // Begining offset for each document. - int32_t doc_offset = 0; - // Start with first document and no offset. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - - while (sample_index <= num_samples) - { - // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) - { - // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) - { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } - else - { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - } - - // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) - { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; }); - - // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples + 1, 2}, // shape - {2 * byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references -} - -inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937 &rand32_gen) -{ - /* Training sample length. */ - if (short_seq_ratio == 0) - { - return max_length; - } - const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) - { - return 2 + random_number % (max_length - 1); - } - return max_length; -} - -template -py::array build_mapping_impl(const py::array_t &docs_, - const py::array_t &sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const double short_seq_prob, - const int32_t seed, - const bool verbose, - const int32_t min_num_sent) -{ - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(short_seq_prob >= 0.0); - assert(short_seq_prob <= 1.0); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - - // For efficiency, convert probability to ratio. Note: rand() generates int. - int32_t short_seq_ratio = 0; - if (short_seq_prob > 0) - { - short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); - } - - if (verbose) - { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl - << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << endl - << std::flush; - cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl - << std::flush; - cout << " total number of sentences: " << num_sentences << endl - << std::flush; - cout << " number of epochs: " << num_epochs << endl - << std::flush; - cout << " maximum number of samples: " << max_num_samples << endl - << std::flush; - cout << " maximum sequence length: " << max_seq_length << endl - << std::flush; - cout << " short sequence probability: " << short_seq_prob << endl - << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << endl - << std::flush; - cout << " seed: " << seed << endl - << std::flush; - } - - // Mapping and it's length (1D). - int64_t num_samples = -1; - DocIdx *maps = NULL; - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration = 0; iteration < 2; ++iteration) - { - - // Set the seed so both iterations produce the same results. - std::mt19937 rand32_gen(seed); - - // Set the flag on second iteration. - second = (iteration == 1); - - // Counters: - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - - // Current map index. - uint64_t map_index = 0; - - // For each epoch: - for (int32_t epoch = 0; epoch < num_epochs; ++epoch) - { - if (map_index >= max_num_samples) - { - if (verbose && (!second)) - { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl - << std::flush; - } - break; - } - // For each document: - for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) - { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) - { - if (num_remain_sent == 0) - { - ++empty_docs; - } - if (num_remain_sent == 1) - { - ++one_sent_docs; - } - } - - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) - { - for (auto sent_index = sent_index_first; - sent_index < sent_index_last; ++sent_index) - { - if (sizes[sent_index] > LONG_SENTENCE_LEN) - { - if ((epoch == 0) && (!second)) - { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - - // If we have more than two sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) - { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - - // Loop through sentences. - for (auto sent_index = sent_index_first; - sent_index < sent_index_last; ++sent_index) - { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent >= min_num_sent)) || - (num_remain_sent == 0)) - { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) - { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) - { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) - { - if (verbose) - { - cout << " number of empty documents: " << empty_docs << endl - << std::flush; - cout << " number of documents with one sentence: " << one_sent_docs << endl - << std::flush; - cout << " number of documents with long sentences: " << long_sent_docs << endl - << std::flush; - cout << " will create mapping for " << map_index << " samples" << endl - << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3 * map_index]; - num_samples = static_cast(map_index); - } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i = (num_samples - 1); i > 0; --i) - { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - } - - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) - { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3 * byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references -} - -py::array build_mapping(const py::array_t &docs_, - const py::array_t &sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const double short_seq_prob, - const int seed, - const bool verbose, - const int32_t min_num_sent) -{ - - if (sizes_.size() > std::numeric_limits::max()) - { - if (verbose) - { - cout << " using uint64 for data mapping..." << endl - << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } - else - { - if (verbose) - { - cout << " using uint32 for data mapping..." << endl - << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } -} - -template -py::array build_blocks_mapping_impl(const py::array_t &docs_, - const py::array_t &sizes_, - const py::array_t &titles_sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const int32_t seed, - const bool verbose, - const bool use_one_sent_blocks) -{ - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - auto titles_sizes = titles_sizes_.unchecked<1>(); - - if (verbose) - { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl - << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << endl - << std::flush; - cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl - << std::flush; - cout << " total number of sentences: " << num_sentences << endl - << std::flush; - cout << " number of epochs: " << num_epochs << endl - << std::flush; - cout << " maximum number of samples: " << max_num_samples << endl - << std::flush; - cout << " maximum sequence length: " << max_seq_length << endl - << std::flush; - cout << " seed: " << seed << endl - << std::flush; - } - - // Mapping and its length (1D). - int64_t num_samples = -1; - DocIdx *maps = NULL; - - // Acceptable number of sentences per block. - int min_num_sent = 2; - if (use_one_sent_blocks) - { - min_num_sent = 1; - } - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration = 0; iteration < 2; ++iteration) - { - - // Set the flag on second iteration. - second = (iteration == 1); - - // Current map index. - uint64_t map_index = 0; - - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - // For each epoch: - for (int32_t epoch = 0; epoch < num_epochs; ++epoch) - { - // assign every block a unique id - int32_t block_id = 0; - - if (map_index >= max_num_samples) - { - if (verbose && (!second)) - { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl - << std::flush; - } - break; - } - // For each document: - for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) - { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - const auto target_seq_len = max_seq_length - titles_sizes[doc]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) - { - if (num_remain_sent == 0) - { - ++empty_docs; - } - if (num_remain_sent == 1) - { - ++one_sent_docs; - } - } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent >= min_num_sent) - { - for (auto sent_index = sent_index_first; - sent_index < sent_index_last; ++sent_index) - { - if (sizes[sent_index] > LONG_SENTENCE_LEN) - { - if ((epoch == 0) && (!second)) - { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - // If we have enough sentences and no long sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) - { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - - // Loop through sentences. - for (auto sent_index = sent_index_first; - sent_index < sent_index_last; ++sent_index) - { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and there are an acceptable number of sentences left - // and if we have at least the minimum number of sentences. - // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent)) || - (num_remain_sent == 0)) - { - - // Populate the map. - if (second) - { - const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) - - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(doc); - maps[map_index_0 + 3] = static_cast(block_id); - } - - // Update indices / counters. - ++map_index; - ++block_id; - prev_start_index = sent_index + 1; - seq_len = 0; - num_sent = 0; - } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) - { - if (verbose) - { - cout << " number of empty documents: " << empty_docs << endl - << std::flush; - cout << " number of documents with one sentence: " << one_sent_docs << endl - << std::flush; - cout << " number of documents with long sentences: " << long_sent_docs << endl - << std::flush; - cout << " will create mapping for " << map_index << " samples" << endl - << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[4 * map_index]; - num_samples = static_cast(map_index); - } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i = (num_samples - 1); i > 0; --i) - { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 4 * i; - const auto j0 = 4 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - swap(maps[i0 + 3], maps[j0 + 3]); - } - - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) - { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4 * byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references -} - -py::array build_blocks_mapping(const py::array_t &docs_, - const py::array_t &sizes_, - const py::array_t &titles_sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const int seed, - const bool verbose, - const bool use_one_sent_blocks) -{ - - if (sizes_.size() > std::numeric_limits::max()) - { - if (verbose) - { - cout << " using uint64 for data mapping..." << endl - << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } - else - { - if (verbose) - { - cout << " using uint32 for data mapping..." << endl - << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } -} - -PYBIND11_MODULE(helpers, m) -{ - m.def("build_mapping", &build_mapping); - m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); - m.def("build_blending_indices", &build_blending_indices); -} diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py deleted file mode 100644 index cd62160ceab7a662d8fff1d597712f90ea9e7aba..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/indexed_dataset.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# Essentially re-written in entirety - -import logging -import os -import shutil -import struct -import time -from enum import Enum -from functools import lru_cache -from itertools import accumulate -from types import TracebackType -from typing import List, Optional, Tuple, Type, Union - -import numpy -import torch - -from megatron.core.datasets.utils import log_single_rank - -logger = logging.getLogger(__name__) - -_INDEX_HEADER = b"MMIDIDX\x00\x00" - - -class DType(Enum): - """The NumPy data type Enum for writing/reading the MMapIndexedDataset indices - """ - - uint8 = 1 - int8 = 2 - int16 = 3 - int32 = 4 - int64 = 5 - float64 = 6 - float32 = 7 - uint16 = 8 - - @classmethod - def code_from_dtype(cls, value: Type[numpy.number]) -> int: - """Get the code from the dtype - - Args: - value (Type[numpy.number]): The dtype - - Returns: - int: The code - """ - return cls[value.__name__].value - - @classmethod - def dtype_from_code(cls, value: int) -> Type[numpy.number]: - """Get the dtype from the code - - Args: - value (int): The code - - Returns: - Type[numpy.number]: The dtype - """ - return getattr(numpy, cls(value).name) - - @staticmethod - def size(key: Union[int, Type[numpy.number]]) -> int: - """Get the size of the dtype/code in bytes - - Args: - key (Union[int, Type[numpy.number]]): The dtype or code - - Raises: - ValueError: If the key is neither dtype nor integer code - - Returns: - int: The size of the dtype/code in in bytes - """ - if isinstance(key, int): - return DType.dtype_from_code(key)().itemsize - elif numpy.number in key.__mro__: - return key().itemsize - else: - raise ValueError - - @staticmethod - def optimal_dtype(cardinality: Optional[int]) -> Type[numpy.number]: - """Get the dtype to use for an index of a certain cardinality - - Args: - cardinality (Optional[int]): The number of elements to be indexed - - Returns: - Type[numpy.number]: The dtype to use for the index - """ - if cardinality is not None and cardinality < 65500: - return numpy.uint16 - else: - return numpy.int32 - - -class _IndexWriter(object): - """Object class to write the index (.idx) file - - Args: - idx_path (str): The path to the index file - - dtype (Type[numpy.number]): The dtype of the index file - """ - - def __init__(self, idx_path: str, dtype: Type[numpy.number]) -> None: - self.idx_path = idx_path - self.dtype = dtype - - def __enter__(self) -> "_IndexWriter": - """Enter the context introduced by the 'with' keyword - - Returns: - _IndexWriter: The instance - """ - self.idx_writer = open(self.idx_path, "wb") - # fixed, vestigial practice - self.idx_writer.write(_INDEX_HEADER) - # fixed, vestigial practice - self.idx_writer.write(struct.pack(" Optional[bool]: - """Exit the context introduced by the 'with' keyword - - Args: - exc_type (Optional[Type[BaseException]]): Exception type - - exc_val (Optional[BaseException]): Exception value - - exc_tb (Optional[TracebackType]): Exception traceback object - - Returns: - Optional[bool]: Whether to silence the exception - """ - self.idx_writer.close() - - def write( - self, - sequence_lengths: List[int], - sequence_modes: Optional[List[int]], - document_indices: List[int], - ) -> None: - """Write the index (.idx) file - - Args: - sequence_lengths (List[int]): The length of each sequence - - sequence_modes (Optional[List[int]]): The mode of each sequences - - document_indices (List[int]): The seqyebce indices demarcating the end of each document - """ - sequence_pointers = self._sequence_pointers(sequence_lengths) - - # the number of sequences in the dataset - sequence_count = len(sequence_lengths) - self.idx_writer.write(struct.pack(" List[int]: - """Build the sequence pointers per the sequence lengths and dtype size - - Args: - sequence_lengths (List[int]): The length of each sequence - - Returns: - List[int]: The pointer to the beginning of each sequence - """ - itemsize = DType.size(self.dtype) - curr_ptr = 0 - list_ptr = [] - for length in sequence_lengths: - list_ptr.append(curr_ptr) - curr_ptr += length * itemsize - return list_ptr - - -class _IndexReader(object): - """Object class to read the index (.idx) file - - Args: - idx_path (str): The path to the index file - - multimodal (bool): Whether the dataset is multimodal - """ - - def __init__(self, idx_path: str, multimodal: bool) -> None: - - log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}") - - with open(idx_path, "rb") as stream: - header = stream.read(9) - assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}" - - version = struct.unpack(" time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers") - t_beg = time.time() - self.sequence_pointers = numpy.frombuffer( - self.bin_buffer, - dtype=numpy.int64, - count=self.sequence_count, - offset=offset + self.sequence_lengths.nbytes, - ) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - log_single_rank(logger, logging.INFO, f"\tExtract the document indices") - t_beg = time.time() - self.document_indices = numpy.frombuffer( - self.bin_buffer, - dtype=numpy.int64, - count=self.document_count, - offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes, - ) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - self.sequence_modes = None - if multimodal: - log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes") - t_beg = time.time() - self.sequence_modes = numpy.frombuffer( - self.bin_buffer, - dtype=numpy.int8, - count=self.sequence_count, - offset=offset - + self.sequence_lengths.nbytes - + self.sequence_pointers.nbytes - + self.document_indices.nbytes, - ) - t_end = time.time() - log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") - - assert self.sequence_lengths.shape[0] == len(self) - assert self.sequence_lengths.shape[0] == self.sequence_count - assert self.sequence_lengths.shape[0] == self.document_indices[-1] - - log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}") - log_single_rank( - logger, - logging.INFO, - f"> total number of documents: {self.document_indices.shape[0] - 1}", - ) - - def __del__(self) -> None: - """Clean up the object - """ - self.bin_buffer_mmap._mmap.close() - del self.bin_buffer_mmap - - def __len__(self) -> int: - """Return the length of the dataset - - Returns: - int: The length of the dataset - """ - return self.sequence_count - - @lru_cache(maxsize=8) - def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: - """Return the pointer, length, and mode at the index - - Args: - idx (int): The index into the dataset - - Returns: - Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at - the index - """ - return ( - self.sequence_pointers[idx], - self.sequence_lengths[idx], - self.sequence_modes[idx] if self.sequence_modes is not None else None, - ) - - -class MMapIndexedDataset(torch.utils.data.Dataset): - """The low-level interface dataset class - - Args: - path_prefix (str): The index (.idx) and data (.bin) prefix - - multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. - """ - - def __init__(self, path_prefix: str, multimodal: bool = False) -> None: - super().__init__() - self.path_prefix = None - self.multimodal = None - - self.index = None - self.bin_buffer = None - self.bin_buffer_mmap = None - - self.initialize(path_prefix, multimodal) - - def initialize(self, path_prefix: str, multimodal: bool) -> None: - """Initialize the dataset - - This method is called by MMapIndexedDataset.__init__ during object creation and by - MMapIndexedDataset.__setstate__ during un-puckling - - Args: - path_prefix (str): The index (.idx) and data (.bin) prefix - - multimodal (bool): Whether the dataset is multimodal - """ - self.path_prefix = path_prefix - self.multimodal = multimodal - self.index = _IndexReader(get_idx_path(self.path_prefix), self.multimodal) - self.bin_buffer_mmap = numpy.memmap(get_bin_path(self.path_prefix), mode="r", order="C") - self.bin_buffer = memoryview(self.bin_buffer_mmap) - - def __getstate__(self) -> Tuple[str, bool]: - """Get the state during pickling - - Returns: - Tuple[str, bool]: The state tuple - """ - return self.path_prefix, self.multimodal - - def __setstate__(self, state: Tuple[str, bool]) -> None: - """Set the state during un-pickling - - Args: - state (Tuple[str, bool]): The state tuple - """ - path_prefix, multimodal = state - self.initialize(path_prefix, multimodal) - - def __del__(self) -> None: - """Clean up the object - """ - if self.bin_buffer_mmap is not None: - self.bin_buffer_mmap._mmap.close() - del self.bin_buffer_mmap - del self.index - - def __len__(self) -> int: - """Return the length of the dataset i.e. the number of sequences in the index - - Returns: - int: The length of the dataset - """ - return len(self.index) - - def __getitem__( - self, idx: Union[int, numpy.integer, slice] - ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: - """Return from the dataset - - Args: - idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset - - Raises: - ValueError: When the index slice is non-contiguous - - TypeError: When the index is of an unexpected type - - Returns: - Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and - modes at the index or index slice - """ - if isinstance(idx, (int, numpy.integer)): - sequence_pointer, sequence_length, sequence_mode = self.index[idx] - sequence = numpy.frombuffer( - self.bin_buffer, - dtype=self.index.dtype, - count=sequence_length, - offset=sequence_pointer, - ) - return (sequence, sequence_mode) if sequence_mode is not None else sequence - elif isinstance(idx, slice): - start, stop, step = idx.indices(len(self)) - if step != 1: - raise ValueError("Slices into indexed_dataset must be contiguous") - sequence_lengths = self.index.sequence_lengths[idx] - sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None - sequence_offsets = list(accumulate(sequence_lengths)) - sequences = numpy.split( - numpy.frombuffer( - self.bin_buffer, - dtype=self.index.dtype, - count=sum(sequence_lengths), - offset=self.index.sequence_pointers[start], - ), - sequence_offsets[:-1], - ) - return (sequences, sequence_modes) if sequence_modes is not None else sequences - else: - raise TypeError("Unexpected type received for idx: {}".format(type(idx))) - - def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray: - """Retrieve a single item from the dataset with the option to only - return a portion of the item. - - get(idx) is the same as [idx] but get() does not support slicing. - """ - sequence_pointer, sequence_length, sequence_mode = self.index[idx] - if length is None: - length = sequence_length - offset - sequence_pointer += offset * DType.size(self.index.dtype) - sequence = numpy.frombuffer( - self.bin_buffer, dtype=self.index.dtype, count=length, offset=sequence_pointer - ) - return (sequence, sequence_mode) if sequence_mode is not None else sequence - - @property - def sequence_lengths(self) -> numpy.ndarray: - """Get the sequence lengths - - Returns: - numpy.ndarray: The sequence lengths - """ - return self.index.sequence_lengths - - @property - def document_indices(self) -> numpy.ndarray: - """Get the document indices - - Returns: - numpy.ndarray: The document indices - """ - return self.index.document_indices - - def get_document_indices(self) -> numpy.ndarray: - """Get the document indices - - This method is slated for deprecation. - - Returns: - numpy.ndarray: The document indices - """ - return self.index.document_indices - - def set_document_indices(self, document_indices: numpy.ndarray) -> None: - """Set the document indices - - This method is slated for deprecation. - - Args: - document_indices (numpy.ndarray): The document indices - """ - self.index.document_indices = document_indices - - @property - def sequence_modes(self) -> numpy.ndarray: - """Get the sequence modes - - Returns: - numpy.ndarray: The sequence modes - """ - return self.index.sequence_modes - - @staticmethod - def exists(path_prefix: str) -> bool: - """Return whether the MMapIndexedDataset exists on disk at the prefix - - Args: - path_prefix (str): The prefix to the index (.idx) and data (.bin) files - - Returns: - bool: Whether the MMapIndexedDataset exists on disk at the prefix - """ - return os.path.exists(get_idx_path(path_prefix)) and os.path.exists( - get_bin_path(path_prefix) - ) - - -class MMapIndexedDatasetBuilder(object): - """Builder class for the MMapIndexedDataset class - - Args: - bin_path (str): The path to the data (.bin) file - - dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32. - - multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. - """ - - def __init__( - self, bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False - ) -> None: - self.data_file = open(bin_path, "wb") - self.dtype = dtype - self.multimodal = multimodal - - self.sequence_lengths = [] - self.document_indices = [0] - self.sequence_modes = [] if self.multimodal else None - - def add_item(self, tensor: torch.Tensor, mode: int = 0) -> None: - """Add a single item to the dataset - - Args: - tensor (torch.Tensor): The item to add to the data file - - mode (int, optional): The mode for the item. Defaults to 0. - """ - np_array = numpy.array(tensor.numpy(), dtype=self.dtype) - self.data_file.write(np_array.tobytes(order="C")) - self.sequence_lengths.append(np_array.size) - if self.multimodal: - self.sequence_modes.append(mode) - - def add_document( - self, tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None - ) -> None: - """Add an entire document to the dataset - - Args: - tensor (torch.Tensor): The document to add - lengths (List[int]): The lengths of each item in the document - modes (Optional[List[int]], optional): The modes for each item in the document. - Defaults to None. - """ - np_array = numpy.array(tensor, dtype=self.dtype) - self.data_file.write(np_array.tobytes(order="C")) - self.sequence_lengths.extend(lengths) - self.document_indices.append(len(self.sequence_lengths)) - if self.multimodal: - self.sequence_modes.extend(modes if modes is not None else [0] * lengths) - - def end_document(self) -> None: - """Finalize the document, for use with MMapIndexedDatasetBuilder.add_item - """ - self.document_indices.append(len(self.sequence_lengths)) - - def add_index(self, path_prefix: str) -> None: - """Add an entire MMapIndexedDataset to the dataset - - Args: - path_prefix (str): The index (.idx) and data (.bin) prefix - """ - # Concatenate index - index = _IndexReader(get_idx_path(path_prefix), multimodal=self.multimodal) - assert index.dtype == self.dtype - - offset = len(self.sequence_lengths) - self.sequence_lengths.extend(index.sequence_lengths) - self.document_indices.extend((offset + index.document_indices)[1:]) - - if self.multimodal: - self.sequence_modes.extend(index.sequence_modes) - - # Concatenate data - with open(get_bin_path(path_prefix), "rb") as f: - shutil.copyfileobj(f, self.data_file) - - def finalize(self, idx_path: str) -> None: - """Clean up and write the index (.idx) file - - Args: - idx_path (str): The path to the index file - """ - self.data_file.close() - with _IndexWriter(idx_path, self.dtype) as writer: - writer.write(self.sequence_lengths, self.sequence_modes, self.document_indices) - - -def get_idx_path(path_prefix: str) -> str: - """Get the path to the index file from the prefix - - Args: - path_prefix (str): The prefix - - Returns: - str: The path to the index file - """ - return path_prefix + ".idx" - - -def get_bin_path(path_prefix: str) -> str: - """Get the path to the data file from the prefix - - Args: - path_prefix (str): The prefix - - Returns: - str: The path to the data file - """ - return path_prefix + ".bin" diff --git a/megatron/core/datasets/megatron_dataset.py b/megatron/core/datasets/megatron_dataset.py deleted file mode 100644 index d75a6455099916128b67fefe1062f028c5071add..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/megatron_dataset.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import hashlib -import json -from abc import ABC, abstractmethod, abstractstaticmethod -from collections import OrderedDict -from typing import Dict, List - -import numpy -import torch - -from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset -from megatron.core.datasets.utils import Split - - -class MegatronDataset(ABC, torch.utils.data.Dataset): - """The wrapper class from which dataset classes should inherit e.g. GPTDataset - - Args: - indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the - MegatronDataset - - indexed_indices (numpy.ndarray): The set of the documents indices to expose - - num_samples (int): The number of samples to draw from the indexed dataset - - index_split (Split): The indexed_indices Split - - config (BlendedMegatronDatasetConfig): The container for all config sourced parameters - """ - - def __init__( - self, - indexed_dataset: MMapIndexedDataset, - indexed_indices: numpy.ndarray, - num_samples: int, - index_split: Split, - config: BlendedMegatronDatasetConfig, - ) -> None: - assert indexed_indices.size > 0 - assert num_samples > 0 - assert self.is_multimodal() == indexed_dataset.multimodal - assert self.is_split_by_sequence() != self.is_split_by_document() - - self.indexed_dataset = indexed_dataset - self.indexed_indices = indexed_indices - self.num_samples = num_samples - self.index_split = index_split - self.config = config - - self.unique_identifiers = OrderedDict() - self.unique_identifiers["class"] = type(self).__name__ - self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix - self.unique_identifiers["num_samples"] = self.num_samples - self.unique_identifiers["index_split"] = self.index_split.name - for attr in self._key_config_attributes(): - self.unique_identifiers[attr] = getattr(self.config, attr) - - self.unique_description = json.dumps(self.unique_identifiers, indent=4) - self.unique_description_hash = hashlib.md5( - self.unique_description.encode("utf-8") - ).hexdigest() - - self._finalize() - - @abstractmethod - def _finalize(self) -> None: - """Build the dataset and assert any subclass-specific conditions - """ - pass - - @abstractmethod - def __len__(self) -> int: - """Return the length of the dataset - - Returns: - int: See abstract implementation - """ - pass - - @abstractmethod - def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: - """Return from the dataset - - Args: - idx (int): The index into the dataset - - Returns: - Dict[str, numpy.ndarray]: See abstract implementation - """ - pass - - @abstractstaticmethod - def is_multimodal() -> bool: - """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal - - Returns: - bool: See abstract implementation - """ - pass - - @abstractstaticmethod - def is_split_by_sequence() -> bool: - """Return whether the dataset is split by sequence - - For example, the GPT train/valid/test split is document agnostic - - Returns: - bool: See abstract implementation - """ - pass - - @classmethod - def is_split_by_document(cls) -> bool: - """Return whether the dataset is split by document - - For example, the BERT train/valid/test split is document aware - - Returns: - bool: The negation of cls.is_split_by_sequence - """ - return not cls.is_split_by_sequence() - - @staticmethod - def _key_config_attributes() -> List[str]: - """Return all config attributes which contribute to uniquely identifying the dataset. - - These attributes will be used to build a uniquely identifying string and MD5 hash which - will be used to cache/load the dataset from run to run. - - Returns: - List[str]: The key config attributes - """ - return ["split", "random_seed", "sequence_length"] diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md deleted file mode 100644 index 77d1e5862f54a9c224d1c4f655883e1b877616f5..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/readme.md +++ /dev/null @@ -1,193 +0,0 @@ -# Data Pipeline - -## Data pre-processing - -Data preprocessing is built around the following classes: - -1. `MMapIndexedDatasetBuilder` -2. `MMapIndexedDataset` - -At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details. - -#### MMapIndexedDatasetBuilder - -The `MMapIndexedDatasetBuilder` is capable of building and merging `MMapIndexedDataset` instances. - -#### MMapIndexedDataset - -The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `MMapIndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata. - -The index file stores dataset-level metadata first: -- The index header, for backward compatibility -- The index version, for backward compatibility -- A numeric code corresponding to the data type used to write data to the data file -- The number of sequences in the dataset -- The number of documents in the dataset - -The index file stores document-level and sequence-level metadata second: -- In order, the number of elements per sequence -- In order, the byte offset (pointer) per sequence -- In order, the consecutive sequence index range `[...)` per document -- In order, the mode per sequence (in the multimodal case) - -## Data loading: construction - -Building the data loaders is a distributed-aware process built around the following classes: - -1. `BlendedMegatronDatasetConfig` -2. `BlendedMegatronDatasetBuilder` -3. `MMapIndexedDataset` -3. `MegatronDataset` -4. `BlendedDataset` - -See the class docstrings for more details. - -#### BlendedMegatronDatasetConfig (extendable) - -The `BlendedMegatronDatasetConfig` class parameterizes the `BlendedMegatronDatasetBuilder` and in turn the `MegatronDataset` and `BlendedDataset`. - -Different training/inference regimes will require different extensions e.g. the `GPTDatasetConfig` - -#### BlendedMegatronDatasetBuilder - -The `BlendedMegatronDatasetBuilder` class builds the highest-level data interfaces in Megatron Core. - -**NB:** All ranks should attempt to build the dataset via the `BlendedMegatronDatasetBuilder` or the program will hang. Which ranks follow through on their attempts can be controlled via the `BlendedMegatronDatasetConfig`. - -#### MMapIndexedDataset - -The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core. - -The `MMapIndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces. - - -#### MegatronDataset (extendable) - -The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MMapIndexedDataset`. - -Different training/inference regimes will require different extensions e.g. the `GPTDataset` - -#### BlendedDataset - -The `BlendedDataset` class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MegatronDataset`. - -The `BlendedDataset` is only necessary when a blend multiple data distributions, i.e. multiple `MegatronDataset` instances, should contribute to a certain dataset split. The blend can be controlled via the `BlendedMegatronDatasetConfig`. - -## Data loading: implementation - -### GPTDataset - -The `GPTDataset` is parameterized by the following variables: the underlying `MMapIndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`. - -The `GPTDataset` creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index. - -1. The document index _Do_idx_ is a 1-D array mapping from _i_ to document index of length `E * |indexed_indices|` where `E` corresponds to the minimum number of epochs such that `E * |indexed_indices| >= N`. The document index is shuffled according to `R`. - - ``` - Given: - - N = 15 - indexed_indices = [5, 6, 7, 8, 9] - E = 3 - - Then, for example: - - Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9] - ``` - -2. The sample index _Sa_idx_ is a 2-D array mapping from _j_ to pairs of (_i_, _Do_idx_[ _i_ ] offset) of shape `[N + 1, 2]`. The rows _j_ and _j_ + 1 serve as the left and right bounds for the _j_-th sample. - - ``` - Given: - - S = 1024 - - Then, for example: - - Sa_idx[0] = (0, 0) - Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S - Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536 - Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536 - Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2] - Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300 - ``` - -3. The shuffle index _Sh_idx_ is a 1-D array mapping from _k_ to _j_ of length `N`. The shuffle index is shuffled according to `R`. - - ``` - Given - - N = 10 - - Then, for example: - - Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3] - ``` - -To query the `GPTDataset` for the _k_-th sample we do the following - -- Use the shuffle index to get the index _j_ into the sample index. - - ``` - j = Sh_idx[k] - ``` -- Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document. - - ``` - i, offset = Sa_idx[j] - i_next, offset_next = Sa_idx[j + 1] - ``` -- Use the document index to retrieve `S` tokens from consecutive (in the document index) documents. - - ``` - sample = [] - sample += indexed_dataset[Do_idx[i]][offset:] - if i != i_next: - sample += indexed_dataset[Do_idx[i + 1:i_next]] - sample += indexed_dataset[Do_idx[i_next]][:offset_next] - ``` - -To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `MegatronDataset.__init__` function. - -### BlendedDataset - -The `BlendedDataset` is parameterized by the following variables: the underlying `MegatronDataset` instances `D`, the weights `W` (one per dataset), and the size `S`. The `BlendedDataset` will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error. - -The `BlendedDataset` creates two "blending" indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index. - -1. The dataset index _Da_idx_ is a 1-D array mapping from _i_ to dataset index of length `S`. - - ``` - Given - - D = [d0, d1, d2] - W = [1/2, 1/4, 1/4] - S = 4 - - Then, for example: - - Da_idx = [0, 1, 2, 0] - - ``` - -2. The dataset sample index _Sa_idx_ is a 1-D mapping from _i_ to the sample index for dataset _Da_idx[i]_ of length `S`. - - ``` - Given - - Da_idx = [0, 1, 2, 0] - - Then, for example: - - Sa_idx = [0, 0, 0, 1] - ``` - -To query the `BlendedDataset` for the _k_-th sample we do the following - -- Use the dataset index to retrieve the corresponding dataset from `D` and the dataset sample index to retrieve the corresponding sample from that dataset. - - ``` - sample = D[Da_idx[k]][Sa_idx[k]] - ``` - -To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function. diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py deleted file mode 100644 index 8a3279b5f44bf0f31f5850a836f2412771356d07..0000000000000000000000000000000000000000 --- a/megatron/core/datasets/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import logging -from enum import Enum -from typing import List - -import numpy -import torch - -logger = logging.getLogger(__name__) - - -class Split(Enum): - train = 0 - valid = 1 - test = 2 - - -def compile_helpers(): - """Compile C++ helper functions at runtime. Make sure this is invoked on a single process. - """ - import os - import subprocess - - command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))] - if subprocess.run(command).returncode != 0: - import sys - - log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions") - sys.exit(1) - - -def log_single_rank(logger: logging.Logger, *args, rank=0, **kwargs): - """If torch distributed is initialized, log only on rank - - Args: - logger (logging.Logger): The logger to write the logs - - rank (int, optional): The rank to write on. Defaults to 0. - """ - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == rank: - logger.log(*args, **kwargs) - else: - logger.log(*args, **kwargs) - - -def normalize(weights: List[float]) -> List[float]: - """Do non-exponentiated normalization - - Args: - weights (List[float]): The weights - - Returns: - List[float]: The normalized weights - """ - w = numpy.array(weights, dtype=numpy.float64) - w_sum = numpy.sum(w) - w = (w / w_sum).tolist() - return w diff --git a/megatron/core/dist_checkpointing/__init__.py b/megatron/core/dist_checkpointing/__init__.py deleted file mode 100644 index df08d7eaba8e55e30a2ade39bb8b01ca580854b5..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -from .core import check_is_distributed_checkpoint -from .mapping import LocalNonpersitentObject, ShardedTensor -from .serialization import ( - load, - load_common_state_dict, - load_plain_tensors, - load_tensors_metadata, - save, -) diff --git a/megatron/core/dist_checkpointing/core.py b/megatron/core/dist_checkpointing/core.py deleted file mode 100644 index f20a0c3a2de730a747b3bf5696b4bdf8b60550ce..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/core.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -import json -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Optional - -CONFIG_FNAME = 'metadata.json' - - -class CheckpointingException(Exception): - pass - - -@dataclass -class CheckpointingConfig: - """ Documents backends used in the checkpoint. """ - - sharded_backend: str - sharded_backend_version: int = 1 - common_backend: str = 'torch' - common_backend_version: int = 1 - - -def check_is_distributed_checkpoint(checkpoint_dir): - return maybe_load_config(checkpoint_dir) is not None - - -def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: - config_path = Path(checkpoint_dir, CONFIG_FNAME) - if not config_path.exists(): - return None - with config_path.open() as f: - config_dict = json.load(f) - return CheckpointingConfig(**config_dict) - - -def save_config(config: CheckpointingConfig, checkpoint_dir: str): - config_path = Path(checkpoint_dir, CONFIG_FNAME) - with config_path.open('w') as f: - json.dump(asdict(config), f) diff --git a/megatron/core/dist_checkpointing/dict_utils.py b/megatron/core/dist_checkpointing/dict_utils.py deleted file mode 100644 index 493a61c91a89a0ee30a7869a8f3a84bef0e1eff4..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/dict_utils.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Utilities for operating with dicts and lists. """ - -from collections import defaultdict -from typing import Any, Callable, Iterable, Optional, Tuple, Union - -import torch - - -def extract_matching_values( - x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False -) -> Tuple[Union[dict, list], Union[dict, list]]: - """ Return matching and nonmatching values. Keeps hierarchy. - - Arguments: - x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list - predicate (object -> bool): determines matching values - return_lists_as_dicts (bool): if True, matching lists will be turned - into dicts, with keys indicating the indices of original elements. - Useful for reconstructing the original hierarchy. - """ - - def _set_elem(target, k, v): - if return_lists_as_dicts: - target[k] = v - else: - target.append(v) - - if isinstance(x, dict): - matching_vals = {} - nonmatching_vals = {} - for k, v in x.items(): - if isinstance(v, (list, dict)): - match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts) - if match: - matching_vals[k] = match - if nonmatch or not v: - nonmatching_vals[k] = nonmatch - elif predicate(v): - matching_vals[k] = v - else: - nonmatching_vals[k] = v - elif isinstance(x, list): - matching_vals = {} if return_lists_as_dicts else [] - nonmatching_vals = {} if return_lists_as_dicts else [] - for ind, v in enumerate(x): - if isinstance(v, (list, dict)) and v: - match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts) - if match: - _set_elem(matching_vals, ind, match) - if nonmatch or not v: - _set_elem(nonmatching_vals, ind, nonmatch) - else: - target = matching_vals if predicate(v) else nonmatching_vals - _set_elem(target, ind, v) - else: - raise ValueError(f'Unexpected top-level object type: {type(x)}') - return matching_vals, nonmatching_vals - - -def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: - mismatch = [] - if isinstance(x1, dict) and isinstance(x2, dict): - only_left = [prefix + (k,) for k in x1.keys() - x2.keys()] - only_right = [prefix + (k,) for k in x2.keys() - x1.keys()] - for k in x2.keys() & x1.keys(): - _left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,)) - only_left.extend(_left) - only_right.extend(_right) - mismatch.extend(_mismatch) - elif isinstance(x1, list) and isinstance(x2, list): - only_left = list(range(len(x1) - 1, len(x2) - 1, -1)) - only_right = list(range(len(x1) - 1, len(x2) - 1, -1)) - for i, (v1, v2) in enumerate(zip(x1, x2)): - _left, _right, _mismatch = diff(v1, v2, prefix + (i,)) - only_left.extend(_left) - only_right.extend(_right) - mismatch.extend(_mismatch) - else: - only_left = [] - only_right = [] - if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): - _is_mismatch = not torch.all(x1 == x2) - else: - try: - _is_mismatch = bool(x1 != x2) - except RuntimeError: - _is_mismatch = True - - if _is_mismatch: - mismatch.append((prefix, type(x1), type(x2))) - - return only_left, only_right, mismatch - - -def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4): - print_indent = lambda: print(' ' * indent * len(prefix), end='') - for k, v in d.items(): - if isinstance(v, dict): - print_indent() - print(f'> {k}:') - inspect_keys_types(v, prefix + (k,), indent) - else: - print_indent() - if isinstance(v, torch.Tensor): - print(f'> {k}: {type(v)} of shape {v.shape}') - else: - print(f'> {k}: {type(v)}') - - -def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): - print_indent = lambda: print(' ' * indent * len(prefix), end='') - if isinstance(x, dict): - print() - for k, v in x.items(): - print_indent() - print(f'> {k}: ', end='') - inspect_types(v, prefix + (k,), indent) - elif isinstance(x, list): - print() - for i, v in enumerate(x): - print_indent() - print(f'- {i}: ', end='') - inspect_types(v, prefix + (i,), indent) - else: - if isinstance(x, torch.Tensor): - print(f'Tensor of shape {x.shape}') - else: - try: - x_str = str(x) - except: - x_str = '' - if len(x_str) > 30: - x_str = x_str[:30] + '... (truncated)' - print(f'[{type(x)}]: {x_str}') - - -def nested_values(x: Union[dict, list]): - x_iter = x.values() if isinstance(x, dict) else x - for v in x_iter: - if isinstance(v, (dict, list)): - yield from nested_values(v) - else: - yield v - - -def nested_items_iter(x: Union[dict, list]): - x_iter = x.items() if isinstance(x, dict) else enumerate(x) - for k, v in x_iter: - if isinstance(v, (dict, list)): - yield from nested_items_iter(v) - else: - yield x, k, v - - -def dict_map(f: Callable, d: dict): - for sub_d, k, v in nested_items_iter(d): - sub_d[k] = f(v) - - -def dict_map_with_key(f: Callable, d: dict): - for sub_d, k, v in nested_items_iter(d): - sub_d[k] = f(k, v) - - -def dict_list_map_inplace(f: Callable, x: Union[dict, list]): - if isinstance(x, dict): - for k, v in x.items(): - x[k] = dict_list_map_inplace(f, v) - elif isinstance(x, list): - x[:] = (dict_list_map_inplace(f, v) for v in x) - else: - return f(x) - return x - - -def dict_list_map_outplace(f: Callable, x: Union[dict, list]): - if isinstance(x, dict): - return {k: dict_list_map_outplace(f, v) for k, v in x.items()} - elif isinstance(x, list): - return [dict_list_map_outplace(f, v) for v in x] - else: - return f(x) - - -def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()): - if isinstance(x1, dict) and isinstance(x2, dict): - for k, v2 in x2.items(): - if k not in x1: - x1[k] = v2 - else: - x1[k] = merge(x1[k], v2, key=key + (k,)) - elif isinstance(x1, list) and isinstance(x2, list): - if len(x1) != len(x2): - raise ValueError( - f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})' - ) - for i, v2 in enumerate(x2): - x1[i] = merge(x1[i], v2, key=key + (i,)) - else: - raise ValueError( - f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})' - ) - return x1 - - -def map_reduce( - xs: Iterable, - key_fn: Callable = lambda x: x, - value_fn: Callable = lambda x: x, - reduce_fn: Callable = lambda x: x, -) -> dict: - res = defaultdict(list) - for x in xs: - res[key_fn(x)].append(value_fn(x)) - for k in res: - res[k] = reduce_fn(res[k]) - return dict(res) diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py deleted file mode 100644 index 2b4d5677d37c01f7196d887f8e7a767d7e0c581b..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/mapping.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Core library classes. """ -import logging -from dataclasses import dataclass, replace -from itertools import chain -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import numpy as np -import torch - -from .core import CheckpointingException -from .dict_utils import dict_list_map_inplace, dict_list_map_outplace - -logger = logging.getLogger(__name__) - -# These type definitions are just hints to differentiate a plain model state -# dict (StateDict) from a state dict with tensors replaced with ShardedTensors -# (ShardedStateDict). -StateDict = Dict[str, Any] -ShardedStateDict = Dict[str, Any] -ReplicaId = Union[int, Tuple[int, ...]] - - -@dataclass -class ShardedTensor: - """Represents a mapping between a local tensor and a global tensor. - - Global tensor is assumed to consist of many local tensors distributed - between different processes. - - Attributes: - key: unique identifier of a global tensor - data: local tensor data. Can be None only for consistency validation - dtype: tensor dtype - local_shape: local tensor shape - global_shape: global tensor shape - global_offset: offset of a local tensor in a global tensor, specified - in number of tensor elements - axis_fragmentations: global tensor fragmentation of each axis - replica_id: indicates given local tensor's replication wrt. local - tensors in different processes - prepend_axis_num: number of axes prepended to the local tensor - to reflect global tensor shape. - The behavior is similar to unsqueezing the local tensor. - allow_shape_mismatch: if True, during loading, the global shape of a - stored tensor does not have to match the expected global shape. - Useful for representing tensors with flexible shape, e.g. padded. - flattened_range: specifies a slice that should be applied to a flattened - tensor with `local_shape` in order to get the tensor stored as `data` - """ - - key: str - data: Optional[torch.Tensor] - dtype: torch.dtype - local_shape: Tuple[int, ...] - global_shape: Tuple[int, ...] - global_offset: Tuple[int, ...] - axis_fragmentations: Optional[Tuple[int, ...]] - replica_id: ReplicaId = 0 - prepend_axis_num: int = 0 - allow_shape_mismatch: bool = False - flattened_range: Optional[slice] = None - - def global_slice(self) -> Tuple[Union[int, slice], ...]: - assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num - return tuple( - chain( - (off for off in self.global_offset[: self.prepend_axis_num]), - ( - slice(off, off + sh) - for off, sh in zip( - self.global_offset[self.prepend_axis_num :], self.local_shape - ) - ), - ) - ) - - def global_coordinates(self) -> Tuple[np.ndarray, ...]: - if self.flattened_range is None: - raise CheckpointingException( - f'`global_coordinates` is undefined for' - f' {self.__class__.__name__} without `flattened_range`' - ) - - local_coords = self.local_coordinates() - assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), ( - len(local_coords), - self, - ) - global_coords = tuple( - c + off - for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset) - ) - return global_coords - - def local_coordinates(self) -> Tuple[np.ndarray, ...]: - if self.flattened_range is None: - raise CheckpointingException( - f'`local_coordinates` is undefined for' - f' {self.__class__.__name__} without `flattened_range`' - ) - - # TODO: np.unravel_index? - mask = np.zeros(np.product(self.local_shape), dtype=bool) - mask[self.flattened_range] = True - return np.nonzero(mask.reshape(self.local_shape)) - - def max_allowed_chunks(self) -> Tuple[int, ...]: - chunks = [] - for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): - if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: - raise CheckpointingException( - f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}' - ) - axis_chunk_size = axis_sh // axis_fragm - chunks.append(axis_chunk_size) - return tuple(chunks) - - def without_data(self): - return replace(self, data=None) - - @classmethod - def from_rank_offsets( - cls, - key: str, - data: torch.Tensor, - *rank_offsets: Tuple[int, int, int], - replica_id: ReplicaId = 0, - prepend_axis_num: int = 0, - allow_shape_mismatch: bool = False, - ): - """Allows to construct the ShardedTensor given offset specified in process ranks. - Arguments: - key: unique key - data: local tensor data - rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) - says that if global tensor is divided into `axis_fragm` - fragment along `axis` axis, then local tensor data - corresponds to the `axis_rank_offset` chunk. - replica_id: see ShardedTensor - prepend_axis_num: see ShardedTensor - allow_shape_mismatch: see ShardedTensor - """ - global_offset = [0] * (data.ndim + prepend_axis_num) - global_shape = ([1] * prepend_axis_num) + list(data.shape) - axis_fragmentations = [1] * (data.ndim + prepend_axis_num) - _seen_axis = set() - for axis, axis_rank_offset, axis_fragm in rank_offsets: - assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, ( - axis, - axis_rank_offset, - axis_fragm, - ) - assert ( - axis_rank_offset < axis_fragm - ), 'Rank offset must be lower than axis fragmentation' - if axis in _seen_axis: - raise CheckpointingException('Duplicated axis specified') - _seen_axis.add(axis) - - local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] - global_shape[axis] = axis_fragm * local_axis_shape - global_offset[axis] = axis_rank_offset * local_axis_shape - axis_fragmentations[axis] = axis_fragm - - return cls( - key, - data, - data.dtype, - tuple(data.shape), - tuple(global_shape), - tuple(global_offset), - tuple(axis_fragmentations), - replica_id, - prepend_axis_num, - allow_shape_mismatch, - ) - - def __str__(self): - return f'{self.__class__.__name__}(key=\'{self.key}\')' - - -def is_main_replica(replica_id): - if isinstance(replica_id, int): - return replica_id == 0 - return all(r == 0 for r in replica_id) - - -class LocalNonpersitentObject: - """Object that should not be stored in a checkpoint, but restored locally. - - Wrapping any object inside the state dict with LocalNonpersitentObject - will result in: - - during saving, this object will *not* be stored in the checkpoint - - during loading, a local version of this object will be placed in a state dict - """ - - def __init__(self, obj): - self.obj = obj - - def unwrap(self): - return self.obj - - -@dataclass -class ShardedObject: - """Represents a mapping between a local object and a global object. - - Global object is assumed to consist of many local objects distributed - between different processes. - - NOTE: Contrary to ShardedTensor, it's impossible to change global object - sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor - with atomic arbitrary typed elements. - - Attributes: - key: unique identifier of a global tensor - data: local object data. Can be None only for consistency validation - global_shape: global object shape - global_offset: offset of a local object in a global object, specified - in number of shards - replica_id: indicates local object replication wrt. local - objects in different processes - """ - - key: str - data: object - global_shape: Tuple[int, ...] - global_offset: Tuple[int, ...] - replica_id: ReplicaId = 0 - - def without_data(self): - return replace(self, data=None) - - @property - def unique_key(self): - return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}' - - def __str__(self): - return f'{self.__class__.__name__}(key=\'{self.key}\')' - - -@dataclass -class ShardedTensorFactory: - """ Allows to apply transformations to tensors before/after serialization. - - The essence of those transformations is that they can be applied to - optimizer states the same way they are applied to the model params. - - Builder creates a sub-state-dict out of a tensor before saving, and merger - merges the corresponding state dict after loading. - """ - - key: str - data: torch.Tensor - build_fn: Callable[[str, torch.Tensor], ShardedStateDict] - merge_fn: Callable[[StateDict], torch.Tensor] - - def build(self): - return self.build_fn(self.key, self.data) - - -def apply_factories(sharded_state_dict: ShardedStateDict): - def apply(x): - if isinstance(x, ShardedTensorFactory): - x = x.build() - return x - - dict_list_map_inplace(apply, sharded_state_dict) - - -def apply_factory_merges(x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()): - if isinstance(x2, ShardedTensorFactory): - return x2.merge_fn(x1) - - # There rest is almost the same as the `merge` function from `dict_utils` - if isinstance(x1, dict) and isinstance(x2, dict): - for k, v2 in x2.items(): - if k not in x1: - raise ValueError( - f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})' - ) - else: - x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) - elif isinstance(x1, list) and isinstance(x2, list): - if len(x1) != len(x2): - err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})' - logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}') - raise ValueError(err_msg) - for i, v2 in enumerate(x2): - x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,)) - elif isinstance(x1, list) and isinstance(x2, dict): - for k, v2 in x2.items(): - if not isinstance(k, int): - raise ValueError( - f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}' - ) - if k >= len(x1): - raise ValueError( - f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})' - ) - x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,)) - else: - raise ValueError( - f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`' - ) - return x1 diff --git a/megatron/core/dist_checkpointing/optimizer.py b/megatron/core/dist_checkpointing/optimizer.py deleted file mode 100644 index d1c698787c4678009f09b5496fa4c5ddc17574d8..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/optimizer.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Optimizer related helpers. """ - -import logging -from copy import deepcopy -from dataclasses import replace -from itertools import chain -from typing import Dict, Iterable, List, Tuple, Union - -logger = logging.getLogger(__name__) - -import torch - -from .dict_utils import nested_values -from .mapping import ( - LocalNonpersitentObject, - ShardedStateDict, - ShardedTensor, - ShardedTensorFactory, - StateDict, -) -from .utils import extract_sharded_tensors, extract_sharded_tensors_and_factories - - -def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]: - param_mappings = {} - for i, param in enumerate(optim_params_iter): - if id(param) not in param_mappings: - param_mappings[id(param)] = i - return param_mappings - - -def get_param_id_to_sharded_param_map( - model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter] -) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: - model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict) - id_to_sharded_param_map = {} - param_to_id_map = get_optim_param_to_id_map(optim_params_iter) - for ten in nested_values(model_sharded_state_dict): - if id(ten.data) in param_to_id_map: - id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten - else: - logger.debug(f'{ten} is not tracked by the optimizer') - - if not id_to_sharded_param_map: - logger.warning( - "Sharded parameters mapping is empty. It means tensors in model state dict" - " do not correspond to tensors in optimizer parameters map." - " Make sure to call state_dict with `keep_vars=True`." - ) - return id_to_sharded_param_map - - -def make_sharded_optimizer_tensor( - model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str -) -> Union[ShardedTensor, ShardedTensorFactory]: - if isinstance(model_param, ShardedTensorFactory): - return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param) - - assert ( - tuple(optim_param.shape) == model_param.local_shape - ), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})' - return replace( - model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype - ) - - -def optim_state_to_sharding_state( - optim_state_dict: StateDict, - id_to_sharded_param_map: Dict[int, ShardedTensor], - exclude_keys: Tuple[str] = (), -): - sharded_state = {} - for param_id, param_state in optim_state_dict['state'].items(): - sharded_state[param_id] = {} - for state_key, param in param_state.items(): - if state_key in exclude_keys: - continue - if param_id in id_to_sharded_param_map: - sharded_state[param_id][state_key] = make_sharded_optimizer_tensor( - id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}' - ) - else: - raise ValueError(f'Param id {param_id} does not match any model sharded param') - - optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups']) - for group in optim_state_dict['param_groups']: - group['params'] = LocalNonpersitentObject(group['params']) - optim_state_dict['state'] = sharded_state diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py deleted file mode 100644 index 85baa16c21eb784369fd87adb8239cd0ac68d05f..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/serialization.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -import logging -import os -from collections import Counter, defaultdict -from itertools import chain -from pathlib import Path -from typing import Iterable, List, Tuple, Union - -import numpy as np -import torch - -from .core import CheckpointingConfig, maybe_load_config, save_config -from .dict_utils import ( - dict_list_map_inplace, - diff, - extract_matching_values, - map_reduce, - merge, - nested_values, -) -from .mapping import ( - CheckpointingException, - ShardedObject, - ShardedStateDict, - ShardedTensor, - ShardedTensorFactory, - StateDict, - apply_factories, - apply_factory_merges, - is_main_replica, -) -from .strategies.base import ( - LoadCommonStrategy, - LoadShardedStrategy, - SaveCommonStrategy, - SaveShardedStrategy, - StrategyAction, - get_default_strategy, -) -from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent - -COMMON_STATE_FNAME = 'common.pt' - -logger = logging.getLogger(__name__) - - -def load( - sharded_state_dict: ShardedStateDict, - checkpoint_dir: str, - sharded_strategy: Union[LoadShardedStrategy, None] = None, - common_strategy: Union[LoadCommonStrategy, None] = None, - validate_access_integrity: bool = True, -) -> StateDict: - """Loading entrypoint. - - Arguments: - sharded_state_dict (ShardedStateDict): state dict of the existing model - populated with ShardedTensors. Used as a mapping to determine which - parts of global tensors stored in the checkpoint should be loaded. - checkpoint_dir (str): directory with the checkpoint - sharded_strategy (LoadShardedStrategy, optional): configures loading behavior for sharded tensors - common_strategy (LoadCommonStrategy, optional): configures loading behavior for common data - validate_access_integrity (bool default = True): checks if each tensor shard is accessed - exactly once (as main replica) by some process - """ - if common_strategy is not None: - raise NotImplementedError('The only supported common strategy is torch') - - checkpoint_dir = Path(checkpoint_dir) - common_state_dict = load_common_state_dict(checkpoint_dir) - if not sharded_state_dict: - return common_state_dict - - sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir) - merge(common_state_dict, sharded_objects) - - saved_config = maybe_load_config(checkpoint_dir) - if saved_config is None: - raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') - - sh_ten_factories, _ = extract_matching_values( - sharded_state_dict, - lambda x: isinstance(x, ShardedTensorFactory), - return_lists_as_dicts=True, - ) - apply_factories(sharded_state_dict) - sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) - sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict) - dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict) - merge(common_state_dict, nonpersistent_state_dict) - - if validate_access_integrity: - validate_sharding_integrity(nested_values(sharded_state_dict)) - - if sharded_strategy is None: - sharded_strategy = get_default_strategy( - StrategyAction.LOAD_SHARDED, - saved_config.sharded_backend, - saved_config.sharded_backend_version, - ) - else: - # TODO: implement consistency checks here - pass - loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) - - loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories) - - merge(common_state_dict, loaded_state_dict) - return common_state_dict - - -# TODO: implement it as common torch strategy -def load_common_state_dict(checkpoint_dir: Path): - return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu') - - -def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - sharded_objects, sharded_state_dict = extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, ShardedObject) - ) - - def load_sharded_object(sh_obj: ShardedObject): - sh_obj.data = None - load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') - loaded_obj = torch.load(load_path) - return loaded_obj - - return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict - - -def load_tensors_metadata( - checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None -) -> ShardedStateDict: - """Load tensors metadata from the checkpoint. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply ShardedTensor keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors without any sharding (so, the only useful - information is tensors global shape and dtype). - - Concrete implementation depends on the loading strategy. If no strategy is - given, a default for a given backend is used. - """ - saved_config = maybe_load_config(checkpoint_dir) - if saved_config is None: - raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint') - - if sharded_strategy is None: - sharded_strategy = get_default_strategy( - StrategyAction.LOAD_SHARDED, - saved_config.sharded_backend, - saved_config.sharded_backend_version, - ) - else: - # TODO: implement consistency checks here - pass - return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir)) - - -def load_plain_tensors(checkpoint_dir: str): - """Load checkpoint tensors without any sharding. - - NOTE: common state dict is NOT included.""" - sharded_state_dict = load_tensors_metadata(checkpoint_dir) - # Don't validate integrity because shards will be overlapped - # if world_size > 1 (all processes load whole tensors) - return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False) - - -def save( - sharded_state_dict: ShardedStateDict, - checkpoint_dir: str, - sharded_strategy: Union[SaveShardedStrategy, None] = None, - common_strategy: Union[SaveCommonStrategy, None] = None, - validate_access_integrity: bool = True, -): - """Saving entrypoint. - - Extracts ShardedTensors from the given state dict. Rank 0 saves the - "regular" part of the checkpoint to common torch file. - The ShardedTensors are saved according to a strategy specified by the - config. - - Arguments: - sharded_state_dict (ShardedStateDict): state dict of the populated with - ShardedTensors. Used as a mapping to determine how local tensors - should be saved as global tensors in the checkpoint. - checkpoint_dir (str): directory to save the checkpoint to - sharded_strategy (SaveShardedStrategy, optional): configures sharded tensors saving behavior and backend - common_strategy (SaveCommonStrategy, optional): configures common data saving behavior and backend - validate_access_integrity (bool default = True): checks if each tensor shard is accessed - exactly once (as main replica) by some process - """ - checkpoint_dir = Path(checkpoint_dir) - - if torch.distributed.get_rank() == 0: - if not checkpoint_dir.exists(): - raise CheckpointingException( - f'Checkpoint destination directory does not exist: {checkpoint_dir}' - ) - - if next(checkpoint_dir.iterdir(), None) is not None: - raise CheckpointingException( - f'Checkpoint destination directory ({checkpoint_dir}) is not empty' - ) - - if common_strategy is not None: - raise NotImplementedError('The only supported common strategy is torch') - - if sharded_strategy is None: - sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1) - - apply_factories(sharded_state_dict) - sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) - sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict) - sharded_tensors = list(nested_values(sharded_state_dict)) - if validate_access_integrity: - validate_sharding_integrity(sharded_tensors) - - _save_common_dict(state_dict, checkpoint_dir, True) - - sharded_strategy.save(sharded_tensors, checkpoint_dir) - save_config( - CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir - ) - - -# TODO: implement it as common torch strategy -def _save_common_dict( - state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False -): - common_state_dict = _extract_and_save_sharded_objects( - state_dict, checkpoint_dir, validate_consistency - ) - if torch.distributed.get_rank() == 0: - torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) - if validate_consistency: - # TODO: implement checking consistency with rank 0 common dict on other ranks - pass - # torch.distributed.barrier() - # if not torch.distributed.get_rank() == 0: - # rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME) - # print(diff(common_state_dict, rank_0_state_dict)) - - -def _extract_and_save_sharded_objects( - state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False -): - sharded_objects, state_dict = extract_matching_values( - state_dict, lambda v: isinstance(v, ShardedObject) - ) - sharded_objects = list(nested_values(sharded_objects)) - if validate_consistency: - validate_objects_sharding_integrity(sharded_objects) - for sh_obj in sharded_objects: - if is_main_replica(sh_obj.replica_id): - save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') - os.makedirs(save_path.parent, exist_ok=True) - torch.save(sh_obj.data, save_path) - return state_dict - - -def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]): - sharding = [ten.without_data() for ten in sharded_tensors] - all_sharding = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_sharding, sharding) - if torch.distributed.get_rank() != 0: - return - - key_shardings = defaultdict(list) - for rank, rank_shardings in enumerate(all_sharding): - for sharding in rank_shardings: - key_shardings[sharding.key].append((rank, sharding)) - for key, shardings in key_shardings.items(): - _validate_sharding_for_key(shardings) - - -def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): - some_rank_shard = rank_sharding[0][1] - global_shape = some_rank_shard.global_shape - local_shape = some_rank_shard.local_shape - dtype = some_rank_shard.dtype - has_flattened_range = some_rank_shard.flattened_range is not None - for rank, sharding in rank_sharding: - assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard) - assert sharding.global_shape == global_shape, ( - sharding.global_shape, - global_shape, - some_rank_shard, - ) - assert sharding.local_shape == local_shape, ( - sharding.local_shape, - local_shape, - some_rank_shard, - ) - assert (sharding.flattened_range is not None) == has_flattened_range, ( - (sharding.flattened_range is not None), - has_flattened_range, - some_rank_shard, - ) - - shard_access_cnt = _compute_shards_access(rank_sharding) - if has_flattened_range: - map_reduce( - rank_sharding, - lambda x: x[1].global_offset, - lambda x: x[1], - _validate_sharding_for_key_flattened, - ) - else: - if not torch.all(shard_access_cnt == 1): - logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') - raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') - - -def _compute_shards_access(rank_sharding): - def chunk_offset(sharding): - assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num - return tuple( - chain( - (off for off in sharding.global_offset[: sharding.prepend_axis_num]), - ( - off // sh - for off, sh in zip( - sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape - ) - ), - ) - ) - - shard_access_cnt = torch.zeros( - rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu' - ) - for rank, sharding in rank_sharding: - if is_main_replica(sharding.replica_id): - shard_access_cnt[chunk_offset(sharding)] += 1 - # TODO: consider validating different replicas too - return shard_access_cnt - - -def _validate_sharding_for_key_flattened(tensors_by_shard): - all_slices = [] - local_shape = tensors_by_shard[0].local_shape - for sharding in tensors_by_shard: - assert sharding.local_shape == local_shape - sharding: ShardedTensor - if not is_main_replica(sharding.replica_id): - # TODO: this checks only saving (and loading replica_id=0) consistency - continue - - all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) - - starts, stops = map(np.asarray, zip(*sorted(all_slices))) - if ( - starts[0] != 0 - or stops[-1] != np.product(local_shape) - or not np.all(starts[1:] == stops[:-1]) - ): - logger.error( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' - ) - raise CheckpointingException( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}' - ) - - -def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]): - """ Ensure uniqueness of saved objects. """ - local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects] - all_sh_objs = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_sh_objs, local_sh_objs) - if torch.distributed.get_rank() != 0: - return - unique_keys = [ - sh_obj.unique_key - for sh_obj in chain.from_iterable(all_sh_objs) - if is_main_replica(sh_obj.replica_id) - ] - if len(unique_keys) != len(set(unique_keys)): - duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1} - logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}') - raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}') diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py deleted file mode 100644 index 7177d973cfccdbf52dbc1fd8efc153f99e389eec..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/strategies/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Various loading and saving strategies """ - -import logging - -logger = logging.getLogger(__name__) - -try: - import tensorstore - import zarr - - from .tensorstore import _import_trigger - from .zarr import _import_trigger -except ImportError: - logger.warning('Zarr-based strategies will not be registered because of missing packages') diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py deleted file mode 100644 index 3989ea74a204349fd8623736c5a2ecaf05b2c06b..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -from abc import ABC, abstractmethod -from collections import defaultdict -from enum import Enum -from pathlib import Path -from typing import Dict, List, Optional - -from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict - - -class StrategyAction(Enum): - LOAD_COMMON = 'load_common' - LOAD_SHARDED = 'load_sharded' - SAVE_COMMON = 'save_common' - SAVE_SHARDED = 'save_sharded' - - -default_strategies = defaultdict(dict) - - -def get_default_strategy(action: StrategyAction, backend: str, version: int): - try: - return default_strategies[action.value][(backend, version)] - except KeyError as e: - hint = '' - if backend == 'zarr': - try: - import tensorstore - import zarr - except ImportError: - hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages' - raise CheckpointingException( - f'Cannot find a default strategy for: {(action.value, backend, version)}.{hint}' - ) from e - - -class LoadStrategyBase(ABC): - @abstractmethod - def check_backend_compatibility(self, loaded_version): - raise NotImplementedError - - @abstractmethod - def check_version_compatibility(self, loaded_version): - raise NotImplementedError - - -class SaveStrategyBase(ABC): - def __init__(self, backend: str, version: int): - self.backend = backend - self.version = version - - -class LoadCommonStrategy(LoadStrategyBase): - @abstractmethod - def load(self, checkpoint_dir: Path): - raise NotImplementedError - - -class LoadShardedStrategy(LoadStrategyBase): - @abstractmethod - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - raise NotImplementedError - - @abstractmethod - def load_tensors_metadata(self, checkpoint_dir: Path): - """Load tensors metadata from the checkpoint. - - Returns a dictionary similar to a sharded state dict, but note that - the dictionary keys are simply ShardedTensor keys (contrary to the - actual sharded state dicts where keys correspond to state dict keys). - - Dict values are ShardedTensors without any sharding (so, the only useful - information is tensors global shape and dtype). - """ - raise NotImplementedError( - f'{self.__class__.__name__} doesnt allow loading only sharded metadata' - ) - - -class SaveCommonStrategy(SaveStrategyBase): - @abstractmethod - def save(self, common_state_dict: StateDict, checkpoint_dir: Path): - raise NotImplementedError - - -class SaveShardedStrategy(SaveStrategyBase): - @abstractmethod - def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): - raise NotImplementedError diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py deleted file mode 100644 index 4a619353a19075f41b68a6cdae519313a7e305b8..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/strategies/tensorstore.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Strategies using TensorStore to load and save Zarr arrays. """ - -from functools import partial -from itertools import starmap -from logging import getLogger -from pathlib import Path - -import tensorstore as ts -import torch - -from ..core import CheckpointingException -from ..dict_utils import dict_list_map_inplace -from ..mapping import ShardedStateDict, ShardedTensor -from .base import LoadShardedStrategy, StrategyAction, default_strategies -from .zarr import ( - load_zarr_based_sharded_metadata, - numpy_to_torch_dtype_dict, - postprocess_numpy_array, -) - -_import_trigger = None - -logger = getLogger(__name__) - - -class TensorStoreLoadShardedStrategy(LoadShardedStrategy): - def __init__(self, load_directly_on_device: bool = False): - super().__init__() - self.load_directly_on_device = load_directly_on_device - - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - if torch.distributed.get_rank() == 0: - print(f'Loading distributed checkpoint with {self.__class__.__name__}') - if self.load_directly_on_device: - print(f'Loading distributed checkpoint directly on the GPU') - load_fn = partial( - _load_from_array, - checkpoint_dir=checkpoint_dir, - load_directly_on_device=self.load_directly_on_device, - ) - dict_list_map_inplace(load_fn, sharded_state_dict) - return sharded_state_dict - - def load_tensors_metadata(self, checkpoint_dir: Path): - def get_ts_shape_dtype(path): - arr = open_ts_array(path) - return arr.shape, arr.dtype.numpy_dtype - - return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype) - - def check_backend_compatibility(self, loaded_version): - pass # TODO - - def check_version_compatibility(self, loaded_version): - pass # TODO - - -def merge_global_slice_with_shape(global_slice, actual_shape, key): - def _merge_slice(dim_slice, dim_size): - if isinstance(dim_slice, slice): - assert ( - dim_slice.start < dim_size - ), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})' - if dim_slice.stop > dim_size: - dim_slice = slice(dim_slice.start, dim_size, dim_slice.step) - return dim_slice - - assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key) - return tuple(starmap(_merge_slice, zip(global_slice, actual_shape))) - - -def _load_from_array( - sharded_tensor: ShardedTensor, - checkpoint_dir: Path, - load_directly_on_device: bool = False, - apply_flattened_range: bool = True, -): - x = _load_regular_chunk(sharded_tensor, checkpoint_dir) - ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range) - if load_directly_on_device: - sharded_tensor.data.data.copy_(ten) - return sharded_tensor.data - else: - return ten - - -def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path): - assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor) - arr = open_ts_array(checkpoint_dir / sharded_tensor.key) - if sharded_tensor.global_shape == arr.shape: - x = ( - arr[sharded_tensor.global_slice()].read().result() - ) # flattened tensors loading is delayed - elif sharded_tensor.allow_shape_mismatch: - global_slice = merge_global_slice_with_shape( - sharded_tensor.global_slice(), arr.shape, sharded_tensor.key - ) - x = arr[global_slice].read().result() # flattened tensors loading is delayed - else: - _msg = ( - f'Global shape mismatch for loaded ({arr.shape})' - f' and expected ({sharded_tensor.global_shape}) tensor' - f' for key {sharded_tensor.key}' - ) - raise CheckpointingException(_msg) - return x - - -def open_ts_array(arr_path: Path): - """Opens a Zarr file array with Tensorstore with basic setting. - - Arguments: - arr_path (Path): path to a Zarr (Tensorstore) array - """ - spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'driver': 'file', - 'path': str(arr_path), - } - try: - arr = ts.open(ts.Spec(spec), open=True).result() - except Exception as e: - raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e - return arr - - -default_strategies[StrategyAction.LOAD_SHARDED.value][ - ('zarr', 1) -] = TensorStoreLoadShardedStrategy() diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py deleted file mode 100644 index a9844ff6e54a629fbe689771f172239a41687f5e..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" 2-stage checkpoint loading. """ -import os -import time -from collections import defaultdict -from dataclasses import dataclass -from functools import partial, wraps -from itertools import chain -from logging import DEBUG, INFO, StreamHandler, getLogger -from operator import attrgetter, itemgetter -from pathlib import Path -from typing import Iterable, List, NamedTuple, Optional, Tuple, Union - -import torch - -from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values -from ..mapping import ShardedStateDict, ShardedTensor, StateDict -from .base import LoadShardedStrategy -from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array -from .zarr import flatten_range, load_zarr_based_sharded_metadata - -_import_trigger = None - - -timers = defaultdict(list) - -logger = getLogger(__name__) - - -def timed(verbose=True): - def timed_dec(fn): - name = fn.__name__ - - @wraps(fn) - def wrapped(*args, **kwargs): - if verbose: - logger.debug(f'{name} init') - start = time.time() - ret = fn(*args, **kwargs) - took = time.time() - start - if verbose: - logger.debug(f'{name} took {took}s') - timers[name].append(took) - return ret - - return wrapped - - return timed_dec - - -@dataclass -class _ShardedTensorMetadata: - global_rank: int - sharded_tensor_no_data: ShardedTensor - dist_group_rank: Tuple[int] # id of distributed group - dist_group_ranks: Tuple[int] # id of distributed group - data_size: Optional[int] = None # bytes - - -def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): - return ( - sharded_tensor.key, - sharded_tensor.global_offset, - ) - - -class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): - """ Loads one checkpoint replica from storage and broadcasts to other nodes. - - This strategy loads checkpoint from storage on minimal set of nodes - and distributes the checkpoint to other nodes with torch.distributed. - Loading is performed with tensorstore. - - Steps: - 0. (optional) create Gloo distributed groups - 1. Exchange ShardedTensors metadata between all nodes - 2. Align needed tensors within DP groups - 3. For each globally unique tensor: - a) on one of the ranks load it from storage to CPU and move to CUDA - b) allocate CUDA tensor on other ranks - c) broadcast within DP group - d) copy tensor content to the model param location - e) free tensor buffers from a) and b) - - Notes: - 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs - 2. There is a lot of overlap potential between all three steps done for each tensor: - a) loading from storage to numpy - b) moving CPU tensors to CUDA - c) broadcast - - """ - - def __init__(self, data_parallel_group, cpu_transfer=True): - super().__init__() - - self.cpu_transfer = cpu_transfer - self.data_parallel_group_orig = data_parallel_group - self.data_parallel_group = None if cpu_transfer else data_parallel_group - self.dp_group_ranks = tuple( - sorted(torch.distributed.get_process_group_ranks(data_parallel_group)) - ) - self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig) - self.global_rank = torch.distributed.get_rank() - - def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - self.maybe_init_gloo_group() - all_tensors_sorted = self._build_load_plan(sharded_state_dict) - self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) - self.summarize_load_times() - return sharded_state_dict - - def summarize_load_times(self): - torch.distributed.barrier() - logger.info('Checkpoint loading finished. Summary:') - for key, times in sorted(timers.items()): - times_sum = sum(times) - max_times = torch.tensor([times_sum], device='cuda') - avg_times = torch.tensor([times_sum], device='cuda') - torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX) - torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM) - avg_times /= torch.distributed.get_world_size() - if torch.distributed.get_rank() == 0: - logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}') - - @timed(verbose=False) - def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): - logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') - ret = _load_from_array( - ten_meta.sharded_tensor_no_data, - checkpoint_dir, - load_directly_on_device=False, - apply_flattened_range=False, - ) - logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE') - return ret - - @timed() - def maybe_init_gloo_group(self): - if not self.cpu_transfer: - return - all_groups = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) - all_groups = set(tuple(sorted(gr)) for gr in all_groups) - for group_ranks in sorted(all_groups): - gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') - if self.global_rank in group_ranks: - self.data_parallel_group = gloo_pg - assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group) - - def check_backend_compatibility(self, loaded_version): - pass # TODO - - def check_version_compatibility(self, loaded_version): - pass # TODO - - @timed() - def _build_load_plan( - self, sharded_state_dict: ShardedStateDict - ) -> List[_ShardedTensorMetadata]: - local_meta = [ - _ShardedTensorMetadata( - self.global_rank, - sharded_ten.without_data(), - self.dp_group_rank, - self.dp_group_ranks, - ) - for sharded_ten in nested_values(sharded_state_dict) - ] - all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group) - torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group) - all_meta = list(chain.from_iterable(all_meta)) - all_tensors_sorted = self.deduplicate_chunks(all_meta) - return all_tensors_sorted - - @timed() - def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]): - """ Group tensors by chunk and then pick the tensor with the lowest rank. - - NOTE: with proper loading overlap, loading from randomized ranks - (instead of the smallest one) could be beneficial here. - """ - ten_metas = map_reduce( - ten_metas, - key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data), - reduce_fn=partial(min, key=attrgetter('dist_group_rank')), - ) - all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items()))) - return all_metas_sorted - - @timed() - def _exchange_loaded_tensors( - self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir - ): - logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}') - for ten_meta in ten_metas: - - src_rank = torch.distributed.get_global_rank( - self.data_parallel_group, ten_meta.dist_group_rank - ) - - if self.dp_group_rank == ten_meta.dist_group_rank: - exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta) - if not self.cpu_transfer: - exchange_tensor = exchange_tensor.cuda() - else: - # TODO: for non-flattened ranges we could reuse the buffer from the start here - exchange_tensor = torch.empty( - ten_meta.sharded_tensor_no_data.local_shape, - device='cpu' if self.cpu_transfer else 'cuda', - dtype=ten_meta.sharded_tensor_no_data.dtype, - ) - - logger.debug( - f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' - ) - torch.distributed.broadcast( - exchange_tensor, group=self.data_parallel_group, src=src_rank - ) - self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict) - logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done') - - # free buffer memory - exchange_tensor = None - - @timed(verbose=False) - def _distribute_data_to_state_dict( - self, - ten_meta: _ShardedTensorMetadata, - loaded_ten: torch.Tensor, - sharded_state_dict: ShardedStateDict, - ): - tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data) - - def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]): - if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key: - # already filled-in or key not matching - return t - sharded_tensor: ShardedTensor = t - x = loaded_ten - if sharded_tensor.flattened_range is not None: - x = flatten_range(sharded_tensor, x) - - # Reuse existing buffer - sharded_tensor.data.data.copy_(x) - return sharded_tensor.data - - dict_list_map_inplace(_fill_in_data, sharded_state_dict) - - def load_tensors_metadata(self, checkpoint_dir: Path): - def get_ts_shape_dtype(path): - arr = open_ts_array(path) - return arr.shape, arr.dtype.numpy_dtype - - return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype) diff --git a/megatron/core/dist_checkpointing/strategies/zarr.py b/megatron/core/dist_checkpointing/strategies/zarr.py deleted file mode 100644 index 0ce0cf0e27c8ab2441c6432840bb1c8f368632c3..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/strategies/zarr.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -""" Strategies using Zarr as an underlying format. """ -import os -from functools import partial -from logging import getLogger -from pathlib import Path -from typing import Callable, List, Optional, Tuple - -import numpy as np -import torch -import zarr - -from ..core import CheckpointingException -from ..dict_utils import dict_list_map_inplace -from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica -from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies - -numpy_to_torch_dtype_dict = { - np.dtype('bool'): torch.bool, - np.dtype('uint8'): torch.uint8, - np.dtype('int8'): torch.int8, - np.dtype('int16'): torch.int16, - np.dtype('int32'): torch.int32, - np.dtype('int64'): torch.int64, - np.dtype('float16'): torch.float16, - np.dtype('float32'): torch.float32, - np.dtype('float64'): torch.float64, - np.dtype('complex64'): torch.complex64, - np.dtype('complex128'): torch.complex128, -} - -torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()} - - -try: - import tensorstore - - HAS_BFLOAT16 = True - numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16 - torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16') -except ImportError: - HAS_BFLOAT16 = False - -_import_trigger = None - -logger = getLogger(__name__) - - -class ZarrSaveShardedStrategy(SaveShardedStrategy): - def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path): - arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir) - for ten, arr in zip(sharded_tensors, arrays): - _save_to_existing_array(ten, arr) - torch.distributed.barrier() - - -def _create_or_open_zarr_arrays( - sharded_tensors: List[ShardedTensor], checkpoint_dir: Path -) -> List[Optional[zarr.Array]]: - """ Returns list of zarr arrays corresponding to given tensors. - - For a sharded tensors that: - a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array - b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process) - c) otherwise, sets the corresponding array to None since it won't be used - - Args: - sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint - checkpoint_dir (Path): checkpoint in which the arrays will be created - """ - arrays = [] - for ten in sharded_tensors: - arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None - arrays.append(arr) - - torch.distributed.barrier() - # Open arrays created above by other processes - for arr_idx, ten in enumerate(sharded_tensors): - if arrays[arr_idx] is not None: - # array created by this process - assert _should_create_array(ten), ten - continue - if not is_main_replica(ten.replica_id): - # this array won't be needed for saving and can stay None - continue - open_kwargs = {} - if ten.flattened_range is not None: - open_kwargs['synchronizer'] = zarr.ProcessSynchronizer( - str(checkpoint_dir / f'{ten.key}.sync') - ) - arrays[arr_idx] = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs) - return arrays - - -def _should_create_array(ten: ShardedTensor): - return ( - is_main_replica(ten.replica_id) - and set(ten.global_offset) == {0} - and (ten.flattened_range is None or ten.flattened_range.start == 0) - ) - - -def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]): - if not is_main_replica(sharded_tensor.replica_id): - return - assert arr is not None - x = sharded_tensor.data - x = x.detach().cpu() - torch.cuda.synchronize() - if x.dtype == torch.bfloat16: - x = x.float() - x = x.numpy() - x = x.astype('bfloat16') - else: - x = x.numpy() - - if sharded_tensor.flattened_range is None: - arr[sharded_tensor.global_slice()] = x - else: - arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x) - - -def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): - np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype] - try: - arr = zarr.create( - sharded_tensor.global_shape, - dtype=np_dtype, - store=checkpoint_dir / sharded_tensor.key, - chunks=sharded_tensor.max_allowed_chunks(), - compressor=None, - fill_value=None, - write_empty_chunks=True, - ) - except zarr.errors.ContainsArrayError as e: - raise CheckpointingException( - f'Array {checkpoint_dir / sharded_tensor.key} already exists' - ) from e - - if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'): - arr._dtype = np_dtype - zarray = arr.store['.zarray'] - arr.store['.zarray'] = zarray.replace(b' exp_sh: - assert ( - False - ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}' - else: - pad_args.extend((0, exp_sh - x_sh)) - # TODO: behavior control with envvar is for testing purposes only, remove it - if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)): - return torch.nn.functional.pad(x, pad_args) - - # unsqueeze and squeeze to get shapes supported by cudnn - print(f'Replicating last row for {expected_sharded_ten.key}') - if x.dtype == torch.bfloat16: - return ( - torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate') - .squeeze(0) - .bfloat16() - ) - return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0) - - -def load_zarr_based_sharded_metadata( - checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]] -) -> ShardedStateDict: - """Load metadata of Zarr arrays. - - Arguments: - checkpoint_dir (str): checkpoint root directory - get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning - an array shape and dtype for a given Zarr array path - """ - sharded_state_dict = {} - for subdir in checkpoint_dir.iterdir(): - if not subdir.is_dir() or not (subdir / '.zarray').exists(): - continue - key = subdir.name - arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir)) - - sharded_state_dict[key] = ShardedTensor( - key, - None, - numpy_to_torch_dtype_dict[arr_dtype], - arr_shape, - arr_shape, - tuple(0 for _ in arr_shape), - tuple(1 for _ in arr_shape), - ) - return sharded_state_dict - - -# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy() -default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy( - 'zarr', 1 -) diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py deleted file mode 100644 index f7976f007408197338b9f9a96eec85db4d63d087..0000000000000000000000000000000000000000 --- a/megatron/core/dist_checkpointing/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - -from typing import Tuple - -from .dict_utils import dict_list_map_inplace, extract_matching_values -from .mapping import ( - LocalNonpersitentObject, - ShardedStateDict, - ShardedTensor, - ShardedTensorFactory, - StateDict, -) - - -def extract_sharded_tensors( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor)) - - -def extract_sharded_tensors_and_factories( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - return extract_matching_values( - sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory)) - ) - - -def extract_sharded_tensors_or_nonpersistent( - sharded_state_dict: ShardedStateDict, -) -> Tuple[ShardedStateDict, StateDict]: - return extract_matching_values( - sharded_state_dict, - lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)), - ) - - -def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str): - def add_prefix(t): - if isinstance(t, ShardedTensor): - t.key = f'{prefix}.{t.key}' - return t - - dict_list_map_inplace(add_prefix, sharded_state_dict) diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py deleted file mode 100644 index 34c7209a27fde7c5202f275663d951276caff85d..0000000000000000000000000000000000000000 --- a/megatron/core/distributed/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .distributed_data_parallel import DistributedDataParallel -from .finalize_model_grads import finalize_model_grads diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py deleted file mode 100644 index 63f6e3d65ec2bb3a7d771f3dd6fea61216112d67..0000000000000000000000000000000000000000 --- a/megatron/core/distributed/distributed_data_parallel.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from contextlib import contextmanager -from typing import Dict - -import torch - -from .. import parallel_state -from ..transformer.module import MegatronModule -from ..transformer.transformer_config import TransformerConfig -from .grad_buffer import GradBuffer - - -class DistributedDataParallel(MegatronModule): - """ - DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping - communication with backprop computation by breaking up full model's gradients into smaller - buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class - also provides the option to do the gradient accumulation in a type other than the param type - (e.g., fp32 for a bf16 model). - - Arguments: - config: Transformer config object. - module: Underlying model. - data_parallel_group: Data-parallel process group. - accumulate_allreduce_grads_in_fp32: If true, do the gradient accumulation and - communication in fp32. - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. - disable_bucketing: If true, force assign all parameters to a single bucket. If false, - use standard bucketing policy: assign parameters to smaller buckets and all-reduce - per bucket _if_ overlap_grad_reduce is True and pp_rank is 0. - - """ - - def __init__( - self, - config: TransformerConfig, - module: torch.nn.Module, - data_parallel_group: torch.distributed.ProcessGroup, - accumulate_allreduce_grads_in_fp32: bool, - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - disable_bucketing: bool = False, - bucket_size: int = 40000000, - ): - super().__init__(config=config) - self.module = module - - # Set bucket_size to infinity if overlap_grad_reduce is False. - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - # Turn off bucketing if overlap_grad_reduce is False, if we are on a pipeline stage - # that is not the first (since data-parallel communication on these stages is not on - # the critical path), or if disable_bucketing is True (e.g., we might not want to - # break up model parameters into buckets for model chunks after the first - # in the interleaved schedule). - if not self.overlap_grad_reduce: - bucket_size = None - if parallel_state.get_pipeline_model_parallel_rank() > 0: - bucket_size = None - if disable_bucketing: - bucket_size = None - self.bucket_size = bucket_size - - self.module = module - self.grad_buffers = {} - self.expert_grads = [] - self.grad_buffer_param_index_map = {} - self.param_to_grad_buffer = {} - - # Group parameters by their gradient type. - grad_dtype_to_params = {} - param_to_name = {} - for name, param in self.module.named_parameters(): - if param.requires_grad and getattr(param, 'allreduce', True): - param.grad_added_to_main_grad = False - param_to_name[param] = name - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - - params = grad_dtype_to_params.get(dtype, []) - params.append(param) - grad_dtype_to_params[dtype] = params - - # Allocate the grad buffers and map the grads. - # The grad buffer under the hood creates buckets as appropriate based on bucket_size. - self.data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) - for dtype, params in grad_dtype_to_params.items(): - self.grad_buffers[dtype] = GradBuffer( - dtype, - params, - data_parallel_group, - bucket_size, - param_to_name, - self.overlap_grad_reduce, - self.use_distributed_optimizer, - ) - self.grad_buffer_param_index_map[dtype] = self.grad_buffers[dtype].param_index_map - for param in params: - self.param_to_grad_buffer[param] = self.grad_buffers[dtype] - - # Allocate separate buffer for MoE params' grads. - for param in self.module.parameters(): - if param.requires_grad and not getattr(param, 'allreduce', True): - param.grad_added_to_main_grad = False - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - param.main_grad = torch.zeros( - param.data.shape, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - self.expert_grads.append(param.main_grad) - - # Register backward hook. - # Accumulation function for the gradients need to be stored so they - # don't go out of scope. - self.grad_accs = [] - for param in self.module.parameters(): - if param.requires_grad: - # Expand so we get access to grad_fn. - param_tmp = param.expand_as(param) - # Get the gradient accumulator function. - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) - self.grad_accs.append(grad_acc) - - def forward(self, *inputs, **kwargs): - """ - Calls the wrapped module's forward() method. - """ - return self.module(*inputs, **kwargs) - - def _make_param_hook( - self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer] - ): - """ - Creates the all-reduce / reduce-scatter hook for backprop. - """ - - def param_hook(*unused): - if param.requires_grad: - if self.overlap_grad_reduce: - assert ( - param.grad is not None - ), 'param.grad being None is not safe when overlap_grad_reduce is True' - if param.grad is not None and not param.grad_added_to_main_grad: - param.main_grad.add_(param.grad.data) - param.grad = None - if self.overlap_grad_reduce: - param_to_grad_buffer[param].register_grad_ready(param) - - return param_hook - - @contextmanager - def no_sync(self): - """ - Context manager that turns off gradient synchronization. - """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.is_last_microbatch = False - try: - yield - finally: - for grad_buffer in self.grad_buffers.values(): - grad_buffer.is_last_microbatch = True - - def start_grad_sync(self, *unused): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operations - for all model gradients. - - When overlap_grad_reduce is set to True, dispatches asynchronous communication - calls. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.start_grad_sync() - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operations - for all model gradients. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - calls to complete. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for grad_buffer in self.grad_buffers.values(): - grad_buffer.finish_grad_sync() - - for expert_grad in self.expert_grads: - expert_grad /= self.data_parallel_world_size - - def zero_grad_buffer(self, zero_buffer): - """ - Zeros out all grad buffers. Needs to be called at the beginning of each - training iteration. - - When zero_buffer is set to True, the underlying grad buffer is zeroed out. - """ - for param in self.module.parameters(): - if param.requires_grad: - param.grad_added_to_main_grad = False - for grad_buffer in self.grad_buffers.values(): - grad_buffer.reset(zero_buffer) - for expert_grad in self.expert_grads: - expert_grad.zero_() - - def broadcast_params(self): - """ - Syncs parameters across all DP ranks. - """ - for param in self.module.parameters(): - torch.distributed.broadcast( - param.data, - src=parallel_state.get_data_parallel_src_rank(with_context_parallel=True), - group=parallel_state.get_data_parallel_group(with_context_parallel=True), - ) - - def state_dict(self, prefix='', keep_vars=False): - """ - Returns a dictionary containing references to the whole state of the - wrapped module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. Parameters and buffers - set to None are not included. - """ - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """ - Returns wrapped module's state_dict for checkpoint saving. - """ - return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) - - def load_state_dict(self, state_dict, strict=True): - """ - Copies parameters and buffers from state_dict into the wrapped module and its - descendants. If strict is True, then the keys of state_dict must exactly match - the keys returned by this module’s state_dict() function. - """ - self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py deleted file mode 100644 index 916e4f3ecbffafca7f97d2b33193bb289e12228d..0000000000000000000000000000000000000000 --- a/megatron/core/distributed/finalize_model_grads.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from typing import List - -import torch -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from .. import parallel_state -from ..transformer.transformer_config import TransformerConfig -from ..utils import get_attr_wrapped_model, get_model_config - - -def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce word embedding grads. - - Reduce grads across first and last stages to ensure that word_embeddings parameters stay in - sync. This should only run for models that support pipelined model parallelism (BERT and GPT). - """ - - if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and parallel_state.get_pipeline_model_parallel_world_size() > 1 - ): - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - model_module = model[0] - elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): - model_module = model[-1] - else: # We do not support the interleaved schedule for T5 yet. - model_module = model[0] - - # Look for module with 'pre_process' attribute to get around the fact that DDP and - # other wrapper classes inherit from non-core MegatronModule that has - # 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight' - # attributes already, causing get_attr_wrapped_model() to not unwrap anything here. - # TODO: Clean this up once the wrapper classes inherit from core MegatronModule. - model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) - if model_module.share_embeddings_and_output_weights: - weight = model_module.shared_embedding_or_output_weight() - grad = weight.main_grad - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) - - -def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to - ensure that position embeddings parameters stay in sync. This should only run for T5 models - with pipeline parallelism. - """ - if ( - parallel_state.is_rank_in_position_embedding_group() - and parallel_state.get_pipeline_model_parallel_world_size() > 1 - and config.pipeline_model_parallel_split_rank is not None - ): - model_module = model[0] - grad = get_attr_wrapped_model( - model_module, 'language_model.embedding.position_embeddings.weight.main_grad' - ) - torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) - - -def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce both word and position embeddings. - """ - _allreduce_word_embedding_grads(model, config) - _allreduce_position_embedding_grads(model, config) - - -def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce layernorm grads (for sequence parallelism). - """ - - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and config.sequence_parallel: - grads = [] - for model_chunk in model: - for param in get_attr_wrapped_model(model_chunk, 'parameters')(): - if getattr(param, 'sequence_parallel', False): - grad = param.main_grad - grads.append(grad.data) - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_tensor_model_parallel_group() - ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - -def _allreduce_expert_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce expert grads (for expert parallelism). - """ - - # All-reduce switchmlp parameters across data modulo expert parallel nodes - if ( - config.expert_model_parallel_size > 1 - and config.expert_model_parallel_size < parallel_state.get_data_parallel_world_size() - ): - grads = [] - for model_chunk in model: - for param in get_attr_wrapped_model(model_chunk, 'parameters')(): - if not getattr(param, 'allreduce', True): - grad = param.main_grad - grads.append(grad.data) - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce( - coalesced, group=parallel_state.get_data_modulo_expert_parallel_group() - ) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - -def finalize_model_grads(model: List[torch.nn.Module]): - """ - All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, - embedding grads across first and last pipeline stages (if not tied), and expert grads - for expert parallelism. - """ - - config = get_model_config(model[0]) - - # All-reduce / reduce-scatter across DP replicas. - if config.timers is not None: - config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time) - for model_chunk in model: - model_chunk.finish_grad_sync() - if config.timers is not None: - config.timers('all-grads-sync').stop() - - # All-reduce layer-norm grads (for sequence parallelism). - if config.timers is not None: - config.timers('layernorm-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_layernorm_grads(model, config) - if config.timers is not None: - config.timers('layernorm-grads-all-reduce').stop() - - # All-reduce embedding grads (for pipeline parallelism). - if config.timers is not None: - config.timers('embedding-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_embedding_grads(model, config) - if config.timers is not None: - config.timers('embedding-grads-all-reduce').stop() - - # All-reduce expert grads (for expert parallelism). - if config.timers is not None: - config.timers('expert-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_expert_grads(model, config) - if config.timers is not None: - config.timers('expert-grads-all-reduce').stop() diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py deleted file mode 100644 index 8bc88a8e710db31840c80444ae726f0b6bd6c1be..0000000000000000000000000000000000000000 --- a/megatron/core/distributed/grad_buffer.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import math -from logging import getLogger -from typing import Dict, List - -import torch - -from .. import parallel_state - -logger = getLogger(__name__) - - -def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): - """ - Shard buffer into data_parallel_world_size chunks of equal size. - """ - assert buffer.numel() % data_parallel_world_size == 0 - shard_size = buffer.numel() // data_parallel_world_size - sharded_buffer = [ - buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size) - ] - return sharded_buffer - - -class Bucket: - """ - Bucket to keep track of a subset of the model's gradients. Provides functionality to register - when params in the bucket have grads ready to be synced; an asynchronous communication call - is automatically launched when _all_ params in the bucket have grads ready. - - Arguments: - params: List of parameters whose gradients are collated in this bucket. - data: View in larger GradBuffer that this bucket is responsible for. - offset: Offset of this bucket's view in the larger GradBuffer. - data_parallel_group: Data-parallel process group. - data_parallel_world_size: World size using the data-parallel group group. - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. - """ - - def __init__( - self, - params: List[torch.nn.Parameter], - data: torch.Tensor, - offset: int, - data_parallel_group: torch.distributed.ProcessGroup, - data_parallel_world_size: int, - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - ): - # State for bookkeeping: params is the set of parameters this bucket is - # responsible for, params_with_grad is the set of parameters with grads - # available. When overlap_grad_reduce is True, communication (all-reduce - # or reduce-scatter) is issued when params_with_grad equals params. - self.params_list = params - self.params = set(params) - self.params_with_grad = set() - self.data = data - # The distributed optimizer needs to keep track of this bucket's offset - # within the full grad_buffer. - self.offset = offset - self.data_parallel_group = data_parallel_group - self.data_parallel_world_size = data_parallel_world_size - self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group) - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - - self.reset() - - def reset(self): - """ - Reset metadata in bucket in preparation for the next iteration of training. - """ - self.params_with_grad = set() - self.communication_handle = None - self.communication_issued = False - - def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operation - for this bucket. - - When overlap_grad_reduce is set to True, dispatches an asynchronous - communication call. When overlap_grad_reduce is set to False, makes - synchronous call. - """ - assert ( - self.communication_handle is None and not self.communication_issued - ), 'Should not have multiple communication calls in flight at once' - - self.data /= self.data_parallel_world_size - # Use async_op only when overlap_grad_reduce is True. - if self.use_distributed_optimizer: - local_data_view = shard_buffer(self.data, self.data_parallel_world_size)[ - self.data_parallel_rank - ] - self.communication_handle = torch.distributed._reduce_scatter_base( - local_data_view, - self.data, - group=self.data_parallel_group, - async_op=self.overlap_grad_reduce, - ) - else: - self.communication_handle = torch.distributed.all_reduce( - self.data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce - ) - self.communication_issued = True - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operation - for this bucket. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - call to complete. When overlap_grad_reduce is set to False, makes synchronous call. - """ - # If overlap_grad_reduce is False, start (and finish) synchronous communication call here. - if not self.overlap_grad_reduce: - self.start_grad_sync() - return - assert self.communication_handle is not None and self.communication_issued, ( - f'Communication call has not been issued for this bucket ' - f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)' - ) - self.communication_handle.wait() - - def register_grad_ready(self, param: torch.nn.Parameter): - """ - Registers grads for the passed-in param to be "ready" for grad sync. - - When the number of microbatches is greater than 1, we only want to register - grads as ready when processing the last microbatch and overlap_grad_reduce is True. - """ - assert param in self.params, 'Param is not in the bucket' - assert param not in self.params_with_grad, 'Cannot set grad twice' - assert ( - self.overlap_grad_reduce - ), 'register_grad_ready() should be called only when overlapping grad reduce' - self.params_with_grad.add(param) - # If all params in bucket have grads available, issue communication call. - if len(self.params_with_grad) == len(self.params): - self.start_grad_sync() - - -class GradBuffer: - """ - Groups gradients into a contiguous buffer, and then breaks the buffer into buckets with - roughly `bucket_size` parameters each. - - Arguments: - dtype: Type of underlying tensor. - params: List of parameters whose gradients are collated in the underlying tensor. - data_parallel_group: Data-parallel process group. - bucket_size: The rough size of each bucket in terms of number of parameters. - param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). - overlap_grad_reduce: If true, overlap communication with backprop computation by - breaking up grads into buckets. If false, single synchronous communication call - is used instead. - use_distributed_optimizer: If true, issue reduce-scatter communication calls as part - of distributed optimizer. If false, issue all-reduce communication calls. - """ - - def __init__( - self, - dtype: torch.dtype, - params: List[torch.nn.Parameter], - data_parallel_group: torch.distributed.ProcessGroup, - bucket_size: int, - param_to_name: Dict[torch.nn.Parameter, str], - overlap_grad_reduce: bool, - use_distributed_optimizer: bool, - ): - - # Check that params are unique. - unique_params = set() - for param in params: - assert param not in unique_params - unique_params.add(param) - del unique_params - - # Store attributes that will be needed later. - self.dtype = dtype - self.data_parallel_group = data_parallel_group - self.data_parallel_world_size = torch.distributed.get_world_size( - group=self.data_parallel_group - ) - self.overlap_grad_reduce = overlap_grad_reduce - self.use_distributed_optimizer = use_distributed_optimizer - self.is_last_microbatch = True - - # Data structures to store underlying buckets and relevant indexing data. - self.buckets = [] - self.param_to_bucket = {} # Param -> bucket mapping. - self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer). - - def _pad_if_needed(data_index: int): - """Pads data indices if using distributed optimizer (to ensure uniform sharding).""" - if use_distributed_optimizer: - return ( - int(math.ceil(data_index / self.data_parallel_world_size)) - * self.data_parallel_world_size - ) - return data_index - - # First, figure out how many elements should be in the underlying buffer storage. - # Note that if we need to split the buffer into smaller buckets, each of these - # might need to be padded as well (if using the distributed optimizer). - data_start_index = 0 - bucket_data_start_index = data_start_index - bucket_params = set() - self.bucket_indices = [] - bucket_id = 0 - for param in params[::-1]: - # Iterate through parameters in reverse order to roughly follow backprop order, - # and skip parameters that don't require gradients. - if not param.requires_grad: - continue - this_numel = param.data.nelement() - data_end_index = data_start_index + this_numel - self.param_index_map[param] = ( - data_start_index, - data_end_index, - bucket_id, - ) - bucket_params.add(param) - - # If we have enough elements already, form a new bucket. - # If bucket_size is None, accumulate everything into a single bucket. - - # TODO: Remove len(bucket_params) > 1 when the final head that transforms token - # representations from hidden space to vocabulary space is in a PyTorch module - # whose forward method is called. If it is not and a bucket contains only this - # one parameter, we get incorrect behavior (i.e., higher losses) since we do not - # call the wait function on the bucket's all_gather_handle (we use forward pre- - # hooks on PyTorch modules to do this when --overlap-param-gather is used). - # As a temporary workaround, we make sure that no bucket has only one parameter. - if bucket_size is not None: - if (data_end_index - bucket_data_start_index) >= bucket_size and len( - bucket_params - ) > 1: - data_end_index = _pad_if_needed(data_end_index) - self.bucket_indices.append((bucket_data_start_index, data_end_index)) - bucket_data_start_index = data_end_index - bucket_params = set() - bucket_id += 1 - data_start_index = data_end_index - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - data_end_index = _pad_if_needed(data_end_index) - self.bucket_indices.append((bucket_data_start_index, data_end_index)) - - # Next, create underlying storage for buffer (with numel elements that includes - # padding as necessary). - self.numel = data_end_index - if use_distributed_optimizer: - assert self.numel % self.data_parallel_world_size == 0 - self.data = torch.zeros( - self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False, - ) - - # Finally, map main_grad fields for each parameter with a .grad field. - bucket_params = set() - bucket_data_start_index = 0 - cur_bucket_id = 0 - for param in params[::-1]: - if not param.requires_grad: - continue - data_start_index, data_end_index, bucket_id = self.param_index_map[param] - param.main_grad = self._get(param.data.shape, data_start_index) - if bucket_id != cur_bucket_id: - bucket_data_end_index = _pad_if_needed(data_start_index) - self._set_bucket( - bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id - ) - bucket_data_start_index = bucket_data_end_index - bucket_params = set() - assert cur_bucket_id + 1 == len(self.buckets) - assert bucket_id == cur_bucket_id + 1 - cur_bucket_id = bucket_id - bucket_params.add(param) - - # Add remaining params to a new bucket. - if len(bucket_params) > 0: - bucket_data_end_index = _pad_if_needed(data_end_index) - self._set_bucket( - bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id - ) - - if not overlap_grad_reduce: - assert len(bucket_params) == len( - params - ), 'All params should be in one bucket when overlap_grad_reduce is False' - - # Log buckets for all PP stages. - if ( - parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0 - and parallel_state.get_tensor_model_parallel_rank() == 0 - ): - logger.info( - f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}' - ) - for index, bucket in enumerate(self.buckets): - numel = 0 - for param in bucket.params: - numel += param.data.nelement() - logger.info(f'Params for bucket {index+1} ({numel} elements):') - for param in bucket.params: - logger.info(f' {param_to_name[param]}') - - def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor: - """ - Return a tensor with the input `shape` as a view into the 1-D data starting at - `start_index`. - """ - end_index = start_index + shape.numel() - assert end_index <= self.numel, 'Requested tensor is out of buffer range' - buffer_tensor = self.data[start_index:end_index] - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor - - def _set_bucket( - self, - bucket_params: List[torch.nn.Parameter], - start_index: int, - end_index: int, - bucket_id: int, - ): - """ - Helper function to create new bucket, add it to list of buckets, and - also update param->bucket mapping. - """ - - # Assert that indices are correctly padded (if needed), and that bucket - # position is same as originally computed. - if self.use_distributed_optimizer: - assert start_index % self.data_parallel_world_size == 0 - assert end_index % self.data_parallel_world_size == 0 - assert (start_index, end_index) == self.bucket_indices[bucket_id] - - # Get appropriate view into global GradBuffer. - bucket_data = self._get(torch.Size([end_index - start_index]), start_index) - bucket = Bucket( - params=bucket_params, - data=bucket_data, - offset=start_index, - data_parallel_group=self.data_parallel_group, - data_parallel_world_size=self.data_parallel_world_size, - overlap_grad_reduce=self.overlap_grad_reduce, - use_distributed_optimizer=self.use_distributed_optimizer, - ) - self.buckets.append(bucket) - for bucket_param in bucket_params: - assert bucket_param not in self.param_to_bucket - self.param_to_bucket[bucket_param] = bucket - - def reset(self, zero_buffer): - """ - Zero out the underlying buffer and reset all buckets in preparation for the next - iteration of training. - - When zero_buffer is set to True, the underlying buffer is zeroed out. - """ - if zero_buffer: - self.data.zero_() - for bucket in self.buckets: - bucket.reset() - self.is_last_microbatch = True - - def start_grad_sync(self): - """ - Initiates grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the grad buffer. - - When overlap_grad_reduce is set to True, dispatches asynchronous communication - calls. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for bucket in self.buckets: - bucket.start_grad_sync() - - def finish_grad_sync(self): - """ - Finishes grad sync (all-reduce or reduce-scatter) communication operations - for all buckets in the grad buffer. - - When overlap_grad_reduce is set to True, waits for asynchronous communication - calls to complete. When overlap_grad_reduce is set to False, calls synchronous - communication ops. - """ - for bucket in self.buckets: - bucket.finish_grad_sync() - - def register_grad_ready(self, param: torch.nn.Parameter): - """ - Registers grads for the passed-in param to be "ready" for grad sync. - - When the number of microbatches is greater than 1, we only want to register - grads as ready when processing the last microbatch and overlap_grad_reduce is True. - """ - assert ( - self.overlap_grad_reduce - ), 'register_grad_ready() should only be called when overlap_grad_reduce is True' - if self.is_last_microbatch: - bucket = self.param_to_bucket[param] - bucket.register_grad_ready(param) diff --git a/megatron/core/enums.py b/megatron/core/enums.py deleted file mode 100644 index 46e7d3b766af061cd36363f8486f75f7ad80b08f..0000000000000000000000000000000000000000 --- a/megatron/core/enums.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import enum - - -class ModelType(enum.Enum): - encoder_or_decoder = 1 - encoder_and_decoder = 2 - retro_encoder = 3 - retro_decoder = 4 diff --git a/megatron/core/fusions/__init__.py b/megatron/core/fusions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py deleted file mode 100644 index 14c1fe0d718223ba78830cf3099ac02907e65fc2..0000000000000000000000000000000000000000 --- a/megatron/core/fusions/fused_bias_dropout.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from typing import Optional, Tuple - -import torch - - -def _bias_dropout_add_func(x_with_bias, residual, prob, training): - # type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor - # NOTE: Previously, the argument `bias` used to be passed as - # `bias.expand_as(residual)` when the `bias_dropout_func` is called from the - # transformer layer but broadcasting should automatically take care of that. - # Also, looking at broadcasting semantics, `expand_as` and broadcasting - # seem to be identical performance-wise (both just change the view). - - x, bias = x_with_bias # unpack - - # If we want to train mixed precision, then the output of this function - # should be half precision. However, in AMP O1, the input (residual) is - # in fp32, and it will up-cast the result to fp32, causing pipeline parallel - # GPU communication to hang. Therefore, we need to cast residual to the same - # dtype as x. - residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) - - # The Dropout operation, Residual Addition and the tensor returning can be - # done generically outside the if statement, but that stops fusing of Bias - # Addition-Dropout-Residual Addition operation. So doing it together inside - # the conditional branch to improve performance - if bias is not None: - x = x + bias - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out - return out - else: - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out - return out - - -def bias_dropout_add_unfused(training): - def _bias_dropout_add(x_with_bias, residual, prob): - return _bias_dropout_add_func(x_with_bias, residual, prob, training) - - return _bias_dropout_add - - -@torch.jit.script -def bias_dropout_add_fused_train( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, -) -> torch.Tensor: - return _bias_dropout_add_func(x_with_bias, residual, prob, True) - - -@torch.jit.script -def bias_dropout_add_fused_inference( - x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, -) -> torch.Tensor: - return _bias_dropout_add_func(x_with_bias, residual, prob, False) - - -def get_bias_dropout_add(training, fused): - if fused: - # jit scripting for a nn.module (with dropout) is not - # triggering the fusion kernel. For now, we use two - # different nn.functional routines to account for varying - # dropout semantics during training and inference phases. - if training: - return bias_dropout_add_fused_train - else: - return bias_dropout_add_fused_inference - else: - return bias_dropout_add_unfused(training) diff --git a/megatron/core/fusions/fused_bias_gelu.py b/megatron/core/fusions/fused_bias_gelu.py deleted file mode 100644 index 9c791c180765b99c49e78dedf63444b57fed5ec1..0000000000000000000000000000000000000000 --- a/megatron/core/fusions/fused_bias_gelu.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -###### BIAS GELU FUSION/ NO AUTOGRAD ################ -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) - - -@torch.jit.script -def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, bias, y): - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - return ff * g - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(bias, input) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, bias, input) - return tmp, tmp - - -bias_gelu_impl = GeLUFunction.apply diff --git a/megatron/core/fusions/fused_layer_norm.py b/megatron/core/fusions/fused_layer_norm.py deleted file mode 100644 index c12ec173d0aabf9548f00bbdd4d1cbadb4d1a9e3..0000000000000000000000000000000000000000 --- a/megatron/core/fusions/fused_layer_norm.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import importlib -import numbers - -import torch -from torch import Tensor -from torch.nn import init -from torch.nn.parameter import Parameter - -from megatron.core.transformer import TransformerConfig -from megatron.core.utils import make_viewless_tensor - -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNormFN - - HAVE_PERSIST_LAYER_NORM = True -except: - HAVE_PERSIST_LAYER_NORM = False - -try: - from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction - - HAVE_FUSED_LAYER_NORM = True -except: - HAVE_FUSED_LAYER_NORM = False - - -class FusedLayerNorm(torch.nn.Module): - - """Layer Norm, fused into a single CUDA kernel. - - Arguments: - hidden_size (int): Transformer hidden dimension. - - eps (float): Epsilon added to denominator, for numerical stability. - - persist_layer_norm (bool): Use persistent fused layer norm kernel. - This kernel supports only a set of hidden sizes. Please - check persist_ln_hidden_sizes if your hidden size is supported. - - sequence parallel (bool): Apply sequence parallelism optimization. - - zero_centered_gamma (bool): Adjust LayerNorm weights such that they are - centered around zero. This improves numerical stability. - - config (TransformerConfig): Transformer config. Include to match custom - layer norm interfaces. - - normalization (str): Normalization type, used for Transformer Engine. - Must equal 'LayerNorm' here. - """ - - def __init__( - self, - config: TransformerConfig, - hidden_size: int, - eps: float = 1e-5, - persist_layer_norm: bool = True, - sequence_parallel: bool = False, - zero_centered_gamma: bool = False, - normalization: str = "LayerNorm", # included to match TE interface - ): - super().__init__() - - self.zero_centered_gamma = config.layernorm_zero_centered_gamma - assert ( - config.normalization == "LayerNorm" - ), f'({config.normalization}) is not supported in FusedLayerNorm' - - # List of hiddens sizes supported in the persistent layer norm kernel - # If the hidden size is not supported, fall back to the non-persistent - # kernel. - persist_ln_hidden_sizes = [ - 1024, - 1536, - 2048, - 2304, - 3072, - 3840, - 4096, - 5120, - 6144, - 8192, - 10240, - 12288, - 12800, - 15360, - 16384, - 18432, - 20480, - 24576, - 25600, - 30720, - 32768, - 40960, - 49152, - 65536, - ] - persist_layer_norm = config.persist_layer_norm - if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: - persist_layer_norm = False - - if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM: - # TODO: Add pytorch only layer norm - raise ValueError(f'Apex must currently be installed to use megatron core.') - - if isinstance(hidden_size, numbers.Integral): - hidden_size = (hidden_size,) - self.hidden_size = torch.Size(hidden_size) - self.eps = eps - self.weight = Parameter(torch.Tensor(*hidden_size)) - self.bias = Parameter(torch.Tensor(*hidden_size)) - self.reset_parameters() - self.persist_layer_norm = persist_layer_norm - self.sequence_parallel = config.sequence_parallel - - # set sequence parallelism flag on weight and bias parameters - setattr(self.weight, 'sequence_parallel', self.sequence_parallel) - setattr(self.bias, 'sequence_parallel', self.sequence_parallel) - - def reset_parameters(self): - - if self.zero_centered_gamma: - init.zeros_(self.weight) - init.zeros_(self.bias) - else: - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, input: Tensor) -> Tensor: - - weight = self.weight + 1 if self.zero_centered_gamma else self.weight - - if self.persist_layer_norm: - output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) - - # Apex's fast layer norm function outputs a 'view' tensor (i.e., has - # a populated '_base' field). This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - output = make_viewless_tensor( - inp=output, requires_grad=input.requires_grad, keep_graph=True - ) - - else: - output = FusedLayerNormAffineFunction.apply( - input, weight, self.bias, self.hidden_size, self.eps - ) - - return output diff --git a/megatron/core/fusions/fused_softmax.py b/megatron/core/fusions/fused_softmax.py deleted file mode 100644 index 56eb2e80111a1ff2eb0af95ea8cfb4c70a2ab9f0..0000000000000000000000000000000000000000 --- a/megatron/core/fusions/fused_softmax.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - - -import torch -import torch.nn as nn - -from megatron.core.transformer.enums import AttnMaskType - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class ScaledSoftmax(torch.autograd.Function): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super(FusedScaleMaskSoftmax, self).__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and 16 < sk <= 4096 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 4096: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - b, np, sq, sk = input.size() - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert sq == sk, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, sq, sk) - input = input.view(-1, sq, sk) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - return probs.view(b, np, sq, sk) - else: - # input is 4D tensor (b, np, sq, sk) - if mask is not None: - return ScaledMaskedSoftmax.apply(input, mask, scale) - else: - return ScaledSoftmax.apply(input, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - import scaled_masked_softmax_cuda - - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py deleted file mode 100644 index 287902460fab6d411781fb15c86f0a333b7cf245..0000000000000000000000000000000000000000 --- a/megatron/core/inference_params.py +++ /dev/null @@ -1,27 +0,0 @@ -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" - - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length - self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} - - def swap_key_value_dict(self, batch_idx): - "swap between batches" - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") - - for layer_number in self.key_value_memory_dict.keys(): - inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] - assert ( - len(batch_idx) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_idx] - new_inference_value_memory = inference_value_memory[:, batch_idx] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, - ) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py deleted file mode 100644 index 22d34da92129ce2193e5c297bc9ec8a6f3e555ff..0000000000000000000000000000000000000000 --- a/megatron/core/model_parallel_config.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Callable, Optional - -import torch - - -@dataclass -class ModelParallelConfig: - """Base configuration for Megatron Core - - Model Parallelism - ----------------- - - tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1. - - context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. Defaults to 1. - - pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU - ranks. Defaults to 1. - - virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by - reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks. - The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient - Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for - more details. Defaults to None. - - sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by - parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer - Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False. - - expert_model_parallel_size (int): Distributes Moe Experts across sub data parallel dimension. Defaults to False. - - Initialization - -------------- - - perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you - know you are going to load values from a checkpoint. - - use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU. - Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False. - - Training - -------- - - fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False. - - bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False. - - params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32 - - timers (optional, default=None): TODO - - Optimizations - ------------- - - gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA - extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with - --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" - ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion. - Defaults to False. - - async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of - tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False. - - tp_comm_overlap (bool, default=False): If true, allows overlapping of Linear layer execution with tensor parallel - communication collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever possible - during the forward and the backward pass. Defaults to False. - - tp_comm_split_ag (bool, default=True): If true, allows All-Gather overlap with Fprop GEMM. Don't care if tp_comm_overlap - is False. - - tp_comm_split_rs (bool, default=True): If true, allows Reduce-Scatter overlap with Fprop GEMM. Don't care if - tp_comm_overlap is False. - - tp_comm_bulk_dgrad (bool, default=True): If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't - care if tp_comm_overlap is False. - - tp_comm_bulk_wgrad (bool, default=True): If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't - care if tp_comm_overlap is False. - - Parallelism - ----------- - - finalize_model_grads_func (optional): Function that finalizes gradients on all workers. Could include ensuring that - grads are all-reduced across data parallelism, pipeline parallelism, and sequence parallelism dimensions. - - Pipeline Parallelism - -------------------- - - pipeline_dtype (required): dtype used in p2p communication, usually params_dtype - - grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the - scaled loss. If None, no function is called on the loss. - - enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False. - - autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype. - - variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this - communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it - should only be set if the sequence length varies by microbatch within a global batch. - - num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches - where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window - of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If - None, the checkpoint and recompute will be left up to the forward_step function. - - overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline - parallelism will overlap with computation. Must be False if batch_p2p_comm is true. - - batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False - if overlap_p2p_comm is True. - - batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work - around a bug in older version of PyTorch. - - use_ring_exchange_p2p (bool, default=False): Use custom ring_exchange kernel instead of - torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange. - - deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent - to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used. - - no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel - communication. If the model is an instance of core.distributed.DistributedDataParallel, the default is to use - core.distributed.DistributedDataParallel.no_sync. - - grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer - gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are - to be synchronized. - - param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed - optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be - synchronized. - - pipeline_model_parallel_split_rank (int, default=None): If int, rank where encoder and decoder should be split in - cases where the model has both an encoder and decoder (e.g., T5). Ignored if None. - - barrier_with_L1_time (bool, default=True): If true, use barrier with level 1 time measurements. It is up to the user - to make sure calling barrier with their timers will not result in hangs. This can happen if for example the user - adds a level 1 timer that is not called by all ranks. - - """ - - # Model parallelism - tensor_model_parallel_size: int = 1 - context_parallel_size: int = 1 - pipeline_model_parallel_size: int = 1 - virtual_pipeline_model_parallel_size: Optional[int] = None - sequence_parallel: bool = False - expert_model_parallel_size: int = 1 - - # Initialization - perform_initialization: bool = True - use_cpu_initialization: bool = False - - # Training - fp16: bool = False - bf16: bool = False - params_dtype: torch.dtype = torch.float32 - timers: Callable = None - - # Optimizations - gradient_accumulation_fusion: bool = False - async_tensor_model_parallel_allreduce: bool = False - tp_comm_overlap: bool = False - - # Debug Options - tp_comm_split_ag: bool = True - tp_comm_split_rs: bool = True - tp_comm_bulk_wgrad: bool = True - tp_comm_bulk_dgrad: bool = True - - # Parallelism - finalize_model_grads_func: Callable = None - - # Pipeline Parallel - pipeline_dtype: torch.dtype = None - grad_scale_func: Callable = None - enable_autocast: bool = False - autocast_dtype: torch.dtype = None - variable_seq_lengths: bool = False - num_microbatches_with_partial_activation_checkpoints: Optional[int] = None - overlap_p2p_comm: bool = False - batch_p2p_comm: bool = True - batch_p2p_sync: bool = True - use_ring_exchange_p2p: bool = False - deallocate_pipeline_outputs: bool = False - no_sync_func: Callable = None - grad_sync_func: Callable = None - param_sync_func: Callable = None - pipeline_model_parallel_split_rank: Optional[int] = None - - # Timing - barrier_with_L1_time: bool = True - - def __post_init__(self): - """ Python dataclass method that is used to modify attributes after initialization. - See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. - """ - if self.sequence_parallel: - if self.tensor_model_parallel_size <= 1: - raise ValueError("Can not use sequence paralllelism without tensor parallelism") - if self.async_tensor_model_parallel_allreduce: - # sequence_parallelism already does this async - self.async_tensor_model_parallel_allreduce = False - - if self.pipeline_model_parallel_size > 1: - if self.pipeline_dtype is None: - raise ValueError( - "When using pipeline parallelism, pipeline_dtype must be specified" - ) - - if self.autocast_dtype is None: - self.autocast_dtype = self.params_dtype - - if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: - if self.sequence_parallel is False: - raise ValueError( - "When using expert parallelism and tensor parallelism, sequence parallelism must be used" - ) diff --git a/megatron/core/models/T5/__init__.py b/megatron/core/models/T5/__init__.py deleted file mode 100644 index f65859a6dafcdfeb650f6b4a0da4fdecfe7f4dcf..0000000000000000000000000000000000000000 --- a/megatron/core/models/T5/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .t5_model import T5Model diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py deleted file mode 100644 index f2ce4809f365ef523274147068cfb0d2815e3051..0000000000000000000000000000000000000000 --- a/megatron/core/models/T5/t5_model.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import logging -from typing import List, Literal, Optional - -import torch -from torch import Tensor - -from megatron.core import InferenceParams, parallel_state, tensor_parallel -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.enums import AttnMaskType, ModelType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint - - -class T5LMHead(MegatronModule): - """Masked LM head for T5 - - Args: - config (TransformerConfig): transformer config - parallel_output (bool): wether output logits being distributed or not. - vocab_size (int): vocabulary size - pre_process (bool): Include embedding layer - share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are - shared. - """ - - def __init__( - self, - config: TransformerConfig, - parallel_output: bool, - vocab_size: int, - pre_process: bool = True, - share_embeddings_and_output_weights: bool = False, - ): - super(T5LMHead, self).__init__(config=config) - - self.parallel_output = parallel_output - - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - vocab_size, - config=config, - init_method=config.init_method, - bias=share_embeddings_and_output_weights, - skip_bias_add=not share_embeddings_and_output_weights, - gather_output=not self.parallel_output, - skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, - ) - - def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor: - """Forward pass. - - Args: - hidden_states (Tensor): output hidden states from decoder - word_embeddings_weight (Tensor): word embedding weight - - Returns: - Tensor: logits tensor - """ - - logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight) - return logits - - -class T5Model(LanguageModule): - """T5 Language model. - - Args: - config (TransformerConfig): transformer config - - transformer_encoder_layer_spec (ModuleSpec): transformer layer customization specs for encoder - - transformer_decoder_layer_spec (ModuleSpec): transformer layer customization specs for decoder - - vocab_size (int): vocabulary size - - max_sequence_length (int): maximum size of sequence. This is used for positional embedding - - pre_process (bool): Include embedding layer (used with pipeline parallelism) - post_process (bool): Include an output layer (used with pipeline parallelism) - - fp16_lm_cross_entropy (bool, optional): Defaults to False - - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks - - share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are - shared. Defaults to False. - - position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. - Defaults is 'learned_absolute'. - - rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. - Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. - - seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences. - The value must be a float larger than 1.0. Defaults to None. - """ - - def __init__( - self, - config: TransformerConfig, - transformer_encoder_layer_spec: ModuleSpec, - transformer_decoder_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', - rotary_percent: float = 1.0, - seq_len_interpolation_factor: Optional[float] = None, - ): - - super(T5Model, self).__init__(config=config) - - self.config: TransformerConfig = config - self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec - self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec - self.vocab_size = vocab_size - self.max_sequence_length = max_sequence_length - self.pre_process = pre_process - self.post_process = post_process - self.add_encoder = True - self.add_decoder = True - self.fp16_lm_cross_entropy = fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.position_embedding_type = position_embedding_type - - # megatron core pipelining currently depends on model type - self.model_type = ModelType.encoder_and_decoder - - # Embeddings. - if self.pre_process: - self.embedding = LanguageModelEmbedding( - config=self.config, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - position_embedding_type=self.position_embedding_type, - ) - - # Rotary Position Embeddings - if self.position_embedding_type == 'rope': - self.rotary_pos_emb = RotaryEmbedding( - self.config.kv_channels, rotary_percent, seq_len_interpolation_factor - ) - - # Transformer encoder - encoder_spec, decoder_spec = ( - self.transformer_encoder_layer_spec, - self.transformer_decoder_layer_spec, - ) - self.encoder = TransformerBlock( - config=self.config, - spec=encoder_spec, - pre_process=self.pre_process, - post_process=self.post_process, - ) - # Transformer decoder - self.decoder = TransformerBlock( - config=self.config, - spec=decoder_spec, - pre_process=self.pre_process, - post_process=self.post_process, - ) - - # Output - if post_process: - self.lm_head = T5LMHead( - config, - parallel_output, - self.vocab_size, - self.pre_process, - self.share_embeddings_and_output_weights, - ) - self.output_layer = self.lm_head.output_layer - - if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): - self.initialize_last_stage_with_word_embeddings() - - def forward( - self, - encoder_input_ids: Tensor, - decoder_input_ids: Tensor, - encoder_attn_mask: Tensor, - decoder_attn_mask: Tensor, - encoder_decoder_attn_mask: Tensor, - lm_labels: Tensor = None, - inference_params: InferenceParams = None, - ) -> Tensor: - """Forward pass. - - Args: - encoder_input_ids (Tensor): input ids for encoder - decoder_input_ids (Tensor): input ids for decoder - encoder_attn_mask (Tensor): self-attention mask for encoder - decoder_attn_mask (Tensor): self-attention mask for decoder - encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder - lm_labels (Tensor): labels for decoder output - inference_params (InferenceParams): relevant arguments for inferencing - - Returns: - Tensor: loss tensor - """ - - ( - encoder_attn_mask, - decoder_attn_mask, - encoder_decoder_attn_mask, - ) = t5_extended_attention_mask( - [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask] - ) - encoder_position_ids = t5_position_ids(encoder_input_ids) - decoder_position_ids = t5_position_ids(decoder_input_ids) - - ## Encoder forward - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding( - input_ids=encoder_input_ids, position_ids=encoder_position_ids - ) - else: - # intermediate stage of pipeline - encoder_input = None - - # Rotary positional embeddings - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.encoder, encoder_input, self.config - ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - - # Run encoder. - encoder_hidden_states = self.encoder( - hidden_states=encoder_input, - attention_mask=encoder_attn_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - ) - - ## Decoder forward - # Decoder embedding. - if self.pre_process: - decoder_input = self.embedding( - input_ids=decoder_input_ids, position_ids=decoder_position_ids - ) - else: - # intermediate stage of pipeline - decoder_input = None ### should it take encoder_hidden_states - - # Rotary positional embeddings - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config - ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - - # Run decoder. - decoder_hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=decoder_attn_mask, - context=encoder_hidden_states, - context_mask=encoder_decoder_attn_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - ) - - # Return if not post_process - if not self.post_process: - return decoder_hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight) - - if lm_labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(lm_labels, logits) - - return loss - - def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" - - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - if self.add_encoder and self.add_decoder: - assert ( - len(input_tensor) == 1 - ), 'input_tensor should only be length 1 for stage with both encoder and decoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_encoder: - assert ( - len(input_tensor) == 1 - ), 'input_tensor should only be length 1 for stage with only encoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_decoder: - if len(input_tensor) == 2: - self.decoder.set_input_tensor(input_tensor[0]) - self.encoder_hidden_state = input_tensor[1] - elif len(input_tensor) == 1: - self.decoder.set_input_tensor(None) - self.encoder_hidden_state = input_tensor[0] - else: - raise Exception('input_tensor must have either length 1 or 2') - else: - raise Exception('Stage must have at least either encoder or decoder') - - def shared_embedding_or_output_weight(self) -> Tensor: - """Function to share the input embeddings and output logit weights.""" - - if self.pre_process: - return self.embedding.word_embeddings.weight - elif self.post_process: - return self.lm_head.output_layer.weight - return None - - def sharded_state_dict(self, prefix: str = ''): - sharded_state_dict = {} - - if self.pre_process: - embedding_prefix = f'{prefix}embedding.' - embedding_sharded_state_dict = self.embedding.sharded_state_dict( - prefix=embedding_prefix - ) - sharded_state_dict.update(embedding_sharded_state_dict) - - encoder_prefix = f'{prefix}encoder.' - encoder_sharded_state_dict = self.encoder.sharded_state_dict(prefix=encoder_prefix) - sharded_state_dict.update(encoder_sharded_state_dict) - - decoder_prefix = f'{prefix}decoder.' - decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix) - sharded_state_dict.update(decoder_sharded_state_dict) - - if self.post_process: - output_layer_prefix = f'{prefix}output_layer.' - output_layer_weight_key = f'{output_layer_prefix}weight' - output_layer_bias_key = f'{output_layer_prefix}bias' - if self.share_embeddings_and_output_weights: - if not self.pre_process: - # when sharing embeddings with last stage, we need to use the weights from the first stage - # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight - tensor = self.shared_embedding_or_output_weight() - first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' - dp_rank = parallel_state.get_data_parallel_rank() - dp_size = parallel_state.get_data_parallel_world_size() - last_stage_word_emb_replica_id = ( - dp_rank + dp_size - ) # copy of first stage embedding - - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, - allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor - # output_layer.weight is shared, but we still need to process output_layer.bias - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=self.lm_head.output_layer.bias, - key=output_layer_bias_key, - allow_shape_mismatch=True, - ) - sharded_state_dict[output_layer_bias_key] = sharded_output_layer_tensor - else: - output_layer_state_dict = self.output_layer.state_dict( - prefix=output_layer_prefix, keep_vars=True - ) - output_layer_tensor = output_layer_state_dict[output_layer_weight_key] - # independent output layer - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=output_layer_tensor, - key=output_layer_weight_key, - replica_id=parallel_state.get_data_parallel_rank(), - allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor - - return sharded_state_dict - - def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_["embedding"] = self.embedding.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - state_dict_["encoder"] = self.encoder.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - state_dict_["decoder"] = self.decoder.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - - if self.post_process and self.add_decoder: - state_dict_["lm_head"] = self.lm_head.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars - ) - # Save word_embeddings. - if self.post_process and not self.pre_process and self.add_decoder: - state_dict_["word_embeddings_for_head"] = self.embedding.state_dict( - prefix=prefix, keep_vars=keep_vars - ) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - self.embedding.load_state_dict(state_dict["embedding"], strict=strict) - - self.encoder.load_state_dict(state_dict["encoder"], strict=strict) - - self.decoder.load_state_dict(state_dict["decoder"], strict=strict) - - if self.post_process and self.add_decoder: - self.lm_head.load_state_dict(state_dict["lm_head"], strict=strict) - - # Load word embeddings - if self.post_process and not self.pre_process and self.add_decoder: - self.word_embeddings.load_state_dict( - state_dict["word_embeddings_for_head"], strict=strict - ) - - -def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]: - def attn_mask_postprocess(attn_mask): - # [b, 1, s, s] - extended_attention_mask = attn_mask.unsqueeze(1) - return extended_attention_mask - - return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] - - -def t5_position_ids(token_ids: Tensor) -> Tensor: - """Calculate position ids from token ids - Args: - token_ids (Tensor): input tokens - - Returns: - Tensor: position ids - """ - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py deleted file mode 100644 index 60f33dbd9810d76ce014669b0f247eb8516017f9..0000000000000000000000000000000000000000 --- a/megatron/core/models/T5/t5_spec.py +++ /dev/null @@ -1,212 +0,0 @@ -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import ( - CrossAttention, - CrossAttentionSubmodules, - SelfAttention, - SelfAttentionSubmodules, -) -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TENorm, - TERowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import ( - TransformerBlockSubmodules, - get_num_layers_to_build, -) -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - - -def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec: - """T5 encoder TE spec (uses Transformer Engine components).""" - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: - """T5 decoder TE spec (uses Transformer Engine components).""" - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_cross_attn_layernorm=TENorm, - cross_attention=ModuleSpec( - module=CrossAttention, - submodules=CrossAttentionSubmodules( - linear_q=TEColumnParallelLinear, - linear_kv=TEColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - cross_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -def encoder_model_with_local_spec() -> ModuleSpec: - """T5 encoder local spec (uses Megatron-Core components).""" - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -def decoder_model_with_local_spec() -> ModuleSpec: - """T5 decoder local spec (uses Megatron-Core components).""" - - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_cross_attn_layernorm=FusedLayerNorm, - cross_attention=ModuleSpec( - module=CrossAttention, - submodules=CrossAttentionSubmodules( - linear_q=ColumnParallelLinear, - linear_kv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - cross_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -def get_t5_encoder_with_transformer_engine_block_spec( - num_layers: int, -) -> TransformerBlockSubmodules: - """T5 encoder block spec for Transformer Engine - - Args: - config (TransformerConfig): config, containing number of layers for encoder - """ - - layer_spec = encoder_model_with_transformer_engine_default_spec() - block_spec = TransformerBlockSubmodules([layer_spec] * num_layers) - return block_spec - - -def get_t5_decoder_with_transformer_engine_block_spec( - num_layers: int, -) -> TransformerBlockSubmodules: - """T5 decoder block spec for Transformer Engine - - Args: - config (TransformerConfig): config, containing number of layers for decoder - """ - - layer_spec = decoder_model_with_transformer_engine_default_spec() - block_spec = TransformerBlockSubmodules([layer_spec] * num_layers) - return block_spec - - -def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules: - """T5 encoder block spec for local (uses Megatron-Core components) - - Args: - num_layers (int): number of encoder layers - """ - - layer_spec = encoder_model_with_local_spec() - block_spec = TransformerBlockSubmodules([layer_spec] * num_layers) - return block_spec - - -def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules: - """T5 decoder block spec for local (uses Megatron-Core components) - - Args: - num_layers (int): number of decoder layers - """ - - layer_spec = decoder_model_with_local_spec() - block_spec = TransformerBlockSubmodules([layer_spec] * num_layers) - return block_spec diff --git a/megatron/core/models/__init__.py b/megatron/core/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/models/bert/__init__.py b/megatron/core/models/bert/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py deleted file mode 100644 index 9c36711fdd38cc1dced5764a4dca9f4f6def309d..0000000000000000000000000000000000000000 --- a/megatron/core/models/bert/bert_layer_specs.py +++ /dev/null @@ -1,64 +0,0 @@ -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - -# Use this spec to use lower level Transformer Engine modules (required for fp8 training) -bert_layer_with_transformer_engine_spec = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), -) - -# Use this spec for an implementation using only modules in megatron core -bert_layer_local_spec = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.padding}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), -) diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py deleted file mode 100644 index ea6f8f122604a30eb82f77bcca46dcc8e7e3d858..0000000000000000000000000000000000000000 --- a/megatron/core/models/bert/bert_lm_head.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from torch import Tensor - -from megatron.core import tensor_parallel -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import erf_gelu, get_linear_layer, openai_gelu -from megatron.model import LayerNorm - - -class BertLMHead(MegatronModule): - """Masked LM head for Bert - - Args: - hidden_size: hidden size - config (TransformerConfig): TransformerConfig object - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks - vocab_size(int): The vocabulary size - share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False - pre_process (bool): Include embedding layer (used with pipeline parallelism) - """ - - def __init__( - self, - hidden_size: int, - config: TransformerConfig, - parallel_output: bool, - vocab_size: int, - pre_process: bool, - share_embeddings_and_output_weights: bool = False, - ): - super().__init__(config=config) - - self.vocab_size = vocab_size - self.parallel_output = parallel_output - - # TODO: Shoudl switch this to TE ? - self.dense = get_linear_layer( - hidden_size, hidden_size, config.init_method, config.perform_initialization - ) - - setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel) - setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel) - - self.layernorm = LayerNorm( - hidden_size, eps=config.layernorm_epsilon, sequence_parallel=config.sequence_parallel - ) - - self.gelu = torch.nn.functional.gelu - # TODO Use activation_func in config to determine what to use - # if config.openai_gelu: # Dont have these configs in transfomer config yet - # self.gelu = openai_gelu - # elif config.onnx_safe: # Dont have these configs in transfomer config yet - # self.gelu = erf_gelu - - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=True, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights, - ) - - def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = self.layernorm(hidden_states) - logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight) - return logits diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py deleted file mode 100644 index 165c1b39028b08d93f447412be01c241a044bb89..0000000000000000000000000000000000000000 --- a/megatron/core/models/bert/bert_model.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from typing import Literal, Optional - -import torch -from torch import Tensor - -from megatron.core.models.bert.bert_lm_head import BertLMHead -from megatron.core.models.bert.pooler import Pooler -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.enums import AttnMaskType, ModelType -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import get_linear_layer -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids - - -class BertModel(LanguageModule): - """Transformer language model. - - Args: - config (TransformerConfig): transformer config - num_tokentypes (int) : Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers - vocab_size (int): vocabulary size - max_sequence_length (int): maximum size of sequence. This is used for positional embedding - pre_process (bool): Include embedding layer (used with pipeline parallelism) - post_process (bool): Include an output layer (used with pipeline parallelism) - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks - share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. - Defaults is 'learned_absolute'. - rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. - Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. - """ - - def __init__( - self, - config: TransformerConfig, - num_tokentypes: int, - transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', - rotary_percent: float = 1.0, - seq_len_interpolation_factor: Optional[float] = None, - add_binary_head=True, - return_embeddings=False, - ): - super(BertModel, self).__init__(config=config) - - if return_embeddings: - assert self.post_process and self.add_binary_head - - self.config: TransformerConfig = config - self.transformer_layer_spec: ModuleSpec = transformer_layer_spec - self.vocab_size = vocab_size - self.max_sequence_length = max_sequence_length - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.position_embedding_type = position_embedding_type - self.add_binary_head = add_binary_head - self.return_embeddings = return_embeddings - - # megatron core pipelining currently depends on model type - self.model_type = ModelType.encoder_or_decoder - - # Embeddings. - if self.pre_process: - self.embedding = LanguageModelEmbedding( - config=self.config, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - position_embedding_type=position_embedding_type, - num_tokentypes=num_tokentypes, - ) - - if self.position_embedding_type == 'rope': - self.rotary_pos_emb = RotaryEmbedding( - self.config.kv_channels, rotary_percent, seq_len_interpolation_factor - ) - - # Transformer. - self.encoder = TransformerBlock( - config=self.config, - spec=self.transformer_layer_spec, - pre_process=self.pre_process, - post_process=self.post_process, - ) - - # Output - if post_process: - # TODO: Make sure you are passing in the mpu_vocab_size properly - self.lm_head = BertLMHead( - config.hidden_size, - config, - parallel_output, - self.vocab_size, - self.pre_process, - self.share_embeddings_and_output_weights, - ) - - self.output_layer = self.lm_head.output_layer - - self.binary_head = None - if self.add_binary_head: - # TODO: Shoudl switch this to TE ? - self.binary_head = get_linear_layer( - config.hidden_size, 2, config.init_method, config.perform_initialization - ) - - self.pooler = Pooler( - config.hidden_size, config.init_method, config, config.sequence_parallel - ) - - if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): - self.initialize_last_stage_with_word_embeddings() - - def set_input_tensor(self, input_tensor: Tensor) -> None: - """Sets input tensor to the model. - - See megatron.model.transformer.set_input_tensor() - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' - self.encoder.set_input_tensor(input_tensor[0]) - - def forward( - self, - input_ids: Tensor, - attention_mask: Tensor, - tokentype_ids: Tensor = None, - lm_labels: Tensor = None, - inference_params=None, - ): - """Forward function of BERT model - - Forward function of the BERT Model This function passes the input tensors - through the embedding layer, and then the encoder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - """ - extended_attention_mask = bert_extended_attention_mask(attention_mask) - - position_ids = bert_position_ids(input_ids) - - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding( - input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids - ) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - encoder_input = None - - # Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?) - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.encoder, encoder_input, self.config - ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - - # Run decoder. - hidden_states = self.encoder( - hidden_states=encoder_input, - attention_mask=extended_attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - ) - if not self.post_process: - return hidden_states - - if self.add_binary_head: - pooled_output = self.pooler(hidden_states, 0) - - if self.return_embeddings: - embeddings = torch.transpose(hidden_states, 0, 1) - masks = torch.sum(attention_mask, dim=1) - # Collect masked embeddings. - output = torch.zeros( - size=(embeddings.shape[0], embeddings.shape[2]), - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - for i, (embedding, mask) in enumerate(zip(embeddings, masks)): - output[i, :] = torch.mean(embedding[1 : mask - 1], dim=0) - return output - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - - logits = self.lm_head(hidden_states=hidden_states, word_embeddings_weight=output_weight) - - binary_logits = None - if self.binary_head is not None: - binary_logits = self.binary_head(pooled_output) - - if lm_labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous(), binary_logits - - loss = self.compute_language_model_loss(lm_labels, logits) - - return loss, binary_logits - - # TODO: add distributed checkpointing - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - pass - - # TODO: add distributed checkpointing - def load_state_dict(self, state_dict, strict=True): - pass diff --git a/megatron/core/models/bert/pooler.py b/megatron/core/models/bert/pooler.py deleted file mode 100644 index c144d8c9c4daba6f1e43d1fe5bc78c9365e1c90e..0000000000000000000000000000000000000000 --- a/megatron/core/models/bert/pooler.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from torch import Tensor - -from megatron.core import tensor_parallel -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import get_linear_layer - - -class Pooler(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Args: - hidden_size (int): The hidden size_ - init_method (callable): weight initialization method for the linear layer. bias is set to zero. - config (TransformerConfig): The transformer configuration - sequence_parallel (bool): Using squence parallel ? Defaults to False - """ - - def __init__( - self, - hidden_size: int, - init_method: callable, - config: TransformerConfig, - sequence_parallel: bool = False, - ): - super(Pooler, self).__init__(config) - # TODO: Shoudl switch this to TE ? - self.dense = get_linear_layer( - hidden_size, hidden_size, init_method, config.perform_initialization - ) - self.sequence_parallel = sequence_parallel - - def forward(self, hidden_states: Tensor, sequence_index=0): - # hidden_states: [s, b, h] - # sequence_index: index of the token to pool. - - # gather data along sequence dimensions - # same pooler is run on all tensor parallel nodes - if self.sequence_parallel: - hidden_states = tensor_parallel.gather_from_sequence_parallel_region( - hidden_states, tensor_parallel_output_grad=False - ) - - pooled = hidden_states[sequence_index, :, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled diff --git a/megatron/core/models/common/__init__.py b/megatron/core/models/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/models/common/embeddings/__init__.py b/megatron/core/models/common/embeddings/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/models/common/embeddings/language_model_embedding.py b/megatron/core/models/common/embeddings/language_model_embedding.py deleted file mode 100644 index 40d679d7b13dd44797957909c4c4fd2272e1f932..0000000000000000000000000000000000000000 --- a/megatron/core/models/common/embeddings/language_model_embedding.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from typing import Literal, Optional - -import torch -from torch import Tensor - -from megatron.core import tensor_parallel -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import ( - make_sharded_tensor_for_checkpoint, - make_tp_sharded_tensor_for_checkpoint, -) - - -class LanguageModelEmbedding(MegatronModule): - """Language model embeddings. - - Arguments: - config (TransformerConfig): config object with all necessary configs for TransformerBlock - vocab_size (int): vocabulary size - max_sequence_length (int): maximum size of sequence. This - is used for positional embedding - add_position_embedding (bool): Add a position embedding. - embedding_dropout_prob (float): dropout probability for embeddings - num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head . Defaults to 0. - """ - - def __init__( - self, - config: TransformerConfig, - vocab_size: int, - max_sequence_length: int, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', - num_tokentypes: int = 0, - ): - super().__init__(config=config) - - self.config: TransformerConfig = config - self.vocab_size: int = vocab_size - self.max_sequence_length: int = max_sequence_length - self.add_position_embedding: bool = position_embedding_type == 'learned_absolute' - self.num_tokentypes = num_tokentypes - - # Word embeddings (parallel). - self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - num_embeddings=self.vocab_size, - embedding_dim=self.config.hidden_size, - init_method=self.config.init_method, - config=self.config, - ) - - # Position embedding (serial). - if self.add_position_embedding: - self.position_embeddings = torch.nn.Embedding( - self.max_sequence_length, self.config.hidden_size - ) - - # Initialize the position embeddings. - if self.config.perform_initialization: - self.config.init_method(self.position_embeddings.weight) - - if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding( - self.num_tokentypes, self.config.hidden_size - ) - # Initialize the token-type embeddings. - if self.config.perform_initialization: - self.config.init_method(self.tokentype_embeddings.weight) - else: - self.tokentype_embeddings = None - - # Embeddings dropout - self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout) - - def zero_parameters(self): - """Zero out all parameters in embedding.""" - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - self.position_embeddings.weight.data.fill_(0) - self.position_embeddings.weight.shared = True - if self.num_tokentypes > 0: - self.tokentype_embeddings.weight.data.fill_(0) - self.tokentype_embeddings.weight.shared = True - - def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor: - """Forward pass of the embedding module - Args: - input_ids (Tensor): The input tokens - position_ids (Tensor): The position id's used to calculate position embeddings - tokentype_ids (int): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None - - Returns: - Tensor: The output embeddings - """ - word_embeddings = self.word_embeddings(input_ids) - if self.add_position_embedding: - position_embeddings = self.position_embeddings(position_ids) - embeddings = word_embeddings + position_embeddings - else: - embeddings = word_embeddings - - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - if tokentype_ids is not None: - assert self.tokentype_embeddings is not None - # [b s h] -> [s b h] (So that it can be added with embeddings) - tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2) - embeddings = embeddings + tokentype_embedding - else: - assert self.tokentype_embeddings is None - - # If the input flag for fp32 residual connection is set, convert for float. - if self.config.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.config.sequence_parallel: - embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) - # `scatter_to_sequence_parallel_region` returns a view, which prevents - # the original tensor from being garbage collected. Clone to facilitate GC. - # Has a small runtime cost (~0.5%). - if self.config.clone_scatter_output_in_embedding: - embeddings = embeddings.clone() - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings - - def sharded_state_dict(self, prefix=''): - - sharded_state_dict = {} - - word_embeddings_prefix = f'{prefix}word_embeddings.' - word_embeddings_state_dict = self.word_embeddings.state_dict( - prefix=word_embeddings_prefix, keep_vars=True - ) - - sharded_word_embeddings_key = f'{word_embeddings_prefix}weight' - sharded_word_embeddings_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=word_embeddings_state_dict[sharded_word_embeddings_key], - key=sharded_word_embeddings_key, - allow_shape_mismatch=True, - ) - sharded_state_dict[sharded_word_embeddings_key] = sharded_word_embeddings_tensor - - if self.add_position_embedding: - position_embeddings_prefix = f'{prefix}position_embeddings.' - position_embeddings_state_dict = self.position_embeddings.state_dict( - prefix=position_embeddings_prefix, keep_vars=True - ) - sharded_position_embeddings_key = f'{position_embeddings_prefix}weight' - sharded_position_embeddings_tensor = make_sharded_tensor_for_checkpoint( - tensor=position_embeddings_state_dict[sharded_position_embeddings_key], - key=sharded_position_embeddings_key, - ) - sharded_state_dict[sharded_position_embeddings_key] = sharded_position_embeddings_tensor - - return sharded_state_dict diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py deleted file mode 100644 index ee2260e3ae001c756d3011bb9fe03ece338716be..0000000000000000000000000000000000000000 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from megatron.core.transformer.transformer_config import TransformerConfig - from megatron.core.transformer.transformer_block import TransformerBlock - -import torch -from torch import Tensor, nn - -from megatron.core import parallel_state - -__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] - - -def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim): - cp_size = parallel_state.get_context_parallel_world_size() - cp_rank = parallel_state.get_context_parallel_rank() - cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=pos_emb.device) - pos_emb = pos_emb.view( - *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :] - ) - pos_emb = pos_emb.index_select(seq_dim, cp_idx) - pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) - return pos_emb - - -class RotaryEmbedding(nn.Module): - """Rotary Embedding for language model. - - Args: - kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config - rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None - rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000. - """ - - def __init__( - self, - kv_channels: int, - rotary_percent: float, - seq_len_interpolation_factor: float = None, - rotary_base: int = 10000, - ) -> None: - super().__init__() - - dim = kv_channels - if rotary_percent < 1.0: - dim = int(dim * rotary_percent) - - self.seq_len_interpolation_factor = seq_len_interpolation_factor - self.inv_freq = 1.0 / ( - rotary_base - ** ( - torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) - / dim - ) - ) - - def forward(self, max_seq_len: int, offset: int = 0) -> Tensor: - """Forward pass of RoPE embedding. - - Args: - max_seq_len (int): Maximum size of sequence - offset (int, optional): _description_. Defaults to 0. - - Returns: - Tensor: Embeddings after applying RoPE. - """ - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) - - if self.seq_len_interpolation_factor is not None: - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.outer(seq, self.inv_freq) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) - # emb [seq_length, .., dim] - emb = emb[:, None, None, :] - if parallel_state.get_context_parallel_world_size() > 1: - # slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0) - return emb - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - state_dict.pop(f'{prefix}inv_freq', None) - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - def get_rotary_seq_len( - self, - inference_params, - transformer: TransformerBlock, - transformer_input: Tensor, - transformer_config: TransformerConfig, - ) -> float: - """Function to get the rotary sequence length. - - Args: - inference_params : Used during Inference time - transformer (TransformerBlock): The transformer block (decoder/encoder) used by the model - transformer_input (Tensor): _description_ - transformer_config (TransformerConfig): Transformer config used by the model - - Returns: - float: The rotary sequence length - """ - if inference_params is not None: - rotary_seq_len = inference_params.max_sequence_length - else: - if transformer.input_tensor is not None: - rotary_seq_len = transformer.input_tensor.size(0) - else: - rotary_seq_len = transformer_input.size(0) - - if transformer_config.sequence_parallel: - rotary_seq_len *= transformer_config.tensor_model_parallel_size - - rotary_seq_len *= transformer_config.context_parallel_size - - return rotary_seq_len - - -def _rotate_half(x: Tensor) -> Tensor: - """Change sign so the last dimension becomes [-odd, +even] - - Args: - x (Tensor): Input tensor - - Returns: - Tensor: Tensor rotated half - """ - - x1, x2 = torch.chunk(x, 2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: - """Apply rotary positional embedding to input tensor T. - - check https://kexue.fm/archives/8265 for detailed formulas - - Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] - - Returns: - Tensor: The input tensor after applying RoPE - """ - rot_dim = freqs.shape[-1] - - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) diff --git a/megatron/core/models/common/language_module/__init__.py b/megatron/core/models/common/language_module/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py deleted file mode 100644 index 97fbbf0f66229eb2ef4ecea52940bfef86342ad8..0000000000000000000000000000000000000000 --- a/megatron/core/models/common/language_module/language_module.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging - -import torch -from torch import Tensor - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig - - -class LanguageModule(MegatronModule): - """Base language module that has common helper functions used across GPT, BERT etc. - - Args: - config (TransformerConfig): Input transformer config for the model - """ - - def __init__(self, config: TransformerConfig) -> None: - super().__init__(config=config) - - def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: - """Computes the language model loss (Cross entropy across vocabulary) - - Args: - labels (Tensor): The labels of dimension [batch size, seq length] - logits (Tensor): The final logits returned by the output layer of the transformer model - - Returns: - Tensor: Loss tensor of dimensions [batch size, sequence_length] - """ - # [b s] => [s b] - labels = labels.transpose(0, 1).contiguous() - loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels) - - # [s b] => [b, s] - loss = loss.transpose(0, 1).contiguous() - return loss - - def initialize_last_stage_with_word_embeddings(self) -> None: - """Intializes the word embeddings in the final stage. - - This function just initalizes word embeddings in the final stage, when we are - using pipeline parallelism and sharind word embeddings. Nothing to do if we - arn't sharing weights or aren't using Pipeline parallelism - """ - if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process): - return - - if self.post_process and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.output_layer.weight.data.fill_(0) - self.output_layer.weight.shared = True - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - - # Ensure that first and last stages have the same initial parameter - # values. - if torch.distributed.is_initialized(): - if parallel_state.is_rank_in_embedding_group(): - weight = self.shared_embedding_or_output_weight() - torch.distributed.all_reduce( - weight.data, group=parallel_state.get_embedding_group() - ) - - elif not getattr(LanguageModule, "embedding_warning_printed", False): - logging.getLogger(__name__).warning( - "Distributed processes aren't initialized, so the output layer " - "is not initialized with weights from the word embeddings. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong." - ) - LanguageModule.embedding_warning_printed = True - - def shared_embedding_or_output_weight(self) -> Tensor: - """Gets the emedding weight or output logit weights when share embedding and output weights set to True. - - Returns: - Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight - """ - if self.pre_process: - return self.embedding.word_embeddings.weight - elif self.post_process: - return self.output_layer.weight - return None diff --git a/megatron/core/models/gpt/__init__.py b/megatron/core/models/gpt/__init__.py deleted file mode 100644 index 2d5eb8674f1d19673664160d5eddf3432a6a5399..0000000000000000000000000000000000000000 --- a/megatron/core/models/gpt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gpt_model import GPTModel diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py deleted file mode 100644 index aace1590d82456dc3b9d32d783bd01aad6514dd2..0000000000000000000000000000000000000000 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEDotProductAttention, - TELayerNormColumnParallelLinear, - TERowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.switch_mlp import SwitchMLP -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules - - -# Use this spec to use lower level Transformer Engine modules (required for fp8 training) -def get_gpt_layer_with_transformer_engine_spec() -> ModuleSpec: - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -# Use this spec for an implementation using only modules in megatron core -def get_gpt_layer_local_spec() -> ModuleSpec: - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ) - - -# Use this spec to use lower level Transformer Engine modules and SwitchMLP based MoE -gpt_layer_with_transformer_engine_spec_moe = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=TELayerNormColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=SwitchMLP, # MOE - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), -) - -# Use this spec for an implementation using only modules in megatron core for MoE models -gpt_layer_local_spec_moe = ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=FusedLayerNorm, - self_attention=ModuleSpec( - module=SelfAttention, - params={"attn_mask_type": AttnMaskType.causal}, - submodules=SelfAttentionSubmodules( - linear_qkv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=FusedLayerNorm, - mlp=ModuleSpec( - module=SwitchMLP, # MOE - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ), - ), - mlp_bda=get_bias_dropout_add, - ), -) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py deleted file mode 100644 index 2cf26bacacd21ba26bd8df0eefb2d52a0f53ea95..0000000000000000000000000000000000000000 --- a/megatron/core/models/gpt/gpt_model.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import logging -from typing import Literal, Optional, Union - -import torch -from torch import Tensor - -from megatron.core import InferenceParams, parallel_state, tensor_parallel -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.models.common.language_module.language_module import LanguageModule -from megatron.core.transformer.enums import AttnMaskType, ModelType -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint - - -class GPTModel(LanguageModule): - """GPT Transformer language model. - - Args: - config (TransformerConfig): Transformer config - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers - vocab_size (int): Vocabulary size - max_sequence_length (int): maximum size of sequence. This is used for positional embedding - pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True. - post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. - fp16_lm_cross_entropy (bool, optional): Defaults to False. - parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True. - share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'. - rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. - rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000. - seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None. - """ - - def __init__( - self, - config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - vocab_size: int, - max_sequence_length: int, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - share_embeddings_and_output_weights: bool = False, - position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', - rotary_percent: float = 1.0, - rotary_base: int = 10000, - seq_len_interpolation_factor: Optional[float] = None, - ) -> None: - super().__init__(config=config) - - self.transformer_layer_spec: ModuleSpec = transformer_layer_spec - self.vocab_size = vocab_size - self.max_sequence_length = max_sequence_length - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.position_embedding_type = position_embedding_type - - # megatron core pipelining currently depends on model type - # TODO: remove this dependency ? - self.model_type = ModelType.encoder_or_decoder - - if self.pre_process: - self.embedding = LanguageModelEmbedding( - config=self.config, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - position_embedding_type=position_embedding_type, - ) - - if self.position_embedding_type == 'rope': - self.rotary_pos_emb = RotaryEmbedding( - kv_channels=self.config.kv_channels, - rotary_percent=rotary_percent, - seq_len_interpolation_factor=seq_len_interpolation_factor, - rotary_base=rotary_base, - ) - - # Transformer. - self.decoder = TransformerBlock( - config=self.config, - spec=transformer_layer_spec, - pre_process=self.pre_process, - post_process=self.post_process, - ) - - # Output - if post_process: - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.pre_process - and self.share_embeddings_and_output_weights, - ) - - if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): - self.initialize_last_stage_with_word_embeddings() - - def set_input_tensor(self, input_tensor: Tensor) -> None: - """Sets input tensor to the model. - - See megatron.model.transformer.set_input_tensor() - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' - self.decoder.set_input_tensor(input_tensor[0]) - - def forward( - self, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - extra_block_kwargs: dict = None, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config - ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - **(extra_block_kwargs or {}), - ) - - if not self.post_process: - return hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer(hidden_states, weight=output_weight) - - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss - - def sharded_state_dict(self, prefix: str = '') -> dict: - sharded_state_dict = {} - - if self.pre_process: - embedding_prefix = f'{prefix}embedding.' - embedding_sharded_state_dict = self.embedding.sharded_state_dict( - prefix=embedding_prefix - ) - sharded_state_dict.update(embedding_sharded_state_dict) - - decoder_prefix = f'{prefix}decoder.' - decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix) - sharded_state_dict.update(decoder_sharded_state_dict) - - if self.post_process: - output_layer_prefix = f'{prefix}output_layer.' - output_layer_key = f'{output_layer_prefix}weight' - if self.share_embeddings_and_output_weights: - if not self.pre_process: - # when sharing embeddings with last stage, we need to use the weights from the first stage - # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight - tensor = self.shared_embedding_or_output_weight() - first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' - last_stage_word_emb_replica_id = ( - 1, # copy of first stage embedding - 0, - parallel_state.get_data_parallel_rank(), - ) - - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, - allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_key] = sharded_output_layer_tensor - - else: - output_layer_state_dict = self.output_layer.state_dict( - prefix=output_layer_prefix, keep_vars=True - ) - output_layer_tensor = output_layer_state_dict[output_layer_key] - # independent output layer - sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( - tensor=output_layer_tensor, key=output_layer_key, allow_shape_mismatch=True, - ) - - sharded_state_dict[output_layer_key] = sharded_output_layer_tensor - - return sharded_state_dict diff --git a/megatron/core/models/retro/__init__.py b/megatron/core/models/retro/__init__.py deleted file mode 100644 index c101fcb1e4cf51be9b2e2268597ed1b1f11a9319..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from .config import RetroConfig -from .decoder_spec import get_retro_decoder_block_spec -from .model import RetroModel diff --git a/megatron/core/models/retro/base_attention.py b/megatron/core/models/retro/base_attention.py deleted file mode 100644 index 4bafd48daf321e7db6e907ac520ebc92716c93a6..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/base_attention.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.models.retro.config import RetroConfig -from megatron.core.transformer.attention import CrossAttention, CrossAttentionSubmodules -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.module import MegatronModule - - -class BaseRetroCrossAttention(MegatronModule): - - """Base class for Retro cross attention, for both encoder & decoder layers. - - This class collects the retro arguments below (i.e., num neighbors, chunk - length, and retrieve length) for use in Retro's custom cross attention - operators. - - Arguments: - config (RetroConfig): Retro config. - - submodules (CrossAttentionSubmodules): Cross attention submodules. - - layer_number (int): Layer number within transformer block. - - attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding'). - """ - - def __init__( - self, - config: RetroConfig, - submodules: CrossAttentionSubmodules, - layer_number: int = 1, - attn_mask_type: AttnMaskType = AttnMaskType.padding, - ): - super().__init__(config=config) - - self.attn = CrossAttention( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - ) - - self.retro_num_neighbors = config.retro_num_neighbors - self.retro_chunk_length = config.retro_preprocess.retro_gpt_chunk_length - self.retro_retrieved_length = config.retro_preprocess.retro_gpt_retrieved_length diff --git a/megatron/core/models/retro/config.py b/megatron/core/models/retro/config.py deleted file mode 100644 index 2ffeb94bb386c3d394f17b3fd5e8bbf74d495474..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/config.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import types -from dataclasses import dataclass - -from megatron.core.transformer import TransformerConfig - - -@dataclass -class RetroConfig(TransformerConfig): - - """Configuration object for Retro models. - - Attributes: - - retro_preprocess (SimpleNamespace): Retro preprocess arguments. - retro_workdir (str): Retro working directory, which contains the - preprocessed data for for pretraining. This directory is built during - preprocessing (see tools/retro/README.md), and contains subdirectories - for the chunk database and pretraining neighbors. - retro_encoder_layers (int): Number of layers to use for the retrieval - encoder. - retro_encoder_hidden_dropout (float): Hidden dropout for retrieval - encoder. - retro_encoder_attention_dropout (float): Attention dropout for retrieval - encoder. - retro_num_neighbors (int): Number of neighbors to retrieve during - pretraining. - retro_num_retrieved_chunks (int): Number of chunks to retrieve from the - retrieval database. - retro_verify_neighbor_count (bool): Verify that len(GPT dataset) == - len(saved neighbors). - """ - - # Retro. - retro_preprocess: types.SimpleNamespace = None - retro_workdir: str = None - retro_encoder_num_layers: int = 2 - retro_encoder_hidden_dropout: float = 0.1 - retro_encoder_attention_dropout: float = 0.1 - retro_num_neighbors: int = 2 - retro_num_retrieved_chunks: int = 2 - retro_verify_neighbor_count: bool = True diff --git a/megatron/core/models/retro/decoder_attention.py b/megatron/core/models/retro/decoder_attention.py deleted file mode 100644 index f934c6c717f370c682329484b47c39c4f0b71577..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/decoder_attention.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Retro's cross attention modules for the decoder block.""" - -from functools import partial -from typing import Callable - -import numpy as np -import torch -from torch import Tensor - -from megatron.core import InferenceParams -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.retro.base_attention import BaseRetroCrossAttention -from megatron.core.models.retro.config import RetroConfig -from megatron.core.transformer import ModuleSpec -from megatron.core.transformer.attention import CrossAttentionSubmodules -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_block import TransformerBlock - - -class RetroDecoderCrossAttention(BaseRetroCrossAttention): - - """Retro decoder's chunked cross attention operator. - - See this paper for more details: https://arxiv.org/abs/2112.04426. - Neighboring chunks retrieved from the chunk database are used here for - chunked-cross attention. - - Arguments: - config (RetroConfig): Retro config. - - submodules (CrossAttentionSubmodules): Cross attention submodules. - - layer_number (int): Layer number within transformer block. - - attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding'). - - encoder_block_spec (ModuleSpec): The first Retro decoder - layer is provided with a transformer block spec to construct the - neighbor encoder. - """ - - def __init__( - self, - config: RetroConfig, - submodules: CrossAttentionSubmodules, - layer_number: int = 1, - attn_mask_type: AttnMaskType = AttnMaskType.padding, - encoder_block_spec: ModuleSpec = None, - ): - """ - ** Note about 'encoder_block_spec' ** - - Retro is an encoder-decoder model that uses its encoder for encoding - neighboring chunks that are retrieved from a chunk database. These - encoded neighbors are then used in the decoder stack for performing - chunked-cross attention (see paper link above). - - In contrast to the T5 model, the encoder and decoder are computationally - intertwined, since the input to the encoder is the output of the self- - attention of the first decoder layer. As such, the encoder block itself - is instantiated within the first Retro decoder layer, in order to receive - the self-attention's output. (Note, that only the first decoder layer - instantiates an encoder block, and the remaining decoder layers use the - encoder output from the first decoder layer.) - """ - - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - ) - - if encoder_block_spec: - self.encoder = TransformerBlock( - config=config, spec=encoder_block_spec, pre_process=True, post_process=False, - ) - # self._encoder_key = 'encoder' # ... necessary? - else: - self.encoder = None - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - key_value_states: Tensor = None, - inference_params: InferenceParams = None, - # rotary_pos_emb: Tensor = None, # ... unsupported for retro. - ) -> Tensor: - """Cross attention for Retro decoder. - - Notation: - ns : Sequence length. - bs : Batch size. - d : Hidden size. - l : Number of chunks per sample (i.e., seq_length/chunk_length). - m : Number of tokens per chunk. - k : Number of neighbors. - r : Number of retrieved tokens (neighbors + continuation). - - Arguments: - hidden_states (Tensor): Transformer layer hidden states. - - attention_mask (Tensor): Attention mask. - - key_value_states (Tensor): Neighbor embeddings if first decoder - layer, else encoder output. - - inference_params (InferenceParams): Inference params. - """ - - # hidden_states: [ ns, bs, d ] - # key_value_states: [ r, k*bs*l, d ] - - ns, bs, d = hidden_states.shape - l = int(np.ceil(ns / self.retro_chunk_length)) - - # Retrieve neighbors. - if self.encoder: - - # Sequence length remainder. - first_ns = ns % self.retro_chunk_length - - # Case 1: Sequence length not divisible by chunk length. - if first_ns > 0: - - # Split sequence into first partial chunk & remaining chunks. - first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:] - - # Pad partial chunk with zeros. - first_chunk = torch.nn.functional.pad( - first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0, - ) - - # Concatenate padded chunk with remaining chunks. - chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ] - - # Case 2: Sequence length is divisible by chunk length. - else: - chunked_output = hidden_states # [ l*m, bs, d ] - - # Chunk & permute hidden states. - # - hidden_states: [ l*m, bs, d ] - # - chunked_output: [ m, bs*l, d ] - chunked_output = ( - chunked_output.reshape(l, self.retro_chunk_length, bs, d) - .permute(1, 2, 0, 3) - .reshape(self.retro_chunk_length, bs * l, d) - .contiguous() - ) - - # Encode neighbors. (Note: 'key_value_states' re-assigned here.) - key_value_states = self.encoder( - hidden_states=key_value_states, - attention_mask=attention_mask, - context=chunked_output, - context_mask=None, - inference_params=inference_params, - ) # [ r, k*bs*l, d ] - key_value_states = key_value_states.reshape( - self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d - ) # [ r*k, bs*l, d ] - - # Attend starting at last token of first chunk. - pad = (ns - 1) % self.retro_chunk_length - attending_chunks = hidden_states[pad:] - - # Pad attending tokens to sequence length. - padded_chunks = torch.nn.functional.pad( - attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0, - ) - - # Permute attending chunks. - # - padded_chunks: [ l*m, bs, d ] - # - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above) - padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute( - 1, 2, 0, 3 - ) - padded_chunked_output = padded_chunked_output.reshape( - self.retro_chunk_length, bs * l, d - ).contiguous() - - # Attend to encoded neighbors. - attention_output, attention_bias = self.attn( - padded_chunked_output, None, key_value_states=key_value_states, - ) - - # Return dimensions for bias-dropout step. - return { - "ns": ns, - "bs": bs, - "d": d, - "l": l, - "pad": pad, - "attention_output": attention_output, # [ m, bs*l, d ] - "attention_bias": attention_bias, # [ d ] - "context": key_value_states, # [ r*k, bs*l, d ] - } - - -class RetroDecoderBiasDropoutAdd(MegatronModule): - - """Retro decoder's bias-dropout-add operator. - - This operator takes care of reshaping and permuting the output from the - chunk dimension to the sequence dimension. - - Arguments: - config (RetroConfig): Retro config. - """ - - def __init__( - self, config: RetroConfig, - ): - super().__init__(config=config) - self.retro_chunk_length = config.retro_preprocess.retro_gpt_chunk_length - - @classmethod - def _forward( - cls, - x_with_bias: dict, - residual: Tensor, - prob: float, - retro_chunk_length: int, - bias_dropout_add: Callable, - ) -> Tensor: - """Per-chunk bias-dropout-add. - - Arguments: - x_with_bias (dict): Attention output and bias, along with other Retro - relevant parameters. - - residual (Tensor): Transformer layer residual. - - prob (float): Dropout probability. - - retro_chunk_length (int): Retro chunk length (e.g., 64). - - bias_dropout_add (Callable): Bias-dropout-add function. - """ - - # Extract input dict. - ns = x_with_bias["ns"] - bs = x_with_bias["bs"] - d = x_with_bias["d"] - l = x_with_bias["l"] - pad = x_with_bias["pad"] - attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ] - attention_bias = x_with_bias["attention_bias"] # [ d ] - - # Re-enable torch grad to enable fused optimization. - with torch.enable_grad(): - - # Bias-dropout-add. - x = bias_dropout_add( - ( - attention_output, - None if attention_bias is None else attention_bias.expand_as(attention_output), - ), - torch.zeros_like(attention_output), - prob, - ) - - # Permute chunks back to sequence dimension. - # 1. [ m, bs*l, d ] - # 2. [ m, bs, l, d ] - # 3. [ l, m, bs, d ] - # 4. [ m*l, bs, d ] == [ ns, bs, d ] - x = ( - x.reshape(retro_chunk_length, bs, l, d) - .permute(2, 0, 1, 3) - .reshape(retro_chunk_length * l, bs, d) - ) - - # Prepend zeros for non-attending tokens. - x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[ - :ns - ] # [ ns, bs, d ] - - # Add residual. [ ns, bs, d ] - x = x + residual - - # Output. [ ns, bs, d ] - return x - - def forward(self, training: bool, fused: bool) -> Tensor: - """Retro decoder bias-dropout-add. - - Arguments: - training (bool): If training, then apply dropout. - - fused (bool): Fuse bias-dropout-add. - """ - return partial( - self._forward, - retro_chunk_length=self.retro_chunk_length, - bias_dropout_add=get_bias_dropout_add(training, fused), - ) diff --git a/megatron/core/models/retro/decoder_spec.py b/megatron/core/models/retro/decoder_spec.py deleted file mode 100644 index d23e4981e004c3de26f5f708e90597177211e1db..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/decoder_spec.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from megatron.core import parallel_state -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.retro.config import RetroConfig -from megatron.core.models.retro.decoder_attention import ( - RetroDecoderBiasDropoutAdd, - RetroDecoderCrossAttention, -) -from megatron.core.models.retro.encoder_spec import get_retro_encoder_block_spec -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer import ModuleSpec -from megatron.core.transformer.attention import CrossAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TENorm, - TERowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.transformer_block import ( - TransformerBlockSubmodules, - get_num_layers_to_build, -) - - -def get_retro_decoder_layer_te_spec(encoder_block_spec: ModuleSpec = None) -> ModuleSpec: - """Retro decoder TE spec (uses Transformer Engine components). - - A Retro decoder layer uses custom attention and bias-dropout-add operators - to perform chunked-cross attention. Additionally, the first Retro decoder - layer instantiates an entire encoder transformer block. As such, the decoder - cross attention module takes an optional encoder block spec, which is only - provided for the first Retro decoder layer. - - Arguments: - encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided - for the first Retro decoder layer. - """ - spec = get_gpt_layer_with_transformer_engine_spec() - spec.submodules.pre_cross_attn_layernorm = TENorm - spec.submodules.cross_attention = ModuleSpec( - module=RetroDecoderCrossAttention, - params={"encoder_block_spec": encoder_block_spec,}, - submodules=CrossAttentionSubmodules( - linear_q=TEColumnParallelLinear, - linear_kv=TEColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ) - spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd) - return spec - - -def get_retro_decoder_layer_local_spec(encoder_block_spec: ModuleSpec = None) -> ModuleSpec: - """Retro decoder local spec (uses Megatron-Core components). - - A Retro decoder layer uses custom attention and bias-dropout-add operators - to perform chunked-cross attention. Additionally, the first Retro decoder - layer instantiates an entire encoder transformer block. As such, the decoder - cross attention module takes an optional encoder block spec, which is only - provided for the first Retro decoder layer. - - Arguments: - encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided - for the first Retro decoder layer. - """ - spec = get_gpt_layer_local_spec() - spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm - spec.submodules.cross_attention = ModuleSpec( - module=RetroDecoderCrossAttention, - params={"encoder_block_spec": encoder_block_spec,}, - submodules=CrossAttentionSubmodules( - linear_q=ColumnParallelLinear, - linear_kv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ) - spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd) - return spec - - -def get_retro_decoder_block_spec( - config: RetroConfig, use_transformer_engine: bool -) -> TransformerBlockSubmodules: - - """Retro decoder block spec. - - Retro decoder block implementation details: - - The retro decoder block consists of interleaved GPT layers and customized - Retro decoder layers. - - The Retro decoder layers are spaced three layers apart, and start on layer - 6 or 9 (depending on the total number of layers). - - The first decoder layer instantiates an encoder block, and it therefore - passes in an encoder_block_spec. - - - Arguments: - config (RetroConfig): Retro config. - - use_transformer_engine (bool): If True, use Transformer Engine (instead - of local modules. - """ - - # Num layers. - assert ( - parallel_state.get_pipeline_model_parallel_world_size() == 1 - ), "retro does not currently support pipeline parallelism." - assert ( - parallel_state.get_virtual_pipeline_model_parallel_world_size() is None - ), "retro does not currently support virtual pipeline parallelism." - num_layers = get_num_layers_to_build(config) - - # Retro layer numbers. - retro_layer_start = 6 if num_layers <= 15 else 9 - retro_layer_numbers = list(range(retro_layer_start, num_layers + 1, 3)) - - # Layer specs. - gpt_layer_spec = ( - get_gpt_layer_with_transformer_engine_spec() - if use_transformer_engine - else get_gpt_layer_local_spec() - ) - get_retro_decoder_layer_spec = ( - get_retro_decoder_layer_te_spec - if use_transformer_engine - else get_retro_decoder_layer_local_spec - ) - retro_layer_spec = get_retro_decoder_layer_spec() - retro_layer_spec_with_retriever = get_retro_decoder_layer_spec( - get_retro_encoder_block_spec(config, use_transformer_engine) - ) - - layer_specs = [] - for layer_number in range(1, num_layers + 1): - if layer_number == retro_layer_numbers[0]: - layer_specs.append(retro_layer_spec_with_retriever) - elif layer_number in retro_layer_numbers: - layer_specs.append(retro_layer_spec) - else: - layer_specs.append(gpt_layer_spec) - - # Block spec. - block_spec = TransformerBlockSubmodules(layer_specs=layer_specs) - - return block_spec diff --git a/megatron/core/models/retro/encoder_attention.py b/megatron/core/models/retro/encoder_attention.py deleted file mode 100644 index 5840e3e3017150956b0997bb6d697e1eda85d5ff..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/encoder_attention.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Retro's cross attention modules for the encoder block.""" - -from functools import partial -from typing import Callable, Optional, Tuple, Type - -import torch -from torch import Tensor - -from megatron.core import InferenceParams -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.retro.base_attention import BaseRetroCrossAttention -from megatron.core.models.retro.config import RetroConfig -from megatron.core.transformer.module import MegatronModule - - -class RetroEncoderCrossAttention(BaseRetroCrossAttention): - - """Retro encoder's cross attention operator. - - See this paper for more details: https://arxiv.org/abs/2112.04426. - Neighboring chunks are retrieved from the chunk database, encoded, and - used by the decoder layers for chunked cross attention. - - Arguments: - config (RetroConfig): Retro config. - - submodules (CrossAttentionSubmodules): Cross attention submodules. - - layer_number (int): Layer number within transformer block. - - attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding'). - """ - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - key_value_states: Tensor = None, - inference_params: InferenceParams = None, - # rotary_pos_emb: Tensor = None, # unsupported for retro. - ) -> Tensor: - """Cross attention for Retro encoder. - - Notation: - ns : Sequence length. - bs : Batch size. - d : Hidden size. - l : Number of chunks per sample (i.e., seq_length/chunk_length). - k : Number of neighbors. - r : Number of retrieved tokens (neighbors + continuation). - - Arguments: - hidden_states (Tensor): Transformer layer hidden states. - - attention_mask (Tensor): Attention mask. - - key_value_states (Tensor): Neighbor embeddings. - - inference_params (InferenceParams): Inference params. - """ - - # Input shape. [ r, bs*l*k, d ] - ns, bs, d = hidden_states.shape - - # Reshape sequence into neighboring chunks. - # - hidden_states: [ r, bs*l*k, d ] - # - chunked_outputs: [ r, bs*l, k, d ] - chunked_outputs = hidden_states.reshape( - self.retro_retrieved_length, -1, self.retro_num_neighbors, d - ) - - # Per-chunk attention. - attention_output_tuples = [] - for k in range(self.retro_num_neighbors): - - # Attend to current neighboring chunks. - # - chunked_output: [ r, bs*l, d ] - # - key_value_states: [ m, bs*l, d ] - # - attention_output: [ r, bs*l, d ] - # - attention_bias: [ d ] - chunked_output = chunked_outputs[:, :, k].contiguous() - attention_output, attention_bias = self.attn( - hidden_states=chunked_output, # Q (neighbor embedding) - attention_mask=None, - key_value_states=key_value_states, # K, V (hidden act) - ) - - # Residual connection. [ r, bs*l, d ] - residual = chunked_output - - # Collect tensors. - attention_output_tuples.append((attention_output, attention_bias, residual,)) - - # Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]]) - return attention_output_tuples - - -class RetroEncoderBiasDropoutAdd(MegatronModule): - - """Retro encoder's bias-dropout-add operator. - - This operator applies bias-dropout-add individually on each neighboring - chunk that is retrieved from the chunk database. - - Arguments: - config (RetroConfig): Retro config. - """ - - def __init__( - self, config: RetroConfig, - ): - super().__init__(config=config) - self.retro_num_neighbors = config.retro_num_neighbors - - @classmethod - def _forward( - cls, - x_with_bias: Tuple[Tensor, Optional[Tensor]], - residual: Tensor, - prob: float, - retro_num_neighbors: int, - bias_dropout_add: Callable, - ) -> Tensor: - """Per-chunk bias-dropout-add. - - Arguments: - x_with_bias (dict): Attention output and bias tuple. - - residual (Tensor): Transformer layer residual. - - prob (float): Dropout probability. - - retro_num_neighbors (int): Number of retrieved neighbor chunks (e.g., 2). - - bias_dropout_add (Callable): Bias-dropout-add function. - """ - - # Re-enable torch grad to enable fused optimization. - with torch.enable_grad(): - - # Per-neighbor bias-dropout-add. - # - attention_output: [ r, bs*l, d ] - # - attention_bias: [ d ] - # - residual: [ r, bs*l, d ] - # - output: [ r, bs*l, d ] - outputs = [ - bias_dropout_add( - ( - attention_output, - None if attention_bias is None else attention_bias.expand_as(residual), - ), - residual, - prob, - ) - for attention_output, attention_bias, residual in x_with_bias - ] - - # Concatenate outputs (to shape [r, k*bs*l, d]; see notation above). - r, _, d = outputs[0].shape - output = torch.stack(outputs, dim=1).reshape(r, -1, d) - - # Output. [ r, k*bs*l, d ] - return output - - def forward(self, training: bool, fused: bool) -> Tensor: - """Retro decoder bias-dropout-add. - - Arguments: - training (bool): If training, then apply dropout. - - fused (bool): Fuse bias-dropout-add. - """ - return partial( - self._forward, - retro_num_neighbors=self.retro_num_neighbors, - bias_dropout_add=get_bias_dropout_add(training, fused), - ) - - -class RetroEncoderLayerNorm(MegatronModule): - - """Retro encoder's layernorm operator. - - This operator applies layernorm individually on each neighboring chunk that - is retrieved from the chunk database, and then concatenates the chunks into - a single tensor. - - Arguments: - config (RetroConfig): Retro config. - """ - - def __init__( - self, config: RetroConfig, submodules: Type, **kwargs, - ): - super().__init__(config=config) - norm_class = submodules - self.norm = norm_class(config=config, **kwargs) - self.retro_num_neighbors = config.retro_num_neighbors - - def forward(self, input: Tensor) -> Tensor: - """Per-chunk layer norm. - - Arguments: - input (Tensor): Input chunks, concatenated into a single tensor. - """ - - # Input shape: [ r, k*bs*l, d ]. (see notation above in attention module) - - # Split input into 'num_neighbors' tensors. - chunk_size = input.shape[1] // self.retro_num_neighbors - inputs = torch.split(input, chunk_size, dim=1) - - # Norm. - outputs = [self.norm(inp.contiguous()) for inp in inputs] - - # Concatenate layer norms (to shape [r, k*bs*l, d]; see notation above). - r, _, d = inputs[0].shape - output = torch.stack(outputs, dim=1).reshape(r, -1, d) - - # Output. [ r, k*bs*l, d ] - return output diff --git a/megatron/core/models/retro/encoder_spec.py b/megatron/core/models/retro/encoder_spec.py deleted file mode 100644 index 63efadedd884f075e35c52b00f216013ede2ab21..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/encoder_spec.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.models.retro.config import RetroConfig -from megatron.core.models.retro.encoder_attention import ( - RetroEncoderBiasDropoutAdd, - RetroEncoderCrossAttention, - RetroEncoderLayerNorm, -) -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer import ModuleSpec -from megatron.core.transformer.attention import CrossAttentionSubmodules -from megatron.core.transformer.custom_layers.transformer_engine import ( - TEColumnParallelLinear, - TEDotProductAttention, - TENorm, - TERowParallelLinear, -) -from megatron.core.transformer.dot_product_attention import DotProductAttention -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules - - -def get_retro_encoder_layer_te_spec() -> ModuleSpec: - """Retro encoder TE spec (uses Transformer Engine components). - - A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm - operators to encode neighboring chunks that are retrieved from the chunk - database. Each operator is responsible for iterating the retrieved chunks - and processing them individually. - """ - spec = get_gpt_layer_with_transformer_engine_spec() - spec.submodules.pre_cross_attn_layernorm = TENorm - spec.submodules.cross_attention = ModuleSpec( - module=RetroEncoderCrossAttention, - params={"attn_mask_type": AttnMaskType.padding,}, - submodules=CrossAttentionSubmodules( - linear_q=TEColumnParallelLinear, - linear_kv=TEColumnParallelLinear, - core_attention=TEDotProductAttention, - linear_proj=TERowParallelLinear, - ), - ) - spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd) - spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm,) - spec.submodules.mlp = ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear, - ), - ) - return spec - - -def get_retro_encoder_layer_local_spec() -> ModuleSpec: - """Retro encoder local spec (uses Megatron-Core components). - - A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm - operators to encode neighboring chunks that are retrieved from the chunk - database. Each operator is responsible for iterating the retrieved chunks - and processing them individually. - """ - spec = get_gpt_layer_local_spec() - spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm - spec.submodules.cross_attention = ModuleSpec( - module=RetroEncoderCrossAttention, - params={"attn_mask_type": AttnMaskType.padding,}, - submodules=CrossAttentionSubmodules( - linear_q=ColumnParallelLinear, - linear_kv=ColumnParallelLinear, - core_attention=DotProductAttention, - linear_proj=RowParallelLinear, - ), - ) - spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd) - spec.submodules.pre_mlp_layernorm = ModuleSpec( - module=RetroEncoderLayerNorm, submodules=FusedLayerNorm, - ) - spec.submodules.mlp = ModuleSpec( - module=MLP, - submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,), - ) - return spec - - -def get_retro_encoder_block_spec( - config: RetroConfig, use_transformer_engine: bool -) -> TransformerBlockSubmodules: - - """Retro encoder block spec. - - The retro encoder block consists of one customized Retro encoder layer - (layer 1), and all of the following layers are standard GPT layers. - - Arguments: - config (RetroConfig): Retro config. - - use_transformer_engine (bool): If True, use Transformer Engine (instead - of local modules. - """ - - # Num layers. - num_layers = config.retro_encoder_num_layers - retro_layer_numbers = [1] - - # Layer specs. - gpt_layer_spec = ( - get_gpt_layer_with_transformer_engine_spec() - if use_transformer_engine - else get_gpt_layer_local_spec() - ) - get_retro_encoder_layer_spec = ( - get_retro_encoder_layer_te_spec - if use_transformer_engine - else get_retro_encoder_layer_local_spec - ) - retro_layer_spec = get_retro_encoder_layer_spec() - for spec in (gpt_layer_spec, retro_layer_spec): - spec.params["hidden_dropout"] = config.retro_encoder_hidden_dropout - spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding - spec.submodules.self_attention.submodules.core_attention = ModuleSpec( - module=TEDotProductAttention if use_transformer_engine else DotProductAttention, - params={"attention_dropout": config.retro_encoder_attention_dropout,}, - ) - - layer_specs = [] - for layer_number in range(1, num_layers + 1): - if layer_number in retro_layer_numbers: - layer_specs.append(retro_layer_spec) - else: - layer_specs.append(gpt_layer_spec) - - # Block spec. - block_spec = TransformerBlockSubmodules(layer_specs=layer_specs) - - return block_spec diff --git a/megatron/core/models/retro/model.py b/megatron/core/models/retro/model.py deleted file mode 100644 index d47c08fb52788b23ba4c3301bd24401d42e6d2a6..0000000000000000000000000000000000000000 --- a/megatron/core/models/retro/model.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Retro Model.""" - -from torch import Tensor - -from megatron.core import InferenceParams -from megatron.core.models.gpt import GPTModel - - -class RetroModel(GPTModel): - - """Retro Model. - - A Retro model mostly re-uses the GPTModel interface, with the only difference - being the embedding of the 'context' this is used by Retro for processing - neighbor tokens. This embedded context is then forwarded to the Transformer - Block. - """ - - def forward( - self, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - context_input_ids: Tensor = None, - context_position_ids: Tensor = None, - context_mask: Tensor = None, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - ) -> Tensor: - """RetroModel forward method. - - Foward input tokens & mask, along with neighbor tokens & mask, through - the Retro model.. - - Arguments: - input_ids (Tensor): Input token IDs. - - position_ids (Tensor): Input position IDs. - - attention_mask (Tensor): Input attention mask. - - context_input_ids (Tensor): Context (i.e., neighbor) token IDs. - - context_position_ids (Tensor): Context (i.e., neighbor) position IDs. - - context_mask (Tensor): Context (i.e., neighbor) attention mask. - - decoder_input (Tensor): When using pipeline parallelism, input_ids and - position_ids will only be used on the first stage, and for all other - stages decoder_input will be provided via communication from the - previous stage. - - labels (Tensor): The labels of dimension [batch size, seq length]. - - inference_params (InferenceParams): Parameters for inference. - """ - - # Argument shapes: - # Notation: - # ns : Sequence length. - # bs : Batch size. - # d : Hidden size. - # l : Number of chunks per sample (i.e., seq_length/chunk_length). - # k : Number of neighbors. - # r : Number of retrieved tokens (neighbors + continuation). - # - input_ids: [ bs, ns ] - # - context_ids: [ k*bs*l, r ] - # - context: [ r, k*bs*l, d ] - # - output: [ ns, bs, d ] - - # Context embedding (e.g., for Retro neighbor tokens). - if context_input_ids is not None: - context = self.embedding(context_input_ids, context_position_ids) - else: - context = None - - # Call GPTModel.forward, and pass in embedded context. - return super().forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - decoder_input=decoder_input, - labels=labels, - inference_params=inference_params, - extra_block_kwargs={"context": context, "context_mask": context_mask,}, - ) diff --git a/megatron/core/package_info.py b/megatron/core/package_info.py deleted file mode 100644 index 55c49b1785b0cdfa94854ca73431d963150b8c7f..0000000000000000000000000000000000000000 --- a/megatron/core/package_info.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - - -MAJOR = 0 -MINOR = 4 -PATCH = 0 -PRE_RELEASE = 'rc0' - -# Use the following formatting: (major, minor, patch, pre-release) -VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) - -__shortversion__ = '.'.join(map(str, VERSION[:3])) -__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) - -__package_name__ = 'megatron_core' -__contact_names__ = 'NVIDIA' -__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email -__homepage__ = ( - 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage -) -__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core' -__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases' -__description__ = ( - 'Megatron Core - a library for efficient and scalable training of transformer based models' -) -__license__ = 'BSD-3' -__keywords__ = ( - 'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch' -) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py deleted file mode 100644 index 5652b208468537cf2b0e97283d8c719995e1112d..0000000000000000000000000000000000000000 --- a/megatron/core/parallel_state.py +++ /dev/null @@ -1,980 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Model and data parallel groups.""" - -import os -from typing import Optional - -import torch - -from .utils import GlobalMemoryBuffer - -# Intra-layer model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None -# Inter-layer model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None -# Model parallel group (both intra- and pipeline) that the current rank belongs to. -_MODEL_PARALLEL_GROUP = None -# Embedding group. -_EMBEDDING_GROUP = None -# Position embedding group. -_POSITION_EMBEDDING_GROUP = None -# Data parallel group that the current rank belongs to. -_DATA_PARALLEL_GROUP = None -_DATA_PARALLEL_GROUP_GLOO = None -# tensor model parallel group and data parallel group combined -# used for fp8 and moe training -_TENSOR_AND_DATA_PARALLEL_GROUP = None -# Expert parallel group that the current rank belongs to. -_TENSOR_AND_EXPERT_PARALLEL_GROUP = None -_DATA_MODULO_EXPERT_PARALLEL_GROUP = None - - -_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None -_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None - -# These values enable us to change the mpu sizes on the fly. -_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_TENSOR_MODEL_PARALLEL_RANK = None -_MPU_PIPELINE_MODEL_PARALLEL_RANK = None - -# A list of ranks that have a copy of the embedding. -_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the position embedding. -_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of global ranks for each pipeline group to ease calculation of the source -# rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None - -# A list of global ranks for each data parallel group to ease calculation of the source -# rank when broadcasting weights from src to all other data parallel ranks -_DATA_PARALLEL_GLOBAL_RANKS = None - -# Context parallel group that the current rank belongs to -_CONTEXT_PARALLEL_GROUP = None -# A list of global ranks for each context parallel group to ease calculation of the -# destination rank when exchanging KV/dKV between context parallel_ranks -_CONTEXT_PARALLEL_GLOBAL_RANKS = None - -# Data parallel group information with context parallel combined. -_DATA_PARALLEL_GROUP_WITH_CP = None -_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None -_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None - -# combined parallel group of TP, DP, and CP used for fp8 -_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None - -# Memory buffers to avoid dynamic memory allocation -_GLOBAL_MEMORY_BUFFER = None - - -def get_nccl_options(pg_name, nccl_comm_cfgs): - """Set the NCCL process group options. - - Arguments: - pg_name (str): process group name - nccl_comm_cfgs (dict): nccl communicator configurations - - When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting. - """ - if pg_name in nccl_comm_cfgs: - nccl_options = torch.distributed.ProcessGroupNCCL.Options() - nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4) - nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32) - nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1) - return nccl_options - else: - return None - - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - virtual_pipeline_model_parallel_size: Optional[int] = None, - pipeline_model_parallel_split_rank: Optional[int] = None, - use_sharp: bool = False, - context_parallel_size: int = 1, - expert_model_parallel_size: int = 1, - nccl_communicator_config_path: Optional[str] = None, -) -> None: - """Initialize model data parallel groups. - - Arguments: - tensor_model_parallel_size (int, default = 1): - The number of GPUs to split individual tensors across. - - pipeline_model_parallel_size (int, default = 1): - The number of tensor parallel GPU groups to split the - Transformer layers across. For example, if - tensor_model_parallel_size is 4 and - pipeline_model_parallel_size is 2, the model will be split - into 2 groups of 4 GPUs. - - virtual_pipeline_model_parallel_size (int, optional): - The number of stages that each pipeline group will have, - interleaving as necessary. If None, no interleaving is - performed. For example, if tensor_model_parallel_size is 1, - pipeline_model_parallel_size is 4, - virtual_pipeline_model_parallel_size is 2, and there are - 16 transformer layers in the model, the model will be - split into 8 stages with two layers each and each GPU - would get 2 stages as such (layer number starting with 1): - - GPU 0: [1, 2] [9, 10] - GPU 1: [3, 4] [11, 12] - GPU 2: [5, 6] [13, 14] - GPU 3: [7, 8] [15, 16] - - pipeline_model_parallel_split_rank (int, optional): - For models with both an encoder and decoder, the rank in - pipeline to switch between encoder and decoder (i.e. the - first rank of the decoder). This allows the user to set - the pipeline parallel size of the encoder and decoder - independently. For example, if - pipeline_model_parallel_size is 8 and - pipeline_model_parallel_split_rank is 3, then ranks 0-2 - will be the encoder and ranks 3-7 will be the decoder. - - use_sharp (bool, default = False): - Set the use of SHARP for the collective communications of - data-parallel process groups. When `True`, run barrier - within each data-parallel process group, which specifies - the SHARP application target groups. - - context_parallel_size (int, default = 1): - The number of tensor parallel GPU groups to split the - network input sequence length across. Compute of attention - module requires tokens of full sequence length, so GPUs - in a context parallel group need to communicate with each - other to exchange information of other sequence chunks. - Each GPU and its counterparts in other tensor parallel - groups compose a context parallel group. - - For example, assume we have 8 GPUs, if tensor model parallel - size is 4 and context parallel size is 2, the network input - will be split into two sequence chunks, which are processed - by 2 different groups of 4 GPUs. One chunk is processed by - GPU0-3, the other chunk is processed by GPU4-7. Four groups - are build to do context parallel communications: [GPU0, GPU4], - [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7]. - - Context parallelism partitions sequence length, so it has no - impact on weights, which means weights are duplicated among - GPUs in a context parallel group. Hence, weight gradients - all-reduce is required in backward. For simplicity, we piggyback - GPUs of context parallelism on data parallel group for - weight gradient all-reduce. - - nccl_communicator_config_path (str, default = None): - Path to the yaml file of NCCL communicator configurations. - `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set - for each communicator. - - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 8 tensor model-parallel groups, 4 pipeline model-parallel groups - and 8 data-parallel groups as: - 8 data_parallel groups: - [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] - 8 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] - 4 pipeline model-parallel groups: - [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - - if ( - world_size - % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) - != 0 - ): - raise RuntimeError( - f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " - f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) " - f"x context_parallel_size ({context_parallel_size})" - ) - - data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size - ) - - if data_parallel_size % expert_model_parallel_size != 0: - raise RuntimeError( - f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size " - ) - - if expert_model_parallel_size > 1 and context_parallel_size > 1: - raise RuntimeError( - f"combination of expert model prallellism and context parallelism is not supported" - ) - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - - if virtual_pipeline_model_parallel_size is not None: - if not pipeline_model_parallel_size > 2: - raise RuntimeError( - "pipeline-model-parallel size should be greater than 2 with interleaved schedule" - ) - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size - - if pipeline_model_parallel_split_rank is not None: - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank - - rank = torch.distributed.get_rank() - - nccl_comm_cfgs = {} - if nccl_communicator_config_path is not None: - try: - import yaml - except ImportError: - raise RuntimeError( - "Cannot import `yaml`. Setting custom nccl communicator configs " - "requires the yaml package." - ) - - with open(nccl_communicator_config_path, "r") as stream: - nccl_comm_cfgs = yaml.safe_load(stream) - - # Build the data-parallel groups. - global _DATA_PARALLEL_GROUP - global _DATA_PARALLEL_GROUP_GLOO - global _DATA_PARALLEL_GLOBAL_RANKS - global _DATA_PARALLEL_GROUP_WITH_CP - global _DATA_PARALLEL_GROUP_WITH_CP_GLOO - global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP - assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized' - all_data_parallel_group_ranks_with_cp = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(context_parallel_size * tensor_model_parallel_size): - ranks = range( - start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size - ) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('dp', nccl_comm_cfgs) - ) - group_gloo = torch.distributed.new_group(ranks, backend="gloo") - if rank in ranks: - _DATA_PARALLEL_GROUP = group - _DATA_PARALLEL_GROUP_GLOO = group_gloo - _DATA_PARALLEL_GLOBAL_RANKS = ranks - for j in range(tensor_model_parallel_size): - ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) - group_with_cp = torch.distributed.new_group( - ranks_with_cp, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs) - ) - group_with_cp_gloo = torch.distributed.new_group(ranks_with_cp, backend="gloo") - if rank in ranks_with_cp: - _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp - _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo - _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp - - # Apply SHARP to DP process groups - if use_sharp: - if rank == 0: - print( - "The number of process groups to use SHARP with depends on the type " - "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 " - "process groups and QM2 supports up to 256 process groups. We apply " - "SHARP to the communications of the data-parallel domain. If the " - "number of data-parallel process groups is larger than the max " - "process groups that the network switch supports, the communication " - "will fall back to non-SHARP operators. To enable SHARP, " - "`#SBATCH_NETWORK=sharp` should be set in the sbatch script." - ) - torch.distributed.barrier( - group=get_data_parallel_group(with_context_parallel=context_parallel_size > 1), - device_ids=[torch.cuda.current_device()], - ) - # Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups - os.environ["NCCL_SHARP_DISABLE"] = "1" - - # Build the context-parallel groups. - global _CONTEXT_PARALLEL_GROUP - global _CONTEXT_PARALLEL_GLOBAL_RANKS - assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized' - for i in range(pipeline_model_parallel_size): - for j in range(data_parallel_size): - start_rank = ( - i * num_pipeline_model_parallel_groups - + j * tensor_model_parallel_size * context_parallel_size - ) - end_rank = ( - i * num_pipeline_model_parallel_groups - + (j + 1) * tensor_model_parallel_size * context_parallel_size - ) - for k in range(tensor_model_parallel_size): - ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('cp', nccl_comm_cfgs) - ) - if rank in ranks: - _CONTEXT_PARALLEL_GROUP = group - _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks - - # Build the model-parallel groups. - global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' - for i in range(data_parallel_size * context_parallel_size): - ranks = [ - data_parallel_group_ranks_with_cp[i] - for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp - ] - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('mp', nccl_comm_cfgs) - ) - if rank in ranks: - _MODEL_PARALLEL_GROUP = group - - # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert ( - _TENSOR_MODEL_PARALLEL_GROUP is None - ), 'tensor model parallel group is already initialized' - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('tp', nccl_comm_cfgs) - ) - if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group - - # Build the pipeline model-parallel groups and embedding groups - # (first and last rank in each pipeline model-parallel group). - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert ( - _PIPELINE_MODEL_PARALLEL_GROUP is None - ), 'pipeline model parallel group is already initialized' - global _EMBEDDING_GROUP - global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' - global _POSITION_EMBEDDING_GROUP - global _POSITION_EMBEDDING_GLOBAL_RANKS - assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('pp', nccl_comm_cfgs) - ) - if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] - position_embedding_ranks = [ranks[0]] - if pipeline_model_parallel_split_rank is not None: - if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: - embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank], - ranks[-1], - ] - if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: - position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] - else: - embedding_ranks = ranks - position_embedding_ranks = ranks - - group = torch.distributed.new_group( - embedding_ranks, pg_options=get_nccl_options('embd', nccl_comm_cfgs) - ) - if rank in embedding_ranks: - _EMBEDDING_GROUP = group - if rank in ranks: - _EMBEDDING_GLOBAL_RANKS = embedding_ranks - - group = torch.distributed.new_group( - position_embedding_ranks, pg_options=get_nccl_options('embd', nccl_comm_cfgs) - ) - if rank in position_embedding_ranks: - _POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - - # Build the tensor + data parallel groups. - global _TENSOR_AND_DATA_PARALLEL_GROUP - global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP - assert ( - _TENSOR_AND_DATA_PARALLEL_GROUP is None - ), 'Tensor + data parallel group is already initialized' - tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * data_parallel_size * context_parallel_size - num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp - for i in range(num_tensor_and_data_groups_with_cp): - start_rank = i * tensor_and_data_group_size_with_cp - end_rank = start_rank + tensor_and_data_group_size_with_cp - ranks = range(start_rank, end_rank) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs) - ) - if rank in ranks: - _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group - - for j in range(context_parallel_size): - ranks = [] - for k in range(data_parallel_size): - start_rank = ( - i * tensor_and_data_group_size_with_cp - + j * tensor_model_parallel_size - + k * tensor_model_parallel_size * context_parallel_size - ) - end_rank = start_rank + tensor_model_parallel_size - ranks = ranks + list(range(start_rank, end_rank)) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs) - ) - if rank in ranks: - _TENSOR_AND_DATA_PARALLEL_GROUP = group - - # Build the tensor + expert parallel groups - global _TENSOR_AND_EXPERT_PARALLEL_GROUP - assert ( - _TENSOR_AND_EXPERT_PARALLEL_GROUP is None - ), 'Tensor + expert parallel group is already initialized' - global _DATA_MODULO_EXPERT_PARALLEL_GROUP - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP is None - ), 'Data modulo expert group is already initialized' - tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size - num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size - tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size - num_expert_groups: int = data_parallel_size // expert_model_parallel_size - for i in range(num_tensor_and_data_groups): - for j in range(num_expert_groups): - start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size - end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size - ranks = range(start_rank, end_rank) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs) - ) - if rank in ranks: - _TENSOR_AND_EXPERT_PARALLEL_GROUP = group - - for i in range(num_tensor_and_data_groups): - start_rank = i * tensor_and_data_group_size - end_rank = (i + 1) * tensor_and_data_group_size - for j in range(tensor_and_expert_group_size): - ranks = range(start_rank + j, end_rank, tensor_and_expert_group_size) - group = torch.distributed.new_group( - ranks, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs) - ) - if rank in ranks: - _DATA_MODULO_EXPERT_PARALLEL_GROUP = group - - # Initialize global memory buffer - # This isn't really "parallel state" but there isn't another good place to - # put this. If we end up with a more generic initialization of megatron-core - # we could stick it there - _set_global_memory_buffer() - - -def is_unitialized(): - """Useful for code segments that may be accessed with or without mpu initialization""" - return _DATA_PARALLEL_GROUP is None - - -def model_parallel_is_initialized(): - """Check if model and data parallel groups are initialized.""" - if ( - _TENSOR_MODEL_PARALLEL_GROUP is None - or _PIPELINE_MODEL_PARALLEL_GROUP is None - or _DATA_PARALLEL_GROUP is None - ): - return False - return True - - -def get_model_parallel_group(): - """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized' - return _MODEL_PARALLEL_GROUP - - -def get_tensor_model_parallel_group(check_initialized=True): - """Get the tensor model parallel group the caller rank belongs to.""" - if check_initialized: - assert ( - _TENSOR_MODEL_PARALLEL_GROUP is not None - ), 'tensor model parallel group is not initialized' - return _TENSOR_MODEL_PARALLEL_GROUP - - -def get_pipeline_model_parallel_group(): - """Get the pipeline model parallel group the caller rank belongs to.""" - assert ( - _PIPELINE_MODEL_PARALLEL_GROUP is not None - ), 'pipeline_model parallel group is not initialized' - return _PIPELINE_MODEL_PARALLEL_GROUP - - -def get_data_parallel_group(with_context_parallel=False): - """Get the data parallel group the caller rank belongs to.""" - if with_context_parallel: - assert ( - _DATA_PARALLEL_GROUP_WITH_CP is not None - ), 'data parallel group with context parallel combined is not initialized' - return _DATA_PARALLEL_GROUP_WITH_CP - else: - assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized' - return _DATA_PARALLEL_GROUP - - -def get_data_parallel_group_gloo(with_context_parallel=False): - """Get the data parallel group-gloo the caller rank belongs to.""" - if with_context_parallel: - assert ( - _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None - ), 'data parallel group-gloo with context parallel combined is not initialized' - return _DATA_PARALLEL_GROUP_WITH_CP_GLOO - else: - assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized' - return _DATA_PARALLEL_GROUP_GLOO - - -def get_context_parallel_group(check_initialized=True): - """Get the context parallel group the caller rank belongs to.""" - if check_initialized: - assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized' - return _CONTEXT_PARALLEL_GROUP - - -def get_context_parallel_global_ranks(check_initialized=True): - """Get all global ranks of the context parallel group that the caller rank belongs to.""" - if check_initialized: - assert ( - _CONTEXT_PARALLEL_GLOBAL_RANKS is not None - ), 'context parallel group is not initialized' - return _CONTEXT_PARALLEL_GLOBAL_RANKS - - -def get_embedding_group(): - """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized' - return _EMBEDDING_GROUP - - -def get_position_embedding_group(): - """Get the position embedding group the caller rank belongs to.""" - assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized' - return _POSITION_EMBEDDING_GROUP - - -def get_amax_reduction_group(with_context_parallel=False): - """Get the FP8 amax reduction group the caller rank belongs to.""" - if with_context_parallel: - assert ( - _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None - ), 'FP8 amax reduction group is not initialized' - return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP - else: - assert ( - _TENSOR_AND_DATA_PARALLEL_GROUP is not None - ), 'FP8 amax reduction group is not initialized' - return _TENSOR_AND_DATA_PARALLEL_GROUP - - -def get_tensor_and_data_parallel_group(with_context_parallel=False): - """Get the tensor and data parallel group the caller rank belongs to.""" - if with_context_parallel: - assert ( - _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None - ), 'tensor and data parallel group is not initialized' - return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP - else: - assert ( - _TENSOR_AND_DATA_PARALLEL_GROUP is not None - ), 'tensor and data parallel group is not initialized' - return _TENSOR_AND_DATA_PARALLEL_GROUP - - -def get_tensor_and_expert_parallel_group(): - assert ( - _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None - ), 'tensor and expert parallel group is not initialized' - return _TENSOR_AND_EXPERT_PARALLEL_GROUP - - -def get_data_modulo_expert_parallel_group(): - assert ( - _DATA_MODULO_EXPERT_PARALLEL_GROUP is not None - ), 'data modulo expert parallel group is not initialized' - return _DATA_MODULO_EXPERT_PARALLEL_GROUP - - -def set_tensor_model_parallel_world_size(world_size): - """Set the tensor model parallel size""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def set_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def set_virtual_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) - - -def set_tensor_model_parallel_rank(rank): - """Set tensor model parallel rank.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_rank(rank): - """Set pipeline model parallel rank.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_split_rank(rank): - """Set pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: - return _MPU_TENSOR_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) - - -def get_pipeline_model_parallel_split_rank(): - """Return pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - - -def is_pipeline_first_stage(ignore_virtual=False): - """Return True if in the first pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - if ( - get_virtual_pipeline_model_parallel_world_size() is not None - and get_virtual_pipeline_model_parallel_rank() != 0 - ): - return False - return get_pipeline_model_parallel_rank() == 0 - - -def is_pipeline_last_stage(ignore_virtual=False): - """Return True if in the last pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - virtual_pipeline_model_parallel_world_size = ( - get_virtual_pipeline_model_parallel_world_size() - ) - if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( - virtual_pipeline_model_parallel_world_size - 1 - ): - return False - return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) - - -def is_rank_in_embedding_group(ignore_virtual=False): - """Return true if current rank is in embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _EMBEDDING_GLOBAL_RANKS - if ignore_virtual: - return rank in _EMBEDDING_GLOBAL_RANKS - if rank in _EMBEDDING_GLOBAL_RANKS: - if rank == _EMBEDDING_GLOBAL_RANKS[0]: - return is_pipeline_first_stage(ignore_virtual=False) - elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: - return is_pipeline_last_stage(ignore_virtual=False) - else: - return True - return False - - -def is_rank_in_position_embedding_group(): - """Return true if current rank is in position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _POSITION_EMBEDDING_GLOBAL_RANKS - - -def is_pipeline_stage_before_split(rank=None): - """Return True if pipeline stage executes encoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_after_split(rank=None): - """Return True if pipeline stage executes decoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_at_split(): - """Return true if pipeline stage executes decoder block and next - stage executes encoder block for a model with both encoder and - decoder.""" - rank = get_pipeline_model_parallel_rank() - return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1) - - -def get_virtual_pipeline_model_parallel_rank(): - """Return the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - - -def set_virtual_pipeline_model_parallel_rank(rank): - """Set the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_virtual_pipeline_model_parallel_world_size(): - """Return the virtual pipeline-parallel world size.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size - - -def get_data_parallel_src_rank(with_context_parallel=False): - """Calculate the global rank corresponding to the first local rank - in the data parallel group.""" - if with_context_parallel: - assert ( - _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None - ), "Data parallel group with context parallel combined is not initialized" - return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0] - else: - assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized" - return _DATA_PARALLEL_GLOBAL_RANKS[0] - - -def get_pipeline_model_parallel_first_rank(): - """Return the global rank of the first process in the pipeline for the - current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - return _PIPELINE_GLOBAL_RANKS[0] - - -def get_pipeline_model_parallel_last_rank(): - """Return the global rank of the last process in the pipeline for the - current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] - - -def get_pipeline_model_parallel_next_rank(): - """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] - - -def get_pipeline_model_parallel_prev_rank(): - """Return the global rank that preceeds the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] - - -def get_data_parallel_world_size(with_context_parallel=False): - """Return world size for the data parallel group.""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_world_size( - group=get_data_parallel_group(with_context_parallel=with_context_parallel) - ) - else: - return 0 - - -def get_data_parallel_rank(with_context_parallel=False): - """Return my rank for the data parallel group.""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank( - group=get_data_parallel_group(with_context_parallel=with_context_parallel) - ) - else: - return 0 - - -def get_context_parallel_world_size(): - """Return world size for the context parallel group.""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_world_size(group=get_context_parallel_group()) - else: - return 0 - - -def get_context_parallel_rank(): - """Return my rank for the context parallel group.""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank(group=get_context_parallel_group()) - else: - return 0 - - -def get_expert_model_parallel_world_size(): - """Return my rank for the expert parallel group""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - tensor_and_expert_parallel_world_size = torch.distributed.get_world_size( - group=get_tensor_and_expert_parallel_group() - ) - return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size() - else: - return 0 - - -def get_expert_model_parallel_rank(): - """Return my rank for the expert parallel group""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - tensor_and_expert_parallel_rank = torch.distributed.get_rank( - group=get_tensor_and_expert_parallel_group() - ) - return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size() - else: - return 0 - - -def get_data_modulo_expert_parallel_rank(): - """Return my rank for the context parallel group.""" - if torch.distributed.is_available() and torch.distributed.is_initialized(): - return torch.distributed.get_rank(group=get_data_modulo_expert_parallel_group()) - else: - return 0 - - -def _set_global_memory_buffer(): - """Initialize global buffer""" - global _GLOBAL_MEMORY_BUFFER - assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' - _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() - - -def get_global_memory_buffer(): - """Return the global GlobalMemoryBuffer object""" - assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' - return _GLOBAL_MEMORY_BUFFER - - -def destroy_global_memory_buffer(): - """Sets the global memory buffer to None""" - global _GLOBAL_MEMORY_BUFFER - _GLOBAL_MEMORY_BUFFER = None - - -def destroy_model_parallel(): - """Set the groups to none.""" - global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None - global _TENSOR_MODEL_PARALLEL_GROUP - _TENSOR_MODEL_PARALLEL_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _DATA_PARALLEL_GROUP - _DATA_PARALLEL_GROUP = None - global _DATA_PARALLEL_GROUP_WITH_CP - _DATA_PARALLEL_GROUP_WITH_CP = None - global _CONTEXT_PARALLEL_GROUP - _CONTEXT_PARALLEL_GROUP = None - global _CONTEXT_PARALLEL_GLOBAL_RANKS - _CONTEXT_PARALLEL_GLOBAL_RANKS = None - global _EMBEDDING_GROUP - _EMBEDDING_GROUP = None - global _POSITION_EMBEDDING_GROUP - _POSITION_EMBEDDING_GROUP = None - global _TENSOR_AND_DATA_PARALLEL_GROUP - _TENSOR_AND_DATA_PARALLEL_GROUP = None - global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP - _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None - global _TENSOR_AND_EXPERT_PARALLEL_GROUP - _TENSOR_AND_EXPERT_PARALLEL_GROUP = None - global _DATA_MODULO_EXPERT_PARALLEL_GROUP - _DATA_MODULO_EXPERT_PARALLEL_GROUP = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = None - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = None - global _GLOBAL_MEMORY_BUFFER - _GLOBAL_MEMORY_BUFFER = None diff --git a/megatron/core/pipeline_parallel/__init__.py b/megatron/core/pipeline_parallel/__init__.py deleted file mode 100644 index 00cd1ff3826564f9eef6cd9b023c0dd331b5d691..0000000000000000000000000000000000000000 --- a/megatron/core/pipeline_parallel/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .schedules import get_forward_backward_func diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py deleted file mode 100644 index 29ee34df8cff6a2f0613c0ad9bca175d71e30ba2..0000000000000000000000000000000000000000 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import operator -from functools import reduce -from typing import Callable, List, Optional, Tuple, Union - -import torch - -from megatron import core -from megatron.core import ModelParallelConfig -from megatron.core.parallel_state import ( - get_pipeline_model_parallel_group, - get_pipeline_model_parallel_next_rank, - get_pipeline_model_parallel_prev_rank, - get_pipeline_model_parallel_rank, -) - -# Types -Shape = Union[List[int], torch.Size] - - -def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config): - """Communicate tensor shapes between stages. Used to communicate - tensor shapes before the actual tensor communication happens. - This is required when the sequence lengths across micro batches - are not uniform. - - Takes the following arguments: - tensor_send_next: tensor to send to next rank (no tensor sent if - set to None). - tensor_send_prev: tensor to send to prev rank (no tensor sent if - set to None). - recv_prev: boolean for whether tensor should be received from - previous rank. - recv_next: boolean for whether tensor should be received from - next rank. - Returns: - (recv_prev_shape, recv_next_shape) - """ - - recv_prev_shape_tensor = None - recv_next_shape_tensor = None - send_prev_shape_tensor = None - send_next_shape_tensor = None - if recv_prev: - recv_prev_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if recv_next: - recv_next_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if tensor_send_prev is not None: - send_prev_shape_tensor = torch.tensor( - tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if tensor_send_next is not None: - send_next_shape_tensor = torch.tensor( - tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64 - ) - - if config.use_ring_exchange_p2p: - torch.distributed.ring_exchange( - tensor_send_prev=send_prev_shape_tensor, - tensor_recv_prev=recv_prev_shape_tensor, - tensor_send_next=send_next_shape_tensor, - tensor_recv_next=recv_next_shape_tensor, - group=get_pipeline_model_parallel_group(), - ) - else: - ops = [] - if send_prev_shape_tensor is not None: - send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, - send_prev_shape_tensor, - get_pipeline_model_parallel_prev_rank(), - ) - ops.append(send_prev_op) - if recv_prev_shape_tensor is not None: - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, - recv_prev_shape_tensor, - get_pipeline_model_parallel_prev_rank(), - ) - ops.append(recv_prev_op) - if send_next_shape_tensor is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, - send_next_shape_tensor, - get_pipeline_model_parallel_next_rank(), - ) - ops.append(send_next_op) - if recv_next_shape_tensor is not None: - recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, - recv_next_shape_tensor, - get_pipeline_model_parallel_next_rank(), - ) - ops.append(recv_next_op) - if len(ops) > 0: - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # To protect against race condition when using batch_isend_irecv(). - # should take this out once the bug with batch_isend_irecv is resolved. - torch.cuda.synchronize() - - recv_prev_shape = [0, 0, 0] - if recv_prev_shape_tensor is not None: - recv_prev_shape = recv_prev_shape_tensor.tolist() - - recv_next_shape = [0, 0, 0] - if recv_next_shape_tensor is not None: - recv_next_shape = recv_next_shape_tensor.tolist() - - return recv_prev_shape, recv_next_shape - - -def _batched_p2p_ops( - *, - tensor_send_prev: Optional[torch.Tensor], - tensor_recv_prev: Optional[torch.Tensor], - tensor_send_next: Optional[torch.Tensor], - tensor_recv_next: Optional[torch.Tensor], - group: torch.distributed.ProcessGroup -): - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor_send_prev, - get_pipeline_model_parallel_prev_rank(), - group, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_recv_prev, - get_pipeline_model_parallel_prev_rank(), - group, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor_send_next, - get_pipeline_model_parallel_next_rank(), - group, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_recv_next, - get_pipeline_model_parallel_next_rank(), - group, - ) - ops.append(recv_next_op) - if len(ops) > 0: - reqs = torch.distributed.batch_isend_irecv(ops) - else: - reqs = [] - return reqs - - -def _p2p_ops( - *, - tensor_send_prev: Optional[torch.Tensor], - tensor_recv_prev: Optional[torch.Tensor], - tensor_send_next: Optional[torch.Tensor], - tensor_recv_next: Optional[torch.Tensor], - group: torch.distributed.ProcessGroup -): - reqs = [] - rank = get_pipeline_model_parallel_rank() - if get_pipeline_model_parallel_rank() % 2 == 0: - if tensor_send_next is not None: - send_next_req = torch.distributed.isend( - tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, - ) - reqs.append(send_next_req) - - if tensor_recv_prev is not None: - recv_prev_req = torch.distributed.irecv( - tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, - ) - reqs.append(recv_prev_req) - - if tensor_send_prev is not None: - send_prev_req = torch.distributed.isend( - tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, - ) - reqs.append(send_prev_req) - - if tensor_recv_next is not None: - recv_next_req = torch.distributed.irecv( - tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, - ) - reqs.append(recv_next_req) - - else: - if tensor_recv_prev is not None: - recv_prev_req = torch.distributed.irecv( - tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, - ) - reqs.append(recv_prev_req) - - if tensor_send_next is not None: - send_next_req = torch.distributed.isend( - tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, - ) - reqs.append(send_next_req) - - if tensor_recv_next is not None: - recv_next_req = torch.distributed.irecv( - tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, - ) - reqs.append(recv_next_req) - - if tensor_send_prev is not None: - send_prev_req = torch.distributed.isend( - tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, - ) - reqs.append(send_prev_req) - return reqs - - -def _communicate( - *, - tensor_send_next: Optional[torch.Tensor], - tensor_send_prev: Optional[torch.Tensor], - recv_prev: bool, - recv_next: bool, - tensor_shape: Shape, - config: ModelParallelConfig, - wait_on_reqs: bool = True -) -> Tuple[torch.Tensor, torch.Tensor]: - """Communicate tensors between stages. Used as helper method in other - communication methods that are used in megatron/schedules.py. - - Arguments: - tensor_send_next (torch.Tensor, optional): - Tensor to send to next rank (no tensor sent if None) - - tensor_send_prev (torch.Tensor, optional): - Tensor to send to prev rank (no tensor sent if None) - - recv_prev (boolean, required): - whether tensor should be received from previous rank. - - recv_next (boolean, required): - whether tensor should be received from next rank. - - tensor_shape (List[int] or torch.Size, required): - shape of tensor to receive (this method assumes that all - tensors sent and received in a single function call are - the same shape). - - wait_on_reqs (boolean, optional, default=False): - For non-batched p2p communication, wait on each request - before returning. - - Returns: - tuple containing - - - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise. - - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise. - - """ - - # Create placeholder tensors for receive in forward and backward directions - # if needed. - tensor_recv_prev = None - tensor_recv_next = None - - if not config.variable_seq_lengths: - recv_prev_shape = tensor_shape - recv_next_shape = tensor_shape - else: - recv_prev_shape, recv_next_shape = _communicate_shapes( - tensor_send_next, tensor_send_prev, recv_prev, recv_next, config - ) - - if recv_prev: - if config.pipeline_dtype is None: - raise RuntimeError("pipeline_dtype must be provided if recv_prev is True") - if tensor_shape is None: - raise RuntimeError( - "tensor_shape must be specified if recv_prev is True. " - "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" - ) - tensor_recv_prev = torch.empty( - recv_prev_shape, - requires_grad=True, - device=torch.cuda.current_device(), - dtype=config.pipeline_dtype, - ) - if recv_next: - if config.pipeline_dtype is None: - raise RuntimeError("dtype must be provided if recv_next is True") - if tensor_shape is None: - raise RuntimeError( - "tensor_shape must be specified if recv_next is True. " - "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)" - ) - tensor_recv_next = torch.empty( - recv_next_shape, - requires_grad=True, - device=torch.cuda.current_device(), - dtype=config.pipeline_dtype, - ) - - # Send tensors in both the forward and backward directions as appropriate. - if config.use_ring_exchange_p2p: - - def _ring_exchange_wrapper(**kwargs): - torch.distributed.ring_exchange(**kwargs) - return [] - - p2p_func = _ring_exchange_wrapper - elif config.batch_p2p_comm: - assert wait_on_reqs - p2p_func = _batched_p2p_ops - else: - p2p_func = _p2p_ops - - reqs = p2p_func( - tensor_send_prev=tensor_send_prev, - tensor_recv_prev=tensor_recv_prev, - tensor_send_next=tensor_send_next, - tensor_recv_next=tensor_recv_next, - group=get_pipeline_model_parallel_group(), - ) - - if wait_on_reqs and len(reqs) > 0: - for req in reqs: - req.wait() - reqs = None - - if config.batch_p2p_comm and config.batch_p2p_sync: - # To protect against race condition when using batch_isend_irecv(). - # User should assert that we have a modern enough PyTorch to not need this - torch.cuda.synchronize() - - return tensor_recv_prev, tensor_recv_next, reqs - - -def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: - """ Receive tensor from previous rank in pipeline (forward receive). - - - See _communicate for argument details. - """ - - if core.parallel_state.is_pipeline_first_stage(): - input_tensor = None - else: - if config.timers is not None: - config.timers('forward-recv', log_level=2).start() - input_tensor, _, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - config=config, - ) - if config.timers is not None: - config.timers('forward-recv').stop() - return input_tensor - - -def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor: - """Receive tensor from next rank in pipeline (backward receive). - - See _communicate for argument details. - """ - if core.parallel_state.is_pipeline_last_stage(): - output_tensor_grad = None - else: - if config.timers is not None: - config.timers('backward-recv', log_level=2).start() - _, output_tensor_grad, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - config=config, - ) - if config.timers is not None: - config.timers('backward-recv').stop() - return output_tensor_grad - - -def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None: - """Send tensor to next rank in pipeline (forward send). - - See _communicate for argument details. - """ - - if not core.parallel_state.is_pipeline_last_stage(): - if config.timers is not None: - config.timers('forward-send', log_level=2).start() - _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, - tensor_shape=None, - config=config, - ) - if config.timers is not None: - config.timers('forward-send').stop() - - -def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None: - """Send tensor to previous rank in pipeline (backward send). - - See _communicate for argument details. - """ - if not core.parallel_state.is_pipeline_first_stage(): - if config.timers is not None: - config.timers('backward-send', log_level=2).start() - _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False, - tensor_shape=None, - config=config, - ) - if config.timers is not None: - config.timers('backward-send').stop() - - -def send_forward_recv_backward( - output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig -) -> torch.Tensor: - """Batched send and recv with next rank in pipeline. - - See _communicate for argument details. - """ - if core.parallel_state.is_pipeline_last_stage(): - output_tensor_grad = None - else: - if config.timers is not None: - config.timers('forward-send-backward-recv', log_level=2).start() - _, output_tensor_grad, _ = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - config=config, - ) - if config.timers is not None: - config.timers('forward-send-backward-recv').stop() - return output_tensor_grad - - -def send_backward_recv_forward( - input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig -) -> torch.Tensor: - """Batched send and recv with previous rank in pipeline. - - See _communicate for argument details. - """ - if core.parallel_state.is_pipeline_first_stage(): - input_tensor = None - else: - if config.timers is not None: - config.timers('backward-send-forward-recv', log_level=2).start() - input_tensor, _, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - config=config, - ) - if config.timers is not None: - config.timers('backward-send-forward-recv').stop() - return input_tensor - - -def send_forward_recv_forward( - output_tensor: torch.Tensor, - recv_prev: bool, - tensor_shape: Shape, - config: ModelParallelConfig, - overlap_p2p_comm: bool = False, -) -> torch.Tensor: - """Batched recv from previous rank and send to next rank in pipeline. - - See _communicate for argument details. - """ - if config.timers is not None: - config.timers('forward-send-forward-recv', log_level=2).start() - input_tensor, _, wait_handles = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - tensor_shape=tensor_shape, - wait_on_reqs=(not overlap_p2p_comm), - config=config, - ) - if config.timers is not None: - config.timers('forward-send-forward-recv').stop() - if overlap_p2p_comm: - return input_tensor, wait_handles - return input_tensor - - -def send_backward_recv_backward( - input_tensor_grad: torch.Tensor, - recv_next: bool, - tensor_shape: Shape, - config: ModelParallelConfig, - overlap_p2p_comm: bool = False, -) -> torch.Tensor: - """Batched recv from next rank and send to previous rank in pipeline. - - See _communicate for argument details. - """ - if config.timers is not None: - config.timers('backward-send-backward-recv', log_level=2).start() - _, output_tensor_grad, wait_handles = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - tensor_shape=tensor_shape, - wait_on_reqs=(not overlap_p2p_comm), - config=config, - ) - if config.timers is not None: - config.timers('backward-send-backward-recv').stop() - if overlap_p2p_comm: - return output_tensor_grad, wait_handles - return output_tensor_grad - - -def send_forward_backward_recv_forward_backward( - output_tensor: torch.Tensor, - input_tensor_grad: torch.Tensor, - recv_prev: bool, - recv_next: bool, - tensor_shape: Shape, - config: ModelParallelConfig, -) -> torch.Tensor: - """Batched send and recv with previous and next ranks in pipeline. - - See _communicate for argument details. - """ - if config.timers is not None: - config.timers('forward-backward-send-forward-backward-recv', log_level=2).start() - input_tensor, output_tensor_grad, _ = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - ) - if config.timers is not None: - config.timers('forward-backward-send-forward-backward-recv').stop() - return input_tensor, output_tensor_grad diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py deleted file mode 100644 index 992da781271e22baefabe75c08a308aa9438f3b3..0000000000000000000000000000000000000000 --- a/megatron/core/pipeline_parallel/schedules.py +++ /dev/null @@ -1,1293 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import contextlib -from typing import Callable, Iterator, List, Optional, Union - -import torch -from torch.autograd.variable import Variable - -from megatron.core import parallel_state -from megatron.core.enums import ModelType -from megatron.core.pipeline_parallel import p2p_communication -from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type - -# Types -Shape = Union[List[int], torch.Size] - - -def get_forward_backward_func(): - """Retrieves the appropriate forward_backward function given the - configuration of parallel_state. - - Returns a function that will perform all of the forward and - backward passes of the model given the pipeline model parallel - world size and virtual pipeline model parallel world size in the - global parallel_state. - - Note that if using sequence parallelism, the sequence length component of - the tensor shape is updated to original_sequence_length / - tensor_model_parallel_world_size. - - The function returned takes the following arguments: - - forward_step_func (required): A function that takes a data - iterator and a model as its arguments and return the model's - forward output and the loss function. The loss function should - take one torch.Tensor and return a torch.Tensor of loss and a - dictionary of string -> torch.Tensor. - - A third argument, checkpoint_activations_microbatch, indicates - that the activations for this microbatch should be - checkpointed. A None value for this argument indicates that - the default from the configuration should be used. This is - used when the - num_microbatches_with_partial_activation_checkpoints is used. - - For example: - - def loss_func(loss_mask, output_tensor): - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} - - def forward_step(data_iterator, model): - data, loss_mask = next(data_iterator) - output = model(data) - return output, partial(loss_func, loss_mask) - - - forward_backward_func(forward_step_func=forward_step, ...) - - - data_iterator (required): an iterator over the data, will be - passed as is to forward_step_func. Expected to be a list of - iterators in the case of interleaved pipeline parallelism. - - model (required): the actual model. Expected to be a list of modules in the case of interleaved - pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule. - - num_microbatches (int, required): - The number of microbatches to go through - - seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack - transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths - in the config is True. Otherwise, each microbatch in the current global batch size must use - this sequence length. - - micro_batch_size (int, required): The number of sequences in a microbatch. - - decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack - transformer. This is ignored for a single-stack transformer. - - forward_only (optional, default = False): Perform only the forward step - - collect_non_loss_data (optional, bool, default=False): TODO - - """ - pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() - if pipeline_model_parallel_size > 1: - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - forward_backward_func = forward_backward_pipelining_with_interleaving - else: - forward_backward_func = forward_backward_pipelining_without_interleaving - else: - forward_backward_func = forward_backward_no_pipelining - return forward_backward_func - - -def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): - '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - ''' - if (out is None) or (not deallocate_pipeline_outputs): - return - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) - - -def custom_backward(output, grad_output): - '''Directly call C++ autograd engine. - - To make the 'deallocate_output_tensor' (above) optimization work, the C++ - autograd engine must be called directly, bypassing Pytorch's - torch.autograd.backward. Pytorch's 'backward' checks that the output and - grad have the same shape, while C++'s 'backward' does not. - ''' - - assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory" - assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__ - assert isinstance(grad_output, (torch.Tensor, type(None))), ( - "grad_output == '%s'." % type(grad_output).__name__ - ) - - # Handle scalar output - if grad_output is None: - assert output.numel() == 1, "implicit grad requires scalar output." - grad_output = torch.ones_like(output, memory_format=torch.preserve_format,) - - # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] - Variable._execution_engine.run_backward( - tensors=(output,), - grad_tensors=(grad_output,), - keep_graph=False, - create_graph=False, - inputs=tuple(), - allow_unreachable=True, - accumulate_grad=True, - ) - - -def forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data=False, - checkpoint_activations_microbatch=None, -): - """Forward step for passed-in model. - - If first stage, input tensor is obtained from data_iterator, otherwise - passed-in input_tensor is used. - - Returns output tensor.""" - if config.timers is not None: - config.timers('forward-compute', log_level=2).start() - - unwrap_output_tensor = False - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - unwrap_output_tensor = True - - set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") - set_input_tensor(input_tensor) - - if config.enable_autocast: - context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) - else: - context_manager = contextlib.nullcontext() - with context_manager: - if checkpoint_activations_microbatch is None: - output_tensor, loss_func = forward_step_func(data_iterator, model) - else: - output_tensor, loss_func = forward_step_func( - data_iterator, model, checkpoint_activations_microbatch - ) - - if parallel_state.is_pipeline_last_stage(): - if not collect_non_loss_data: - output_tensor = loss_func(output_tensor) - loss, loss_reduced = output_tensor - output_tensor = loss / num_microbatches - forward_data_store.append(loss_reduced) - else: - data = loss_func(output_tensor, non_loss_data=True) - forward_data_store.append(data) - - if config.timers is not None: - config.timers('forward-compute').stop() - - # If T5 model (or other model with encoder and decoder) - # and in decoder stack, then send encoder_hidden_state - # downstream as well. - model_type = get_model_type(model) - if ( - parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - return [output_tensor, input_tensor[-1]] - if unwrap_output_tensor: - return output_tensor - return [output_tensor] - - -def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): - """Backward step through passed-in output tensor. - - If last stage, output_tensor_grad is None, otherwise gradient of loss - with respect to stage's output tensor. - - Returns gradient of loss with respect to input tensor (None if first - stage).""" - - # NOTE: This code currently can handle at most one skip connection. It - # needs to be modified slightly to support arbitrary numbers of skip - # connections. - - if config.timers is not None: - config.timers('backward-compute', log_level=2).start() - - # Retain the grad on the input_tensor. - unwrap_input_tensor_grad = False - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - unwrap_input_tensor_grad = True - for x in input_tensor: - if x is not None: - x.retain_grad() - - if not isinstance(output_tensor, list): - output_tensor = [output_tensor] - if not isinstance(output_tensor_grad, list): - output_tensor_grad = [output_tensor_grad] - - # Backward pass. - if output_tensor_grad[0] is None and config.grad_scale_func is not None: - output_tensor[0] = config.grad_scale_func(output_tensor[0]) - - if config.deallocate_pipeline_outputs: - custom_backward(output_tensor[0], output_tensor_grad[0]) - else: - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) - - # Collect the grad of the input_tensor. - input_tensor_grad = [None] - if input_tensor is not None: - input_tensor_grad = [] - for x in input_tensor: - if x is None: - input_tensor_grad.append(None) - else: - input_tensor_grad.append(x.grad) - - # Handle single skip connection if it exists (encoder_hidden_state in - # model with encoder and decoder). - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - if output_tensor_grad[1] is not None: - input_tensor_grad[-1].add_(output_tensor_grad[1]) - if unwrap_input_tensor_grad: - input_tensor_grad = input_tensor_grad[0] - - if config.timers is not None: - config.timers('backward-compute').stop() - - return input_tensor_grad - - -def forward_backward_no_pipelining( - *, - forward_step_func, - data_iterator: Union[Iterator, List[Iterator]], - model: Union[torch.nn.Module, List[torch.nn.Module]], - num_microbatches: int, - seq_length: int, # unused - micro_batch_size: int, # unused - decoder_seq_length: int = None, # unused - forward_only: bool = False, - collect_non_loss_data: bool = False, -): - """Run forward and backward passes with no pipeline parallelism - (no inter-stage communication). - - Returns dictionary with losses. - - - See get_forward_backward_func() for argument details - """ - - if isinstance(model, list): - assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking" - model = model[0] - if isinstance(data_iterator, list): - assert ( - len(data_iterator) == 1 - ), "non-pipeline-parallel schedule does not support model chunking" - data_iterator = data_iterator[0] - - config = get_model_config(model) - if config.timers is not None: - config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - - no_sync_func = config.no_sync_func - if no_sync_func is None: - no_sync_func = contextlib.nullcontext - - model_type = get_model_type(model) - - forward_data_store = [] - input_tensor, output_tensor_grad = None, None - with no_sync_func(): - for i in range(num_microbatches - 1): - output_tensor = forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - ) - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - # Run computation for last microbatch out of context handler (want to - # synchronize gradients). - output_tensor = forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - ) - - if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) - - if config.timers is not None: - config.timers('forward-backward').stop() - - if config.finalize_model_grads_func is not None and not forward_only: - # Finalize model grads (perform full grad all-reduce / reduce-scatter for - # data parallelism and layernorm all-reduce for sequence parallelism). - config.finalize_model_grads_func([model]) - - return forward_data_store - - -def forward_backward_pipelining_with_interleaving( - *, - forward_step_func, - data_iterator: Union[Iterator, List[Iterator]], - model: Union[torch.nn.Module, List[torch.nn.Module]], - num_microbatches: int, - seq_length: int, - micro_batch_size: int, - decoder_seq_length: int = None, - forward_only: bool = False, - collect_non_loss_data: bool = False, -): - """Run interleaved 1F1B schedule (model split into model chunks), with - communication between pipeline stages as needed. - - Returns dictionary with losses if the last stage, empty dict otherwise.""" - assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking" - assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking" - assert isinstance( - data_iterator, list - ), "interleaved pipeline parallelism expected each model chunk to have a data iterator" - - config = get_model_config(model[0]) - if config.overlap_p2p_comm and config.batch_p2p_comm: - raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") - - if config.timers is not None: - config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - - # Disable async grad reductions - no_sync_func = config.no_sync_func - if isinstance(no_sync_func, list): - - def multi_no_sync(): - stack = contextlib.ExitStack() - for model_chunk_no_sync_func in config.no_sync_func: - stack.enter_context(model_chunk_no_sync_func()) - return stack - - no_sync_func = multi_no_sync - if no_sync_func is None: - no_sync_func = contextlib.nullcontext - no_sync_context = None - - if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list): - config.grad_sync_func = [config.grad_sync_func for _ in model] - - if config.param_sync_func is not None and not isinstance(config.param_sync_func, list): - config.param_sync_func = [config.param_sync_func for _ in model] - - def disable_grad_sync(): - """Disable asynchronous grad reductions""" - nonlocal no_sync_context - if no_sync_context is None: - no_sync_context = no_sync_func() - no_sync_context.__enter__() - - def enable_grad_sync(): - """Enable asynchronous grad reductions""" - nonlocal no_sync_context - if no_sync_context is not None: - no_sync_context.__exit__(None, None, None) - no_sync_context = None - - disable_grad_sync() - - # Model chunk IDs with synchronized grads - synchronized_model_chunks = set() - - input_tensors = [[] for _ in range(len(model))] - output_tensors = [[] for _ in range(len(model))] - forward_data_store = [] - if not forward_only: - output_tensor_grads = [[] for _ in range(len(model))] - - pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() - pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() - - if num_microbatches % pipeline_parallel_size != 0: - msg = f'number of microbatches ({num_microbatches}) is not divisible by ' - msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) ' - msg += 'when using interleaved schedule' - raise RuntimeError(msg) - - model_type = get_model_type(model[0]) - if model_type == ModelType.encoder_and_decoder: - raise RuntimeError("Interleaving is not supported with an encoder and decoder model.") - - if decoder_seq_length is not None and decoder_seq_length != seq_length: - raise RuntimeError( - "Interleaving is not supported with a different decoder sequence length." - ) - - tensor_shape = [seq_length, micro_batch_size, config.hidden_size] - if config.sequence_parallel: - tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size() - - # Compute number of warmup and remaining microbatches. - num_model_chunks = len(model) - total_num_microbatches = num_microbatches * num_model_chunks - all_warmup_microbatches = False - if forward_only: - num_warmup_microbatches = total_num_microbatches - else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on - # all workers, followed by more microbatches after depending on - # stage ID (more forward passes for earlier stages, later stages can - # immediately start with 1F1B). - if num_microbatches == pipeline_parallel_size: - num_warmup_microbatches = total_num_microbatches - all_warmup_microbatches = True - else: - num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) - num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches - - # Checkpoint the activations of partial Transformer layers in a number of micro-batches - # within the maximum outstanding micro-batch backpropagations. - # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' - # checkpoint partial Transformer layers (or skip checkpointing) and - # the rest of micro-batches within a window of micro-batches checkpoint - # all Transformer layers. The window of micro-batches is set by the maximum - # outstanding backpropagations and becomes smaller at later pipeline stages. - # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf - max_outstanding_backprops = None - if config.num_microbatches_with_partial_activation_checkpoints is not None: - max_outstanding_backprops = num_warmup_microbatches + 1 - - # Synchronize params for first two model chunks - if config.param_sync_func is not None: - config.param_sync_func[0](model[0].parameters()) - config.param_sync_func[1](model[1].parameters()) - - def get_model_chunk_id(microbatch_id, forward): - """Helper method to get the model chunk ID given the iteration number.""" - microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the first for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == 0: - return microbatch_id_in_group % pipeline_parallel_size == 0 - else: - return False - - def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: - """Check if an iteration is the last for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == num_microbatch_groups - 1: - return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 - else: - return False - - def forward_step_helper(microbatch_id, checkpoint_activations_microbatch): - """Helper method to run forward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - forward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - # launch param synchronization for next model chunk - # Note: Asynchronous communication tends to slow down compute. - # To reduce idling from mismatched microbatch times, we launch - # asynchronous communication at the same time across the - # pipeline-parallel group. - if config.param_sync_func is not None: - param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank - if ( - param_sync_microbatch_id < total_num_microbatches - and is_first_microbatch_for_model_chunk(param_sync_microbatch_id) - ): - param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1 - if 1 < param_sync_chunk_id < num_model_chunks: - config.param_sync_func[param_sync_chunk_id]( - model[param_sync_chunk_id].parameters() - ) - - # forward step - if parallel_state.is_pipeline_first_stage(): - if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): - input_tensors[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = forward_step( - forward_step_func, - data_iterator[model_chunk_id], - model[model_chunk_id], - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - checkpoint_activations_microbatch, - ) - output_tensors[model_chunk_id].append(output_tensor) - - # if forward-only, no need to save tensors for a backward pass - if forward_only: - input_tensors[model_chunk_id].pop() - output_tensors[model_chunk_id].pop() - - return output_tensor - - def backward_step_helper(microbatch_id): - """Helper method to run backward step with model split into chunks - (run set_virtual_pipeline_model_parallel_rank() before calling - backward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - # launch grad synchronization (default) - if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id): - enable_grad_sync() - synchronized_model_chunks.add(model_chunk_id) - - if parallel_state.is_pipeline_last_stage(): - if len(output_tensor_grads[model_chunk_id]) == 0: - output_tensor_grads[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id].pop(0) - output_tensor = output_tensors[model_chunk_id].pop(0) - output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - # launch grad synchronization (custom grad sync) - # Note: Asynchronous communication tends to slow down compute. - # To reduce idling from mismatched microbatch times, we launch - # asynchronous communication at the same time across the - # pipeline-parallel group. - if config.grad_sync_func is not None: - grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank - if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( - grad_sync_microbatch_id - ): - grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False) - enable_grad_sync() - config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters()) - synchronized_model_chunks.add(grad_sync_chunk_id) - disable_grad_sync() - - return input_tensor_grad - - # Run warmup forward passes. - parallel_state.set_virtual_pipeline_model_parallel_rank(0) - input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) - - fwd_wait_handles = None - bwd_wait_handles = None - - for k in range(num_warmup_microbatches): - - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - k % max_outstanding_backprops - >= config.num_microbatches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_microbatch = None - - output_tensor = forward_step_helper(k, checkpoint_activations_microbatch) - - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (total_num_microbatches - 1): - recv_prev = False - - # Don't send tensor downstream if on last stage. - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - if not config.overlap_p2p_comm: - if ( - k == (num_warmup_microbatches - 1) - and not forward_only - and not all_warmup_microbatches - ): - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - ) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - else: - input_tensor = p2p_communication.send_forward_recv_forward( - output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config - ) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - else: - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, - ) - - if ( - k == (num_warmup_microbatches - 1) - and not forward_only - and not all_warmup_microbatches - ): - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - - ( - output_tensor_grad, - bwd_wait_handles, - ) = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, - ) - - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - - deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) - - # Run 1F1B in steady state. - for k in range(num_microbatches_remaining): - # Forward pass. - forward_k = k + num_warmup_microbatches - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - forward_k % max_outstanding_backprops - >= config.num_microbatches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_microbatch = None - - if config.overlap_p2p_comm: - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() - - deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) - - output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch) - - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - - # Last virtual stage no activation tensor to send - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Send activation tensor to the next stage and receive activation tensor from the - # previous stage - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, - ) - # assert fwd_wait_handles is not None - - if bwd_wait_handles is not None: - for req in bwd_wait_handles: - req.wait() - - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - - # First virtual stage no activation gradient tensor to send - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if the current virtual stage has an activation gradient tensor to receive - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, - ) - - else: # no p2p overlap - output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch) - - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. - - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Communicate tensors. - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - ) - deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) - - # Put input_tensor and output_tensor_grad in data structures in the - # right location. - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) - - deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) - - # Run cooldown backward passes (flush out pipeline). - if not forward_only: - if config.overlap_p2p_comm and bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() - - if all_warmup_microbatches: - output_tensor_grads[num_model_chunks - 1].append( - p2p_communication.recv_backward(tensor_shape, config=config) - ) - for k in range(num_microbatches_remaining, total_num_microbatches): - input_tensor_grad = backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False - if k == (total_num_microbatches - 1): - recv_next = False - output_tensor_grads[next_backward_model_chunk_id].append( - p2p_communication.send_backward_recv_backward( - input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config - ) - ) - - # Launch any remaining grad reductions. - enable_grad_sync() - if config.grad_sync_func is not None: - for model_chunk_id in range(num_model_chunks): - if model_chunk_id not in synchronized_model_chunks: - config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) - synchronized_model_chunks.add(model_chunk_id) - - if config.timers is not None: - config.timers('forward-backward').stop() - - if config.finalize_model_grads_func is not None and not forward_only: - # Finalize model grads (perform full grad all-reduce / reduce-scatter for - # data parallelism, layernorm all-reduce for sequence parallelism, and - # embedding all-reduce for pipeline parallelism). - config.finalize_model_grads_func(model) - - return forward_data_store - - -def get_tensor_shapes( - *, - rank: int, - model_type: ModelType, - seq_length: int, - micro_batch_size: int, - decoder_seq_length: int, - config, -): - # Determine right tensor sizes (based on position of rank with respect to split - # rank) and model size. - # Send two tensors if model is T5 and rank is in decoder stage: - # first tensor is decoder (pre-transpose), - # second tensor is encoder (post-transpose). - # If model is T5 and rank is at the boundary: - # send one tensor (post-transpose from encoder). - # Otherwise, send one tensor (pre-transpose). - tensor_shapes = [] - - if config.sequence_parallel: - seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size() - if model_type == ModelType.encoder_and_decoder: - decoder_seq_length = ( - decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size() - ) - - if model_type == ModelType.encoder_and_decoder: - if parallel_state.is_pipeline_stage_before_split(rank): - tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) - else: - tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size)) - tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) - else: - tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size)) - return tensor_shapes - - -def recv_forward(tensor_shapes, config): - input_tensors = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - input_tensors.append(None) - else: - input_tensors.append(p2p_communication.recv_forward(tensor_shape, config)) - return input_tensors - - -def recv_backward(tensor_shapes, config): - output_tensor_grads = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - output_tensor_grads.append(None) - else: - output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config)) - return output_tensor_grads - - -def send_forward(output_tensors, tensor_shapes, config): - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_forward(output_tensor, config) - - -def send_backward(input_tensor_grads, tensor_shapes, config): - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_backward(input_tensor_grad, config) - - -def send_forward_recv_backward(output_tensors, tensor_shapes, config): - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - output_tensor_grads = [] - for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - output_tensor_grads.append(None) - continue - output_tensor_grad = p2p_communication.send_forward_recv_backward( - output_tensor, tensor_shape, config - ) - output_tensor_grads.append(output_tensor_grad) - return output_tensor_grads - - -def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config): - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - input_tensors = [] - for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - input_tensors.append(None) - continue - input_tensor = p2p_communication.send_backward_recv_forward( - input_tensor_grad, tensor_shape, config - ) - input_tensors.append(input_tensor) - return input_tensors - - -def forward_backward_pipelining_without_interleaving( - *, - forward_step_func, - data_iterator: Union[Iterator, List[Iterator]], - model: Union[torch.nn.Module, List[torch.nn.Module]], - num_microbatches: int, - seq_length: int, - micro_batch_size: int, - decoder_seq_length: int = None, - forward_only: bool = False, - collect_non_loss_data: bool = False, -): - """Run non-interleaved 1F1B schedule, with communication between pipeline - stages. - - Returns dictionary with losses if the last stage, empty dict otherwise.""" - - if isinstance(model, list): - assert ( - len(model) == 1 - ), "non-interleaved pipeline parallelism does not support model chunking" - model = model[0] - if isinstance(data_iterator, list): - assert ( - len(data_iterator) == 1 - ), "non-pipeline-parallel schedule does not support model chunking" - data_iterator = data_iterator[0] - - config = get_model_config(model) - if config.overlap_p2p_comm: - raise ValueError( - "Non-interleaved pipeline parallelism does not support overlapping p2p communication" - ) - - if config.timers is not None: - config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) - - # Disable async grad reductions - no_sync_func = config.no_sync_func - if no_sync_func is None: - no_sync_func = contextlib.nullcontext - no_sync_context = None - - def disable_grad_sync(): - """Disable asynchronous grad reductions""" - nonlocal no_sync_context - if no_sync_context is None: - no_sync_context = no_sync_func() - no_sync_context.__enter__() - - def enable_grad_sync(): - """Enable asynchronous grad reductions""" - nonlocal no_sync_context - if no_sync_context is not None: - no_sync_context.__exit__(None, None, None) - no_sync_context = None - - disable_grad_sync() - - # Compute number of warmup microbatches. - num_warmup_microbatches = ( - parallel_state.get_pipeline_model_parallel_world_size() - - parallel_state.get_pipeline_model_parallel_rank() - - 1 - ) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches - - # Checkpoint the activations of partial Transformer layers in a number of micro-batches - # within the maximum outstanding micro-batch backpropagations. - # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' - # checkpoint partial Transformer layers (or skip checkpointing) and - # the rest of micro-batches within a window of micro-batches checkpoint - # all Transformer layers. The window of micro-batches is set by the maximum - # outstanding backpropagations and becomes smaller at later pipeline stages. - # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf - max_outstanding_backprops = None - if config.num_microbatches_with_partial_activation_checkpoints is not None: - max_outstanding_backprops = num_warmup_microbatches + 1 - - model_type = get_model_type(model) - - rank = parallel_state.get_pipeline_model_parallel_rank() - recv_tensor_shapes = get_tensor_shapes( - rank=rank - 1, - model_type=model_type, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=decoder_seq_length, - config=config, - ) - send_tensor_shapes = get_tensor_shapes( - rank=rank, - model_type=model_type, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=decoder_seq_length, - config=config, - ) - - # Input, output tensors only need to be saved when doing backward passes - input_tensors = None - output_tensors = None - if not forward_only: - input_tensors = [] - output_tensors = [] - forward_data_store = [] - - # Run warmup forward passes. - for i in range(num_warmup_microbatches): - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - i % max_outstanding_backprops - >= config.num_microbatches_with_partial_activation_checkpoints - ) - else: - checkpoint_activations_microbatch = None - - input_tensor = recv_forward(recv_tensor_shapes, config) - output_tensor = forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - checkpoint_activations_microbatch, - ) - send_forward(output_tensor, send_tensor_shapes, config) - - if not forward_only: - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) - - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_microbatches_remaining > 0: - input_tensor = recv_forward(recv_tensor_shapes, config) - - # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - last_iteration = i == (num_microbatches_remaining - 1) - - # Decide to checkpoint all layers' activations of the current micro-batch - if max_outstanding_backprops is not None: - checkpoint_activations_microbatch = ( - (i + num_warmup_microbatches) % max_outstanding_backprops - ) >= config.num_microbatches_with_partial_activation_checkpoints - else: - checkpoint_activations_microbatch = None - - output_tensor = forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - checkpoint_activations_microbatch, - ) - - if forward_only: - send_forward(output_tensor, send_tensor_shapes, config) - - if not last_iteration: - input_tensor = recv_forward(recv_tensor_shapes, config) - - else: - output_tensor_grad = send_forward_recv_backward( - output_tensor, send_tensor_shapes, config - ) - - # Add input_tensor and output_tensor to end of list. - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) - - # Pop input_tensor and output_tensor from the start of the list for - # the backward pass. - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - # Enable grad sync for the last microbatch in the batch if the full - # backward pass completes in the 1F1B stage. - if num_warmup_microbatches == 0 and last_iteration: - if config.grad_sync_func is None or rank == 0: - enable_grad_sync() - - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - if last_iteration: - input_tensor = None - send_backward(input_tensor_grad, recv_tensor_shapes, config) - else: - input_tensor = send_backward_recv_forward( - input_tensor_grad, recv_tensor_shapes, config - ) - - # Run cooldown backward passes. - if not forward_only: - for i in range(num_warmup_microbatches): - - # Enable async grad reduction in the last backward pass - # Note: If grad sync function is provided, only enable - # async grad reduction in first pipeline stage. Other - # pipeline stages do grad reduction during pipeline - # bubble. - if i == num_warmup_microbatches - 1: - if config.grad_sync_func is None or rank == 0: - enable_grad_sync() - - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - output_tensor_grad = recv_backward(send_tensor_shapes, config) - - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config - ) - - send_backward(input_tensor_grad, recv_tensor_shapes, config) - - # Launch any remaining grad reductions. - if no_sync_context is not None: - enable_grad_sync() - if config.grad_sync_func is not None: - config.grad_sync_func(model.parameters()) - - if config.timers is not None: - config.timers('forward-backward').stop() - - if config.finalize_model_grads_func is not None and not forward_only: - # Finalize model grads (perform full grad all-reduce / reduce-scatter for - # data parallelism, layernorm all-reduce for sequence parallelism, and - # embedding all-reduce for pipeline parallelism). - config.finalize_model_grads_func([model]) - - return forward_data_store diff --git a/megatron/core/requirements.txt b/megatron/core/requirements.txt deleted file mode 100644 index 08ed5eeb4b9f080b780db7d3e0af6712866c0493..0000000000000000000000000000000000000000 --- a/megatron/core/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -torch \ No newline at end of file diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py deleted file mode 100644 index c8040e9e84305bab5577befd73691a019130bd1c..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -from .cross_entropy import vocab_parallel_cross_entropy -from .data import broadcast_data -from .layers import ( - ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding, - copy_tensor_model_parallel_attributes, - linear_with_grad_accumulation_and_async_allreduce, - param_is_not_tensor_parallel_duplicate, - set_defaults_if_not_set_tensor_model_parallel_attributes, - set_tensor_model_parallel_attributes, -) -from .mappings import ( - copy_to_tensor_model_parallel_region, - gather_from_sequence_parallel_region, - gather_from_sequence_parallel_region_to_moe, - gather_from_tensor_model_parallel_region, - reduce_scatter_to_sequence_parallel_region_from_moe, - scatter_to_sequence_parallel_region, - scatter_to_tensor_model_parallel_region, -) -from .random import ( - checkpoint, - get_cuda_rng_tracker, - get_data_parallel_rng_tracker_name, - model_parallel_cuda_manual_seed, -) -from .utils import ( - gather_split_1d_tensor, - split_tensor_along_last_dim, - split_tensor_into_1d_equal_chunks, -) - -__all__ = [ - # cross_entropy.py - "vocab_parallel_cross_entropy", - # data.py - "broadcast_data", - # layers.py - "ColumnParallelLinear", - "RowParallelLinear", - "VocabParallelEmbedding", - "set_tensor_model_parallel_attributes", - "set_defaults_if_not_set_tensor_model_parallel_attributes", - "copy_tensor_model_parallel_attributes", - "param_is_not_tensor_parallel_duplicate", - "linear_with_grad_accumulation_and_async_allreduce", - # mappings.py - "copy_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "gather_from_sequence_parallel_region", - # "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - # random.py - "checkpoint", - "get_cuda_rng_tracker", - "model_parallel_cuda_manual_seed", - # utils.py - "split_tensor_along_last_dim", - "split_tensor_into_1d_equal_chunks", - "gather_split_1d_tensor", - "gather_from_sequence_parallel_region_to_moe", - "reduce_scatter_to_sequence_parallel_region_from_moe", -] diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py deleted file mode 100644 index 645fd1ea0c3673299f354c610c36e27e3345c0bc..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/cross_entropy.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -from megatron.core.parallel_state import ( - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) - -from .utils import VocabUtility - - -class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): - - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce( - logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() - ) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce( - predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group(), - ) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce( - sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_tensor_model_parallel_group(), - ) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - - # Normalize and optionally smooth logits - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - - vocab_size = exp_logits.size(-1) - if label_smoothing > 0: - """ - We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. - = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) - = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i - = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K - From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py - """ - assert 1.0 > label_smoothing > 0.0 - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - - # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. - log_probs = torch.log(exp_logits) - mean_log_probs = log_probs.mean(dim=-1) - loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs - - ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size - - # Store softmax, target-mask and masked-target for backward pass. - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - - return loss - - @staticmethod - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - - softmax_update = 1.0 - target_mask.view(-1).float() - - if label_smoothing > 0: - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update - average_grad = 1 / vocab_size - grad_2d[arange_1d, :] -= smoothing * average_grad - else: - grad_2d[arange_1d, masked_target_1d] -= softmax_update - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None, None - - -def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): - """ - Performs cross entropy loss when logits are split across tensor parallel ranks - - Arguments: - vocab_parallel_logits: logits split across tensor parallel ranks - dimension is [sequence_length, batch_size, hidden_size] - - target: correct vocab ids of dimseion [sequence_length, micro_batch_size] - - lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) - default is no smoothing (=0.0) - """ - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) diff --git a/megatron/core/tensor_parallel/data.py b/megatron/core/tensor_parallel/data.py deleted file mode 100644 index 45c4fe7eb0cc0b83bcf2407c5bb2c5f36ed4582e..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/data.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -from megatron.core.parallel_state import ( - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_src_rank, -) - -_MAX_DATA_DIM = 5 - - -def _check_data_types(keys, data, target_dtype): - """Check that all the keys have the same target data type.""" - for key in keys: - assert data[key].dtype == target_dtype, ( - '{} has data type {} which ' - 'is different than {}'.format(key, data[key].dtype, target_dtype) - ) - - -def _build_key_size_numel_dictionaries(keys, data): - """Build the size on rank 0 and broadcast.""" - max_dim = _MAX_DATA_DIM - sizes = [0 for _ in range(max_dim) for _ in keys] - - # Pack the sizes on rank zero. - if get_tensor_model_parallel_rank() == 0: - offset = 0 - for key in keys: - assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' - size = data[key].size() - for i, s in enumerate(size): - sizes[i + offset] = s - offset += max_dim - - # Move to GPU and broadcast. - sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast( - sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() - ) - - # Move back to cpu and unpack. - sizes_cpu = sizes_cuda.cpu() - key_size = {} - key_numel = {} - total_numel = 0 - offset = 0 - for key in keys: - i = 0 - size = [] - numel = 1 - while sizes_cpu[offset + i] > 0: - this_size = sizes_cpu[offset + i] - size.append(this_size) - numel *= this_size - i += 1 - key_size[key] = size - key_numel[key] = numel - total_numel += numel - offset += max_dim - - return key_size, key_numel, total_numel - - -def broadcast_data(keys, data, datatype): - """Broadcast data from rank zero of each model parallel group to the - members of the same model parallel group. - - Arguments: - keys: list of keys in the data disctionary to be broadcasted - data: data dictionary of string keys and cpu tensor values. - datatype: torch data type of all tensors in data associated - with keys. - """ - # Build (key, size) and (key, number of elements) dictionaries along - # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) - - # Pack on rank zero. - if get_tensor_model_parallel_rank() == 0: - # Check that all keys have the same data type. - _check_data_types(keys, data, datatype) - # Flatten the data associated with the keys - flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() - else: - flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) - - # Broadcast - torch.distributed.broadcast( - flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() - ) - - # Unpack - output = {} - offset = 0 - for key in keys: - size = key_size[key] - numel = key_numel[key] - output[key] = flatten_data.narrow(0, offset, numel).view(size) - offset += numel - - return output diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py deleted file mode 100644 index f31ee42df623f4433af90570c5bf0fd0914aa4e0..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/layers.py +++ /dev/null @@ -1,949 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch - -import math -import os -import warnings -from typing import Callable, Optional - -import torch -import torch.nn.functional as F -import torch.nn.init as init -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.parameter import Parameter - -from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.parallel_state import ( - get_global_memory_buffer, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) - -from .mappings import ( - copy_to_tensor_model_parallel_region, - gather_from_sequence_parallel_region, - gather_from_tensor_model_parallel_region, - reduce_from_tensor_model_parallel_region, - reduce_scatter_to_sequence_parallel_region, - scatter_to_tensor_model_parallel_region, -) -from .random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name -from .utils import VocabUtility, divide, split_tensor_along_last_dim - -_grad_accum_fusion_available = True -try: - import fused_weight_gradient_mlp_cuda -except ImportError: - _grad_accum_fusion_available = False - -_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { - 'tensor_model_parallel': False, - 'partition_dim': -1, - 'partition_stride': 1, -} - - -def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( - get_tensor_model_parallel_rank() == 0 - ) - - -def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): - # Make sure the attributes are not set. - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) - # Set the attributes. - setattr(tensor, 'tensor_model_parallel', is_parallel) - setattr(tensor, 'partition_dim', dim) - setattr(tensor, 'partition_stride', stride) - - -def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): - def maybe_set(attribute, value): - if not hasattr(tensor, attribute): - setattr(tensor, attribute, value) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) - - -def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): - def maybe_copy(attribute): - if hasattr(source_tensor, attribute): - setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_copy(attribute) - - -def _initialize_affine_weight_gpu( - weight, init_method, partition_dim, stride=1, expert_parallel=False -): - """Initialize affine weight for model parallel on GPU.""" - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - if not expert_parallel: - with get_cuda_rng_tracker().fork(): - init_method(weight) - else: - with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()): - init_method(weight) - - -def _initialize_affine_weight_cpu( - weight, - output_size, - input_size, - per_partition_size, - partition_dim, - init_method, - stride=1, - return_master_weight=False, - *, - params_dtype=torch.float32, -): - """Initialize affine weight for model parallel. - - Build the master weight on all processes and scatter - the relevant chunk.""" - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - # Initialize master weight - master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) - init_method(master_weight) - master_weight = master_weight.to(dtype=params_dtype) - - # Split and copy - per_partition_per_stride_size = divide(per_partition_size, stride) - weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - my_weight_list = weight_list[rank::world_size] - - with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) - if return_master_weight: - return master_weight - return None - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - - Keyword Arguments: - config: A megatron.core.ModelParallelConfig object - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - *, - init_method: Callable, - config: ModelParallelConfig, - ): - super(VocabParallelEmbedding, self).__init__() - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() - # Divide the weight matrix along the vocaburaly dimension. - ( - self.vocab_start_index, - self.vocab_end_index, - ) = VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size - ) - self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index - - # Allocate weights and initialize. - if config.use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype - ) - ) - if config.perform_initialization: - _initialize_affine_weight_cpu( - self.weight, - self.num_embeddings, - self.embedding_dim, - self.num_embeddings_per_partition, - 0, - init_method, - params_dtype=config.params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - ) - if config.perform_initialization: - _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) - - def forward(self, input_): - assert not torch.any( - (input_ < 0) | (input_ >= self.num_embeddings) - ), "An input token is out of bounds of the embedding table" - if self.tensor_model_parallel_size > 1: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = self.weight[masked_input] - # Mask the output embedding. - if self.tensor_model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = reduce_from_tensor_model_parallel_region(output_parallel) - return output - - -class LinearWithFrozenWeight(torch.autograd.Function): - """Linear operator that does not calculate gradient for weight. - This op and LinearWithGradAccumulationAndAsyncCommunication performs - mathematically-identical forward and DGRAD. - - Conceptually this op is the same as torch.nn.functional.linear with - weight.requires_grad==False, but in experiments they are not identical - mathematically. """ - - @staticmethod - @custom_fwd - def forward( - ctx, input, weight, bias, - ): - ctx.save_for_backward(weight) - output = torch.matmul(input, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - (weight,) = ctx.saved_tensors - grad_input = grad_output.matmul(weight) - return grad_input, None, None - - -def linear_with_frozen_weight( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel: bool, -) -> torch.Tensor: - """Linear layer execution with weight.requires_grad == False. - - This function handles linear layers with weight frozen (untrainable). - In the forward, it only saves weight and does not save input activations. - In the backward, it does not perform weight gradient calculation, or - weight gradient allreduce. - - Arguments: - - input (torch.Tensor required): input like torch.nn.functional.linear - - weight (torch.Tensor required): weight like torch.nn.functional.linear - - bias (torch.Tensor optional): bias like torch.nn.functional.linear - - gradient_accumulation_fusion (bool required): dummy argument, used to - keep the API unified between all forward implementation functions. - - async_grad_allreduce (bool required): dummy argument, used to - keep the API unified between all forward implementation functions. - - sequence_parallel (bool required): Indicates that sequence - parallelism is used and thus in the forward pass the input is - all gathered, and the backward pass the input gradients are - reduce scattered. - """ - - if sequence_parallel: - input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True) - else: - input = input - - args = [ - input, - weight, - bias, - ] - - return LinearWithFrozenWeight.apply(*args) - - -class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): - """See linear_with_grad_accumulation_and_async_allreduce""" - - @staticmethod - @custom_fwd - def forward( - ctx, - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel, - ): - ctx.save_for_backward(input, weight) - ctx.use_bias = bias is not None - ctx.gradient_accumulation_fusion = gradient_accumulation_fusion - ctx.async_grad_allreduce = async_grad_allreduce - ctx.sequence_parallel = sequence_parallel - - if sequence_parallel: - world_size = get_tensor_model_parallel_world_size() - dim_size = list(input.size()) - dim_size[0] = dim_size[0] * world_size - - all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") - torch.distributed._all_gather_base( - all_gather_buffer, input, group=get_tensor_model_parallel_group() - ) - total_input = all_gather_buffer - else: - total_input = input - - output = torch.matmul(total_input, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - use_bias = ctx.use_bias - - if ctx.sequence_parallel: - world_size = get_tensor_model_parallel_world_size() - dim_size = list(input.size()) - dim_size[0] = dim_size[0] * world_size - - all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") - handle = torch.distributed._all_gather_base( - all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True - ) - - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # gather is scheduled before the input gradient computation - total_input = all_gather_buffer - else: - total_input = input - grad_input = grad_output.matmul(weight) - - if ctx.sequence_parallel: - handle.wait() - - # Doing gather + slicing during the NeMo forward pass can make this tensor - # not be contiguous. PyTorch only checks if the tensor is contiguous, and only - # clones it if it's not contiguous: - # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view( - grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] - ) - total_input = total_input.view( - total_input.shape[0] * total_input.shape[1], total_input.shape[2] - ) - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = torch.distributed.all_reduce( - grad_input, group=get_tensor_model_parallel_group(), async_op=True - ) - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # all-reduce is scheduled before the weight gradient computation - - if ctx.sequence_parallel: - assert not ctx.async_grad_allreduce - dim_size = list(input.size()) - sub_grad_input = torch.empty( - dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False - ) - # reduce_scatter - handle = torch.distributed._reduce_scatter_base( - sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True - ) - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # reduce scatter is scheduled before the weight gradient computation - - if ctx.gradient_accumulation_fusion: - if weight.main_grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( - total_input, grad_output, weight.main_grad - ) - elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( - total_input, grad_output, weight.main_grad - ) - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - - if hasattr(weight, 'grad_added_to_main_grad'): - # When overlap_grad_reduce is True, need to ensure that backward hooks - # are all run on the main backprop thread to prevent deadlocks. Setup - # dummy grad_weight tensor to prevent backward hooks from being run - # in a background thread. - grad_weight = torch.empty( - weight.main_grad.shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - weight.grad_added_to_main_grad = True - else: - grad_weight = None - else: - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.sequence_parallel: - handle.wait() - return sub_grad_input, grad_weight, grad_bias, None, None, None - - if ctx.async_grad_allreduce: - handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - - -def linear_with_grad_accumulation_and_async_allreduce( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel: bool, -) -> torch.Tensor: - """Linear layer execution with asynchronous communication and - gradient accumulation fusion in backprop. - - This has the option to accumulate the result of backprop - calculation into an existing gradient buffer, preventing the need - to do an additional addition kernel after the gradient - calculation. - - Additionally, the tensor parallel all reduce of the input - gradients can be done asynchronously with the calculation of - the weight gradients. - - In the case of sequence parallelism, the reduce scatter of the - input gradients is done asynchronously with the calcluation of the - weight gradients. - - Use of this module requires that the environment variable - CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective - operations, noted in the code, that should be scheduled before - compute kernels to overlap the communication with the computation, - which is necessary for a speedup but not for correctness so that - ordering isn't imposed by the scheduler. Setting - CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled - in the order they are called. - - Arguments: - - input (torch.Tensor required): input like torch.nn.functional.linear - - weight (torch.Tensor required): weight like torch.nn.functional.linear - - bias (torch.Tensor optional): bias like torch.nn.functional.linear - - gradient_accumulation_fusion (bool required): Perform the gradient - accumulation fusion, requires the custom CUDA extension - fused_weight_gradient_mlp_cuda module. To use - gradient_accumulation_fusion you must install APEX with - --cpp_ext and --cuda_ext. For example: "pip install - --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" - " Note that the extension requires CUDA>=11. Otherwise, you - must turn off gradient accumulation fusion." - - async_grad_allreduce (bool required): Do the allreduce of input - gradients asyncronously with the computation of weight - gradients. If sequence_parallel is True, this must be - False, as no all reduce is performed. - - sequence_parallel (bool required): Indicates that sequence - parallelism is used and thus in the forward pass the input is - all gathered, and the backward pass the input gradients are - reduce scattered. - """ - args = [ - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel, - ] - - if not linear_with_grad_accumulation_and_async_allreduce.warned: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": - if sequence_parallel: - warnings.warn( - "When using sequence parallelism it is recommended to set the " - "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " - "maximum speedup" - ) - linear_with_grad_accumulation_and_async_allreduce.warned = True - - if async_grad_allreduce: - warnings.warn( - "When using async grad allreduce it is recommended to set the " - "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " - "maximum speedup" - ) - linear_with_grad_accumulation_and_async_allreduce.warned = True - - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) - - -linear_with_grad_accumulation_and_async_allreduce.warned = False - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments - bias: If true, add bias - gather_output: If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: If True, do not add the bias term, instead - return it to be added by the caller. This - enables performance optimations where bias can - be fused with other elementwise operations. - skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed - as a keyword argument `weight` during the forward pass. Note - that this does not affect bias, which will be allocated if - bias is True. Defaults to False. - is_expert: If True, the layer is treated as an MoE expert layer. - config: ModelParallelConfig object - tp_comm_buffer_name: Communication buffer name is not used in - non-Transformer-Engine modules. - - """ - - def __init__( - self, - input_size, - output_size, - *, - config: ModelParallelConfig, - init_method: Callable, - bias=True, - gather_output=False, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - skip_weight_param_allocation: bool = False, - is_expert: bool = False, - tp_comm_buffer_name: str = None, # Not used - ): - super(ColumnParallelLinear, self).__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, world_size) - self.skip_bias_add = skip_bias_add - self.is_expert = is_expert - self.expert_parallel = config.expert_model_parallel_size > 1 - self.config = config - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if not skip_weight_param_allocation: - if config.use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.output_size_per_partition, self.input_size, dtype=config.params_dtype - ) - ) - if config.perform_initialization: - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size_per_partition, - self.input_size, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - ) - if config.perform_initialization: - _initialize_affine_weight_gpu( - self.weight, - init_method, - partition_dim=0, - stride=stride, - expert_parallel=(self.is_expert and self.expert_parallel), - ) - - setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel)) - else: - self.weight = None - - if bias: - if config.use_cpu_initialization: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, dtype=config.params_dtype) - ) - else: - self.bias = Parameter( - torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - ) - set_tensor_model_parallel_attributes(self.bias, True, 0, stride) - if config.perform_initialization: - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel)) - else: - self.register_parameter('bias', None) - - self.async_tensor_model_parallel_allreduce = ( - config.async_tensor_model_parallel_allreduce and world_size > 1 - ) - - self.sequence_parallel = config.sequence_parallel - if self.sequence_parallel and world_size <= 1: - warnings.warn( - f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. " - f"Disabling sequence parallel." - ) - self.sequence_parallel = False - - if config.gradient_accumulation_fusion and not _grad_accum_fusion_available: - raise RuntimeError( - "ColumnParallelLinear was called with gradient_accumulation_fusion set " - "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " - "module is not found. To use gradient_accumulation_fusion you must " - "install APEX with --cpp_ext and --cuda_ext. For example: " - "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " - "Note that the extension requires CUDA>=11. Otherwise, you must turn off " - "gradient accumulation fusion." - ) - self.gradient_accumulation_fusion = config.gradient_accumulation_fusion - - if self.async_tensor_model_parallel_allreduce and self.sequence_parallel: - raise RuntimeError( - "`async_tensor_model_parallel_allreduce` and `sequence_parallel` " - "cannot be enabled at the same time." - ) - - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - self.explicit_expert_comm = self.is_expert and ( - self.sequence_parallel or self.expert_parallel - ) - - def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): - """Forward of ColumnParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - weight (optional): weight tensor to use, compulsory when - skip_weight_param_allocation is True. - - Returns: - - output - - bias - - """ - if weight is None: - if self.weight is None: - raise RuntimeError( - "weight was not supplied to ColumnParallelLinear forward pass " - "and skip_weight_param_allocation is True." - ) - weight = self.weight - else: - # Check the weight passed in is the correct shape - expected_shape = (self.output_size_per_partition, self.input_size) - if weight.shape != expected_shape: - raise RuntimeError( - f"supplied weight's shape is {tuple(weight.shape)}, " - f"not {expected_shape} as expected" - ) - - bias = self.bias if not self.skip_bias_add else None - - if ( - self.async_tensor_model_parallel_allreduce - or self.sequence_parallel - or self.explicit_expert_comm - ): - input_parallel = input_ - else: - input_parallel = copy_to_tensor_model_parallel_region(input_) - - # Matrix multiply. - if not weight.requires_grad: - self._forward_impl = linear_with_frozen_weight - else: - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - output_parallel = self._forward_impl( - input=input_parallel, - weight=weight, - bias=bias, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False - if self.explicit_expert_comm - else self.async_tensor_model_parallel_allreduce, - sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel, - ) - if self.gather_output: - # All-gather across the partitions. - assert not self.sequence_parallel - output = gather_from_tensor_model_parallel_region(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments: - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: If True, do not add the bias term, instead - return it to be added by the caller. This - enables performance optimations where bias can - be fused with other elementwise operations. - is_expert: If True, the layer is treated as an MoE expert layer - tp_comm_buffer_name: Communication buffer name. Not used in - non-Transformer-Engine modules. - config: ModelParallelConfig object - - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - input_is_parallel: bool, - skip_bias_add: bool, - stride: int = 1, - keep_master_weight_for_test: bool = False, - is_expert: bool = False, - tp_comm_buffer_name: str = None, # Not used - ): - super(RowParallelLinear, self).__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, world_size) - self.skip_bias_add = skip_bias_add - self.config = config - self.is_expert = is_expert - self.expert_parallel = config.expert_model_parallel_size > 1 - self.gradient_accumulation_fusion = config.gradient_accumulation_fusion - self.sequence_parallel = config.sequence_parallel - if self.sequence_parallel and not self.input_is_parallel: - raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if config.use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.output_size, self.input_size_per_partition, dtype=config.params_dtype - ) - ) - if config.perform_initialization: - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - params_dtype=config.params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size, - self.input_size_per_partition, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - ) - if config.perform_initialization: - _initialize_affine_weight_gpu( - self.weight, - init_method, - partition_dim=1, - stride=stride, - expert_parallel=(self.is_expert and self.expert_parallel), - ) - setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel)) - - if bias: - if config.use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype)) - else: - self.bias = Parameter( - torch.empty( - self.output_size, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - ) - - if config.perform_initialization: - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel)) - setattr(self.bias, 'sequence_parallel', self.sequence_parallel) - else: - self.register_parameter('bias', None) - - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - self.explicit_expert_comm = self.is_expert and ( - self.sequence_parallel or self.expert_parallel - ) - - def forward(self, input_): - """Forward of RowParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - assert not self.sequence_parallel - input_parallel = scatter_to_tensor_model_parallel_region(input_) - # Matrix multiply. - if not self.weight.requires_grad: - self._forward_impl = linear_with_frozen_weight - else: - self._forward_impl = linear_with_grad_accumulation_and_async_allreduce - output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, - bias=None, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False, - sequence_parallel=False, - ) - - # All-reduce across all the partitions. - if self.explicit_expert_comm: - assert self.skip_bias_add - output_ = output_parallel - elif self.sequence_parallel: - output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) - else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) - if not self.skip_bias_add: - output = (output_ + self.bias) if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py deleted file mode 100644 index 95c8841be7c90e740baab285de7d9e03a90c98bb..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/mappings.py +++ /dev/null @@ -1,358 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - -from megatron.core.parallel_state import ( - get_tensor_and_expert_parallel_group, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) - -from .utils import split_tensor_along_last_dim - - -def _reduce(input_): - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - - # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) - - return input_ - - -def _split_along_last_dim(input_): - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) - - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() - - return output - - -def _split_along_first_dim(input_): - """Split the tensor along its first dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along first dimension. - dim_size = input_.size()[0] - assert ( - dim_size % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" - local_dim_size = dim_size // world_size - rank = get_tensor_model_parallel_rank() - dim_offset = rank * local_dim_size - - output = input_[dim_offset : dim_offset + local_dim_size].contiguous() - - return output - - -def _gather_along_last_dim(input_): - """Gather tensors and concatinate along the last dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output - - -def _gather_along_first_dim(input_): - """Gather tensors and concatinate along the first dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._all_gather_base( - output, input_.contiguous(), group=get_tensor_model_parallel_group() - ) - - return output - - -def _reduce_scatter_along_first_dim(input_): - """Reduce-scatter the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" - - dim_size[0] = dim_size[0] // world_size - - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._reduce_scatter_base( - output, input_.contiguous(), group=get_tensor_model_parallel_group() - ) - return output - - -def _gather_along_first_dim_moe(input_): - """Gather tensors and concatenate along the first dimension.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - dim_size[0] = dim_size[0] * world_size - - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._all_gather_base(output, input_.contiguous(), group=group) - - return output - - -def _reduce_scatter_along_first_dim_moe(input_): - """Reduce-scatter the input tensor across model parallel group.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - dim_size = list(input_.size()) - assert dim_size[0] % world_size == 0 - dim_size[0] = dim_size[0] // world_size - - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=group) - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _ScatterToModelParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def symbolic(graph, input_): - return _split_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_last_dim(grad_output) - - -class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from model parallel region and concatinate.""" - - @staticmethod - def symbolic(graph, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_last_dim(grad_output) - - -class _ScatterToSequenceParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - @staticmethod - def symbolic(graph, input_): - return _split_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _GatherFromSequenceParallelRegion(torch.autograd.Function): - """Gather the input from sequence parallel region and concatinate.""" - - @staticmethod - def symbolic(graph, input_, tensor_parallel_output_grad=True): - return _gather_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_, tensor_parallel_output_grad=True): - ctx.tensor_parallel_output_grad = tensor_parallel_output_grad - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - tensor_parallel_output_grad = ctx.tensor_parallel_output_grad - - # If the computation graph after the gather operation is - # in the tensor parallel mode, output gradients need to reduce - # scattered and whereas if the computation is duplicated, - # output gradients need to be scattered. - if tensor_parallel_output_grad: - return _reduce_scatter_along_first_dim(grad_output), None - else: - return _split_along_first_dim(grad_output), None - - -class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): - """Reduce scatter the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _GatherFromSequenceParallelRegionToMOE(torch.autograd.Function): - """Gather the input from model parallel region and concatenate.""" # TODO - - @staticmethod - def symbolic(graph, input_): - return _gather_along_first_dim_moe(input_) - - @staticmethod - def forward(ctx, input_): - return _gather_along_first_dim_moe(input_,) - - @staticmethod - def backward(ctx, grad_output): - return _reduce_scatter_along_first_dim_moe(grad_output) - - -class _ReduceScatterToSequenceParallelRegionFromMOE(torch.autograd.Function): - """Reduce scatter the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _reduce_scatter_along_first_dim_moe(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce_scatter_along_first_dim_moe(input_,) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim_moe(grad_output) - - -# ----------------- -# Helper functions. -# ----------------- - - -def copy_to_tensor_model_parallel_region(input_): - return _CopyToModelParallelRegion.apply(input_) - - -def reduce_from_tensor_model_parallel_region(input_): - return _ReduceFromModelParallelRegion.apply(input_) - - -def scatter_to_tensor_model_parallel_region(input_): - return _ScatterToModelParallelRegion.apply(input_) - - -def gather_from_tensor_model_parallel_region(input_): - return _GatherFromModelParallelRegion.apply(input_) - - -def scatter_to_sequence_parallel_region(input_): - return _ScatterToSequenceParallelRegion.apply(input_) - - -def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): - return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) - - -def reduce_scatter_to_sequence_parallel_region(input_): - return _ReduceScatterToSequenceParallelRegion.apply(input_) - - -def gather_from_sequence_parallel_region_to_moe(input_): - return _GatherFromSequenceParallelRegionToMOE.apply(input_) - - -def reduce_scatter_to_sequence_parallel_region_from_moe(input_): - return _ReduceScatterToSequenceParallelRegionFromMOE.apply(input_) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py deleted file mode 100644 index 9d51b09f7e3a37042b555f10afd889b0e5930df2..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/random.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch - -import contextlib - -import torch -from torch import _C -from torch.cuda import _lazy_call -from torch.cuda import device as device_ctx_manager -from torch.utils.checkpoint import detach_variable - -from megatron.core.parallel_state import ( - get_data_parallel_rank, - get_expert_model_parallel_rank, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from megatron.core.utils import safely_set_viewless_tensor_data - -from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' -_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng' -_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng' - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Argumentss: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - - else: - # newer PyTorch - if device == -1: - device = torch.device('cuda') - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device('cuda', device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - -def get_expert_parallel_rng_tracker_name(): - global _EXPERT_PARALLEL_RNG_TRACKER_NAME - return _EXPERT_PARALLEL_RNG_TRACKER_NAME - - -def get_data_parallel_rng_tracker_name(): - global _DATA_PARALLEL_RNG_TRACKER_NAME - return _DATA_PARALLEL_RNG_TRACKER_NAME - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model paralle groups. This is used for - example for dropout in the non-tensor-model-parallel regions. - tensor-model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() - # Data parallel gets the original seed. - data_parallel_seed = seed - - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - _CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed) - - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) - - expert_parallel_seed = ( - seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank() - ) - _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - - @staticmethod - def forward(ctx, run_function, distribute_saved_activations, *args): - ctx.run_function = run_function - ctx.distribute_saved_activations = distribute_saved_activations - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - with torch.no_grad(): - outputs = run_function(*args) - - # Divide hidden states across model parallel group and only keep - # the chunk corresponding to the current rank. - if distribute_saved_activations: - ctx.input_0_shape = args[0].data.shape - safely_set_viewless_tensor_data( - args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) - ) - - # Store everything. - ctx.save_for_backward(*args) - - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad(), " - "please use .backward() if possible" - ) - inputs = ctx.saved_tensors - if ctx.distribute_saved_activations: - safely_set_viewless_tensor_data( - inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape) - ) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # Compute the forward pass. - detached_inputs = detach_variable(inputs) - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) - return (None, None) + grads - - -def checkpoint(function, distribute_saved_activations, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/megatron/core/tensor_parallel/utils.py b/megatron/core/tensor_parallel/utils.py deleted file mode 100644 index a79ae1e87e33463ab6d54ffef8c73ad214cff3ef..0000000000000000000000000000000000000000 --- a/megatron/core/tensor_parallel/utils.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from typing import List, Sequence - -import torch - -from megatron.core import parallel_state -from megatron.core.utils import divide - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """ Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): - """ Break a tensor into equal 1D chunks across tensor parallel ranks. - - Returns a Tensor or View with this rank's portion of the data. - - Arguments: - tensor: The tensor to split - - Keyword Arguments: - new_buffer (bool): If True, returns a new Tensor. - If False, returns a view into the existing Tensor. - Default is False - - """ - partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size() - start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() - end_index = start_index + partition_size - if new_buffer: - data = torch.empty( - partition_size, - dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor): - """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor - model parallel ranks. - - Returns a new Tensor with the gathered data. - - Arguments: - tensor: A Tensor or view of this rank's portion of the data. - """ - numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size() - gathered = torch.empty( - numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False - ) - # TODO: This API is experimental in pytorch (as of Feb 2022) and - # this might break in future pytorch releases. We chose this API - # as opposed to torch.distributed.all_gather for efficiency reasons. - # This API calls directly NCCL all-gather versus the former does - # internal copies and can potentially cause slow down. - torch.distributed._all_gather_base( - gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() - ) - return gathered - - -class VocabUtility: - """ Split the vocabulary into `world_size` chunks and return the first - and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last) - - """ - - @staticmethod - def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank, world_size: int - ) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - @staticmethod - def vocab_range_from_global_vocab_size( - global_vocab_size: int, rank: int, world_size: int - ) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size - ) diff --git a/megatron/core/transformer/__init__.py b/megatron/core/transformer/__init__.py deleted file mode 100644 index 7cc10776b7459542eb35ecb5e768dbef9bd54d05..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from .module import MegatronModule -from .spec_utils import ModuleSpec, build_module -from .transformer_config import TransformerConfig -from .transformer_layer import TransformerLayer, TransformerLayerSubmodules diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py deleted file mode 100644 index c725c7f3a20b9276b8d58987ea5c47cce136a8d0..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/attention.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Union - -import torch - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import divide - -from .enums import AttnMaskType -from .transformer_config import TransformerConfig -from .utils import make_sharded_tensors_for_checkpoint - - -@dataclass -class SelfAttentionSubmodules: - linear_qkv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - - -@dataclass -class CrossAttentionSubmodules: - linear_q: Union[ModuleSpec, type] = None - linear_kv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - - -class Attention(MegatronModule, ABC): - """Attention layer abstract class. - - This layer only contains common modules required for the "self attn" and - "cross attn" specializations. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules], - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - ): - super().__init__(config=config) - - self.config = config - self.layer_number = layer_number - self.attn_mask_type = attn_mask_type - self.attention_type = attention_type - - # For normal attention without groups, num_query_groups == num_attention_heads, - # so these two will be the same - self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads - self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = divide( - self.query_projection_size, self.config.num_attention_heads - ) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) - - self.core_attention = build_module( - submodules.core_attention, - config=self.config, - layer_number=self.layer_number, - attn_mask_type=self.attn_mask_type, - attention_type=self.attention_type, - ) - - self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' - - # Output. - self.linear_proj = build_module( - submodules.linear_proj, - self.query_projection_size, - self.config.hidden_size, - config=self.config, - init_method=self.config.output_layer_init_method, - bias=self.config.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True, - is_expert=False, - tp_comm_buffer_name='proj', - ) - - def _checkpointed_attention_forward( - self, query, key, value, attention_mask, rotary_pos_emb=None, attn_mask_type=None - ): - """Forward method with selective activation checkpointing.""" - - def custom_forward(*inputs): - query = inputs[0] - key = inputs[1] - value = inputs[2] - attention_mask = inputs[3] - attn_mask_type = inputs[5] - attn_mask_type = AttnMaskType(attn_mask_type.item()) - output_ = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type - ) - return output_ - - if attn_mask_type is None: - attn_mask_type = self.attn_mask_type - attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) - hidden_states = tensor_parallel.checkpoint( - custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type - ) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): - """Allocate memory to store kv cache during inference.""" - - return torch.empty( - inference_max_sequence_length, - batch_size, - self.num_query_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) - - def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb): - """ - Saves the generated key and value tensors to the end of the buffers in inference_params. - Returns the full size keys and values from the provided inference_params, as well as - adjusted rotary_pos_emb. - - Returns a tuple: (key, value, rotary_pos_emb) - - """ - attn_mask_type = self.attn_mask_type - if inference_params is None: - return key, value, rotary_pos_emb, attn_mask_type - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - is_first_step = False - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_length = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, key.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, value.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - is_first_step = True - else: - # Get the pre-allocated buffers for this layer - inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ - self.layer_number - ] - attn_mask_type = AttnMaskType.no_mask - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key.size(0) - assert sequence_end <= inference_key_memory.size(0) - # Copy key and values. - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value - key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - # adjust the key rotary positional embedding - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # need to cross check this condition during inference - # if not set_inference_key_value_memory: - if not is_first_step: - # In inference, we compute one token at a time. - # Select the correct positional embedding - # (only the last token in the sequence) - q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] - else: - # In the first forward pass of inference, - # we use the entire provided prefix. - # q_pos_emb here has the rope embeddings of the entire - # prefix + to-be-generated output so - # we slice to just the prefix. - q_pos_emb = q_pos_emb[:sequence_end, :, :, :] - k_pos_emb = k_pos_emb[:sequence_end, :, :, :] - rotary_pos_emb = (q_pos_emb, k_pos_emb) - - return key, value, rotary_pos_emb, attn_mask_type - - @abstractmethod - def get_query_key_value_tensors(self, hidden_states, key_value_states): - """ - This method needs to be implemented based on whether the derived class - is "self-attn" or "cross-attn". - """ - - def forward( - self, - hidden_states, - attention_mask, - key_value_states=None, - inference_params=None, - rotary_pos_emb=None, - ): - # hidden_states: [sq, b, h] - - # For self attention we just duplicate the rotary_pos_emb if it isn't already - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb,) * 2 - - # ===================== - # Query, Key, and Value - # ===================== - # Get the query, key and value tensors based on the type of attention - - # self or cross attn. - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - # =================================================== - # Adjust key, value, and rotary_pos_emb for inference - # =================================================== - key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, key, value, rotary_pos_emb - ) - - # ================================================ - # relative positional embedding (rotary embedding) - # ================================================ - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention: - core_attn_out = self._checkpointed_attention_forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type - ) - else: - core_attn_out = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type - ) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.linear_proj(core_attn_out) - - return output, bias - - -class SelfAttention(Attention): - """Self-attention layer class - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: SelfAttentionSubmodules, - layer_number: int, - attn_mask_type=AttnMaskType.padding, - ): - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - attention_type="self", - ) - - self.linear_qkv = build_module( - submodules.linear_qkv, - self.config.hidden_size, - self.query_projection_size + 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='qkv', - ) - - def get_query_key_value_tensors(self, hidden_states, key_value_states=None): - """ - Derives `query`, `key` and `value` tensors from `hidden_states`. - """ - # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - mixed_qkv, _ = self.linear_qkv(hidden_states) - - # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - new_tensor_shape = mixed_qkv.size()[:-1] + ( - self.num_query_groups_per_partition, - ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head - ), - ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape) - - # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split( - mixed_qkv, - [ - ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ], - dim=3, - ) - # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) - - return query, key, value - - def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): - sharded_key_prefix = prefix if sharded_key_prefix is None else sharded_key_prefix - sharded_state_dict = {} - for name, module in ( - ('linear_qkv', self.linear_qkv), - ('linear_proj', self.linear_proj), - ): - sub_sd = module.sharded_state_dict( - prefix=f'{prefix}{name}.', - sharded_key_prefix=f'{sharded_key_prefix}{name}.', - sharded_offsets=sharded_offsets, - ) - sharded_state_dict.update(sub_sd) - return sharded_state_dict - - -class CrossAttention(Attention): - """Cross-attention layer class - - Cross-attention layer takes input with size [s, b, h] and context with size - [s, b, h] and returns output of the same size. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: CrossAttentionSubmodules, - layer_number: int, - attn_mask_type=AttnMaskType.padding, - ): - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - attn_mask_type=attn_mask_type, - attention_type="cross", - ) - - if self.config.num_query_groups != self.config.num_attention_heads: - raise ValueError( - f"Group query attention is not currently supported in cross attention." - ) - assert self.query_projection_size == self.kv_projection_size - - self.linear_q = build_module( - submodules.linear_q, - self.config.hidden_size, - self.query_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=False, - is_expert=False, - ) - - self.linear_kv = build_module( - submodules.linear_kv, - self.config.hidden_size, - 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=False, - is_expert=False, - ) - - def get_query_key_value_tensors(self, hidden_states, key_value_states): - """ - Derives `query` tensor from `hidden_states`, and `key`/`value` tensors - from `key_value_states`. - """ - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv, _ = self.linear_kv(key_value_states) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv = mixed_kv.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query, _ = self.linear_q(hidden_states) - - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query = query.view(*new_tensor_shape) - - return query, key, value diff --git a/megatron/core/transformer/custom_layers/__init__.py b/megatron/core/transformer/custom_layers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py deleted file mode 100644 index d784184623a877480b4f945ebf6e44661e2fc1a3..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/custom_layers/transformer_engine.py +++ /dev/null @@ -1,431 +0,0 @@ -import os -from importlib.metadata import version -from typing import Callable - -import torch -import transformer_engine as te -from pkg_resources import packaging -from torch import Tensor - -from megatron.core import ModelParallelConfig -from megatron.core.parallel_state import ( - get_context_parallel_global_ranks, - get_context_parallel_group, - get_tensor_model_parallel_group, -) -from megatron.core.tensor_parallel import get_cuda_rng_tracker -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint - - -def _get_extra_te_kwargs(config: TransformerConfig): - extra_transformer_engine_kwargs = { - "params_dtype": config.params_dtype, - } - - te_version = packaging.version.Version(version("transformer-engine")) - if te_version >= packaging.version.Version("0.12.0"): - if config.use_cpu_initialization: - extra_transformer_engine_kwargs["device"] = 'cpu' - else: - extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() - return extra_transformer_engine_kwargs - - -class TENorm: - """ - A conditional wrapper to initialize an instance of Transformer-Engine's - `LayerNorm` or `RMSNorm` based on input - """ - - # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? - def __new__( - cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5, - ): - if config.normalization == "LayerNorm": - instance = te.pytorch.LayerNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - elif config.normalization == "RMSNorm": - assert hasattr( - te.pytorch, "RMSNorm" - ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - else: - raise Exception('Only LayerNorm and RMSNorm are curently supported') - - return instance - - -class TELinear(te.pytorch.Linear): - """ - Wrapper for the Transformer-Engine's `Linear` layer. - - Note that if Megatron's parallel_state has not been initialized - yet, the tp_group passed to TE will be None and must be set later - via set_tensor_parallel_group(). - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - parallel_mode: str, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - skip_bias_add: bool, - skip_weight_param_allocation: bool, - tp_comm_buffer_name: str = None, - ): - self.config = config - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - - if skip_weight_param_allocation: - raise ValueError( - 'Transformer Engine linear layers do not support skip_weight_param_allocation' - ) - - extra_kwargs = _get_extra_te_kwargs(config) - - te_version = packaging.version.Version(version("transformer-engine")) - if te_version >= packaging.version.Version("0.8.0"): - if self.config.tp_comm_overlap: - extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag - extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs - if te_version > packaging.version.Version("1.0.0"): - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - super().__init__( - in_features=input_size, - out_features=output_size, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=get_tensor_model_parallel_group(check_initialized=False), - tp_size=self.config.tensor_model_parallel_size, - get_rng_state_tracker=get_cuda_rng_tracker, - init_method=init_method, - bias=bias, - return_bias=self.te_return_bias, - parallel_mode=parallel_mode, - **extra_kwargs, - ) - - def forward(self, x): - out = super().forward(x) - - # TE only returns a tuple when return_bias is True, otherwise - # it returns a single Tensor, we always want to return two - # values regardless of the arguments. - if self.te_return_bias: - return out - return out, None - - -class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): - """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines - layernorm and linear layers - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: TransformerConfig, - init_method: Callable, - gather_output: bool, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, - ): - self.config = config - - if gather_output: - raise ValueError('Transformer Engine linear layers do not support gather_output = True') - - if is_expert: - raise ValueError('Transformer Engine linear layers do not yet support MoE') - - if skip_weight_param_allocation: - raise ValueError( - 'Transformer Engine linear layers do not support skip_weight_param_allocation' - ) - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - - extra_kwargs = _get_extra_te_kwargs(config) - - # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` - te_version = packaging.version.Version(version("transformer-engine")) - if te_version >= packaging.version.Version("0.11.0"): - extra_kwargs["normalization"] = self.config.normalization - elif self.config.normalization != "LayerNorm": - raise ValueError( - f"Transformer Engine v{te_version} does not support {self.config.normalization}." - ) - - if te_version >= packaging.version.Version("0.8.0"): - if self.config.tp_comm_overlap: - extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad - extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad - extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag - if te_version > packaging.version.Version("1.0.0"): - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - super().__init__( - in_features=input_size, - out_features=output_size, - eps=self.config.layernorm_epsilon, - sequence_parallel=self.config.sequence_parallel, - fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion, - tp_group=get_tensor_model_parallel_group(check_initialized=False), - tp_size=self.config.tensor_model_parallel_size, - get_rng_state_tracker=get_cuda_rng_tracker, - init_method=init_method, - bias=bias, - return_bias=self.te_return_bias, - parallel_mode="column", - return_layernorm_output=False, - zero_centered_gamma=self.config.layernorm_zero_centered_gamma, - **extra_kwargs, - ) - - def forward(self, x): - out = super().forward(x) - - # TE only returns a tuple when return_bias is True, otherwise - # it returns a single Tensor, we always want to return two - # values regardless of the arguments. - if self.te_return_bias: - return out - return out, None - - def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): - """ Sharding along axis 0, bias sharded """ - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, sharded_key_prefix, {'weight': 0, 'bias': 0}, sharded_offsets - ) - - -class TEColumnParallelLinear(TELinear): - """ - Wrapper for the Transformer-Engine's `Linear` layer but specialized similar - to megatron's `ColumnParallelLinear` layer. - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - gather_output: bool, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - skip_weight_param_allocation: bool = False, - tp_comm_buffer_name: str = None, - ): - if gather_output: - raise ValueError('Transformer Engine linear layers do not support gather_output = True') - - if is_expert: - raise ValueError('Transformer Engine linear layers do not yet support MoE') - - super().__init__( - input_size=input_size, - output_size=output_size, - parallel_mode="column", - config=config, - init_method=init_method, - bias=bias, - skip_bias_add=skip_bias_add, - skip_weight_param_allocation=skip_weight_param_allocation, - tp_comm_buffer_name=tp_comm_buffer_name, - ) - - def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): - """ Sharding along axis 0, bias sharded """ - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, sharded_key_prefix, {'weight': 0, 'bias': 0}, sharded_offsets - ) - - -class TERowParallelLinear(TELinear): - """ - Wrapper for the Transformer-Engine's `Linear` layer but specialized similar - to megatron's `RowParallelLinear` layer. - """ - - def __init__( - self, - input_size: int, - output_size: int, - *, - config: ModelParallelConfig, - init_method: Callable, - bias: bool, - input_is_parallel: bool, - skip_bias_add: bool, - is_expert: bool, - tp_comm_buffer_name: str = None, - ): - if not input_is_parallel: - raise ValueError( - "Transformer Engine linear layers do not support input_is_parallel = False" - ) - - if is_expert: - raise ValueError('Transformer Engine linear layers do not yet support MoE') - - super().__init__( - input_size=input_size, - output_size=output_size, - parallel_mode="row", - config=config, - init_method=init_method, - bias=bias, - skip_bias_add=skip_bias_add, - skip_weight_param_allocation=False, # We don't currently use this for row parallel layers - tp_comm_buffer_name=tp_comm_buffer_name, - ) - - def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): - """ Sharding along axis 1, bias not sharded """ - state_dict = self.state_dict(prefix='', keep_vars=True) - return make_sharded_tensors_for_checkpoint( - state_dict, prefix, sharded_key_prefix, {'weight': 1}, sharded_offsets - ) - - -class TEDotProductAttention(te.pytorch.DotProductAttention): - """ - Wrapper for the Transformer-Engine's `DotProductAttention` layer that also - has "flash attention" enabled. - - Note that if Megatron's parallel_state has not been initialized yet, the - tp_group and cp_group passed to TE will be None and must be set later - via set_tensor_parallel_group() and set_context_parallel_group(). - """ - - cp_stream: torch.cuda.Stream = None - - def __init__( - self, - config: TransformerConfig, - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - attention_dropout: float = None, - ): - self.config = config - self.te_forward_mask_type = False - - if self.config.apply_query_key_layer_scaling != bool( - int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0')) - ): - raise ValueError( - f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} " - f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is " - f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support " - f"setting query key layer scaling via argument, so these two must match." - ) - - extra_kwargs = {} - te_version = packaging.version.Version(version("transformer-engine")) - if te_version >= packaging.version.Version("0.11.0"): - extra_kwargs["num_gqa_groups"] = self.config.num_query_groups - elif self.config.num_query_groups != self.config.num_attention_heads: - raise ValueError( - f"Transformer Engine v{te_version} does not support Grouped Query Attention, " - f"use a newer version of Transformer Engine. " - f"(num_query_groups ({self.config.num_query_groups}) != " - f"num_attention_heads ({self.config.num_attention_heads}))" - ) - - if te_version >= packaging.version.Version("0.10.0"): - extra_kwargs["attention_type"] = attention_type - # older version don't need attention_type - - if te_version > packaging.version.Version("0.12.0"): - self.te_forward_mask_type = True - - # Only Transformer-Engine version >= 1.0.0 supports context parallelism - if te_version >= packaging.version.Version("1.0.0"): - if getattr(TEDotProductAttention, "cp_stream") is None: - TEDotProductAttention.cp_stream = torch.cuda.Stream() - extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) - extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks( - check_initialized=False - ) - extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream - else: - assert ( - self.config.context_parallel_size == 1 - ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!" - - super().__init__( - num_attention_heads=self.config.num_attention_heads, - kv_channels=self.config.kv_channels, - attention_dropout=self.config.attention_dropout - if attention_dropout is None - else attention_dropout, - attn_mask_type=attn_mask_type.name, - sequence_parallel=self.config.sequence_parallel, - tp_size=self.config.tensor_model_parallel_size, - get_rng_state_tracker=get_cuda_rng_tracker, - tp_group=get_tensor_model_parallel_group(check_initialized=False), - layer_number=layer_number, - **extra_kwargs, - ) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - attention_mask: Tensor, - attn_mask_type: AttnMaskType, - ): - if self.te_forward_mask_type: - return super().forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type.name - ) - else: - return super().forward(query, key, value, attention_mask) diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py deleted file mode 100644 index 7eab478bd00a1c4507e52f80465c31f214a47fd1..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/dot_product_attention.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - - -import math - -import torch -from torch import Tensor - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import attention_mask_func -from megatron.core.utils import divide - - -class DotProductAttention(MegatronModule): - """ - Region where selective activation recomputation is applied. - This region is memory intensive but less compute intensive which - makes activation checkpointing more efficient for LLMs (20B+). - See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. - - We use the following notation: - h: hidden size - n: number of attention heads - p: number of tensor model parallel partitions - b: batch size - s: sequence length - """ - - def __init__( - self, - config: TransformerConfig, - layer_number: int, - attn_mask_type: AttnMaskType, - attention_type: str, - attention_dropout: float = None, - ): - super().__init__(config=config) - - self.config: TransformerConfig = config - - assert ( - self.config.context_parallel_size == 1 - ), "Context parallelism is only supported by TEDotProductAttention!" - - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.attention_type = attention_type # unused for now - - projection_size = self.config.kv_channels * self.config.num_attention_heads - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = divide(projection_size, world_size) - self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.config.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - input_in_fp16=self.config.fp16, - input_in_bf16=self.config.bf16, - attn_mask_type=self.attn_mask_type, - scaled_masked_softmax_fusion=self.config.masked_softmax_fusion, - mask_func=attention_mask_func, - softmax_in_fp32=self.config.attention_softmax_in_fp32, - scale=coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout( - self.config.attention_dropout if attention_dropout is None else attention_dropout - ) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - attention_mask: Tensor, - attn_mask_type: AttnMaskType = None, - ): - - # =================================== - # Raw attention scores. [b, n/p, s, s] - # =================================== - - # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn] - # This is a noop for normal attention where ng == np. When using group query attention this - # creates a view that has the keys and values virtually repeated along their dimension to - # match the number of queries. - - # attn_mask_type is not used. - if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: - key = key.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 - ) - value = value.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 - ) - - # [b, np, sq, sk] - output_size = ( - query.size(1), - query.size(2), - query.size(0), - key.size(0), - ) - - # [sq, b, np, hn] -> [sq, b * np, hn] - # This will be a simple view when doing normal attention, but in group query attention - # the key and value tensors are repeated to match the queries so you can't use simple strides - # to extract the queries. - query = query.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key = key.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor( - (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu", - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query.transpose(0, 1), # [b * np, sq, hn] - key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - - if not self.config.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value.size(1), - value.size(2), - query.size(0), - value.size(3), - ) - - # change view [sk, b * np, hn] - value = value.view(value.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context = torch.bmm(attention_probs, value.transpose(0, 1)) - - # change view [b, np, sq, hn] - context = context.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context = context.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,) - context = context.view(*new_context_shape) - - return context diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py deleted file mode 100644 index ab72f3536854413443eb56455fe96171aef5a72e..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/enums.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import enum - - -# can we get rid of this? -# it's being used in pipeline schedules -class ModelType(enum.Enum): - encoder_or_decoder = 1 - encoder_and_decoder = 2 - - -# class LayerType(enum.Enum): -# encoder = 1 -# decoder = 2 - - -class AttnType(enum.Enum): - self_attn = 1 - cross_attn = 2 - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - no_mask = 3 # only used for TE diff --git a/megatron/core/transformer/identity_op.py b/megatron/core/transformer/identity_op.py deleted file mode 100644 index 5d9388ffcc628bdd0f04dd5969b9e669153446a8..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/identity_op.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import torch - - -class IdentityOp(torch.nn.Module): - """ - This is a placeholder for IdentityOp(x) -> x - """ - - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, x, *args, **kwargs): - return x - - -class IdentityFuncOp(IdentityOp): - """ - This is a placeholder for IdentityFuncOp(...)(x) -> IdentityOp(x) -> x. - Such a func is handy for ops like `bias_dropout_fusion` which themselves - return a function at runtime based on passed arguments - """ - - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, *args, **kwargs): - return super().forward diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py deleted file mode 100644 index 8f5575b72467ed92752fc06ff5eeb015cb97db92..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/mlp.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Tuple, Union - -import torch -import torch.nn.functional as F - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing import ShardedTensor -from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory -from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint - - -@dataclass -class MLPSubmodules: - linear_fc1: Union[ModuleSpec, type] = None - linear_fc2: Union[ModuleSpec, type] = None - - -class MLP(MegatronModule): - """ - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - - - Returns an output and a bias to be added to the output. - If config.add_bias_linear is False, the bias returned is None. - - We use the following notation: - h: hidden size - p: number of tensor model parallel partitions - b: batch size - s: sequence length - """ - - def __init__( - self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False - ): - super().__init__(config=config) - - self.config: TransformerConfig = config - - # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf - ffn_hidden_size = self.config.ffn_hidden_size - if self.config.gated_linear_unit: - ffn_hidden_size *= 2 - - self.linear_fc1 = build_module( - submodules.linear_fc1, - self.config.hidden_size, - ffn_hidden_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=True, - is_expert=is_expert, - tp_comm_buffer_name='fc1', - ) - - if self.config.gated_linear_unit: - - def glu(x): - x = torch.chunk(x, 2, dim=-1) - return self.config.activation_func(x[0]) * x[1] - - self.activation_func = glu - else: - self.activation_func = self.config.activation_func - - self.linear_fc2 = build_module( - submodules.linear_fc2, - self.config.ffn_hidden_size, - self.config.hidden_size, - config=self.config, - init_method=self.config.output_layer_init_method, - bias=self.config.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True, - is_expert=is_expert, - tp_comm_buffer_name='fc2', - ) - - def forward(self, hidden_states): - - # [s, b, 4 * h/p] - intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) - - if self.config.bias_gelu_fusion: - assert self.config.add_bias_linear is True - assert self.activation_func == F.gelu - intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) - else: - if bias_parallel is not None: - intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - - # [s, b, h] - output, output_bias = self.linear_fc2(intermediate_parallel) - - return output, output_bias - - def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()): - sharded_key_prefix = prefix if sharded_key_prefix is None else sharded_key_prefix - sharded_state_dict = {} - for name, module in self._modules.items(): - if name == 'linear_fc1' and self.config.gated_linear_unit: - sub_sd = self._sharded_state_dict_for_glu( - name, module, prefix, sharded_key_prefix, sharded_offsets - ) - else: - sub_sd = module.sharded_state_dict( - prefix=f'{prefix}{name}.', - sharded_key_prefix=f'{sharded_key_prefix}{name}.', - sharded_offsets=sharded_offsets, - ) - sharded_state_dict.update(sub_sd) - return sharded_state_dict - - def _sharded_state_dict_for_glu( - self, - module_name: str, - module: torch.nn.Module, - prefix: str, - sharded_key_prefix: str, - sharded_offsets: Tuple[Tuple[int, int, int]], - ): - assert module_name == 'linear_fc1', module_name - sharded_state_dict = module.sharded_state_dict( - prefix=f'{prefix}{module_name}.', - sharded_key_prefix=f'{sharded_key_prefix}{module_name}.', - sharded_offsets=sharded_offsets, - ) - weight_key = f'{prefix}{module_name}.weight' - prev_sh_ten = sharded_state_dict[weight_key] - - # We must split the tensor into 2 parts, each sharded separately. - # This requires a ShardedTensorFactory which `chunk`s during saving - # and `cat`s during loading - tp_rank = parallel_state.get_tensor_model_parallel_rank() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - - tp_shard_axis = 0 - replica_id = prev_sh_ten.replica_id - prepend_axis_num = len(sharded_offsets) - - def sh_ten_build_fn(key: str, t: torch.Tensor): - offset_w = (tp_shard_axis + prepend_axis_num, tp_rank, tp_size * 2) - offset_v = (tp_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2) - with torch.no_grad(): - tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_shard_axis) - return [ - ShardedTensor.from_rank_offsets( - key, - tensor_w, - *sharded_offsets, - offset_w, - replica_id=replica_id, - prepend_axis_num=1, - ), - ShardedTensor.from_rank_offsets( - key, - tensor_v, - *sharded_offsets, - offset_v, - replica_id=replica_id, - prepend_axis_num=1, - ), - ] - - def sh_ten_merge_fn(sub_state_dict): - with torch.no_grad(): - return torch.cat(sub_state_dict) - - sharded_state_dict[weight_key] = ShardedTensorFactory( - prev_sh_ten.key, prev_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn - ) - return sharded_state_dict diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py deleted file mode 100644 index d20074aa07644dd8c556b69646eb3d0ffaf6012e..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/module.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Megatron Module.""" - -import torch -from torch.autograd import Variable -from torch.nn.parameter import Parameter - -from megatron.core import parallel_state -from megatron.core.transformer.transformer_config import TransformerConfig - -_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) -_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) -_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) - - -def param_is_not_shared(param): - return not hasattr(param, 'shared') or not param.shared - - -class MegatronModule(torch.nn.Module): - """Base Megatron module inhertied by all Models. - - Megatron specific extensions of torch Module with support - for pipelining - - Args: - config (TransformerConfig): Transformer config - """ - - # def __init__(self, config: TransformerConfig, share_word_embeddings=True): - def __init__(self, config: TransformerConfig): - super().__init__() - self.config = config - - def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False): - """Override state dict for saving checkpoints Use this function to override the - state dict for saving checkpoints. - - Args: - prefix (str, optional): _description_. Defaults to ''. - keep_vars (bool, optional): _description_. Defaults to False. - - Returns: - _type_: _description_ - """ - - return self.state_dict(prefix=prefix, keep_vars=keep_vars) - - def sharded_state_dict(self, prefix: str = ''): - """Override sharded state dict with Dist Checkpointing. - - Override sharded_state_dict when using distributed checkpointing. keep_vars must always be set to True so that optimizer states can be sharded. - - Args: - prefix (str, optional): _description_. Defaults to ''. - - Returns: - _type_: _description_ - """ - return self.state_dict(prefix=prefix, keep_vars=True) - - -def conversion_helper(val, conversion): - if not isinstance(val, (tuple, list)): - return conversion(val) - rtn = [conversion_helper(v, conversion) for v in val] - if isinstance(val, tuple): - rtn = tuple(rtn) - return rtn - - -def fp32_to_float16(val, float16_convertor): - def half_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, _FLOAT_TYPES): - val = float16_convertor(val) - return val - - return conversion_helper(val, half_conversion) - - -def float16_to_fp32(val): - def float_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): - val = val.float() - return val - - return conversion_helper(val, float_conversion) - - -class Float16Module(MegatronModule): - """Float 16 Module. - - Attributes: - config (TransformerConfig): Transformer config - fp16 (bool) : Specifies if the model runs in fp16 mode - bf16 (bool) : Specifies if the model runs in bf16 mode - - Args: - config (TransformerConfig): The transformer config used to initalize the model - """ - - def __init__(self, config: TransformerConfig, module: torch.nn.Module): - super(Float16Module, self).__init__(config) - self.config = config - self.fp16 = config.fp16 - self.bf16 = config.bf16 - - if self.fp16: - self.add_module('module', module.half()) - - def float16_convertor(val): - return val.half() - - elif self.bf16: - self.add_module('module', module.bfloat16()) - - def float16_convertor(val): - return val.bfloat16() - - else: - raise Exception('Either config.fp16 or config.bf16 should be True.') - - self.float16_convertor = float16_convertor - - def set_input_tensor(self, input_tensor): - return self.module.set_input_tensor(input_tensor) - - def forward(self, *inputs, **kwargs): - if parallel_state.is_pipeline_first_stage(): - inputs = fp32_to_float16(inputs, self.float16_convertor) - outputs = self.module(*inputs, **kwargs) - if parallel_state.is_pipeline_last_stage(): - outputs = float16_to_fp32(outputs) - return outputs - - def state_dict(self, destination=None, prefix='', keep_vars=False): - return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """Retrieve state_dict from the module being wrapped.""" - return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) - - def sharded_state_dict(self, prefix=''): - """Retrieve state_dict from the module being wrapped. - - When using distributed checkpointing, keep_vars must always be set to True. - """ - return self.module.sharded_state_dict(prefix=prefix) - - def load_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/core/transformer/spec_utils.py b/megatron/core/transformer/spec_utils.py deleted file mode 100644 index 473933e45297903a76f539db0e1c5990ff2a946d..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/spec_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import types -from dataclasses import dataclass, field -from typing import Tuple, Union - -import torch - - -@dataclass -class ModuleSpec: - """This is a Module Specification dataclass. - - Specification defines the location of the module (to import dynamically) - or the imported module itself. It also defines the params that need to be - passed to initialize the module. - - Args: - module (Union[Tuple, type]): A tuple describing the location of the - module class e.g. `(module.location, ModuleClass)` or the imported - module class itself e.g. `ModuleClass` (which is already imported - using `from module.location import ModuleClass`). - params (dict): A dictionary of params that need to be passed while init. - - """ - - module: Union[Tuple, type] - params: dict = field(default_factory=lambda: {}) - submodules: type = None - - -def import_module(module_path: Tuple[str]): - """Import a named object from a module in the context of this function. - - TODO: make this importer module more robust, at least make sure there - are no side effects of using this as is - """ - base_path, name = module_path - try: - module = __import__(base_path, globals(), locals(), [name]) - except ImportError as e: - print(f"couldn't import module due to {e}") - return None - return vars(module)[name] - - -def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs): - # If a module clas is already provided return it as is - if isinstance(spec_or_module, (type, types.FunctionType)): - return spec_or_module - - # If the module is provided instead of module path, then return it as is - if isinstance(spec_or_module.module, (type, types.FunctionType)): - return spec_or_module.module - - # Otherwise, return the dynamically imported module from the module path - return import_module(spec_or_module.module) - - -def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs): - # If the passed `spec_or_module` is - # a `Function`, then return it as it is - # NOTE: to support an already initialized module add the following condition - # `or isinstance(spec_or_module, torch.nn.Module)` to the following if check - if isinstance(spec_or_module, types.FunctionType): - return spec_or_module - - # If the passed `spec_or_module` is actually a spec (instance of - # `ModuleSpec`) and it specifies a `Function` using its `module` - # field, return the `Function` as it is - if isinstance(spec_or_module, ModuleSpec) and isinstance( - spec_or_module.module, types.FunctionType - ): - return spec_or_module.module - - # Check if a module class is provided as a spec or if the module path - # itself is a class - if isinstance(spec_or_module, type): - module = spec_or_module - elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type): - module = spec_or_module.module - else: - # Otherwise, dynamically import the module from the module path - module = import_module(spec_or_module.module) - - # If the imported module is actually a `Function` return it as it is - if isinstance(module, types.FunctionType): - return module - - # Finally return the initialized module with params from the spec as well - # as those passed as **kwargs from the code - - # Add the `submodules` argument to the module init call if it exists in the - # spec. - if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None: - kwargs["submodules"] = spec_or_module.submodules - - try: - return module( - *args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs - ) - except Exception as e: - # improve the error message since we hide the module name in the line above - import sys - - tb = sys.exc_info()[2] - raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback( - sys.exc_info()[2] - ) diff --git a/megatron/core/transformer/switch_mlp.py b/megatron/core/transformer/switch_mlp.py deleted file mode 100644 index 092c6c6402bad1bf1689b18f5c2aeed53930b17b..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/switch_mlp.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import torch - -from megatron.core import parallel_state, tensor_parallel -from megatron.core.parallel_state import ( - get_tensor_and_expert_parallel_group, - get_tensor_model_parallel_group, -) -from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.transformer_config import TransformerConfig - -from .mlp import MLP, MLPSubmodules - - -def sinkhorn(cost, tol=0.0001): - "Sinkhorn based MoE routing function" - cost = torch.exp(cost) - d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) - d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) - - eps = 0.00000001 - error = 1e9 - d1_old = d1 - while error > tol: - d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) - d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) - error = torch.mean(torch.abs(d1_old - d1)) - d1_old = d1 - return d1 * cost * d0.unsqueeze(1) - - -def get_router_linear_layer(config): - router = torch.nn.Linear(config.hidden_size, config.num_moe_experts, bias=False) - with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): - config.init_method(router.weight) - setattr(router.weight, 'sequence_parallel', config.sequence_parallel) - return router - - -class SwitchMLP(MegatronModule): - """ - Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts" - Curently supports Sinkhorn based expert routing. - """ - - def __init__(self, config: TransformerConfig, submodules: MLPSubmodules): - super().__init__(config=config) - - self.config: TransformerConfig = config - - self.router = get_router_linear_layer(self.config) - self.add_bias = config.add_bias_linear - self.sequence_parallel = config.sequence_parallel - self.route_algo = sinkhorn - self.router_activation = torch.sigmoid - self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size() - - assert self.config.num_moe_experts % self.expert_parallel_size == 0 - self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size - local_expert_indices_offset = ( - parallel_state.get_expert_model_parallel_rank() * self.num_local_experts - ) - self.local_expert_indices = [ - local_expert_indices_offset + i for i in range(self.num_local_experts) - ] - - self.local_experts = torch.nn.ModuleList() - for _ in range(self.num_local_experts): - expert = MLP(self.config, submodules, is_expert=True) - self.local_experts.append(expert) - - def gather_indices(self, local_indices): - """ Gather tensors and concatenate along the first dimension.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return local_indices - - dim_size = list(local_indices.size()) - dim_size[0] = dim_size[0] * world_size - - # TODO pre allocate memory - output = torch.empty( - dim_size, dtype=local_indices.dtype, device=torch.cuda.current_device() - ) - torch.distributed._all_gather_base(output, local_indices.contiguous(), group=group) - return output - - def forward(self, hidden_states): - hidden_shape = hidden_states.shape - route = self.router(hidden_states) - route = route.view(-1, self.config.num_moe_experts) - - if self.training: - with torch.no_grad(): - norm_route = self.route_algo( - route.detach().to(dtype=torch.float32) - ) # explicit fp32 conversion for stability - _, max_ind = torch.max(norm_route, dim=1) - route = self.router_activation(route) - max_prob = route[torch.arange(route.size(0)), max_ind] - else: - route = self.router_activation(route) - max_prob, max_ind = torch.max(route, dim=1) - - max_prob = torch.unsqueeze(max_prob, 1) - hidden_states = hidden_states.view(-1, hidden_shape[-1]) - - if self.sequence_parallel or (self.expert_parallel_size > 1): - global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe( - hidden_states - ) - global_indices = self.gather_indices(max_ind) - else: - global_hidden_states = hidden_states - global_indices = max_ind - - output_total = torch.zeros_like(global_hidden_states) - if self.add_bias: - output_bias_total = torch.zeros_like(global_hidden_states) - - for expert_num, expert in enumerate(self.local_experts): - local_expert_index = self.local_expert_indices[expert_num] - local_indices = (global_indices == local_expert_index).nonzero() - hidden = global_hidden_states[local_indices, :] - output, output_bias = expert(hidden) - - output_total[local_indices, :] = output - if self.add_bias: - output_bias = output_bias.expand_as(output) - output_bias_total[local_indices, :] = output_bias - - if self.sequence_parallel or (self.expert_parallel_size > 1): - output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( - output_total - ) - if self.add_bias: - output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe( - output_bias_total - ) - # bias is duplicated across tensor parallelism ranks; - # reduce scatter reduces bias across tensor parallel_ranks - output_bias_total = ( - output_bias_total / parallel_state.get_tensor_model_parallel_world_size() - ) - - output_total = output_total * max_prob - output_total = output_total.view(hidden_shape) - if self.add_bias: - output_bias_total = output_bias_total * max_prob - output_bias_total = output_bias_total.view(hidden_shape) - else: - output_bias_total = None - - return output_total, output_bias_total diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py deleted file mode 100644 index 74bf29c859dba98cd6c0ce6aac1444d6d53f81f6..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/transformer_block.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import re -from contextlib import nullcontext -from dataclasses import dataclass -from typing import List, Union - -import torch -from torch import Tensor - -from megatron.core import InferenceParams, parallel_state, tensor_parallel -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm -from megatron.core.transformer.custom_layers.transformer_engine import TENorm -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayer -from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor - - -def get_num_layers_to_build(config: TransformerConfig) -> int: - - num_layers_per_pipeline_rank = ( - config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() - ) - - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - # Interleaved pipeline parallelism: - # Number of layers in each model chunk is the number of layers in the stage, - # divided by the number of model chunks in a stage. - # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0] [2] [4] [6] - # Stage 1: [1] [3] [5] [7] - # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0, 1] [4, 5] - # Stage 1: [2, 3] [6, 7] - - vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() - - num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size - - num_layers_to_build = num_layers_per_virtual_rank - - else: - # Non-interleaved pipeline parallelism: - # Each stage gets a contiguous set of layers. - - num_layers_to_build = num_layers_per_pipeline_rank - - return num_layers_to_build - - -@dataclass -class TransformerBlockSubmodules: - layer_specs: List[ModuleSpec] = None - - -def _get_block_submodules( - config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec], -) -> TransformerBlockSubmodules: - - # Transformer block submodules. - if isinstance(spec, TransformerBlockSubmodules): - return spec - - # ModuleSpec here is generally assumed to be for a transformer layer. - elif isinstance(spec, ModuleSpec): - if issubclass(spec.module, TransformerBlock): - return spec.submodules - elif issubclass(spec.module, TransformerLayer): - num_layers = get_num_layers_to_build(config) - return TransformerBlockSubmodules(layer_specs=[spec] * num_layers) - else: - raise Exception(f"specialize for {spec.module.__name__}.") - else: - raise Exception(f"specialize for {type(spec).__name__}.") - - -class TransformerBlock(MegatronModule): - """Transformer class.""" - - def __init__( - self, - config: TransformerConfig, - spec: Union[TransformerBlockSubmodules, ModuleSpec], - post_layer_norm: bool = True, - pre_process: bool = True, - post_process: bool = True, - ): - super().__init__(config=config) - - self.submodules = _get_block_submodules(config, spec) - self.post_layer_norm = post_layer_norm - self.pre_process = pre_process - self.post_process = post_process - - # required for pipeline parallel schedules - self.input_tensor = None - - self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' - - self._build_layers() - self.num_layers_per_pipeline_rank = len(self.layers) - - def _build_layers(self): - # Transformer layers. - # @jcasper can we improve how we deal with layer_number? - # currently it's only used in CoreAttention? - # if self.apply_query_key_layer_scaling: - # coeff = self.layer_number - # self.norm_factor *= coeff - def build_layer(layer_spec, layer_number): - return build_module(layer_spec, config=self.config, layer_number=layer_number,) - - # offset is implicit in TransformerLayer - self.layers = torch.nn.ModuleList( - [ - build_layer(layer_spec, i + 1) - for i, layer_spec in enumerate(self.submodules.layer_specs) - ] - ) - - # # TODO: add back standalone_embedding_stage - # if self.num_layers == 0: - # # When a standalone embedding stage is used (e.g., - # # args.standalone_embedding_stage == True), virtual pipeline ranks - # # on pipeline rank 0 will have zero transformer layers assigned to - # # them. This results in the model's input and output tensors to be - # # the same, which will cause failure for certain output tensor - # # optimizations (e.g., pipeline output deallocation). To remedy - # # this, we assign a 'no-op' layer on these ranks, which will - # # disconnect the input tensor from the output tensor. - # self.num_layers = 1 - # self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) - # else: - # self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) - - if self.post_process and self.post_layer_norm: - # Final layer norm before output. - self.final_layernorm = TENorm( - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - def _get_layer(self, layer_number: int): - return self.layers[layer_number] - - def _checkpointed_forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - context: Tensor, - context_mask: Tensor, - rotary_pos_emb: Tensor, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - def custom_forward( - hidden_states, attention_mask, context, context_mask, rotary_pos_emb, - ): - for index in range(start, end): - layer = self._get_layer(index) - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - inference_params=None, - ) - return hidden_states, context - - return custom_forward - - if self.config.recompute_method == 'uniform': - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - l = 0 - while l < self.num_layers_per_pipeline_rank: - hidden_states, context = tensor_parallel.checkpoint( - custom(l, l + self.config.recompute_num_layers), - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - - l += self.config.recompute_num_layers - - elif self.config.recompute_method == 'block': - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - for l in range(self.num_layers_per_pipeline_rank): - if l < self.config.recompute_num_layers: - hidden_states, context = tensor_parallel.checkpoint( - custom(l, l + 1), - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - else: - hidden_states, context = custom(l, l + 1)( - hidden_states, attention_mask, context, context_mask, rotary_pos_emb, - ) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def set_input_tensor(self, input_tensor: Tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - context: Tensor = None, - context_mask: Tensor = None, - rotary_pos_emb: Tensor = None, - inference_params: InferenceParams = None, - ): - # hidden_states (float): [s, b, h] - # attention_mask (bool): [1, 1, s, s] - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True, - ) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - if self.config.fp8: - import transformer_engine # To keep out TE dependency when not training in fp8 - - if self.config.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif self.config.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - - fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - margin=self.config.fp8_margin, - interval=self.config.fp8_interval, - fp8_format=fp8_format, - amax_compute_algo=self.config.fp8_amax_compute_algo, - amax_history_len=self.config.fp8_amax_history_len, - override_linear_precision=(False, False, not self.config.fp8_wgrad), - ) - fp8_group = None - if parallel_state.model_parallel_is_initialized(): - fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True) - fp8_context = transformer_engine.pytorch.fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group - ) - else: - fp8_context = nullcontext() - - with rng_context and fp8_context: - # Forward pass. - if self.config.recompute_granularity == 'full': - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - ) - else: - for layer in self.layers: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - inference_params=inference_params, - ) - - # Final layer norm. - if self.post_process and self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states - - def sharded_state_dict(self, prefix: str = ''): - - sharded_state_dict = {} - - layer_prefix = f'{prefix}layers.' - for layer in self.layers: - sharded_state_dict.update(layer.sharded_state_dict(prefix=layer_prefix)) - - if self.post_process and self.post_layer_norm: - state_dict = self.state_dict(keep_vars=True) - - tensor = state_dict['final_layernorm.weight'] - layer_name = f'{prefix}final_layernorm.weight' - sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(tensor, layer_name) - - # RMSNorm doesn't have bias. - if 'final_layernorm.bias' in state_dict.keys(): - tensor = state_dict['final_layernorm.bias'] - layer_name = f'{prefix}final_layernorm.bias' - sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint( - tensor, layer_name - ) - - return sharded_state_dict diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py deleted file mode 100644 index adccd4409bd148cb394683ad3d8e8f69a1c43064..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/transformer_config.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import types -from dataclasses import dataclass -from typing import Callable - -import torch -import torch.nn.functional as F - -from ..model_parallel_config import ModelParallelConfig -from ..utils import init_method_normal, scaled_init_method_normal - - -@dataclass -class TransformerConfig(ModelParallelConfig): - """Configuration object for megatron-core transformers. - - Attributes: - - # model architecture - num_layers (int): Number of transformer layers in a transformer block. - hidden_size (int): Transformer hidden size. - ffn_hidden_size (int): Transformer Feed-Forward Network hidden size. - This is set to 4*hidden_size if not provided. Defaults to None.') - num_attention_heads (int): Number of transformer attention heads. - kv_channels (int): Projection weights dimension in multi-head attention. - This is set to hidden_size // num_attention_heads if not provided. - Defaults to None. - num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used. - - hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1. - attention_dropout (float): Post attention dropout probability. Defaults to 0.1. - fp32_residual_connection (bool): If true, move residual connections to fp32. - apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering. - Defaults to False. - layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5. - - layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values - around 0. This improves numerical stability. Defaults to False. - - add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two - in MLP layer). Default is True. - - gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False. - - activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu. - - num_moe_experts (int): Number of experts to use for Mixture of Experts. - When set, it replaces MLP with Switch MLP. Defaults to None (no MoE). - - # initialization - init_method (Callable): Method to initialize weights. Note that bias is always set to - zero. Should be a function that takes a single Tensor and - initializes it. Defaults to - megatron.core.utils.init_method_normal(init_method_std) which is - torch.nn.init.normal_ with mean=0.0 and std=init_method_Std. - - output_layer_init_method (Callable): Method to initialize weights of the output layer of - both attention and MLP blocks. Defaults to - megatron.core.utils.scaled_init_method_normal(init_method_std) - which is torch.nn.init.normal_ with mean=0.0 and - std=init_method_std / math.sqrt(2.0 * num_layers). - - init_method_std (float): Standard deviation of the zero mean normal for the default - initialization method, not used if init_method and - output_layer_init_method are provided. Defaults to 0.02. - - # mixed-precision - apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True. - attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32. - This should be true if apply_query_key_layer_scaling is true. - - # fusion - bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False. - masked_softmax_fusion (bool): If true, uses softmax fusion. - persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel. - This kernel only supports a fixed set of hidden sizes. - Defaults to False. - bias_dropout_fusion (bool): If true, uses bias dropout fusion. - - # activation recomputation - - recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory - intensive part of attention is checkpointed. These memory intensive activations - are also less compute intensive which makes activation checkpointing more efficient - for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer - Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint - the entire transformer layer. Must be 'selective' or 'full'. 'selective' always uses all layers. - Defaults to None. - - recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer - block and recompute the input activation of each divided chunk at the specified - granularity. block will recompute the input activations for only a set number of - transformer layers per pipeline stage. The rest of the layers in the pipeline stage - will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to - None. - - recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer - layers in each uniformly divided recompute unit. When recompute_method is block, - recompute_num_layers is the number of transformer layers to recompute within each - pipeline stage. Must be None for 'selective' activation checkpointing. Defaults to None. - - distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel - group. Defaults to None. - - # fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at - # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html - - fp8 (str): If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) 'e4m3' - uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and - e5m2 for all FP8 output activation gradient tensors. Defaults to None. - - fp8_margin (int): Margin for the scaling factor computation. - - fp8_interval (int): Controls how often the scaling factor is recomputed. - - fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation. - - fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation. - There are 2 predefined choices: `max` chooses the largest `amax` in the history - window, while `most_recent` always chooses the most recently seen value. - - fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision. - Defaults to True. - - # Miscellaneous - clone_scatter_output_in_embedding (bool): When set to true, clone the output of scatter_to_sequence_parallel_region - in embedding layer to facilitate garbage collection of input. - - # Experimental - normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily - used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`. - - - """ - - # model architecture - num_layers: int = 0 - hidden_size: int = 0 - num_attention_heads: int = 0 - num_query_groups: int = None - - ffn_hidden_size: int = None - kv_channels: int = None - hidden_dropout: float = 0.1 - attention_dropout: float = 0.1 - fp32_residual_connection: bool = False - # @jcasper should we keep this option? - apply_residual_connection_post_layernorm: bool = False - layernorm_epsilon: float = 1e-5 - layernorm_zero_centered_gamma: bool = False - add_bias_linear: bool = True - gated_linear_unit: bool = False - activation_func: Callable = F.gelu - num_moe_experts: int = None - - # initialization - init_method: Callable = None - output_layer_init_method: Callable = None - init_method_std: float = 0.02 - - # mixed-precision - apply_query_key_layer_scaling: bool = False - attention_softmax_in_fp32: bool = True - - # communication - - # fusion - bias_gelu_fusion: bool = False # TODO: this should be bias_activation_fusion ? - masked_softmax_fusion: bool = False - persist_layer_norm: bool = False - bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion? - - # activation recomputation - recompute_granularity: str = None - recompute_method: str = None - recompute_num_layers: int = None - distribute_saved_activations: bool = None - - # fp8 related - fp8: str = None - fp8_margin: int = 0 - fp8_interval: int = 1 - fp8_amax_history_len: int = 1 - fp8_amax_compute_algo: str = "most_recent" - fp8_wgrad: bool = True - - # miscellaneous - clone_scatter_output_in_embedding: bool = True - - # experimental section (TODO: move to apt. section above once stable) - normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm" - - def __post_init__(self): - """ Python dataclass method that is used to modify attributes after initialization. - See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. - """ - super().__post_init__() - if self.fp16 and self.bf16: - raise ValueError( - f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.' - ) - - if self.num_attention_heads % self.tensor_model_parallel_size != 0: - raise ValueError( - f"num_attention_heads ({self.num_attention_heads}) must be a multiple of " - f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." - ) - - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - - if self.kv_channels is None: - self.kv_channels = self.hidden_size // self.num_attention_heads - - if self.num_query_groups is None: - self.num_query_groups = self.num_attention_heads - - if self.num_query_groups % self.tensor_model_parallel_size != 0: - raise ValueError( - f"num_query_groups ({self.num_query_groups}) must be a multiple of " - f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." - ) - - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - - if self.expert_model_parallel_size > 1 and self.num_moe_experts is None: - raise ValueError(f'num_moe_experts must be non None to use expert-parallel.') - - if self.recompute_granularity is not None: - if not self.recompute_granularity in ['full', 'selective']: - raise ValueError( - f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".' - ) - - if self.recompute_method is not None: - if not self.recompute_method in ['block', 'uniform']: - raise ValueError( - f'recompute_method: {self.recompute_method} must be "block" or "uniform".' - ) - elif self.recompute_granularity != 'selective': - raise ValueError( - f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"' - ) - - if self.recompute_granularity != 'selective' and self.recompute_num_layers is None: - raise ValueError( - f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between ' - f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}' - ) - elif ( - self.recompute_granularity == 'selective' and self.recompute_num_layers is not None - ): - raise ValueError( - f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.' - ) - - if self.distribute_saved_activations and self.sequence_parallel: - raise ValueError( - f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}' - ) - - if self.virtual_pipeline_model_parallel_size is not None: - if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0: - raise ValueError( - f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}' - ) - - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - - if self.bias_gelu_fusion: - if not self.add_bias_linear: - raise ValueError( - "When bias_gelu_fusion is True, add_bias_linear must also be True." - ) - - if self.activation_func != F.gelu: - raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.') - - if self.init_method is None: - self.init_method = init_method_normal(self.init_method_std) - - if self.output_layer_init_method is None: - self.output_layer_init_method = scaled_init_method_normal( - self.init_method_std, self.num_layers - ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py deleted file mode 100644 index b9951d4347005f262f255d1d49c531114e9dc554..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/transformer_layer.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Union - -import torch - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_viewless_tensor - - -@dataclass -class TransformerLayerSubmodules: - input_layernorm: Union[ModuleSpec, type] = IdentityOp - self_attention: Union[ModuleSpec, type] = IdentityOp - self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp - - pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - cross_attention: Union[ModuleSpec, type] = IdentityOp - cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp - - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp - mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp - - -class TransformerLayer(MegatronModule): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: TransformerLayerSubmodules, - layer_number: int = 1, - hidden_dropout: float = None, - ): - super().__init__(config=config) - - self.layer_number = layer_number + self._get_layer_offset() - self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout - - ## [Module 1: Input Layernorm] Optional Layernorm on the input data - # TODO: add pytorch only layernorm - self.input_layernorm = build_module( - submodules.input_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - ## [Module 2: SelfAttention] - self.self_attention = build_module( - submodules.self_attention, config=self.config, layer_number=layer_number, - ) - - ## [Module 3: BiasDropoutFusion] - self.self_attn_bda = build_module(submodules.self_attn_bda) - - ## [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - ## [Module 5: CrossAttention] - self.cross_attention = build_module( - submodules.cross_attention, config=self.config, layer_number=layer_number, - ) - - ## [Module 6: BiasDropoutFusion] - self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,) - - ## [Module 7: Pre MLP] Optional Layernorm before MLP - self.pre_mlp_layernorm = build_module( - submodules.pre_mlp_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - ## [Module 8: MLP block] - # TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1, - # where MLP and SwitchMLP both appear alternately? - self.mlp = build_module(submodules.mlp, config=self.config) - - ## [Module 9: BiasDropoutFusion] - self.mlp_bda = build_module(submodules.mlp_bda) - - # @jcasper how should we handle nvfuser? - # Set bias+dropout+add fusion grad_enable execution handler. - # TORCH_MAJOR = int(torch.__version__.split('.')[0]) - # TORCH_MINOR = int(torch.__version__.split('.')[1]) - # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad - self.bias_dropout_add_exec_handler = torch.enable_grad - - def _get_layer_offset(self): - - pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() - - num_layers_per_pipeline_rank = ( - self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() - ) - - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() - vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() - - total_num_layers = self.config.num_layers - num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size - total_virtual_chunks = total_num_layers // vp_size - offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) - - else: - # Each stage gets a contiguous set of layers. - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - offset = pipeline_rank * num_layers_per_pipeline_rank - else: - offset = 0 - - return offset - - def forward( - self, - hidden_states, - attention_mask, - context=None, - context_mask=None, - rotary_pos_emb=None, - inference_params=None, - ): - # hidden_states: [s, b, h] - - # Residual connection. - residual = hidden_states - - # Optional Input Layer norm - input_layernorm_output = self.input_layernorm(hidden_states) - - # Self attention. - attention_output_with_bias = self.self_attention( - input_layernorm_output, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - ) - - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - with self.bias_dropout_add_exec_handler(): - hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( - attention_output_with_bias, residual, self.hidden_dropout - ) - - # Residual connection. - residual = hidden_states - - # Optional Layer norm after self-attention - pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) - - # Cross attention. - attention_output_with_bias = self.cross_attention( - pre_cross_attn_layernorm_output, - attention_mask=context_mask, - key_value_states=context, - inference_params=inference_params, - ) - - if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: - context = attention_output_with_bias["context"] - - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - with self.bias_dropout_add_exec_handler(): - hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( - attention_output_with_bias, residual, self.hidden_dropout - ) - - # Residual connection. - residual = hidden_states - - # Optional Layer norm post the cross-attention. - pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) - - # MLP. - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - with self.bias_dropout_add_exec_handler(): - hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( - mlp_output_with_bias, residual, self.hidden_dropout - ) - - # Jit compiled function creates 'view' tensor. This tensor - # potentially gets saved in the MPU checkpoint function context, - # which rejects view tensors. While making a viewless tensor here - # won't result in memory savings (like the data loader, or - # p2p_communication), it serves to document the origin of this - # 'view' tensor. - output = make_viewless_tensor( - inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True - ) - - return output, context - - def sharded_state_dict(self, prefix=''): - offset = self._get_layer_offset() - num_layers = self.config.num_layers - - global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1 - state_dict_prefix = ( - f'{prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock - ) - sharded_pp_offset = [ - (0, global_layer_offset, num_layers) - ] # PP sharding offset for ShardedTensors - - attn_state_dict = self.self_attention.sharded_state_dict( - prefix=f'{state_dict_prefix}self_attention.', - sharded_key_prefix=f'{prefix}self_attention.', - sharded_offsets=sharded_pp_offset, - ) - - mlp_state_dict = self.mlp.sharded_state_dict( - prefix=f'{state_dict_prefix}mlp.', - sharded_key_prefix=f'{prefix}mlp.', - sharded_offsets=sharded_pp_offset, - ) - - sharded_state_dict = {**mlp_state_dict, **attn_state_dict} - - return sharded_state_dict diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py deleted file mode 100644 index d7d002734f7f8c9fc93767604c40cf76a4fb82d4..0000000000000000000000000000000000000000 --- a/megatron/core/transformer/utils.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Utilities for transformer layers.""" -from operator import itemgetter -from typing import Any, Dict, Iterable, Optional, Tuple, Union - -import torch - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedObject, StateDict -from megatron.core.utils import ( - make_sharded_tensor_for_checkpoint, - make_tp_sharded_tensor_for_checkpoint, -) - - -def get_linear_layer(rows, columns, init_method, perform_initialization=True): - """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns) - if perform_initialization: # Take from modelparallel config - init_method(layer.weight) - with torch.no_grad(): - layer.bias.zero_() - return layer - - -def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) - - -def openai_gelu(x): - return gelu_impl(x) - - -# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter -@torch.jit.script -def erf_gelu(x): - return ( - x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) - ) - - -def make_sharded_tensors_for_checkpoint( - state_dict: StateDict, - state_dict_prefix: str, - sharded_key_prefix: Optional[str] = None, - tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None, - sharded_offsets: Iterable[Tuple[int, int, int]] = (), - extra_state_suffix: str = '_extra_state', -): - """Wraps tensors from transformer layers with ShardedTensor or ShardedObject. - - For a given `state_dict`, wraps: - - all _extra_states with ShardedObject - - all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor - - other values with DP sharded ShardedTensor - - Args: - state_dict (StateDict): state_dict to convert - state_dict_prefix (str): prefix appended to keys in final state dict - sharded_key_prefix (str, optional): prefix appended to ShardedTensor keys - tensor_parallel_layers_axis_map (Dict[str, int], optional): dict mapping layer - names to the axis for TP sharding - sharded_offsets (Iterable[Tuple[int, int, int]], optional): sharding already - applied (e.g. PP related), passed along to ShardedTensor - extra_state_suffix (str, default = '_extra_state'): layers with this - suffix will be wrapped with ShardedObject instead of ShardedTensor. - - """ - if sharded_key_prefix is None: - sharded_key_prefix = state_dict_prefix - - if tensor_parallel_layers_axis_map is None: - tensor_parallel_layers_axis_map = {} - - sharded_state_dict = {} - for layer_name in state_dict.keys(): - tensor = state_dict[layer_name] - layer_key = f'{state_dict_prefix}{layer_name}' - sharded_key = f'{sharded_key_prefix}{layer_name}' - - if layer_name.endswith(extra_state_suffix): - sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint( - tensor, sharded_key, sharded_offsets - ) - - elif layer_name in tensor_parallel_layers_axis_map: - tp_axis = tensor_parallel_layers_axis_map[layer_name] - sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint( - tensor, sharded_key, tp_axis, prepend_offsets=sharded_offsets, - ) - - else: - sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint( - tensor, sharded_key, prepend_offsets=sharded_offsets, - ) - - return sharded_state_dict - - -def make_sharded_object_for_checkpoint( - obj: Any, - key: str, - sharded_offsets: Iterable[Tuple[int, int, int]] = (), - replica_id: Union[None, int, Tuple[int, ...]] = None, - **kwargs, -): - """ Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group). - - Arguments: - obj (object): any object to be sharded - key (str): unique identifier of the object - sharded_offsets (Iterable[Tuple[int, int, int]]): offsets normally - prepended to ShardedTensors, will be used as global offsets for - ShardedObject - replica_id (Union[None, int, Tuple[int, ...]]): replica id - """ - if replica_id is None: - replica_id = ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_data_parallel_rank(), - ) - - return ShardedObject(key, obj, *_get_extra_state_offsets(sharded_offsets), replica_id, **kwargs) - - -def _get_extra_state_offsets( - sharded_offsets: Iterable[Tuple[int, int, int]] -) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: - """ Turns ShardedTensor offsets into offsets suitable for ShardedObject. """ - if sharded_offsets: - sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis - axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets) - assert list(axis) == list( - range(len(axis)) - ), f'Expected contiguous axis for offsets: {sharded_offsets}' - else: - extra_state_shape = (1,) - extra_state_offset = (0,) - return extra_state_shape, extra_state_offset diff --git a/megatron/core/utils.py b/megatron/core/utils.py deleted file mode 100644 index d4e042b2d403022b2728ac4b76a7abadc3ec8117..0000000000000000000000000000000000000000 --- a/megatron/core/utils.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Utility functions used throughout Megatron core""" -import math -import operator -from functools import reduce - -import torch - -from megatron.core import parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedTensor - - -def ensure_divisibility(numerator, denominator): - """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) - - -def divide(numerator, denominator): - """Ensure that numerator is divisible by the denominator and return - the division value.""" - ensure_divisibility(numerator, denominator) - return numerator // denominator - - -def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False): - """Get an attribute from a wrapped model. - If return_model_obj is true, return the object that has the 'attr' attribute; - otherwise, return the attribute directly.""" - if isinstance(model, list): - raise RuntimeError("_get_attr_wrapped_model given a list of models") - - if allow_none: - - def condition(model, attr): - return not hasattr(model, attr) - - else: - - def condition(model, attr): - return getattr(model, attr, None) is None - - while condition(model, attr): - if not hasattr(model, "module"): - raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}") - - model = model.module - - if return_model_obj: - return model - return getattr(model, attr) - - -def get_model_type(model): - return get_attr_wrapped_model(model, 'model_type') - - -def get_model_config(model): - return get_attr_wrapped_model(model, 'config', allow_none=False) - - -class GlobalMemoryBuffer: - """Global buffer to avoid dynamic memory allocations. - Caller should ensure that buffers of the same name - are not used concurrently.""" - - def __init__(self): - self.buffer = {} - - def get_tensor(self, tensor_shape, dtype, name): - required_len = reduce(operator.mul, tensor_shape, 1) - if ( - self.buffer.get((name, dtype), None) is None - or self.buffer[(name, dtype)].numel() < required_len - ): - self.buffer[(name, dtype)] = torch.empty( - required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False - ) - - return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) - - -def _kernel_make_viewless_tensor(inp, requires_grad): - '''Make a viewless tensor. - - View tensors have the undesirable side-affect of retaining a reference - to the originally-viewed tensor, even after manually setting the '.data' - field. This method creates a new tensor that links to the old tensor's - data, without linking the viewed tensor, referenced via the '._base' - field. - ''' - out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,) - out.data = inp.data - return out - - -class MakeViewlessTensor(torch.autograd.Function): - ''' - Autograd function to make a viewless tensor. - - This function should be used in cases where the computation graph needs - to be propagated, but we only want a viewless tensor (e.g., - ParallelTransformer's hidden_states). Call this function by passing - 'keep_graph = True' to 'make_viewless_tensor()'. - ''' - - @staticmethod - def forward(ctx, inp, requires_grad): - return _kernel_make_viewless_tensor(inp, requires_grad) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -def make_viewless_tensor(inp, requires_grad, keep_graph): - ''' - Entry-point for creating viewless tensors. - - This method should be used, rather than calling 'MakeViewlessTensor' - or '_kernel_make_viewless_tensor' directly. This method acts as a - switch for determining if an autograd function or a regular method - should be used to create the tensor. - ''' - - # return tensor as-is, if not a 'view' - if inp._base is None: - return inp - - # create viewless tensor - if keep_graph: - return MakeViewlessTensor.apply(inp, requires_grad) - else: - return _kernel_make_viewless_tensor(inp, requires_grad) - - -def assert_viewless_tensor(tensor, extra_msg=None): - '''Assert that a tensor is not a view (i.e., its '._base' field is - not set).''' - if isinstance(tensor, list): - [assert_viewless_tensor(t) for t in tensor] - return tensor - if not isinstance(tensor, torch.Tensor): - return tensor - assert tensor._base is None, ( - "Ensure tensor._base is None before setting tensor.data or storing " - "tensor to memory buffer. Otherwise, a memory leak will occur (and " - "likely accumulate over iterations). %s" - ) % extra_msg - return tensor - - -def safely_set_viewless_tensor_data(tensor, new_data_tensor): - '''Safely set tensor's '.data' field. - - Check first that the tensor is viewless (i.e., '._base' not set). If not, - raise an exception. - ''' - assert_viewless_tensor( - tensor, - extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s." - % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape), - ) - tensor.data = new_data_tensor - - -def init_method_normal(sigma): - """Init method based on N(0, sigma).""" - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma, num_layers): - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -def make_tp_sharded_tensor_for_checkpoint( - tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs -): - """ Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group. - - Optionally, can provide offsets which prepend new dimensions to the tensor. - """ - - prepend_axis_num = len(prepend_offsets) - - if replica_id is None: - replica_id = (0, 0, parallel_state.get_data_parallel_rank()) - - return ShardedTensor.from_rank_offsets( - key, - tensor, - *prepend_offsets, - ( - tp_axis + prepend_axis_num, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_tensor_model_parallel_world_size(), - ), - replica_id=replica_id, - prepend_axis_num=prepend_axis_num, - **kwargs, - ) - - -def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs): - """ Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group). - - Optionally, can provide offsets which prepend new dimensions to the tensor. - """ - - prepend_axis_num = len(prepend_offsets) - - if replica_id is None: - replica_id = ( - 0, - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_data_parallel_rank(), - ) - - return ShardedTensor.from_rank_offsets( - key, - tensor, - *prepend_offsets, - replica_id=replica_id, - prepend_axis_num=prepend_axis_num, - **kwargs, - ) diff --git a/megatron/data/__init__.py b/megatron/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/data/autoaugment.py b/megatron/data/autoaugment.py deleted file mode 100644 index 7f988c5f0411707a8988e63898a49fabb932fbb5..0000000000000000000000000000000000000000 --- a/megatron/data/autoaugment.py +++ /dev/null @@ -1,320 +0,0 @@ -"""AutoAugment data augmentation policy for ImageNet. - --- Begin license text. - -MIT License - -Copyright (c) 2018 Philip Popien - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --- End license text. - -Code adapted from https://github.com/DeepVoltaire/AutoAugment. - -This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in -Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation -policies. - -Reference: -[1] https://arxiv.org/abs/1805.09501 -""" - -import random - -import numpy as np -from PIL import Image -from PIL import ImageEnhance -from PIL import ImageOps - -_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable. - - -class ImageNetPolicy: - """Definition of an ImageNetPolicy. - - Implements a fixed AutoAugment data augmentation policy targeted at - ImageNet training by randomly applying at runtime one of the 25 pre-defined - data augmentation sub-policies provided in Reference [1]. - - Usage example as a Pytorch Transform: - >>> transform=transforms.Compose([transforms.Resize(256), - >>> ImageNetPolicy(), - >>> transforms.ToTensor()]) - """ - - def __init__(self, fillcolor=(128, 128, 128)): - """Initialize an ImageNetPolicy. - - Args: - fillcolor (tuple): RGB color components of the color to be used for - filling when needed (default: (128, 128, 128), which - corresponds to gray). - """ - # Instantiate a list of sub-policies. - # Each entry of the list is a SubPolicy which consists of - # two augmentation operations, - # each of those parametrized as operation, probability, magnitude. - # Those two operations are applied sequentially on the image upon call. - self.policies = [ - SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor), - SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), - SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), - SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor), - SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), - SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor), - SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor), - SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor), - SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor), - SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor), - SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor), - SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor), - SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor), - SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), - SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), - SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor), - SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor), - SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor), - SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor), - SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor), - SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor), - SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor), - SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor), - SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor), - SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor), - ] - - def __call__(self, img): - """Define call method for ImageNetPolicy class.""" - policy_idx = random.randint(0, len(self.policies) - 1) - return self.policies[policy_idx](img) - - def __repr__(self): - """Define repr method for ImageNetPolicy class.""" - return "ImageNetPolicy" - - -class SubPolicy: - """Definition of a SubPolicy. - - A SubPolicy consists of two augmentation operations, - each of those parametrized as operation, probability, magnitude. - The two operations are applied sequentially on the image upon call. - """ - - def __init__( - self, - operation1, - probability1, - magnitude_idx1, - operation2, - probability2, - magnitude_idx2, - fillcolor, - ): - """Initialize a SubPolicy. - - Args: - operation1 (str): Key specifying the first augmentation operation. - There are fourteen key values altogether (see supported_ops below - listing supported operations). probability1 (float): Probability - within [0., 1.] of applying the first augmentation operation. - magnitude_idx1 (int): Integer specifiying the strength of the first - operation as an index further used to derive the magnitude from a - range of possible values. - operation2 (str): Key specifying the second augmentation operation. - probability2 (float): Probability within [0., 1.] of applying the - second augmentation operation. - magnitude_idx2 (int): Integer specifiying the strength of the - second operation as an index further used to derive the magnitude - from a range of possible values. - fillcolor (tuple): RGB color components of the color to be used for - filling. - Returns: - """ - # List of supported operations for operation1 and operation2. - supported_ops = [ - "shearX", - "shearY", - "translateX", - "translateY", - "rotate", - "color", - "posterize", - "solarize", - "contrast", - "sharpness", - "brightness", - "autocontrast", - "equalize", - "invert", - ] - assert (operation1 in supported_ops) and ( - operation2 in supported_ops - ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation." - - assert ( - 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0 - ), "SubPolicy: prob1 and prob2 should be within [0., 1.]." - - assert ( - isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10 - ), "SubPolicy: idx1 should be specified as an integer within [0, 10]." - - assert ( - isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10 - ), "SubPolicy: idx2 should be specified as an integer within [0, 10]." - - # Define a dictionary where each key refers to a specific type of - # augmentation and the corresponding value is a range of ten possible - # magnitude values for that augmentation. - num_levels = _MAX_LEVEL + 1 - ranges = { - "shearX": np.linspace(0, 0.3, num_levels), - "shearY": np.linspace(0, 0.3, num_levels), - "translateX": np.linspace(0, 150 / 331, num_levels), - "translateY": np.linspace(0, 150 / 331, num_levels), - "rotate": np.linspace(0, 30, num_levels), - "color": np.linspace(0.0, 0.9, num_levels), - "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype( - np.int32 - ), - "solarize": np.linspace(256, 0, num_levels), # range [0, 256] - "contrast": np.linspace(0.0, 0.9, num_levels), - "sharpness": np.linspace(0.0, 0.9, num_levels), - "brightness": np.linspace(0.0, 0.9, num_levels), - "autocontrast": [0] - * num_levels, # This augmentation doesn't use magnitude parameter. - "equalize": [0] - * num_levels, # This augmentation doesn't use magnitude parameter. - "invert": [0] - * num_levels, # This augmentation doesn't use magnitude parameter. - } - - def rotate_with_fill(img, magnitude): - """Define rotation transformation with fill. - - The input image is first rotated, then it is blended together with - a gray mask of the same size. Note that fillcolor as defined - elsewhere in this module doesn't apply here. - - Args: - magnitude (float): rotation angle in degrees. - Returns: - rotated_filled (PIL Image): rotated image with gray filling for - disoccluded areas unveiled by the rotation. - """ - rotated = img.convert("RGBA").rotate(magnitude) - rotated_filled = Image.composite( - rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated - ) - return rotated_filled.convert(img.mode) - - # Define a dictionary of augmentation functions where each key refers - # to a specific type of augmentation and the corresponding value defines - # the augmentation itself using a lambda function. - # pylint: disable=unnecessary-lambda - func_dict = { - "shearX": lambda img, magnitude: img.transform( - img.size, - Image.AFFINE, - (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), - Image.BICUBIC, - fillcolor=fillcolor, - ), - "shearY": lambda img, magnitude: img.transform( - img.size, - Image.AFFINE, - (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), - Image.BICUBIC, - fillcolor=fillcolor, - ), - "translateX": lambda img, magnitude: img.transform( - img.size, - Image.AFFINE, - ( - 1, - 0, - magnitude * img.size[0] * random.choice([-1, 1]), - 0, - 1, - 0, - ), - fillcolor=fillcolor, - ), - "translateY": lambda img, magnitude: img.transform( - img.size, - Image.AFFINE, - ( - 1, - 0, - 0, - 0, - 1, - magnitude * img.size[1] * random.choice([-1, 1]), - ), - fillcolor=fillcolor, - ), - "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), - "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( - 1 + magnitude * random.choice([-1, 1]) - ), - "posterize": lambda img, magnitude: ImageOps.posterize( - img, magnitude - ), - "solarize": lambda img, magnitude: ImageOps.solarize( - img, magnitude - ), - "contrast": lambda img, magnitude: ImageEnhance.Contrast( - img - ).enhance(1 + magnitude * random.choice([-1, 1])), - "sharpness": lambda img, magnitude: ImageEnhance.Sharpness( - img - ).enhance(1 + magnitude * random.choice([-1, 1])), - "brightness": lambda img, magnitude: ImageEnhance.Brightness( - img - ).enhance(1 + magnitude * random.choice([-1, 1])), - "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), - "equalize": lambda img, magnitude: ImageOps.equalize(img), - "invert": lambda img, magnitude: ImageOps.invert(img), - } - - # Store probability, function and magnitude of the first augmentation - # for the sub-policy. - self.probability1 = probability1 - self.operation1 = func_dict[operation1] - self.magnitude1 = ranges[operation1][magnitude_idx1] - - # Store probability, function and magnitude of the second augmentation - # for the sub-policy. - self.probability2 = probability2 - self.operation2 = func_dict[operation2] - self.magnitude2 = ranges[operation2][magnitude_idx2] - - def __call__(self, img): - """Define call method for SubPolicy class.""" - # Randomly apply operation 1. - if random.random() < self.probability1: - img = self.operation1(img, self.magnitude1) - - # Randomly apply operation 2. - if random.random() < self.probability2: - img = self.operation2(img, self.magnitude2) - - return img diff --git a/megatron/data/bert_dataset.py b/megatron/data/bert_dataset.py deleted file mode 100644 index 036e6bccc9166214ed82f836dab4dbff974686ed..0000000000000000000000000000000000000000 --- a/megatron/data/bert_dataset.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""BERT Style dataset.""" - -import numpy as np -import torch - -from megatron import ( - get_args, - get_tokenizer, - mpu, - print_rank_0 -) -from megatron.data.dataset_utils import ( - get_samples_mapping, - get_a_and_b_segments, - truncate_segments, - create_tokens_and_tokentypes, - create_masked_lm_predictions -) - -class BertDataset(torch.utils.data.Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, - num_epochs, max_num_samples, masked_lm_prob, - max_seq_length, short_seq_prob, seed, binary_head): - - # Params to store. - self.name = name - self.seed = seed - self.masked_lm_prob = masked_lm_prob - self.max_seq_length = max_seq_length - self.binary_head = binary_head - - # Dataset. - self.indexed_dataset = indexed_dataset - - # Build the samples mapping. - self.samples_mapping = get_samples_mapping(self.indexed_dataset, - data_prefix, - num_epochs, - max_num_samples, - self.max_seq_length - 3, # account for added tokens - short_seq_prob, - self.seed, - self.name, - self.binary_head) - - # Vocab stuff. - tokenizer = get_tokenizer() - self.vocab_id_list = list(tokenizer.inv_vocab.keys()) - self.vocab_id_to_token_dict = tokenizer.inv_vocab - self.cls_id = tokenizer.cls - self.sep_id = tokenizer.sep - self.mask_id = tokenizer.mask - self.pad_id = tokenizer.pad - - def __len__(self): - return self.samples_mapping.shape[0] - - def __getitem__(self, idx): - start_idx, end_idx, seq_length = self.samples_mapping[idx] - sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] - # Note that this rng state should be numpy and not python since - # python randint is inclusive whereas the numpy one is exclusive. - # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1 - np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) - return build_training_sample(sample, seq_length, - self.max_seq_length, # needed for padding - self.vocab_id_list, - self.vocab_id_to_token_dict, - self.cls_id, self.sep_id, - self.mask_id, self.pad_id, - self.masked_lm_prob, np_rng, - self.binary_head) - - - - -def build_training_sample(sample, - target_seq_length, max_seq_length, - vocab_id_list, vocab_id_to_token_dict, - cls_id, sep_id, mask_id, pad_id, - masked_lm_prob, np_rng, binary_head): - """Biuld training sample. - - Arguments: - sample: A list of sentences in which each sentence is a list token ids. - target_seq_length: Desired sequence length. - max_seq_length: Maximum length of the sequence. All values are padded to - this length. - vocab_id_list: List of vocabulary ids. Used to pick a random id. - vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. - cls_id: Start of example id. - sep_id: Separator id. - mask_id: Mask token id. - pad_id: Padding token id. - masked_lm_prob: Probability to mask tokens. - np_rng: Random number genenrator. Note that this rng state should be - numpy and not python since python randint is inclusive for - the opper bound whereas the numpy one is exclusive. - """ - - if binary_head: - # We assume that we have at least two sentences in the sample - assert len(sample) > 1 - assert target_seq_length <= max_seq_length - - # Divide sample into two segments (A and B). - if binary_head: - tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, - np_rng) - else: - tokens_a = [] - for j in range(len(sample)): - tokens_a.extend(sample[j]) - tokens_b = [] - is_next_random = False - - # Truncate to `target_sequence_length`. - max_num_tokens = target_seq_length - truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), - len(tokens_b), max_num_tokens, np_rng) - - # Build tokens and toketypes. - tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, - cls_id, sep_id) - - # Masking. - max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions( - tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, - cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) - - # Padding. - tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ - = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length) - - train_sample = { - 'text': tokens_np, - 'types': tokentypes_np, - 'labels': labels_np, - 'is_random': int(is_next_random), - 'loss_mask': loss_mask_np, - 'padding_mask': padding_mask_np, - 'truncated': int(truncated)} - return train_sample - - -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): - """Pad sequences and convert them to numpy.""" - - # Some checks. - num_tokens = len(tokens) - padding_length = max_seq_length - num_tokens - assert padding_length >= 0, \ - f"num_tokens ({num_tokens}) is greater than " \ - "max_seq_length ({max_seq_length})." - assert len(tokentypes) == num_tokens - assert len(masked_positions) == len(masked_labels) - - # Tokens and token types. - filler = [pad_id] * padding_length - tokens_np = np.array(tokens + filler, dtype=np.int64) - tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) - - # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) - - # Lables and loss mask. - labels = [-1] * max_seq_length - loss_mask = [0] * max_seq_length - for i in range(len(masked_positions)): - assert masked_positions[i] < num_tokens - labels[masked_positions[i]] = masked_labels[i] - loss_mask[masked_positions[i]] = 1 - labels_np = np.array(labels, dtype=np.int64) - loss_mask_np = np.array(loss_mask, dtype=np.int64) - - return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py deleted file mode 100644 index f137528adaef2af2d87567db9c73ef709b02b0f0..0000000000000000000000000000000000000000 --- a/megatron/data/biencoder_dataset_utils.py +++ /dev/null @@ -1,209 +0,0 @@ -import os -import time - -import numpy as np -import torch - -from megatron import get_args, get_tokenizer, print_rank_0 -from megatron.core import mpu, tensor_parallel -from megatron.data.dataset_utils import create_masked_lm_predictions, \ - pad_and_convert_to_numpy -from megatron.data.data_samplers import MegatronPretrainingSampler - -def make_attention_mask(source_block, target_block): - """ - Returns a 2-dimensional (2-D) attention mask - :param source_block: 1-D array - :param target_block: 1-D array - """ - mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) - mask = mask.astype(np.int64) - # (source_length, target_length) - return mask - -def get_one_epoch_dataloader(dataset, micro_batch_size=None): - """Specifically one epoch to be used in an indexing job.""" - args = get_args() - - if micro_batch_size is None: - micro_batch_size = args.micro_batch_size - num_workers = args.num_workers - - # Use megatron's sampler with consumed samples set to 0 as - # this is only for evaluation and don't intend to resume half way. - # Also, set the drop last to false as don't intend to remove - # the last batch - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=0, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size(), - drop_last=False) - - return torch.utils.data.DataLoader(dataset, - batch_sampler=batch_sampler, - num_workers=num_workers, - pin_memory=True) - - -def get_ict_batch(data_iterator): - # Items and their type. - keys = ['query_tokens', 'query_mask', - 'context_tokens', 'context_mask', 'block_data'] - datatype = torch.int64 - - # Broadcast data. - if data_iterator is None: - data = None - else: - data = next(data_iterator) - data_b = tensor_parallel.broadcast_data(keys, data, datatype) - - # Unpack. - query_tokens = data_b['query_tokens'].long() - query_mask = data_b['query_mask'] < 0.5 - context_tokens = data_b['context_tokens'].long() - context_mask = data_b['context_mask'] < 0.5 - block_indices = data_b['block_data'].long() - - return query_tokens, query_mask,\ - context_tokens, context_mask, block_indices - - -def join_str_list(str_list): - """Join a list of strings, handling spaces appropriately""" - result = "" - for s in str_list: - if s.startswith("##"): - result += s[2:] - else: - result += " " + s - return result - - -class BlockSampleData(object): - """A struct for fully describing a fixed-size block of data as used in REALM - - :param start_idx: for first sentence of the block - :param end_idx: for last sentence of the block (may be partially truncated in sample construction) - :param doc_idx: the index of the document from which the block comes in the original indexed dataset - :param block_idx: a unique integer identifier given to every block. - """ - def __init__(self, start_idx, end_idx, doc_idx, block_idx): - self.start_idx = start_idx - self.end_idx = end_idx - self.doc_idx = doc_idx - self.block_idx = block_idx - - def as_array(self): - return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) - - def as_tuple(self): - return self.start_idx, self.end_idx, self.doc_idx, self.block_idx - - -class BlockSamplesMapping(object): - def __init__(self, mapping_array): - # make sure that the array is compatible with BlockSampleData - assert mapping_array.shape[1] == 4 - self.mapping_array = mapping_array - - def __len__(self): - return self.mapping_array.shape[0] - - def __getitem__(self, idx): - """Get the data associated with an indexed sample.""" - sample_data = BlockSampleData(*self.mapping_array[idx]) - return sample_data - - -def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): - """Get samples mapping for a dataset over fixed size blocks. This function also requires - a dataset of the titles for the source documents since their lengths must be taken into account. - - :return: samples_mapping (BlockSamplesMapping) - """ - - if not num_epochs: - if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") - num_epochs = np.iinfo(np.int32).max - 1 - if not max_num_samples: - max_num_samples = np.iinfo(np.int64).max - 1 - - # Filename of the index mapping - indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) - if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) - if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{}s'.format(seed) - if use_one_sent_docs: - indexmap_filename += '_1sentok' - indexmap_filename += '.npy' - - # Build the indexed mapping if not exist. - if mpu.get_data_parallel_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) - - # Make sure the types match the helpers input types. - assert block_dataset.document_indices.dtype == np.int64 - assert block_dataset.sequence_lengths.dtype == np.int32 - - # Build samples mapping - verbose = torch.distributed.get_rank() == 0 - start_time = time.time() - print_rank_0(' > building samples index mapping for {} ...'.format( - name)) - - from megatron.core.datasets import helpers - mapping_array = helpers.build_blocks_mapping( - block_dataset.document_indices, - block_dataset.sequence_lengths, - title_dataset.sequence_lengths, - num_epochs, - max_num_samples, - max_seq_length - 3, # account for added tokens - seed, - verbose, - use_one_sent_docs) - - - print_rank_0(' > done building samples index mapping') - np.save(indexmap_filename, mapping_array, allow_pickle=True) - print_rank_0(' > saved the index mapping in {}'.format( - indexmap_filename)) - # Make sure all the ranks have built the mapping - print_rank_0(' > elapsed time to build and save samples mapping ' - '(seconds): {:4f}'.format( - time.time() - start_time)) - - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - assert counts[0].item() == torch.distributed.get_world_size( - group=mpu.get_data_parallel_group()) - - # Load indexed dataset. - print_rank_0(' > loading indexed mapping from {}'.format( - indexmap_filename)) - start_time = time.time() - - mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - samples_mapping = BlockSamplesMapping(mapping_array) - - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - mapping_array.shape[0])) - - return samples_mapping diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py deleted file mode 100644 index 8dec2c192236f0c0af4b5534e963a1e65cf455ad..0000000000000000000000000000000000000000 --- a/megatron/data/data_samplers.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Dataloaders.""" - - -import random -import torch -import numpy as np -from torch.utils.data import Dataset -from megatron import get_args -from megatron.core import mpu - - -def build_pretraining_data_loader(dataset, consumed_samples): - """Buld dataloader given an input dataset.""" - - if dataset is None: - return None - args = get_args() - - # Megatron sampler - if args.dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) - elif args.dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomSampler( - dataset, - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size(), - data_sharding=args.data_sharding) - else: - raise Exception('{} dataloader type is not supported.'.format( - args.dataloader_type)) - - # Torch dataloader. - return torch.utils.data.DataLoader(dataset, - batch_sampler=batch_sampler, - num_workers=args.num_workers, - pin_memory=True) - -class MegatronPretrainingSampler: - - def __init__(self, total_samples, consumed_samples, micro_batch_size, - data_parallel_rank, data_parallel_size, drop_last=True): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.drop_last = drop_last - - # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.consumed_samples < self.total_samples, \ - 'no samples left to consume: {}, {}'.format(self.consumed_samples, - self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) - - def __len__(self): - return self.total_samples - - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.micro_batch_size - end_idx = start_idx + self.micro_batch_size - return start_idx, end_idx - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.micro_batch_times_data_parallel_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - -class RandomSeedDataset(Dataset): - - def __init__(self, dataset): - args = get_args() - self.base_seed = args.seed - self.curr_seed = args.seed - self.dataset = dataset - - def __len__(self): - return len(self.dataset) - - def set_epoch(self, epoch): - self.curr_seed = self.base_seed + epoch - - def __getitem__(self, idx): - seed = idx + self.curr_seed - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - return self.dataset[idx] - - -class MegatronPretrainingRandomSampler: - - def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, - data_parallel_rank, data_parallel_size, data_sharding): - # Keep a copy of input params for later use. - self.dataset = dataset - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.data_sharding = data_sharding - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size - - # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) - - def __len__(self): - return self.total_samples - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - - if isinstance(self.dataset, RandomSeedDataset): - self.dataset.set_epoch(self.epoch) - - # data sharding and random sampling - if self.data_sharding: - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ - * self.micro_batch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - else: - full_bucket_size = (self.total_samples // self.micro_batch_size) \ - * self.micro_batch_size - full_bucket_offset = current_epoch_samples - g = torch.Generator() - g.manual_seed(self.epoch) - idx_range_total = \ - torch.randperm(full_bucket_size, generator=g).tolist() - idx_range_active = idx_range_total[full_bucket_offset:] - idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.micro_batch_size: - self.consumed_samples += self.micro_batch_times_data_parallel_size - yield batch - batch = [] diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py deleted file mode 100644 index 561129c865d14f2c5578dfbdc6b2a8bbfac88b78..0000000000000000000000000000000000000000 --- a/megatron/data/dataset_utils.py +++ /dev/null @@ -1,743 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors, and NVIDIA. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Most of the code here has been copied from: -# https://github.com/google-research/albert/blob/master/create_pretraining_data.py -# with some modifications. - -import math -import os -import time -import collections - -import numpy as np -import torch - -from megatron import ( - get_args, - print_rank_0 -) -from megatron.core import mpu -from megatron.core.datasets.indexed_dataset import MMapIndexedDataset - - -DSET_TYPE_BERT = 'standard_bert' -DSET_TYPE_ICT = 'ict' -DSET_TYPE_T5 = 't5' -DSET_TYPE_MULTIMODAL = 'multimodal' - -DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_MULTIMODAL] - - -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): - - # The data prefix should be in the format of: - # weight-1, data-prefix-1, weight-2, data-prefix-2, .. - assert len(data_prefix) % 2 == 0 - num_datasets = len(data_prefix) // 2 - weights = [0]*num_datasets - prefixes = [0]*num_datasets - for i in range(num_datasets): - weights[i] = float(data_prefix[2*i]) - prefixes[i] = (data_prefix[2*i+1]).strip() - # Normalize weights - weight_sum = 0.0 - for weight in weights: - weight_sum += weight - assert weight_sum > 0.0 - weights = [weight / weight_sum for weight in weights] - - # Add 0.5% (the 1.005 factor) so in case the bleding dataset does - # not uniformly distribute the number of samples, we still have - # samples left to feed to the network. - if isinstance(train_valid_test_num_samples, list): - datasets_train_valid_test_num_samples = [] - for weight in weights: - datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) - for val in train_valid_test_num_samples]) - else: - # Used when separate dataset files are provided for train, - # valid and test - datasets_train_valid_test_num_samples = [ - int(math.ceil(train_valid_test_num_samples * weight * 1.005)) - for weight in weights] - - return prefixes, weights, datasets_train_valid_test_num_samples - - -def get_a_and_b_segments(sample, np_rng): - """Divide sample into a and b segments.""" - - # Number of sentences in the sample. - n_sentences = len(sample) - # Make sure we always have two sentences. - assert n_sentences > 1, 'make sure each sample has at least two sentences.' - - # First part: - # `a_end` is how many sentences go into the `A`. - a_end = 1 - if n_sentences >= 3: - # Note that randin in numpy is exclusive. - a_end = np_rng.randint(1, n_sentences) - tokens_a = [] - for j in range(a_end): - tokens_a.extend(sample[j]) - - # Second part: - tokens_b = [] - for j in range(a_end, n_sentences): - tokens_b.extend(sample[j]) - - # Random next: - is_next_random = False - if np_rng.random() < 0.5: - is_next_random = True - tokens_a, tokens_b = tokens_b, tokens_a - - return tokens_a, tokens_b, is_next_random - - -def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): - """Truncates a pair of sequences to a maximum sequence length.""" - #print(len_a, len_b, max_num_tokens) - assert len_a > 0 - if len_a + len_b <= max_num_tokens: - return False - while len_a + len_b > max_num_tokens: - if len_a > len_b: - len_a -= 1 - tokens = tokens_a - else: - len_b -= 1 - tokens = tokens_b - if np_rng.random() < 0.5: - del tokens[0] - else: - tokens.pop() - return True - - -def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): - """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" - - tokens = [] - tokentypes = [] - # [CLS]. - tokens.append(cls_id) - tokentypes.append(0) - # Segment A. - for token in tokens_a: - tokens.append(token) - tokentypes.append(0) - # [SEP]. - tokens.append(sep_id) - tokentypes.append(0) - # Segment B. - for token in tokens_b: - tokens.append(token) - tokentypes.append(1) - if tokens_b: - # [SEP]. - tokens.append(sep_id) - tokentypes.append(1) - - return tokens, tokentypes - - -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) - - -def is_start_piece(piece): - """Check if the current word piece is the starting piece (BERT).""" - # When a word has been split into - # WordPieces, the first token does not have any marker and any subsequence - # tokens are prefixed with ##. So whenever we see the ## token, we - # append it to the previous set of word indexes. - return not piece.startswith("##") - - -def create_masked_lm_predictions(tokens, - vocab_id_list, vocab_id_to_token_dict, - masked_lm_prob, - cls_id, sep_id, mask_id, - max_predictions_per_seq, - np_rng, - max_ngrams=3, - do_whole_word_mask=True, - favor_longer_ngram=False, - do_permutation=False, - geometric_dist=False, - masking_style="bert"): - """Creates the predictions for the masked LM objective. - Note: Tokens here are vocab ids and not text tokens.""" - - cand_indexes = [] - # Note(mingdachen): We create a list for recording if the piece is - # the starting piece of current token, where 1 means true, so that - # on-the-fly whole word masking is possible. - token_boundary = [0] * len(tokens) - - for (i, token) in enumerate(tokens): - if token == cls_id or token == sep_id: - token_boundary[i] = 1 - continue - # Whole Word Masking means that if we mask all of the wordpieces - # corresponding to an original word. - # - # Note that Whole Word Masking does *not* change the training code - # at all -- we still predict each WordPiece independently, softmaxed - # over the entire vocabulary. - if (do_whole_word_mask and len(cand_indexes) >= 1 and - not is_start_piece(vocab_id_to_token_dict[token])): - cand_indexes[-1].append(i) - else: - cand_indexes.append([i]) - if is_start_piece(vocab_id_to_token_dict[token]): - token_boundary[i] = 1 - - output_tokens = list(tokens) - - masked_lm_positions = [] - masked_lm_labels = [] - - if masked_lm_prob == 0: - return (output_tokens, masked_lm_positions, - masked_lm_labels, token_boundary) - - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) - - ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) - if not geometric_dist: - # Note(mingdachen): - # By default, we set the probilities to favor shorter ngram sequences. - pvals = 1. / np.arange(1, max_ngrams + 1) - pvals /= pvals.sum(keepdims=True) - if favor_longer_ngram: - pvals = pvals[::-1] - - ngram_indexes = [] - for idx in range(len(cand_indexes)): - ngram_index = [] - for n in ngrams: - ngram_index.append(cand_indexes[idx:idx + n]) - ngram_indexes.append(ngram_index) - - np_rng.shuffle(ngram_indexes) - - (masked_lms, masked_spans) = ([], []) - covered_indexes = set() - for cand_index_set in ngram_indexes: - if len(masked_lms) >= num_to_predict: - break - if not cand_index_set: - continue - # Note(mingdachen): - # Skip current piece if they are covered in lm masking or previous ngrams. - for index_set in cand_index_set[0]: - for index in index_set: - if index in covered_indexes: - continue - - if not geometric_dist: - n = np_rng.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) - else: - # Sampling "n" from the geometric distribution and clipping it to - # the max_ngrams. Using p=0.2 default from the SpanBERT paper - # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) - n = min(np_rng.geometric(0.2), max_ngrams) - - index_set = sum(cand_index_set[n - 1], []) - n -= 1 - # Note(mingdachen): - # Repeatedly looking for a candidate that does not exceed the - # maximum number of predictions by trying shorter ngrams. - while len(masked_lms) + len(index_set) > num_to_predict: - if n == 0: - break - index_set = sum(cand_index_set[n - 1], []) - n -= 1 - # If adding a whole-word mask would exceed the maximum number of - # predictions, then just skip this candidate. - if len(masked_lms) + len(index_set) > num_to_predict: - continue - is_any_index_covered = False - for index in index_set: - if index in covered_indexes: - is_any_index_covered = True - break - if is_any_index_covered: - continue - for index in index_set: - covered_indexes.add(index) - masked_token = None - if masking_style == "bert": - # 80% of the time, replace with [MASK] - if np_rng.random() < 0.8: - masked_token = mask_id - else: - # 10% of the time, keep original - if np_rng.random() < 0.5: - masked_token = tokens[index] - # 10% of the time, replace with random word - else: - masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] - elif masking_style == "t5": - masked_token = mask_id - else: - raise ValueError("invalid value of masking style") - - output_tokens[index] = masked_token - masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) - - masked_spans.append(MaskedLmInstance( - index=index_set, - label=[tokens[index] for index in index_set])) - - assert len(masked_lms) <= num_to_predict - np_rng.shuffle(ngram_indexes) - - select_indexes = set() - if do_permutation: - for cand_index_set in ngram_indexes: - if len(select_indexes) >= num_to_predict: - break - if not cand_index_set: - continue - # Note(mingdachen): - # Skip current piece if they are covered in lm masking or previous ngrams. - for index_set in cand_index_set[0]: - for index in index_set: - if index in covered_indexes or index in select_indexes: - continue - - n = np.random.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) - index_set = sum(cand_index_set[n - 1], []) - n -= 1 - - while len(select_indexes) + len(index_set) > num_to_predict: - if n == 0: - break - index_set = sum(cand_index_set[n - 1], []) - n -= 1 - # If adding a whole-word mask would exceed the maximum number of - # predictions, then just skip this candidate. - if len(select_indexes) + len(index_set) > num_to_predict: - continue - is_any_index_covered = False - for index in index_set: - if index in covered_indexes or index in select_indexes: - is_any_index_covered = True - break - if is_any_index_covered: - continue - for index in index_set: - select_indexes.add(index) - assert len(select_indexes) <= num_to_predict - - select_indexes = sorted(select_indexes) - permute_indexes = list(select_indexes) - np_rng.shuffle(permute_indexes) - orig_token = list(output_tokens) - - for src_i, tgt_i in zip(select_indexes, permute_indexes): - output_tokens[src_i] = orig_token[tgt_i] - masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) - - masked_lms = sorted(masked_lms, key=lambda x: x.index) - # Sort the spans by the index of the first span - masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) - - for p in masked_lms: - masked_lm_positions.append(p.index) - masked_lm_labels.append(p.label) - return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) - - -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): - """Pad sequences and convert them to numpy.""" - - # Some checks. - num_tokens = len(tokens) - padding_length = max_seq_length - num_tokens - assert padding_length >= 0 - assert len(tokentypes) == num_tokens - assert len(masked_positions) == len(masked_labels) - - # Tokens and token types. - filler = [pad_id] * padding_length - tokens_np = np.array(tokens + filler, dtype=np.int64) - tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) - - # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) - - # Lables and loss mask. - labels = [-1] * max_seq_length - loss_mask = [0] * max_seq_length - for i in range(len(masked_positions)): - assert masked_positions[i] < num_tokens - labels[masked_positions[i]] = masked_labels[i] - loss_mask[masked_positions[i]] = 1 - labels_np = np.array(labels, dtype=np.int64) - loss_mask_np = np.array(loss_mask, dtype=np.int64) - - return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np - - -def build_train_valid_test_datasets_with_prefixes(train_valid_test_num_samples, - max_seq_length, - seed, - train_data_prefix=None, - valid_data_prefix=None, - test_data_prefix=None, - binary_head=False, - max_seq_length_dec=None, - dataset_type='standard_bert'): - print_rank_0("Separate data paths provided for train, valid & test.") - - train_dataset, valid_dataset, test_dataset = None, None, None - # Single dataset. - if train_data_prefix is not None: - train_dataset = build_dataset("train", train_data_prefix, - train_valid_test_num_samples[0], - max_seq_length, seed, - binary_head, max_seq_length_dec, - dataset_type=dataset_type) - - if valid_data_prefix is not None: - valid_dataset = build_dataset("valid", valid_data_prefix, - train_valid_test_num_samples[1], - max_seq_length, seed, False, - binary_head, max_seq_length_dec, - dataset_type=dataset_type) - - if test_data_prefix is not None: - test_dataset = build_dataset("test", test_data_prefix, - train_valid_test_num_samples[2], - max_seq_length, seed, False, - binary_head, max_seq_length_dec, - dataset_type=dataset_type) - - return (train_dataset, valid_dataset, test_dataset) - - -def build_train_valid_test_datasets(data_prefix, splits_string, - train_valid_test_num_samples, - max_seq_length, seed, - binary_head=False, - max_seq_length_dec=None, - dataset_type='standard_bert'): - - if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - splits_string, - train_valid_test_num_samples, - max_seq_length, seed, - binary_head, - max_seq_length_dec, - dataset_type=dataset_type) - - raise NotImplementedError("Blending currently unsupported for non-GPT dataset instances") - - -def _build_train_valid_test_datasets(data_prefix, splits_string, - train_valid_test_num_samples, - max_seq_length, seed, - binary_head, - max_seq_length_dec, - dataset_type='standard_bert'): - - # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - dataset_type) - - # Get start and end indices of train/valid/train into doc-idx - # Note that doc-idx is desinged to be num-docs + 1 so we can - # easily iterate over it. - total_num_of_documents = indexed_dataset.document_indices.shape[0] - 1 - splits = get_train_valid_test_split_(splits_string, total_num_of_documents) - - # Print stats about the splits. - print_rank_0(' > dataset split:') - - def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) - start_index = indexed_dataset.document_indices[splits[index]] - end_index = indexed_dataset.document_indices[splits[index + 1]] - print_rank_0(' sentence indices in [{}, {}) total of {} ' - 'sentences'.format(start_index, end_index, - end_index - start_index)) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) - - def build_split_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - # Get the pointer to the original doc-idx so we can set it later. - doc_idx_ptr = indexed_dataset.get_document_indices() - # Slice the doc-idx - start_index = splits[index] - # Add +1 so we can index into the dataset to get the upper bound. - end_index = splits[index + 1] + 1 - # New doc_idx view. - indexed_dataset.set_document_indices(doc_idx_ptr[start_index:end_index]) - - dataset = build_dataset( - name, data_prefix, - train_valid_test_num_samples[index], max_seq_length, - seed, binary_head, max_seq_length_dec, - dataset_type, indexed_dataset) - - # Set the original pointer so dataset remains the main dataset. - indexed_dataset.set_document_indices(doc_idx_ptr) - # Checks. - assert indexed_dataset.document_indices[0] == 0 - assert indexed_dataset.document_indices.shape[0] == \ - (total_num_of_documents + 1) - return dataset - - train_dataset = build_split_dataset(0, 'train') - valid_dataset = build_split_dataset(1, 'valid') - test_dataset = build_split_dataset(2, 'test') - - return (train_dataset, valid_dataset, test_dataset) - - -def build_dataset(name, data_prefix, max_num_samples, - max_seq_length, seed, binary_head, - max_seq_length_dec, dataset_type='standard_bert', - indexed_dataset=None): - - from megatron.data.bert_dataset import BertDataset - from megatron.data.ict_dataset import ICTDataset - from megatron.data.t5_dataset import T5Dataset - from megatron.data.multimodal_dataset import MultiModalDataset - - if dataset_type not in DSET_TYPES: - raise ValueError("Invalid dataset_type: ", dataset_type) - - if indexed_dataset is None: - indexed_dataset = get_indexed_dataset_(data_prefix, - dataset_type) - - kwargs = dict( - name=name, - data_prefix=data_prefix, - num_epochs=None, - max_num_samples=max_num_samples, - max_seq_length=max_seq_length, - seed=seed, - ) - - if dataset_type == DSET_TYPE_ICT: - args = get_args() - - title_dataset = get_indexed_dataset_( - args.titles_data_path, - dataset_type) - - dataset = ICTDataset( - block_dataset=indexed_dataset, - title_dataset=title_dataset, - query_in_block_prob=args.query_in_block_prob, - use_one_sent_docs=args.use_one_sent_docs, - binary_head=binary_head, - **kwargs - ) - elif dataset_type == DSET_TYPE_T5: - args = get_args() - dataset = T5Dataset( - indexed_dataset=indexed_dataset, - masked_lm_prob=args.mask_prob, - max_seq_length_dec=max_seq_length_dec, - short_seq_prob=args.short_seq_prob, - **kwargs - ) - elif dataset_type == DSET_TYPE_BERT: - args = get_args() - dataset = BertDataset( - indexed_dataset=indexed_dataset, - masked_lm_prob=args.mask_prob, - short_seq_prob=args.short_seq_prob, - binary_head=binary_head, - **kwargs - ) - elif dataset_type == DSET_TYPE_MULTIMODAL: - args = get_args() - dataset = MultiModalDataset( - name=name, - data_prefix=data_prefix, - indexed_dataset=indexed_dataset, - num_samples=max_num_samples, - seq_length=max_seq_length, - seed=seed, - img_h=args.img_h, - img_w=args.img_w, - ) - else: - raise NotImplementedError("Dataset type not fully implemented.") - - return dataset - - -def get_indexed_dataset_(data_prefix, dataset_type): - - print_rank_0(' > building dataset index ...') - - start_time = time.time() - multimodal = dataset_type == DSET_TYPE_MULTIMODAL - indexed_dataset = MMapIndexedDataset(data_prefix, multimodal) - assert indexed_dataset.sequence_lengths.shape[0] == indexed_dataset.document_indices[-1] - print_rank_0(' > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time)) - - print_rank_0(' > indexed dataset stats:') - print_rank_0(' number of documents: {}'.format( - indexed_dataset.document_indices.shape[0] - 1)) - print_rank_0(' number of sentences: {}'.format( - indexed_dataset.sequence_lengths.shape[0])) - - return indexed_dataset - - -def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" - - splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] - else: - splits = [float(splits_string)] - while len(splits) < 3: - splits.append(0.) - splits = splits[:3] - splits_sum = sum(splits) - assert splits_sum > 0.0 - splits = [split / splits_sum for split in splits] - splits_index = [0] - for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) - diff = splits_index[-1] - size - for index in range(1, len(splits_index)): - splits_index[index] -= diff - assert len(splits_index) == 4 - assert splits_index[-1] == size - return splits_index - -def get_samples_mapping(indexed_dataset, - data_prefix, - num_epochs, - max_num_samples, - max_seq_length, - short_seq_prob, - seed, - name, - binary_head): - """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" - - if not num_epochs: - if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") - num_epochs = np.iinfo(np.int32).max - 1 - if not max_num_samples: - max_num_samples = np.iinfo(np.int64).max - 1 - - # Filename of the index mapping - indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) - if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) - if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) - indexmap_filename += '_{}s'.format(seed) - indexmap_filename += '.npy' - - # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) - - # Make sure the types match the helpers input types. - assert indexed_dataset.document_indices.dtype == np.int64 - assert indexed_dataset.sequence_lengths.dtype == np.int32 - - # Build samples mapping - verbose = torch.distributed.get_rank() == 0 - start_time = time.time() - print_rank_0(' > building samples index mapping for {} ...'.format( - name)) - # First compile and then import. - from megatron.core.datasets import helpers - samples_mapping = helpers.build_mapping( - indexed_dataset.document_indices, - indexed_dataset.sequence_lengths, - num_epochs, - max_num_samples, - max_seq_length, - short_seq_prob, - seed, - verbose, - 2 if binary_head else 1) - print_rank_0(' > done building samples index maping') - np.save(indexmap_filename, samples_mapping, allow_pickle=True) - print_rank_0(' > saved the index mapping in {}'.format( - indexmap_filename)) - # Make sure all the ranks have built the mapping - print_rank_0(' > elasped time to build and save samples mapping ' - '(seconds): {:4f}'.format( - time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) - - # Load indexed dataset. - print_rank_0(' > loading indexed mapping from {}'.format( - indexmap_filename)) - start_time = time.time() - samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - samples_mapping.shape[0])) - - return samples_mapping diff --git a/megatron/data/ict_dataset.py b/megatron/data/ict_dataset.py deleted file mode 100644 index 6dac35ff9d413898146df5c1cc8553719e142105..0000000000000000000000000000000000000000 --- a/megatron/data/ict_dataset.py +++ /dev/null @@ -1,156 +0,0 @@ -import itertools -import random - -import numpy as np -from torch.utils.data import Dataset - -from megatron import get_tokenizer -from megatron import get_args -from megatron.data.dataset_utils import get_indexed_dataset_ -from megatron.data.realm_dataset_utils import get_block_samples_mapping - -def make_attention_mask(source_block, target_block): - """ - Returns a 2-dimensional (2-D) attention mask - :param source_block: 1-D array - :param target_block: 1-D array - """ - mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) - mask = mask.astype(np.int64) - # (source_length, target_length) - return mask - -def get_ict_dataset(use_titles=True, query_in_block_prob=1): - """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) - rather than for training, since it is only built with a single epoch sample mapping. - """ - args = get_args() - block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) - titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) - - kwargs = dict( - name='full', - block_dataset=block_dataset, - title_dataset=titles_dataset, - data_prefix=args.data_path, - num_epochs=1, - max_num_samples=None, - max_seq_length=args.seq_length, - seed=1, - query_in_block_prob=query_in_block_prob, - use_titles=use_titles, - use_one_sent_docs=args.use_one_sent_docs - ) - dataset = ICTDataset(**kwargs) - return dataset - - -class ICTDataset(Dataset): - """Dataset containing sentences and their blocks for an inverse cloze task.""" - def __init__(self, name, block_dataset, title_dataset, data_prefix, - num_epochs, max_num_samples, max_seq_length, query_in_block_prob, - seed, use_titles=True, use_one_sent_docs=False, binary_head=False): - self.name = name - self.seed = seed - self.max_seq_length = max_seq_length - self.query_in_block_prob = query_in_block_prob - self.block_dataset = block_dataset - self.title_dataset = title_dataset - self.rng = random.Random(self.seed) - self.use_titles = use_titles - self.use_one_sent_docs = use_one_sent_docs - - self.samples_mapping = get_block_samples_mapping( - block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs) - self.tokenizer = get_tokenizer() - self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) - self.vocab_id_to_token_list = self.tokenizer.inv_vocab - self.cls_id = self.tokenizer.cls - self.sep_id = self.tokenizer.sep - self.mask_id = self.tokenizer.mask - self.pad_id = self.tokenizer.pad - - def __len__(self): - return len(self.samples_mapping) - - def __getitem__(self, idx): - """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" - sample_data = self.samples_mapping[idx] - start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() - - if self.use_titles: - title = self.title_dataset[int(doc_idx)] - title_pad_offset = 3 + len(title) - else: - title = None - title_pad_offset = 2 - block = [self.block_dataset[i] for i in range(start_idx, end_idx)] - assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 - - # randint() is inclusive for Python rng - rand_sent_idx = self.rng.randint(0, len(block) - 1) - - # keep the query in the context query_in_block_prob fraction of the time. - if self.rng.random() < self.query_in_block_prob: - query = block[rand_sent_idx].copy() - else: - query = block.pop(rand_sent_idx) - - # still need to truncate because blocks are concluded when - # the sentence lengths have exceeded max_seq_length. - query = query[:self.max_seq_length - 2] - block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] - - query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) - context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) - - query_mask = make_attention_mask(query_tokens, query_tokens) - context_mask = make_attention_mask(context_tokens, context_tokens) - - block_data = sample_data.as_array() - - sample = { - 'query_tokens': query_tokens, - 'query_mask': query_mask, - 'query_pad_mask': query_pad_mask, - 'context_tokens': context_tokens, - 'context_mask': context_mask, - 'context_pad_mask': context_pad_mask, - 'block_data': block_data, - } - - return sample - - def get_block(self, start_idx, end_idx, doc_idx): - """Get the IDs for an evidence block plus the title of the corresponding document""" - block = [self.block_dataset[i] for i in range(start_idx, end_idx)] - title = self.title_dataset[int(doc_idx)] - - block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] - block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) - - return block_tokens, block_pad_mask - - def get_null_block(self): - """Get empty block and title - used in REALM pretraining""" - block, title = [], [] - block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) - - return block_tokens, block_pad_mask - - def concat_and_pad_tokens(self, tokens, title=None): - """Concat with special tokens and pad sequence to self.max_seq_length""" - tokens = list(tokens) - if title is None: - tokens = [self.cls_id] + tokens + [self.sep_id] - else: - title = list(title) - tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] - assert len(tokens) <= self.max_seq_length - - num_pad = self.max_seq_length - len(tokens) - pad_mask = [1] * len(tokens) + [0] * num_pad - tokens += [self.pad_id] * num_pad - - return np.array(tokens), np.array(pad_mask) diff --git a/megatron/data/image_folder.py b/megatron/data/image_folder.py deleted file mode 100644 index de15b29bf0665562a00bfcab8b106ff2d4ca26f2..0000000000000000000000000000000000000000 --- a/megatron/data/image_folder.py +++ /dev/null @@ -1,302 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) Soumith Chintala 2016, -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# code taken from -# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py -# added support for classes_fraction and data_per_class_fraction - -from torchvision.datasets import VisionDataset -from PIL import Image - -import os -import os.path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple -import numpy as np - -def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: - """Checks if a file is an allowed extension. - Args: - filename (string): path to a file - extensions (tuple of strings): extensions to consider (lowercase) - Returns: - bool: True if the filename ends with one of given extensions - """ - return filename.lower().endswith(extensions) - - -def is_image_file(filename: str) -> bool: - """Checks if a file is an allowed image extension. - Args: - filename (string): path to a file - Returns: - bool: True if the filename ends with a known image extension - """ - return has_file_allowed_extension(filename, IMG_EXTENSIONS) - - -def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - data_per_class_fraction: float, - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, int]]: - """Generates a list of samples of a form (path_to_sample, class). - Args: - directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index - extensions (optional): A list of allowed extensions. - Either extensions or is_valid_file should be passed. Defaults to None. - is_valid_file (optional): A function that takes path of a file - and checks if the file is a valid file - (used to check of corrupt files) both extensions and - is_valid_file should not be passed. Defaults to None. - Raises: - ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. - Returns: - List[Tuple[str, int]]: samples of a form (path_to_sample, class) - """ - instances = [] - directory = os.path.expanduser(directory) - both_none = extensions is None and is_valid_file is None - both_something = extensions is not None and is_valid_file is not None - if both_none or both_something: - raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") - if extensions is not None: - def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) - is_valid_file = cast(Callable[[str], bool], is_valid_file) - for target_class in sorted(class_to_idx.keys()): - class_index = class_to_idx[target_class] - target_dir = os.path.join(directory, target_class) - if not os.path.isdir(target_dir): - continue - local_instances = [] - for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - item = path, class_index - local_instances.append(item) - - instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)]) - - return instances - - -class DatasetFolder(VisionDataset): - """A generic data loader where the samples are arranged in this way: :: - root/class_x/xxx.ext - root/class_x/xxy.ext - root/class_x/[...]/xxz.ext - root/class_y/123.ext - root/class_y/nsdf3.ext - root/class_y/[...]/asd932_.ext - Args: - root (string): Root directory path. - loader (callable): A function to load a sample given its path. - extensions (tuple[string]): A list of allowed extensions. - both extensions and is_valid_file should not be passed. - transform (callable, optional): A function/transform that takes in - a sample and returns a transformed version. - E.g, ``transforms.RandomCrop`` for images. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. - is_valid_file (callable, optional): A function that takes path of a file - and check if the file is a valid file (used to check of corrupt files) - both extensions and is_valid_file should not be passed. - Attributes: - classes (list): List of the class names sorted alphabetically. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample path, class_index) tuples - targets (list): The class_index value for each image in the dataset - """ - - def __init__( - self, - root: str, - loader: Callable[[str], Any], - extensions: Optional[Tuple[str, ...]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - classes_fraction=1.0, - data_per_class_fraction=1.0, - is_valid_file: Optional[Callable[[str], bool]] = None, - ) -> None: - super(DatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) - self.classes_fraction = classes_fraction - self.data_per_class_fraction = data_per_class_fraction - classes, class_to_idx = self._find_classes(self.root) - samples = self.make_dataset(self.root, - class_to_idx, - self.data_per_class_fraction, - extensions, - is_valid_file) - if len(samples) == 0: - msg = "Found 0 files in subfolders of: {}\n".format(self.root) - if extensions is not None: - msg += "Supported extensions are: {}".format(",".join(extensions)) - raise RuntimeError(msg) - - self.loader = loader - self.extensions = extensions - self.total = len(samples) - self.classes = classes - self.class_to_idx = class_to_idx - self.samples = samples - self.targets = [s[1] for s in samples] - - @staticmethod - def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - data_per_class_fraction: float, - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, - ) -> List[Tuple[str, int]]: - return make_dataset(directory, - class_to_idx, - data_per_class_fraction, - extensions=extensions, - is_valid_file=is_valid_file) - - def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: - """ - Finds the class folders in a dataset. - Args: - dir (string): Root directory path. - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - Ensures: - No class is a subdirectory of another. - """ - all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes = all_classes[0:int(len(all_classes) * self.classes_fraction)] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """ - Args: - index (int): Index - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - curr_index = index - for x in range(self.total): - try: - path, target = self.samples[curr_index] - sample = self.loader(path) - break - except Exception as e: - curr_index = np.random.randint(0, self.total) - - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - def __len__(self) -> int: - return len(self.samples) - - -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') - - -def pil_loader(path: str) -> Image.Image: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: - img = Image.open(f) - return img.convert('RGB') - - -# TODO: specify the return type -def accimage_loader(path: str) -> Any: - import accimage - try: - return accimage.Image(path) - except IOError: - # Potentially a decoding problem, fall back to PIL.Image - return pil_loader(path) - - -def default_loader(path: str) -> Any: - from torchvision import get_image_backend - if get_image_backend() == 'accimage': - return accimage_loader(path) - else: - return pil_loader(path) - - -class ImageFolder(DatasetFolder): - """A generic data loader where the images are arranged in this way: :: - root/dog/xxx.png - root/dog/xxy.png - root/dog/[...]/xxz.png - root/cat/123.png - root/cat/nsdf3.png - root/cat/[...]/asd932_.png - Args: - root (string): Root directory path. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - loader (callable, optional): A function to load an image given its path. - is_valid_file (callable, optional): A function that takes path of an Image file - and check if the file is a valid file (used to check of corrupt files) - Attributes: - classes (list): List of the class names sorted alphabetically. - class_to_idx (dict): Dict with items (class_name, class_index). - imgs (list): List of (image path, class_index) tuples - """ - - def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - classes_fraction=1.0, - data_per_class_fraction=1.0, - loader: Callable[[str], Any] = default_loader, - is_valid_file: Optional[Callable[[str], bool]] = None, - ): - super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, - transform=transform, - target_transform=target_transform, - classes_fraction=classes_fraction, - data_per_class_fraction=data_per_class_fraction, - is_valid_file=is_valid_file) - self.imgs = self.samples - diff --git a/megatron/data/multimodal_dataset.py b/megatron/data/multimodal_dataset.py deleted file mode 100644 index 93ea790329c7aa58f2e63745c4aaca464567c99f..0000000000000000000000000000000000000000 --- a/megatron/data/multimodal_dataset.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from PIL import Image, UnidentifiedImageError -import numpy as np -import io -import torch - -try: - from torchvision.transforms import InterpolationMode - BICUBIC = InterpolationMode.BICUBIC -except ImportError: - BICUBIC = Image.BICUBIC - -from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, RandomResizedCrop, Resize - -def _convert_image_to_rgb(image): - return image.convert("RGB") - -def _transform(img_h, img_w): - return Compose([ - ToPILImage(), - RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0), interpolation=BICUBIC), - _convert_image_to_rgb, - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ]) - -class MultiModalDataset(torch.utils.data.Dataset): - - def __init__(self, name, data_prefix, indexed_dataset, - num_samples, seq_length, seed, img_h, img_w): - - self.name = name - self.indexed_dataset = indexed_dataset - self.doc_idx = indexed_dataset.get_document_indices() - self.visual_transform = _transform(img_h, img_w) - - def __len__(self): - return self.indexed_dataset.sequence_lengths.shape[0] - - def __getitem__(self, idx): - text_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]) - assert mode == 0 - img_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]+1) - assert mode == 1 - img_pad = img_sample[0].item() - xs = img_sample[1:].tobytes(order='C') - xs = xs[:len(xs)-img_pad] - - img_sample = np.array(Image.open(io.BytesIO(xs))) - img_sample = self.visual_transform(img_sample).reshape(-1) - - return {'text': np.array(text_sample, dtype=np.int64), - 'img': np.array(img_sample, dtype=np.float32)} diff --git a/megatron/data/orqa_wiki_dataset.py b/megatron/data/orqa_wiki_dataset.py deleted file mode 100644 index 4019cd764c204b34de0df28bfcc7969c3c19d937..0000000000000000000000000000000000000000 --- a/megatron/data/orqa_wiki_dataset.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Wikipedia dataset from DPR code for ORQA.""" - -from abc import ABC -import csv -import numpy as np -import random -import torch -from torch.utils.data import Dataset - -from megatron import print_rank_0, get_args, get_tokenizer -from megatron.core import tensor_parallel -from megatron.data.biencoder_dataset_utils import make_attention_mask - -def get_open_retrieval_wiki_dataset(): - args = get_args() - tokenizer = get_tokenizer() - - dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', - 'evidence', - args.evidence_data_path, - tokenizer, - args.retriever_seq_length) - return dataset - - -def get_open_retrieval_batch(data_iterator): - # Items and their type. - keys = ['row_id', 'context', 'context_mask', 'context_types', - 'context_pad_mask'] - datatype = torch.int64 - - # Broadcast data. - data = None if data_iterator is None else next(data_iterator) - data_b = tensor_parallel.broadcast_data(keys, data, datatype) - - # Unpack. - row_id = data_b['row_id'].long() - context = data_b['context'].long() - - # TODO: make the context mask a binary one - context_mask = (data_b['context_mask'] < 0.5) - - context_types = data_b['context_types'].long() - context_pad_mask = data_b['context_pad_mask'].long() - - return row_id, context, context_mask, context_types, context_pad_mask - - -def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length): - """Build token types and paddings, trim if needed, and pad if needed.""" - - title_ids = tokenizer.tokenize(row['title']) - context_ids = tokenizer.tokenize(row['text']) - - # Appending the title of the context at front - extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids - - context_ids, context_types, context_pad_mask = \ - build_tokens_types_paddings_from_ids(extended_context_ids, - max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) - - return context_ids, context_types, context_pad_mask - - -# noinspection DuplicatedCode -def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, - cls_id, sep_id, pad_id): - """Build token types and paddings, trim if needed, and pad if needed.""" - enc_ids = [] - tokentypes_enc = [] - - # [CLS]. - enc_ids.append(cls_id) - tokentypes_enc.append(0) - - # A. - len_src = len(text_ids) - enc_ids.extend(text_ids) - tokentypes_enc.extend([0] * len_src) - - # Cap the size. - if len(enc_ids) > max_seq_length - 1: - enc_ids = enc_ids[0: max_seq_length - 1] - tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] - - # [SEP]. - enc_ids.append(sep_id) - tokentypes_enc.append(0) - - num_tokens_enc = len(enc_ids) - # Padding. - padding_length = max_seq_length - len(enc_ids) - if padding_length > 0: - enc_ids.extend([pad_id] * padding_length) - tokentypes_enc.extend([pad_id] * padding_length) - - pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) - pad_mask = np.array(pad_mask, dtype=np.int64) - - return enc_ids, tokentypes_enc, pad_mask - - -def build_sample(row_id, context_ids, context_types, context_pad_mask): - """Convert to numpy and return a sample consumed by the batch producer.""" - - context_ids = np.array(context_ids, dtype=np.int64) - context_types = np.array(context_types, dtype=np.int64) - context_mask = make_attention_mask(context_ids, context_ids) - - sample = ({ - 'row_id': row_id, - 'context': context_ids, - 'context_mask': context_mask, - 'context_types': context_types, - 'context_pad_mask': context_pad_mask - }) - return sample - - -class OpenRetrievalEvidenceDataset(ABC, Dataset): - """Open Retrieval Evidence dataset class.""" - - def __init__(self, task_name, dataset_name, datapath, tokenizer, - max_seq_length): - # Store inputs. - self.task_name = task_name - self.dataset_name = dataset_name - self.tokenizer = tokenizer - self.max_seq_length = max_seq_length - print_rank_0(' > building {} dataset for {}:'.format(self.task_name, - self.dataset_name)) - # Process the files. - print_rank_0(datapath) - self.samples, self.id2text = self.process_samples_from_single_path( - datapath) - - args = get_args() - if args.sample_rate < 1: # subsample - k = int(len(self.samples) * args.sample_rate) - self.samples = random.sample(self.samples, k) - - print_rank_0(' >> total number of samples: {}'.format( - len(self.samples))) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - row = self.samples[idx] - - context_ids, context_types, context_pad_mask = \ - build_tokens_types_paddings_from_text(row, self.tokenizer, - self.max_seq_length) - - sample = build_sample(row['doc_id'], - context_ids, - context_types, - context_pad_mask) - return sample - - @staticmethod - def process_samples_from_single_path(filename): - print_rank_0(' > Processing {} ...'.format(filename)) - total = 0 - - rows = [] - id2text = {} - - with open(filename) as tsvfile: - reader = csv.reader(tsvfile, delimiter='\t') - next(reader, None) # skip the headers - for row in reader: - # file format: doc_id, doc_text, title - doc_id = int(row[0]) - text = row[1] - title = row[2] - - rows.append({'doc_id': doc_id, - 'text': text, - 'title': title}) - - assert doc_id not in id2text - id2text[doc_id] = (text, title) - - total += 1 - if total % 100000 == 0: - print_rank_0(' > processed {} rows so far ...'.format( - total)) - - print_rank_0(' >> processed {} samples.'.format(len(rows))) - return rows, id2text diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py deleted file mode 100644 index 3c8672bb583aa0c839fc055405ae3761dda48911..0000000000000000000000000000000000000000 --- a/megatron/data/realm_dataset_utils.py +++ /dev/null @@ -1,199 +0,0 @@ -import os -import time - -import numpy as np -import torch - -from megatron import print_rank_0 -from megatron.core import mpu, tensor_parallel -from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy -from megatron import get_args, get_tokenizer, print_rank_0 - - -def get_one_epoch_dataloader(dataset, micro_batch_size=None): - """Specifically one epoch to be used in an indexing job.""" - args = get_args() - - world_size = mpu.get_data_parallel_world_size() - rank = mpu.get_data_parallel_rank() - if micro_batch_size is None: - micro_batch_size = args.micro_batch_size - global_batch_size = micro_batch_size * world_size - num_workers = args.num_workers - - sampler = torch.utils.data.SequentialSampler(dataset) - # importantly, drop_last must be False to get all the data. - assert False, 'DistributedBatchSampler deprecated, change the implementation' - from megatron.data.samplers import DistributedBatchSampler - batch_sampler = DistributedBatchSampler(sampler, - batch_size=global_batch_size, - drop_last=False, - rank=rank, - world_size=world_size) - - return torch.utils.data.DataLoader(dataset, - batch_sampler=batch_sampler, - num_workers=num_workers, - pin_memory=True) - - -def get_ict_batch(data_iterator): - # Items and their type. - keys = ['query_tokens', 'query_pad_mask', - 'block_tokens', 'block_pad_mask', 'block_data'] - datatype = torch.int64 - - # Broadcast data. - if data_iterator is None: - data = None - else: - data = next(data_iterator) - data_b = tensor_parallel.broadcast_data(keys, data, datatype) - - # Unpack. - query_tokens = data_b['query_tokens'].long() - query_pad_mask = data_b['query_pad_mask'].long() - block_tokens = data_b['block_tokens'].long() - block_pad_mask = data_b['block_pad_mask'].long() - block_indices = data_b['block_data'].long() - - return query_tokens, query_pad_mask,\ - block_tokens, block_pad_mask, block_indices - - -def join_str_list(str_list): - """Join a list of strings, handling spaces appropriately""" - result = "" - for s in str_list: - if s.startswith("##"): - result += s[2:] - else: - result += " " + s - return result - - -class BlockSampleData(object): - """A struct for fully describing a fixed-size block of data as used in REALM - - :param start_idx: for first sentence of the block - :param end_idx: for last sentence of the block (may be partially truncated in sample construction) - :param doc_idx: the index of the document from which the block comes in the original indexed dataset - :param block_idx: a unique integer identifier given to every block. - """ - def __init__(self, start_idx, end_idx, doc_idx, block_idx): - self.start_idx = start_idx - self.end_idx = end_idx - self.doc_idx = doc_idx - self.block_idx = block_idx - - def as_array(self): - return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64) - - def as_tuple(self): - return self.start_idx, self.end_idx, self.doc_idx, self.block_idx - - -class BlockSamplesMapping(object): - def __init__(self, mapping_array): - # make sure that the array is compatible with BlockSampleData - assert mapping_array.shape[1] == 4 - self.mapping_array = mapping_array - - def __len__(self): - return self.mapping_array.shape[0] - - def __getitem__(self, idx): - """Get the data associated with an indexed sample.""" - sample_data = BlockSampleData(*self.mapping_array[idx]) - return sample_data - - -def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): - """Get samples mapping for a dataset over fixed size blocks. This function also requires - a dataset of the titles for the source documents since their lengths must be taken into account. - - :return: samples_mapping (BlockSamplesMapping) - """ - - if not num_epochs: - if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") - num_epochs = np.iinfo(np.int32).max - 1 - if not max_num_samples: - max_num_samples = np.iinfo(np.int64).max - 1 - - # Filename of the index mapping - indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) - if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) - if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{}s'.format(seed) - if use_one_sent_docs: - indexmap_filename += '_1sentok' - indexmap_filename += '.npy' - - # Build the indexed mapping if not exist. - if mpu.get_data_parallel_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) - - # Make sure the types match the helpers input types. - assert block_dataset.document_indices.dtype == np.int64 - assert block_dataset.sequence_lengths.dtype == np.int32 - - # Build samples mapping - verbose = torch.distributed.get_rank() == 0 - start_time = time.time() - print_rank_0(' > building samples index mapping for {} ...'.format( - name)) - - from megatron.core.datasets import helpers - mapping_array = helpers.build_blocks_mapping( - block_dataset.document_indices, - block_dataset.sequence_lengths, - title_dataset.sequence_lengths, - num_epochs, - max_num_samples, - max_seq_length - 3, # account for added tokens - seed, - verbose, - use_one_sent_docs) - - - print_rank_0(' > done building samples index mapping') - np.save(indexmap_filename, mapping_array, allow_pickle=True) - print_rank_0(' > saved the index mapping in {}'.format( - indexmap_filename)) - # Make sure all the ranks have built the mapping - print_rank_0(' > elapsed time to build and save samples mapping ' - '(seconds): {:4f}'.format( - time.time() - start_time)) - - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - assert counts[0].item() == torch.distributed.get_world_size( - group=mpu.get_data_parallel_group()) - - # Load indexed dataset. - print_rank_0(' > loading indexed mapping from {}'.format( - indexmap_filename)) - start_time = time.time() - - mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - samples_mapping = BlockSamplesMapping(mapping_array) - - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - mapping_array.shape[0])) - - return samples_mapping diff --git a/megatron/data/realm_index.py b/megatron/data/realm_index.py deleted file mode 100644 index 1fa4a309edcd5a761b3a87b973b799c30ac73458..0000000000000000000000000000000000000000 --- a/megatron/data/realm_index.py +++ /dev/null @@ -1,224 +0,0 @@ -import itertools -import os -import pickle -import shutil - -import numpy as np -import torch - -from megatron import get_args -from megatron.core import mpu - - -def detach(tensor): - return tensor.detach().cpu().numpy() - - -class OpenRetreivalDataStore(object): - """ - Serializable data structure for holding data for blocks -- - embeddings and necessary metadata for Retriever - """ - def __init__(self, embedding_path=None, load_from_path=True, rank=None): - self.embed_data = dict() - if embedding_path is None: - args = get_args() - embedding_path = args.embedding_path - rank = args.rank - self.embedding_path = embedding_path - self.rank = rank - - if load_from_path: - self.load_from_file() - - block_data_name = os.path.splitext(self.embedding_path)[0] - self.temp_dir_name = block_data_name + '_tmp' - - def state(self): - return { - 'embed_data': self.embed_data, - } - - def clear(self): - """ - Clear the embedding data structures to save memory. - The metadata ends up getting used, and is also much smaller in - dimensionality so it isn't really worth clearing. - """ - self.embed_data = dict() - - def load_from_file(self): - """Populate members from instance saved to file""" - - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print("\n> Unpickling BlockData", flush=True) - state_dict = pickle.load(open(self.embedding_path, 'rb')) - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print(">> Finished unpickling BlockData\n", flush=True) - - self.embed_data = state_dict['embed_data'] - - def add_block_data(self, row_id, block_embeds, allow_overwrite=False): - """ - Add data for set of blocks - :param row_id: 1D array of unique int ids for the blocks - :param block_embeds: 2D array of embeddings of the blocks - In the case of retriever this will be [start_idx, end_idx, doc_idx] - """ - for idx, embed in zip(row_id, block_embeds): - if not allow_overwrite and idx in self.embed_data: - raise ValueError("Unexpectedly tried to overwrite block data") - - self.embed_data[idx] = np.float16(embed) - - def save_shard(self): - """ - Save the block data that was created this in this process - """ - if not os.path.isdir(self.temp_dir_name): - os.makedirs(self.temp_dir_name, exist_ok=True) - - # save the data for each shard - with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \ - as writer: - pickle.dump(self.state(), writer) - - def merge_shards_and_save(self): - #Combine all the shards made using save_shard - shard_names = os.listdir(self.temp_dir_name) - seen_own_shard = False - - for fname in os.listdir(self.temp_dir_name): - shard_rank = int(os.path.splitext(fname)[0]) - if shard_rank == self.rank: - seen_own_shard = True - continue - - with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f: - data = pickle.load(f) - old_size = len(self.embed_data) - shard_size = len(data['embed_data']) - - # add the shard's data and check to make sure there - # is no overlap - self.embed_data.update(data['embed_data']) - assert len(self.embed_data) == old_size + shard_size - - assert seen_own_shard - - # save the consolidated shards and remove temporary directory - with open(self.embedding_path, 'wb') as final_file: - pickle.dump(self.state(), final_file) - shutil.rmtree(self.temp_dir_name, ignore_errors=True) - - print("Finished merging {} shards for a total of {} embeds".format( - len(shard_names), len(self.embed_data)), flush=True) - - -class FaissMIPSIndex(object): - """ - Wrapper object for a BlockData which similarity search via FAISS under the hood - """ - def __init__(self, embed_size, embed_data=None, use_gpu=False): - self.embed_size = embed_size - self.embed_data = embed_data - self.use_gpu = use_gpu - - self.mips_index = None - self._set_mips_index() - - def _set_mips_index(self): - """ - Create a Faiss Flat index with inner product as the metric - to search against - """ - try: - import faiss - except ImportError: - raise Exception("Error: Please install faiss to use FaissMIPSIndex") - - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print("\n> Building index", flush=True) - - cpu_index = faiss.IndexFlatIP(self.embed_size) - - if self.use_gpu: - # create resources and config for GpuIndex - config = faiss.GpuMultipleClonerOptions() - config.shard = True - config.useFloat16 = True - gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) - self.mips_index = faiss.IndexIDMap(gpu_index) - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print(">> Initialized index on GPU", flush=True) - else: - # CPU index supports IDs so wrap with IDMap - self.mips_index = faiss.IndexIDMap(cpu_index) - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print(">> Initialized index on CPU", flush=True) - - # if we were constructed with a BlockData, then automatically load it - # when the FAISS structure is built - if self.embed_data is not None: - self.add_embed_data(self.embed_data) - - def reset_index(self): - """Delete existing index and create a new""" - del self.mips_index - - # reset the block data so that _set_block_index will reload it as well - if self.embed_data is not None: - embed_data_path = self.embed_data.embedding_path - del self.embed_data - self.embed_data = OpenRetreivalDataStore(embed_data_path) - - self._set_mips_index() - - def update_index(self): - """Delete existing index and create a new""" - del self.mips_index - - # reset the block data so that _set_mips_index will reload it as well - if self.embed_data is not None: - self.embed_data.load_from_file() - self._set_mips_index() - - def add_embed_data(self, all_embed_data): - """Add the embedding of each block to the underlying FAISS index""" - - # this assumes the embed_data is a dict : {int: np.array} - block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) - - # the embeddings have to be entered in as float32 even though the math - # internally is done with float16. - embeds_arr = np.float32(np.array(block_embeds)) - indices_arr = np.array(block_indices) - - # we no longer need the embedding data since it's in the index now - all_embed_data.clear() - - self.mips_index.add_with_ids(embeds_arr, indices_arr) - - if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0: - print(">>> Finished adding block data to index", flush=True) - - def search_mips_index(self, query_embeds, top_k, reconstruct=True): - """ - Get the top-k blocks by the index distance metric. - - :param reconstruct: if True: return a [num_queries x k x embed_dim] - array of blocks - if False: return [num_queries x k] array of - distances, and another for indices - """ - query_embeds = np.float32(detach(query_embeds)) - - if reconstruct: - # get the vectors themselves - top_k_block_embeds = self.mips_index.search_and_reconstruct(\ - query_embeds, top_k) - return top_k_block_embeds - else: - # get distances and indices of closest vectors - distances, block_indices = self.mips_index.search(query_embeds, top_k) - return distances, block_indices diff --git a/megatron/data/t5_dataset.py b/megatron/data/t5_dataset.py deleted file mode 100644 index 075b089f8e39ba9d236f5b39d62861dc2f17608d..0000000000000000000000000000000000000000 --- a/megatron/data/t5_dataset.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""T5 Style dataset.""" - -import collections - -import numpy as np -import torch - -from megatron import get_tokenizer -from megatron.data.dataset_utils import ( - create_masked_lm_predictions, - get_samples_mapping -) - -class T5Dataset(torch.utils.data.Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, - num_epochs, max_num_samples, masked_lm_prob, - max_seq_length, max_seq_length_dec, - short_seq_prob, seed): - - # Params to store. - self.name = name - self.desc = name - self.seed = seed - self.masked_lm_prob = masked_lm_prob - self.max_seq_length = max_seq_length - self.max_seq_length_dec = max_seq_length_dec - - # Dataset. - self.indexed_dataset = indexed_dataset - - # Build the samples mapping. - self.samples_mapping = get_samples_mapping(self.indexed_dataset, - data_prefix, - num_epochs, - max_num_samples, - self.max_seq_length - 2, # account for added tokens - short_seq_prob, - self.seed, - self.name, - False) - - # Vocab stuff. - tokenizer = get_tokenizer() - self.vocab_id_list = list(tokenizer.inv_vocab.keys()) - self.vocab_id_to_token_dict = tokenizer.inv_vocab - self.cls_id = tokenizer.cls - self.sep_id = tokenizer.sep - self.mask_id = tokenizer.mask - self.pad_id = tokenizer.pad - self.bos_id = tokenizer.bos_token_id - self.eos_id = tokenizer.eos_token_id - self.sentinel_tokens = tokenizer.additional_special_tokens_ids - assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" - - def __len__(self): - return self.samples_mapping.shape[0] - - def __getitem__(self, idx): - - start_index, end_index, seq_length = self.samples_mapping[idx] - sample = [] - for index in range(start_index, end_index): - sample.append(self.indexed_dataset[index]) - # Note that this rng state should be numpy and not python since - # python randint is inclusive whereas the numpy one is exclusive. - np_rng = np.random.RandomState(seed=(self.seed + idx)) - return build_training_sample(sample, seq_length, - self.max_seq_length, # needed for padding - self.max_seq_length_dec, - self.vocab_id_list, - self.vocab_id_to_token_dict, - self.cls_id, self.sep_id, - self.mask_id, self.pad_id, - self.masked_lm_prob, np_rng, - self.bos_id, self.eos_id, - self.sentinel_tokens) - - -def build_training_sample(sample, target_seq_length, - max_seq_length, max_seq_length_dec, - vocab_id_list, vocab_id_to_token_dict, - cls_id, sep_id, mask_id, pad_id, - masked_lm_prob, np_rng, bos_id=None, - eos_id=None, sentinel_tokens=None): - """Build training sample. - - Arguments: - sample: A list of sentences in which each sentence is a list token ids. - target_seq_length: Desired sequence length. - max_seq_length: Maximum length of the sequence. All values are padded to - this length. - vocab_id_list: List of vocabulary ids. Used to pick a random id. - vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. - cls_id: Start of example id. - sep_id: Separator id. - mask_id: Mask token id. - pad_id: Padding token id. - masked_lm_prob: Probability to mask tokens. - np_rng: Random number genenrator. Note that this rng state should be - numpy and not python since python randint is inclusive for - the opper bound whereas the numpy one is exclusive. - bos_id: start of decoder example id - eos_id: end of generation id - sentinel_tokens: unique value to be substituted for every replaced span - """ - - assert target_seq_length <= max_seq_length - - # flatten sentences into one list - tokens = [token for sentence in sample for token in sentence] - - # Truncate to `target_sequence_length`. - max_num_tokens = target_seq_length - truncated = len(tokens) > max_num_tokens - tokens = tokens[:max_num_tokens] - - # Masking. - max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions( - tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, - cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng, - max_ngrams=10, geometric_dist=True, masking_style="t5") - - # Padding. - tokens_enc, tokens_dec_in, labels, enc_mask, \ - dec_mask, enc_dec_mask, loss_mask \ - = pad_and_convert_to_numpy(tokens, masked_positions, - masked_labels, pad_id, max_seq_length, - max_seq_length_dec, masked_spans, - bos_id, eos_id, sentinel_tokens) - - train_sample = { - 'text_enc': tokens_enc, - 'text_dec': tokens_dec_in, - 'labels': labels, - 'loss_mask': loss_mask, - 'truncated': int(truncated), - 'enc_mask': enc_mask, - 'dec_mask': dec_mask, - 'enc_dec_mask': enc_dec_mask, - } - return train_sample - - -def pad_and_convert_to_numpy(tokens, masked_positions, - masked_labels, pad_id, - max_seq_length, max_seq_length_dec, - masked_spans=None, bos_id=None, - eos_id=None, sentinel_tokens=None): - """Pad sequences and convert them to numpy.""" - - sentinel_tokens = collections.deque(sentinel_tokens) - t5_input = [] - (t5_decoder_in, t5_decoder_out) = ([bos_id], []) - (start_index, end_index) = (0, None) - for span in masked_spans: - flag = sentinel_tokens.popleft() - - # Append the same tokens in decoder input and output - t5_decoder_in.append(flag) - t5_decoder_in.extend(span.label) - t5_decoder_out.append(flag) - t5_decoder_out.extend(span.label) - - end_index = span.index[0] - t5_input.extend(tokens[start_index: end_index]) - t5_input.append(flag) - - # the next start index is the token after the last span token - start_index = span.index[-1] + 1 - - # Add token to the t5_decoder_out - t5_decoder_out.append(eos_id) - - # Add the remaining tokens to the t5 input - t5_input.extend(tokens[start_index:]) - - # assert (len(t5_input) - len(masked_spans)) + \ - # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens) - - # Some checks. - - # Encoder-side padding mask. - num_tokens = len(t5_input) - padding_length = max_seq_length - num_tokens - assert padding_length >= 0 - assert len(masked_positions) == len(masked_labels) - - # Tokens.. - filler = [pad_id] * padding_length - tokens_enc = np.array(t5_input + filler, dtype=np.int64) - - # Decoder-side padding mask. - num_tokens_dec = len(t5_decoder_in) - padding_length_dec = max_seq_length_dec - num_tokens_dec - assert padding_length_dec >= 0 - filler_dec = [pad_id] * padding_length_dec - tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64) - - # Create attention masks - enc_mask = make_attention_mask(tokens_enc, tokens_enc) - enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc) - dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in) - dec_mask = dec_mask * make_history_mask(tokens_dec_in) - - # Labels mask. - labels = t5_decoder_out + ([-1] * padding_length_dec) - labels = np.array(labels, dtype=np.int64) - - # Loss mask - loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec) - loss_mask = np.array(loss_mask, dtype=np.int64) - - return tokens_enc, tokens_dec_in, labels, enc_mask, \ - dec_mask, enc_dec_mask, loss_mask - - -def make_attention_mask(source_block, target_block): - """ - Returns a 2-dimensional (2-D) attention mask - :param source_block: 1-D array - :param target_block: 1-D array - """ - mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) - mask = mask.astype(np.int64) - # (source_length, target_length) - return mask - - -def make_attention_mask_3d(source_block, target_block): - """ - Returns a 3-dimensional (3-D) attention mask - :param source_block: 1-D array - :param target_block: 1-D array - """ - mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1) - # (batch, source_length, target_length) - # mask = mask.astype(np.int64) - return mask - - -def make_history_mask(block): - length = block.shape[0] - arange = np.arange(length) - history_mask = (arange[None, ] <= arange[:, None]) - history_mask = history_mask.astype(np.int64) - return history_mask - - -def make_history_mask_3d(block): - batch, length = block.shape - arange = torch.arange(length, device=block.device) - history_mask = (arange[None, ] <= arange[:, None])[None, ] - history_mask = history_mask.expand(batch, length, length) - return history_mask diff --git a/megatron/data/vit_dataset.py b/megatron/data/vit_dataset.py deleted file mode 100644 index 82391e9157e1a989dd2e13ea69d4b146284b6f4b..0000000000000000000000000000000000000000 --- a/megatron/data/vit_dataset.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import os -import random -import numpy as np -import torch -import torchvision.transforms as T -from torchvision import datasets -from megatron import get_args -from megatron.data.image_folder import ImageFolder -from megatron.data.autoaugment import ImageNetPolicy -from megatron.data.data_samplers import RandomSeedDataset -from PIL import Image, ImageFilter, ImageOps - - -class GaussianBlur(object): - """ - Apply Gaussian Blur to the PIL image. - """ - def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): - self.prob = p - self.radius_min = radius_min - self.radius_max = radius_max - - def __call__(self, img): - do_it = random.random() <= self.prob - if not do_it: - return img - - return img.filter( - ImageFilter.GaussianBlur( - radius=random.uniform(self.radius_min, self.radius_max) - ) - ) - - -class Solarization(object): - """ - Apply Solarization to the PIL image. - """ - def __init__(self, p): - self.p = p - - def __call__(self, img): - if random.random() < self.p: - return ImageOps.solarize(img) - else: - return img - - -class ClassificationTransform(): - def __init__(self, image_size, train=True): - args = get_args() - assert args.fp16 or args.bf16 - self.data_type = torch.half if args.fp16 else torch.bfloat16 - if train: - self.transform = T.Compose([ - T.RandomResizedCrop(image_size), - T.RandomHorizontalFlip(), - T.ColorJitter(0.4, 0.4, 0.4, 0.1), - ImageNetPolicy(), - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - T.ConvertImageDtype(self.data_type) - ]) - else: - self.transform = T.Compose([ - T.Resize(image_size), - T.CenterCrop(image_size), - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - T.ConvertImageDtype(self.data_type) - ]) - - def __call__(self, input): - output = self.transform(input) - return output - - -class InpaintingTransform(): - def __init__(self, image_size, train=True): - - args = get_args() - self.mask_factor = args.mask_factor - self.mask_type = args.mask_type - self.image_size = image_size - self.patch_size = args.patch_dim - self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size)) - self.train = train - assert args.fp16 or args.bf16 - self.data_type = torch.half if args.fp16 else torch.bfloat16 - - if self.train: - self.transform = T.Compose([ - T.RandomResizedCrop(self.image_size), - T.RandomHorizontalFlip(), - T.ColorJitter(0.4, 0.4, 0.4, 0.1), - ImageNetPolicy(), - T.ToTensor(), - T.ConvertImageDtype(self.data_type) - ]) - else: - self.transform = T.Compose([ - T.Resize(self.image_size, interpolation=2), - T.CenterCrop(self.image_size), - T.ToTensor(), - T.ConvertImageDtype(self.data_type) - ]) - - def gen_mask(self, image_size, mask_size, mask_type, patch_size): - # output: mask as a list with indices for missing patches - action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]] - assert image_size[0] == image_size[1] - img_size_patch = image_size[0] // patch_size - - # drop masked patches - mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float) - - if mask_type == 'random': - x = torch.randint(0, img_size_patch, ()) - y = torch.randint(0, img_size_patch, ()) - for i in range(mask_size): - r = torch.randint(0, len(action_list), ()) - x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1) - y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1) - x_offset = x * patch_size - y_offset = y * patch_size - mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 - else: - assert mask_type == 'row' - count = 0 - for x in reversed(range(img_size_patch)): - for y in reversed(range(img_size_patch)): - if (count < mask_size): - count += 1 - x_offset = x * patch_size - y_offset = y * patch_size - mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1 - return mask - - def __call__(self, input): - trans_input = self.transform(input) - mask = self.gen_mask(self.image_size, self.mask_size, - self.mask_type, self.patch_size) - mask = mask.unsqueeze(dim=0) - return trans_input, mask - - -class DinoTransform(object): - def __init__(self, image_size, train=True): - args = get_args() - self.data_type = torch.half if args.fp16 else torch.bfloat16 - - flip_and_color_jitter = T.Compose([ - T.RandomHorizontalFlip(p=0.5), - T.RandomApply( - [T.ColorJitter(brightness=0.4, contrast=0.4, - saturation=0.2, hue=0.1)], - p=0.8 - ), - T.RandomGrayscale(p=0.2), - ]) - - if args.fp16 or args.bf16: - normalize = T.Compose([ - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - T.ConvertImageDtype(self.data_type) - ]) - else: - normalize = T.Compose([ - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) - - # first global crop - scale_const = 0.4 - self.global_transform1 = T.Compose([ - T.RandomResizedCrop(image_size, - scale=(scale_const, 1), - interpolation=Image.BICUBIC), - flip_and_color_jitter, - GaussianBlur(1.0), - normalize - ]) - # second global crop - self.global_transform2 = T.Compose([ - T.RandomResizedCrop(image_size, - scale=(scale_const, 1), - interpolation=Image.BICUBIC), - flip_and_color_jitter, - GaussianBlur(0.1), - Solarization(0.2), - normalize - ]) - # transformation for the local small crops - self.local_crops_number = args.dino_local_crops_number - self.local_transform = T.Compose([ - T.RandomResizedCrop(args.dino_local_img_size, - scale=(0.05, scale_const), - interpolation=Image.BICUBIC), - flip_and_color_jitter, - GaussianBlur(p=0.5), - normalize - ]) - - def __call__(self, image): - crops = [] - crops.append(self.global_transform1(image)) - crops.append(self.global_transform2(image)) - for _ in range(self.local_crops_number): - crops.append(self.local_transform(image)) - return crops - - -def build_train_valid_datasets(data_path, image_size=224): - args = get_args() - - if args.vision_pretraining_type == 'classify': - train_transform = ClassificationTransform(image_size) - val_transform = ClassificationTransform(image_size, train=False) - elif args.vision_pretraining_type == 'inpaint': - train_transform = InpaintingTransform(image_size, train=False) - val_transform = InpaintingTransform(image_size, train=False) - elif args.vision_pretraining_type == 'dino': - train_transform = DinoTransform(image_size, train=True) - val_transform = ClassificationTransform(image_size, train=False) - else: - raise Exception('{} vit pretraining type is not supported.'.format( - args.vit_pretraining_type)) - - # training dataset - train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] - train_data = ImageFolder( - root=train_data_path, - transform=train_transform, - classes_fraction=args.classes_fraction, - data_per_class_fraction=args.data_per_class_fraction - ) - train_data = RandomSeedDataset(train_data) - - # validation dataset - val_data_path = data_path[1] - val_data = ImageFolder( - root=val_data_path, - transform=val_transform - ) - val_data = RandomSeedDataset(val_data) - - return train_data, val_data diff --git a/megatron/dist_signal_handler.py b/megatron/dist_signal_handler.py deleted file mode 100644 index a60204f004a3f149da0bc059cd875b6ec390c0c4..0000000000000000000000000000000000000000 --- a/megatron/dist_signal_handler.py +++ /dev/null @@ -1,81 +0,0 @@ -import signal - -import torch - - -def get_world_size(): - if torch.distributed.is_available() and torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - else: - world_size = 1 - return world_size - - -def get_device(local_rank=None): - backend = torch.distributed.get_backend() - if backend == 'nccl': - if local_rank is None: - device = torch.device('cuda') - else: - device = torch.device(f'cuda:{local_rank}') - elif backend == 'gloo': - device = torch.device('cpu') - else: - raise RuntimeError - return device - - -def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): - if not torch.distributed.is_available() or \ - not torch.distributed.is_initialized(): - return [item] - - device = get_device(local_rank) - - if group is not None: - group_size = group.size() - else: - group_size = get_world_size() - - tensor = torch.tensor([item], device=device, dtype=dtype) - output_tensors = [ - torch.zeros(1, dtype=tensor.dtype, device=tensor.device) - for _ in range(group_size) - ] - torch.distributed.all_gather(output_tensors, tensor, group, async_op) - output = [elem.item() for elem in output_tensors] - return output - - -class DistributedSignalHandler: - def __init__(self, sig=signal.SIGTERM): - self.sig = sig - - def signals_received(self): - all_received = all_gather_item( - self._signal_received, dtype=torch.int32 - ) - return all_received - - def __enter__(self): - self._signal_received = False - self.released = False - self.original_handler = signal.getsignal(self.sig) - - def handler(signum, frame): - self._signal_received = True - - signal.signal(self.sig, handler) - - return self - - def __exit__(self, type, value, tb): - self.release() - - def release(self): - if self.released: - return False - - signal.signal(self.sig, self.original_handler) - self.released = True - return True diff --git a/megatron/fp16_deprecated/__init__.py b/megatron/fp16_deprecated/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/fp16_deprecated/loss_scaler.py b/megatron/fp16_deprecated/loss_scaler.py deleted file mode 100644 index e31d00ad3215b443b6cae8e97da9b03dcdcc3ad4..0000000000000000000000000000000000000000 --- a/megatron/fp16_deprecated/loss_scaler.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""For backward compatibility, we need the class definitions to deserialize.""" - -class LossScaler: - def __init__(self, scale=1): - self.cur_scale = scale - -class DynamicLossScaler: - def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000, - min_scale=1, - delayed_shift=1, - consecutive_hysteresis=False): - self.cur_scale = init_scale - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = scale_factor - self.scale_window = scale_window - self.min_scale = min_scale - self.delayed_shift = delayed_shift - self.cur_hysteresis = delayed_shift - self.consecutive_hysteresis = consecutive_hysteresis - diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py deleted file mode 100644 index 87cceac3e35f983cf9f2264ff651a1067069f9e2..0000000000000000000000000000000000000000 --- a/megatron/fused_kernels/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import os -import pathlib -import subprocess - -from torch.utils import cpp_extension - -# Setting this param to a list has a problem of generating different -# compilation commands (with diferent order of architectures) and -# leading to recompilation of fused kernels. Set it to empty string -# to avoid recompilation and assign arch flags explicity in -# extra_cuda_cflags below -os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - -def load(args): - - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( - cpp_extension.CUDA_HOME - ) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - if int(bare_metal_minor) >= 8: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_90,code=sm_90') - - # Build path - srcpath = pathlib.Path(__file__).parent.absolute() - buildpath = srcpath / "build" - _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): - return cpp_extension.load( - name=name, - sources=sources, - build_directory=buildpath, - extra_cflags=[ - "-O3", - ], - extra_cuda_cflags=[ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "--use_fast_math", - ] - + extra_cuda_flags - + cc_flag, - verbose=(args.rank == 0), - ) - - -def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def _create_build_dir(buildpath): - try: - os.mkdir(buildpath) - except OSError: - if not os.path.isdir(buildpath): - print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h deleted file mode 100644 index 5495d7807762d8b4e3dbc11b28dba15f85bd8108..0000000000000000000000000000000000000000 --- a/megatron/fused_kernels/compat.h +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - -/*This code is copied fron NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - - - -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/fused_kernels/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py deleted file mode 100644 index 74024c5020f45ea3e80607fad38a45c4dab3453b..0000000000000000000000000000000000000000 --- a/megatron/fused_kernels/tests/test_fused_kernels.py +++ /dev/null @@ -1,388 +0,0 @@ -import math - -import torch -from torch.nn import LayerNorm - -from megatron.model.enums import AttnMaskType -from megatron.model.fused_layer_norm import MixedFusedLayerNorm -from megatron.model.fused_softmax import FusedScaleMaskSoftmax -from megatron.model.utils import attention_mask_func -from megatron.fused_kernels import load - -def test_load_fused_kernels(): - try: - import fused_layer_norm_cuda - import scaled_masked_softmax_cuda - import scaled_upper_triang_masked_softmax_cuda - import torch - - print("[Success] load_fused_kernels") - except ImportError as e: - print("[Fail] load_fused_kernels") - raise e - -def test_fused_softmax(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - embedding_output = bert.embeddings( - input_ids=tokens["input_ids"].cuda(), - position_ids=None, - token_type_ids=tokens["token_type_ids"].cuda(), - inputs_embeds=None, - past_key_values_length=0, - ) - - # (bsz, 1, 1, seq_len) - mask = bert.get_extended_attention_mask( - attention_mask=tokens["attention_mask"].cuda(), - input_shape=tokens["input_ids"].shape, - device=bert.device, - ) - # (bsz, 1, seq_len, seq_len) - mask = mask.repeat(1, 1, mask.size()[-1], 1) - - attention = bert.encoder.layer[0].attention.self - key_layer = attention.transpose_for_scores(attention.key(embedding_output)) - query_layer = attention.transpose_for_scores(attention.query(embedding_output)) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores /= math.sqrt(key_layer.size()[-1]) - - fused_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - scaled_masked_softmax_fusion=True, - ) - .cuda() - .half() - ) - - fused_softmax_output = fused_softmax( - attention_scores, - (mask != 0), - ) - - torch_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - scaled_masked_softmax_fusion=False, - ) - .cuda() - .half() - ) - - torch_softmax_output = torch_softmax( - attention_scores, - (mask != 0), - ) - - test_result = (fused_softmax_output - torch_softmax_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_fused_softmax" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_fused_softmax" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - - -def test_fused_upper_triangle_mask_softmax(): - gpt = GPT2Model.from_pretrained("gpt2").cuda().half() - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi" # 24 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - attention_mask = tokens["attention_mask"].cuda() - attention_mask = attention_mask.view(attention_mask.size(0), -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = (1.0 - attention_mask) * -10000.0 - attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) - attn = gpt.h[0] - - hidden_states = gpt.wte(tokens["input_ids"].cuda()) - q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) - q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) - k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) - attn_weights = torch.matmul(q, k.transpose(-1, -2)) - - sq, sk = q.size(-2), k.size(-2) - causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() - total_mask = ~(causal_mask & (attention_mask == 0)) - """ - tensor([[[[False, True, True, ..., True, True, True], - [False, False, True, ..., True, True, True], - [False, False, False, ..., True, True, True], - ..., - [False, False, False, ..., False, True, True], - [False, False, False, ..., False, False, True], - [False, False, False, ..., False, False, False]]] - """ - - fused_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - ) - .cuda() - .half() - ) - - fused_softmax_output = fused_softmax( - attn_weights, - total_mask, - ) - - torch_softmax = ( - FusedScaleMaskSoftmax( - input_in_fp16=True, - input_in_bf16=False, - mask_func=attention_mask_func, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=False, - ) - .cuda() - .half() - ) - - torch_softmax_output = torch_softmax( - attn_weights, - total_mask, - ) - - test_result = (fused_softmax_output - torch_softmax_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_fused_upper_triangle_mask_softmax" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_fused_upper_triangle_mask_softmax" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" - ) - - -def test_layer_norm(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() - tokenizer = BertTokenizer.from_pretrained("bert-base-cased") - test_text = ( - "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 - ) - - tokens = tokenizer( - [test_text] * 4, - return_tensors="pt", - ) - - # [bsz, seq_len, d_model] - embedding_output = ( - bert.embeddings( - input_ids=tokens["input_ids"].cuda(), - position_ids=None, - token_type_ids=tokens["token_type_ids"].cuda(), - inputs_embeds=None, - past_key_values_length=0, - ) - .cuda() - .half() - ) - - fused_layernorm_layer = ( - MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() - ) - - torch_layernorm_layer = ( - LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() - ) - - fused_output = fused_layernorm_layer(embedding_output) - torch_output = torch_layernorm_layer(embedding_output) - test_result = (fused_output - torch_output).abs() - - while test_result.dim() != 1: - test_result = test_result.mean(dim=-1) - - diff = test_result.mean(dim=-1) - - if diff <= 1e-3: - print( - f"\n[Success] test_layer_norm" - f"\n > mean_difference={diff}" - f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" - f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" - ) - else: - print( - f"\n[Fail] test_layer_norm" - f"\n > mean_difference={diff}, " - f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " - f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" - ) - - -def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def forward_torch_softmax(input, mask, scale): - input = input * scale - mask_output = attention_mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - return probs - - -def test_masked_softmax_forward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - error = (softmax_results_torch - softmax_results).abs().max() - assert error < 1e-3 - -def test_masked_softmax_backward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') - masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) - - inputs.requires_grad = True - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - softmax_results_torch.backward(backward) - error = (back_grad - inputs.grad).abs().max() - assert error < 1e-3 - - -def test_allmasked_softmax_forward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - softmax_results_torch = torch.zeros_like(inputs) - error = (softmax_results_torch - softmax_results).abs().max() - assert error == 0.0 - - -def test_allmasked_softmax_backward(): - import scaled_masked_softmax_cuda - - batch = 2 - attn = 16 - scale_t = torch.tensor([1.0]) - for qlen in [128, 256, 1024, 2048, 4096]: - for klen in [128, 256, 1024, 2048]: - inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') - backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') - masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') - softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) - back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) - inputs.requires_grad = True - softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) - softmax_results_torch.backward(backward) - error = (back_grad - inputs.grad).abs().max() - assert error < 1e-3 - - -if __name__ == "__main__": - try: - from transformers import BertTokenizer, GPT2Tokenizer - from transformers.models.bert.modeling_bert import BertModel - from transformers.models.gpt2.modeling_gpt2 import GPT2Model - import transformers - - transformers.logging.set_verbosity( - transformers.logging.FATAL, - ) - - except: - print("\n[Fail] Please install `transformers` package to test fused kernels\n") - exit(-1) - - load() - test_masked_softmax_forward() - test_masked_softmax_backward() - test_allmasked_softmax_forward() - test_allmasked_softmax_backward() - test_load_fused_kernels() - test_fused_softmax() - test_fused_upper_triangle_mask_softmax() - test_layer_norm() diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h deleted file mode 100644 index d60a6f8c6fb50e241f9ddcc852adec71e963e1b2..0000000000000000000000000000000000000000 --- a/megatron/fused_kernels/type_shim.h +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ - - -#include -#include "compat.h" - - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Half: \ - { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Half: \ - { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch(TYPEIN) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_in = float; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ - } - diff --git a/megatron/global_vars.py b/megatron/global_vars.py deleted file mode 100644 index 04b1448ee80b2707a882d482edce3731356e99eb..0000000000000000000000000000000000000000 --- a/megatron/global_vars.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron global variables.""" - -import os -import sys -import torch - -from megatron import dist_signal_handler -from megatron.tokenizer import build_tokenizer -from .microbatches import build_num_microbatches_calculator -from .timers import Timers - -_GLOBAL_ARGS = None -_GLOBAL_RETRO_ARGS = None -_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None -_GLOBAL_TOKENIZER = None -_GLOBAL_TENSORBOARD_WRITER = None -_GLOBAL_WANDB_WRITER = None -_GLOBAL_ADLR_AUTORESUME = None -_GLOBAL_TIMERS = None -_GLOBAL_SIGNAL_HANDLER = None - -def get_args(): - """Return arguments.""" - _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') - return _GLOBAL_ARGS - - -def get_retro_args(): - """Return retro arguments.""" - return _GLOBAL_RETRO_ARGS - - -def get_num_microbatches(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() - - -def get_current_global_batch_size(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() - - -def update_num_microbatches(consumed_samples, consistency_check=True): - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, - consistency_check) - - -def get_tokenizer(): - """Return tokenizer.""" - _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') - return _GLOBAL_TOKENIZER - - -def get_tensorboard_writer(): - """Return tensorboard writer. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_TENSORBOARD_WRITER - - -def get_wandb_writer(): - """Return tensorboard writer. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_WANDB_WRITER - - -def get_adlr_autoresume(): - """ADLR autoresume object. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_ADLR_AUTORESUME - - -def get_timers(): - """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') - return _GLOBAL_TIMERS - - -def get_signal_handler(): - _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') - return _GLOBAL_SIGNAL_HANDLER - - -def _set_signal_handler(): - global _GLOBAL_SIGNAL_HANDLER - _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') - _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__() - - - -def set_global_variables(args, build_tokenizer=True): - """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" - - assert args is not None - - _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') - set_args(args) - - _build_num_microbatches_calculator(args) - if build_tokenizer: - _ = _build_tokenizer(args) - _set_tensorboard_writer(args) - _set_wandb_writer(args) - _set_adlr_autoresume(args) - _set_timers(args) - - if args.exit_signal_handler: - _set_signal_handler() - - -def set_args(args): - global _GLOBAL_ARGS - _GLOBAL_ARGS = args - - -def set_retro_args(retro_args): - global _GLOBAL_RETRO_ARGS - _GLOBAL_RETRO_ARGS = retro_args - - -def _build_num_microbatches_calculator(args): - - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, - 'num microbatches calculator') - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - args) - - -def _build_tokenizer(args): - """Initialize tokenizer.""" - global _GLOBAL_TOKENIZER - _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') - _GLOBAL_TOKENIZER = build_tokenizer(args) - return _GLOBAL_TOKENIZER - - -def rebuild_tokenizer(args): - global _GLOBAL_TOKENIZER - _GLOBAL_TOKENIZER = None - return _build_tokenizer(args) - - -def _set_tensorboard_writer(args): - """Set tensorboard writer.""" - global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, - 'tensorboard writer') - - if hasattr(args, 'tensorboard_dir') and \ - args.tensorboard_dir and args.rank == (args.world_size - 1): - try: - from torch.utils.tensorboard import SummaryWriter - print('> setting tensorboard ...') - _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( - log_dir=args.tensorboard_dir, - max_queue=args.tensorboard_queue_size) - except ModuleNotFoundError: - print('WARNING: TensorBoard writing requested but is not ' - 'available (are you using PyTorch 1.1.0 or later?), ' - 'no TensorBoard logs will be written.', flush=True) - - -def _set_wandb_writer(args): - global _GLOBAL_WANDB_WRITER - _ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER, - 'wandb writer') - if getattr(args, 'wandb_project', '') and args.rank == (args.world_size - 1): - if args.wandb_exp_name == '': - raise ValueError("Please specify the wandb experiment name!") - - import wandb - if args.wandb_save_dir: - save_dir = args.wandb_save_dir - else: - # Defaults to the save dir. - save_dir = os.path.join(args.save, 'wandb') - wandb_kwargs = { - 'dir': save_dir, - 'name': args.wandb_exp_name, - 'project': args.wandb_project, - 'config': vars(args)} - os.makedirs(wandb_kwargs['dir'], exist_ok=True) - wandb.init(**wandb_kwargs) - _GLOBAL_WANDB_WRITER = wandb - - -def _set_adlr_autoresume(args): - """Initialize ADLR autoresume.""" - global _GLOBAL_ADLR_AUTORESUME - _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') - - if args.adlr_autoresume: - if args.rank == 0: - print('enabling autoresume ...', flush=True) - sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '')) - try: - from userlib.auto_resume import AutoResume - except BaseException: - print('ADLR autoresume is not available, exiting ...') - sys.exit() - - _GLOBAL_ADLR_AUTORESUME = AutoResume - - -def _set_timers(args): - """Initialize timers.""" - global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') - _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) - - -def _ensure_var_is_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is not None, '{} is not initialized.'.format(name) - - -def _ensure_var_is_not_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) - - - diff --git a/megatron/indexer.py b/megatron/indexer.py deleted file mode 100644 index 45f530a7d4d7d72bfeb1f2d9a3098b3812a332fe..0000000000000000000000000000000000000000 --- a/megatron/indexer.py +++ /dev/null @@ -1,129 +0,0 @@ -import sys -import time -import torch -import torch.distributed as dist - -from megatron import get_args, print_rank_0 -from megatron.core import mpu -from megatron.checkpointing import load_biencoder_checkpoint -from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset -from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch -from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader -from megatron.data.realm_index import detach, OpenRetreivalDataStore -from megatron.model.biencoder_model import get_model_provider -from megatron.training import get_model - - -class IndexBuilder(object): - """ - Object for taking one pass over a dataset and creating a BlockData of its - embeddings - """ - def __init__(self): - args = get_args() - self.model = None - self.dataloader = None - self.evidence_embedder_obj = None - self.biencoder_shared_query_context_model = \ - args.biencoder_shared_query_context_model - - # need to know whether we're using a REALM checkpoint (args.load) - # or ICT checkpoint - assert not (args.load and args.ict_load) - - self.log_interval = args.indexer_log_interval - self.batch_size = args.indexer_batch_size - - self.load_attributes() - self.is_main_builder = mpu.get_data_parallel_rank() == 0 - self.num_total_builders = mpu.get_data_parallel_world_size() - self.iteration = self.total_processed = 0 - - def load_attributes(self): - """ - Load the necessary attributes: model, dataloader and empty BlockData - """ - only_context_model = True - if self.biencoder_shared_query_context_model: - only_context_model = False - - model = get_model(get_model_provider(only_context_model=\ - only_context_model, biencoder_shared_query_context_model=\ - self.biencoder_shared_query_context_model)) - - self.model = load_biencoder_checkpoint(model, - only_context_model=only_context_model) - - assert len(self.model) == 1 - self.model[0].eval() - - self.dataset = get_open_retrieval_wiki_dataset() - self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ - self.batch_size)) - - self.evidence_embedder_obj = OpenRetreivalDataStore( \ - load_from_path=False) - - def track_and_report_progress(self, batch_size): - """ - Utility function for tracking progress - """ - self.iteration += 1 - self.total_processed += batch_size * self.num_total_builders - if self.is_main_builder and self.iteration % self.log_interval == 0: - print('Batch {:10d} | Total {:10d}'.format(self.iteration, - self.total_processed), flush=True) - - def build_and_save_index(self): - """ - Goes through one epoch of the dataloader and adds all data to this - instance's BlockData. - - The copy of BlockData is saved as a shard, which when run in a - distributed setting will be consolidated by the rank 0 process - and saved as a final pickled BlockData. - """ - assert len(self.model) == 1 - unwrapped_model = self.model[0] - - while not hasattr(unwrapped_model, 'embed_text'): - unwrapped_model = unwrapped_model.module - - while True: - try: - # batch also has query_tokens and query_pad_data - row_id, context_tokens, context_mask, context_types, \ - context_pad_mask = get_open_retrieval_batch( \ - self.dataloader) - except (StopIteration, IndexError): - break - - # TODO: can we add with torch.no_grad() to reduce memory usage - # detach, separate fields and add to BlockData - assert context_mask.dtype == torch.bool - context_logits = unwrapped_model.embed_text( - unwrapped_model.context_model, context_tokens, context_mask, - context_types) - - context_logits = detach(context_logits) - row_id = detach(row_id) - - self.evidence_embedder_obj.add_block_data(row_id, context_logits) - self.track_and_report_progress(batch_size=len(row_id)) - - # This process signals to finalize its shard and then synchronize with - # the other processes - self.evidence_embedder_obj.save_shard() - torch.distributed.barrier() - del self.model - - # rank 0 process builds the final copy - if self.is_main_builder: - self.evidence_embedder_obj.merge_shards_and_save() - # make sure that every single piece of data was embedded - assert len(self.evidence_embedder_obj.embed_data) == \ - len(self.dataset) - self.evidence_embedder_obj.clear() - - # complete building the final copy - torch.distributed.barrier() diff --git a/megatron/initialize.py b/megatron/initialize.py deleted file mode 100644 index fb7866ab03510701afb351022f566e8f1e7ba00b..0000000000000000000000000000000000000000 --- a/megatron/initialize.py +++ /dev/null @@ -1,387 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron initialization.""" - -import random -import os -import time - -import numpy as np -import torch -from datetime import timedelta - -from megatron import fused_kernels -from megatron import get_adlr_autoresume -from megatron import get_args -from megatron import get_tensorboard_writer -from megatron.core import mpu, tensor_parallel -from megatron.arguments import parse_args, validate_args -from megatron.checkpointing import load_args_from_checkpoint -from megatron.global_vars import set_global_variables -from megatron.model.transformer import bias_dropout_add_fused_train -from megatron.model.fused_bias_gelu import bias_gelu - -def initialize_megatron( - extra_args_provider=None, - args_defaults={}, - ignore_unknown_args=False, - allow_no_cuda=False, - skip_mpu_initialization=False, -): - """Set global variables, initialize distributed, and - set autoresume and random seeds. - `allow_no_cuda` should not be set unless using megatron for cpu only - data processing. In general this arg should not be set unless you know - what you are doing. - Returns a function to finalize distributed env initialization - (optionally, only when args.lazy_mpu_init == True) - """ - if not allow_no_cuda: - # Make sure cuda is available. - assert torch.cuda.is_available(), "Megatron requires CUDA." - - # Parse arguments - args = parse_args(extra_args_provider, ignore_unknown_args) - - if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): - assert args.load is not None, "--use-checkpoints-args requires --load argument" - load_args_from_checkpoint(args) - - validate_args(args, args_defaults) - - # set global args, build tokenizer, and set adlr-autoresume, - # tensorboard-writer, and timers. - set_global_variables(args) - - # torch.distributed initialization - def finish_mpu_init(): - args = get_args() - # Pytorch distributed. - _initialize_distributed() - - # Random seeds for reproducibility. - if args.rank == 0: - print("> setting random seeds to {} ...".format(args.seed)) - _set_random_seed(args.seed, args.data_parallel_random_init) - - if skip_mpu_initialization: - return None - - args = get_args() - if args.lazy_mpu_init: - # TODO is this still a necessary option? - args.use_cpu_initialization = True - # delayed initialization of DDP-related stuff - # We only set basic DDP globals - mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) - # and return function for external DDP manager - # to call when it has DDP initialized - mpu.set_tensor_model_parallel_rank(args.rank) - return finish_mpu_init - else: - # Megatron's MPU is the master. Complete initialization right away. - finish_mpu_init() - - # Autoresume. - _init_autoresume() - - # Compile dependencies. - _compile_dependencies() - - if args.tp_comm_overlap: - _initialize_tp_communicators() - - # No continuation function - return None - - -def _compile_dependencies(): - - args = get_args() - - # ========================= - # Compile dataset C++ code. - # ========================= - # TODO: move this to ninja - if torch.distributed.get_rank() == 0: - start_time = time.time() - print("> compiling dataset index builder ...") - from megatron.core.datasets.utils import compile_helpers - - compile_helpers() - print( - ">>> done with dataset index builder. Compilation time: {:.3f} " - "seconds".format(time.time() - start_time), - flush=True, - ) - - # ================== - # Load fused kernels - # ================== - - # Custom kernel constraints check. - seq_len = args.seq_length - attn_batch_size = ( - args.num_attention_heads / args.tensor_model_parallel_size - ) * args.micro_batch_size - # Constraints on sequence length and attn_batch_size to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = ( - seq_len > 16 - and seq_len <= 16384 - and seq_len % 4 == 0 - and attn_batch_size % 4 == 0 - ) - # Print a warning. - if not ( - (args.fp16 or args.bf16) - and custom_kernel_constraint - and args.masked_softmax_fusion - ): - if args.rank == 0: - print( - "WARNING: constraints for invoking optimized" - " fused softmax kernel are not met. We default" - " back to unfused kernel invocations.", - flush=True, - ) - - # Always build on rank zero first. - if torch.distributed.get_rank() == 0: - start_time = time.time() - print("> compiling and loading fused kernels ...", flush=True) - fused_kernels.load(args) - torch.distributed.barrier() - else: - torch.distributed.barrier() - fused_kernels.load(args) - # Simple barrier to make sure all ranks have passed the - # compilation phase successfully before moving on to the - # rest of the program. We think this might ensure that - # the lock is released. - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print( - ">>> done with compiling and loading fused kernels. " - "Compilation time: {:.3f} seconds".format(time.time() - start_time), - flush=True, - ) - -def _initialize_tp_communicators(): - """ initializing the communicators with user buffers for high-performance tensor-model-parallel - communication overlap """ - - try: - import yaml - - import transformer_engine - from transformer_engine.pytorch import module as te_module - - except ImportError: - raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and " - "'transformer_engine' packages") - - args = get_args() - - if args.tp_comm_overlap_cfg is not None: - with open(args.tp_comm_overlap_cfg,"r") as stream: - ub_cfgs = yaml.safe_load(stream) - else: - ub_cfgs = {} - - input_shape = [args.seq_length * args.micro_batch_size , args.hidden_size] - - #We create a MPI process group, which is needed to bootstrap the pipelined - #tensor-model-parallel communication overlap - torch.distributed.new_group(backend='mpi') - - te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, - use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,) - -def _initialize_distributed(): - """Initialize torch.distributed and core model parallel.""" - args = get_args() - - device_count = torch.cuda.device_count() - if torch.distributed.is_initialized(): - - if args.rank == 0: - print( - "torch distributed is already initialized, " - "skipping initialization ...", - flush=True, - ) - args.rank = torch.distributed.get_rank() - args.world_size = torch.distributed.get_world_size() - - else: - - if args.rank == 0: - print("> initializing torch distributed ...", flush=True) - # Manually set the device ids. - if device_count > 0: - device = args.rank % device_count - if args.local_rank is not None: - assert ( - args.local_rank == device - ), "expected local-rank to be the same as rank % device-count." - else: - args.local_rank = device - torch.cuda.set_device(device) - # Call the init process - torch.distributed.init_process_group( - backend=args.distributed_backend, - world_size=args.world_size, - rank=args.rank, - timeout=timedelta(minutes=args.distributed_timeout_minutes), - ) - - # Set the tensor model-parallel, pipeline model-parallel, and - # data-parallel communicators. - if device_count > 0: - if mpu.model_parallel_is_initialized(): - print("model parallel is already initialized") - else: - mpu.initialize_model_parallel( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank, - context_parallel_size=args.context_parallel_size, - expert_model_parallel_size=args.expert_model_parallel_size, - nccl_communicator_config_path=args.nccl_communicator_config_path, - ) - if args.rank == 0: - print( - f"> initialized tensor model parallel with size " - f"{mpu.get_tensor_model_parallel_world_size()}" - ) - print( - f"> initialized pipeline model parallel with size " - f"{mpu.get_pipeline_model_parallel_world_size()}" - ) - - -def _init_autoresume(): - """Set autoresume start time.""" - autoresume = get_adlr_autoresume() - if autoresume: - torch.distributed.barrier() - autoresume.init() - torch.distributed.barrier() - - -def _set_random_seed(seed_, data_parallel_random_init=False): - """Set random seed for reproducability.""" - if seed_ is not None and seed_ > 0: - # Ensure that different pipeline MP stages get different seeds. - seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) - # Ensure different data parallel ranks get different seeds - if data_parallel_random_init: - seed = seed + (10 * mpu.get_data_parallel_rank()) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.device_count() > 0: - tensor_parallel.model_parallel_cuda_manual_seed(seed) - else: - raise ValueError("Seed ({}) should be a positive integer.".format(seed)) - - -def write_args_to_tensorboard(): - """Write arguments to tensorboard.""" - args = get_args() - writer = get_tensorboard_writer() - if writer: - for arg in vars(args): - writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration) - - -def set_jit_fusion_options(): - """Set PyTorch JIT layer fusion options.""" - # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): - # nvfuser - torch._C._jit_set_profiling_executor(True) - torch._C._jit_set_profiling_mode(True) - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(True) - torch._C._debug_set_autodiff_subgraph_inlining(False) - else: - # legacy pytorch fuser - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - - _warmup_jit_function() - - -def _warmup_jit_function(): - """Compilie JIT functions before the main training steps""" - args = get_args() - if args.bf16: - dtype = torch.bfloat16 - elif args.fp16: - dtype = torch.float16 - else: - dtype = torch.float32 - - # Warmup fused bias+gelu - bias = torch.rand( - args.ffn_hidden_size // args.tensor_model_parallel_size, - dtype=dtype, - device="cuda", - ) - input = torch.rand( - ( - args.seq_length, - args.micro_batch_size, - args.ffn_hidden_size // args.tensor_model_parallel_size, - ), - dtype=dtype, - device="cuda", - ) - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for bias_grad, input_grad in zip([True, True], [False, True]): - bias.requires_grad, input.requires_grad = bias_grad, input_grad - for _ in range(5): - output = bias_gelu(bias, input) - del bias, input, output - - # Warmup fused bias+dropout+add - if args.sequence_parallel: - seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() - else: - seq_length = args.seq_length - input = torch.rand( - (seq_length, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - residual = torch.rand( - (seq_length, args.micro_batch_size, args.hidden_size), - dtype=dtype, - device="cuda", - ) - bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as( - residual - ) - dropout_rate = 0.1 - # Warmup JIT fusions with the input grad_enable state of both forward - # prop and recomputation - for input_grad, bias_grad, residual_grad in zip( - [False, True], [True, True], [True, True] - ): - input.requires_grad = input_grad - bias.requires_grad = bias_grad - residual.requires_grad = residual_grad - for _ in range(5): - output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate) - del bias, input, residual, output - torch.cuda.empty_cache() diff --git a/megatron/log_handler.py b/megatron/log_handler.py deleted file mode 100644 index 06f5d1842d1d8bb89ca78633854ce4d910761f1a..0000000000000000000000000000000000000000 --- a/megatron/log_handler.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import sys -from logging import LogRecord, StreamHandler - -BLACKLISTED_MODULES = ["torch.distributed"] - - -class CustomHandler(StreamHandler): - """ - Custom handler to filter out logging from code outside of - Megatron Core, and dump to stdout. - """ - - def __init__(self): - super().__init__(stream=sys.stdout) - - def filter(self, record: LogRecord) -> bool: - # Prevent log entries that come from the blacklisted modules - # through (e.g., PyTorch Distributed). - for blacklisted_module in BLACKLISTED_MODULES: - if record.name.startswith(blacklisted_module): - return False - return True diff --git a/megatron/memory.py b/megatron/memory.py deleted file mode 100644 index a5fef75baa749d557da227bbccf706501ffdd10f..0000000000000000000000000000000000000000 --- a/megatron/memory.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - - -import torch - - -# A dictionary of all the memory buffers allocated. -_MEM_BUFFS = dict() - - -def allocate_mem_buff(name, numel, dtype, track_usage): - """Allocate a memory buffer.""" - assert name not in _MEM_BUFFS, \ - 'memory buffer {} already allocated.'.format(name) - _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) - return _MEM_BUFFS[name] - - -def get_mem_buff(name): - """Get the memory buffer.""" - return _MEM_BUFFS[name] - - -class MemoryBuffer: - """Contiguous memory buffer. - Allocate a contiguous memory of type `dtype` and size `numel`. It is - used to reduce memory fragmentation. - - Usage: After the allocation, the `_start` index is set tot the first - index of the memory. A memory chunk starting from `_start` index - can be `allocated` for an input tensor, with the elements of the - tensor being coppied. The buffer can be reused by resetting the - `_start` index. - - """ - def __init__(self, name, numel, dtype, track_usage): - if torch.distributed.get_rank() == 0: - element_size = torch.tensor([], dtype=dtype).element_size() - print('> building the {} memory buffer with {} num elements ' - 'and {} dtype ({:.1f} MB)...'.format( - name, numel, dtype, numel*element_size/1024/1024), - flush=True) - self.name = name - self.numel = numel - self.dtype = dtype - self.data = torch.empty(self.numel, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - - # Index tracking the start of the free memory. - self._start = 0 - - # Values used for tracking usage. - self.track_usage = track_usage - if self.track_usage: - self.in_use_value = 0.0 - self.total_value = 0.0 - - - def reset(self): - """Reset the buffer start index to the beginning of the buffer.""" - self._start = 0 - - - def is_in_use(self): - """Whether the current buffer hold on to any memory.""" - return self._start > 0 - - - def numel_in_use(self): - """Return number of elements in use.""" - return self._start - - - def add(self, tensor): - """Allocate a chunk of memory from the buffer to tensor and copy - the values.""" - assert tensor.dtype == self.dtype, \ - 'Input tensor type {} different from buffer type {}'.format( - tensor.dtype, self.dtype) - # Number of elements of the input tensor. - tensor_numel = torch.numel(tensor) - new_start = self._start + tensor_numel - assert new_start <= self.numel, \ - 'Not enough memory left in the buffer ({} > {})'.format( - tensor_numel, self.numel - self._start) - # New tensor is a view into the memory. - new_tensor = self.data[self._start:new_start] - self._start = new_start - new_tensor = new_tensor.view(tensor.shape) - new_tensor.copy_(tensor) - # Return a pointer to the new tensor. - return new_tensor - - - def get_data(self): - """Return the data currently in use.""" - if self.track_usage: - self.in_use_value += float(self._start) - self.total_value += float(self.numel) - return self.data[:self._start] - - - def print_average_usage(self): - """Print memory usage average over time. We would like this value - to be as high as possible.""" - assert self.track_usage, 'You need to enable track usage.' - if torch.distributed.get_rank() == 0: - print(' > usage of {} memory buffer: {:.2f} %'.format( - self.name, self.in_use_value * 100.0 / self.total_value), - flush=True) - - - -class RingMemBuffer: - """A ring of memory buffers.""" - - def __init__(self, name, num_buffers, numel, dtype, track_usage): - self.num_buffers = num_buffers - self.buffers = [ - allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) - for i in range(num_buffers)] - self._index = -1 - - - def get_next_buffer(self): - self._index += 1 - self._index = self._index % self.num_buffers - buff = self.buffers[self._index] - assert not buff.is_in_use(), 'buffer is already in use.' - return buff diff --git a/megatron/microbatches.py b/megatron/microbatches.py deleted file mode 100644 index 6449d7479c9c983b4813889ee8f1beec9e027cc3..0000000000000000000000000000000000000000 --- a/megatron/microbatches.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron number of micro-batches calculators.""" - -from abc import ABC -from abc import abstractmethod - - -def build_num_microbatches_calculator(args): - - # Constant num micro-batches. - if args.rampup_batch_size is None: - num_microbatches_calculator = ConstantNumMicroBatches( - args.global_batch_size, args.micro_batch_size, - args.data_parallel_size) - if args.rank == 0: - print('setting number of micro-batches to constant {}'.format( - num_microbatches_calculator.get()), flush=True) - - else: - assert len(args.rampup_batch_size) == 3, 'expected the following ' \ - 'format: --rampup-batch-size ' \ - ' ' - start_batch_size = int(args.rampup_batch_size[0]) - batch_size_increment = int(args.rampup_batch_size[1]) - ramup_samples = int(args.rampup_batch_size[2]) - if args.rank == 0: - print('will use batch size rampup starting from global batch ' - 'size {} to global batch size {} with batch size increments ' - '{} over {} samples.'.format(start_batch_size, - args.global_batch_size, - batch_size_increment, - ramup_samples), flush=True) - num_microbatches_calculator = RampupBatchsizeNumMicroBatches( - start_batch_size, batch_size_increment, ramup_samples, - args.global_batch_size, args.micro_batch_size, - args.data_parallel_size) - - return num_microbatches_calculator - - -class NumMicroBatchesCalculator(ABC): - - def __init__(self): - self.num_micro_batches = None - self.current_global_batch_size = None - - def get(self): - return self.num_micro_batches - - def get_current_global_batch_size(self): - return self.current_global_batch_size - - @abstractmethod - def update(self, consumed_samples, consistency_check): - pass - - -class ConstantNumMicroBatches(NumMicroBatchesCalculator): - - def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): - micro_batch_times_data_parallel = micro_batch_size * \ - data_parallel_size - assert global_batch_size % micro_batch_times_data_parallel == 0, \ - 'global batch size ({}) is not divisible by micro batch size ({})' \ - ' times data parallel size ({})'.format(global_batch_size, - micro_batch_size, - data_parallel_size) - self.num_micro_batches = global_batch_size // \ - micro_batch_times_data_parallel - assert self.num_micro_batches >= 1 - self.current_global_batch_size = global_batch_size - - def update(self, consumed_samples, consistency_check): - pass - - -class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): - - def __init__(self, start_batch_size, batch_size_increment, ramup_samples, - global_batch_size, micro_batch_size, data_parallel_size): - """Batch size ramp up. - Over - steps = (global-batch-size - start-batch-size) / batch_size_increment - increment batch size from start-batch-size to global-batch-size using - rampup-samples / steps - samples. - Arguments: - start_batch_size: global batch size to start with - batch_size_increment: global batch size increments - ramup_samples: number of samples to use ramp up global - batch size from `start_batch_size` to `global_batch_size` - global_batch_size: global batch size post rampup - micro_batch_size: micro batch size - data_parallel_size: data parallel size. - """ - - self.micro_batch_size = micro_batch_size - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ - self.data_parallel_size - assert self.micro_batch_times_data_parallel_size > 0 - - assert start_batch_size > 0 - self.start_batch_size = start_batch_size - - assert global_batch_size > 0 - self.global_batch_size = global_batch_size - diff_batch_size = self.global_batch_size - self.start_batch_size - assert diff_batch_size >= 0 - assert batch_size_increment > 0 - self.batch_size_increment = batch_size_increment - assert diff_batch_size % batch_size_increment == 0, 'expected ' \ - 'global batch size interval ({}) to be divisible by global batch ' \ - 'size increment ({})'.format(diff_batch_size, batch_size_increment) - - num_increments = diff_batch_size // self.batch_size_increment - self.ramup_samples = ramup_samples - assert self.ramup_samples >= 0 - self.rampup_samples_per_increment = self.ramup_samples / num_increments - - # Initialize number of microbatches. - self.update(0, False) - - - def update(self, consumed_samples, consistency_check): - - if consumed_samples > self.ramup_samples: - self.current_global_batch_size = self.global_batch_size - else: - steps = int(consumed_samples / self.rampup_samples_per_increment) - self.current_global_batch_size = self.start_batch_size + \ - steps * self.batch_size_increment - assert self.current_global_batch_size <= self.global_batch_size - - if consistency_check: - assert self.current_global_batch_size % \ - self.micro_batch_times_data_parallel_size == 0, 'current global ' \ - 'batch size ({}) is not divisible by micro-batch-size ({}) times' \ - 'data parallel size ({})'.format(self.current_global_batch_size, - self.micro_batch_size, - self.data_parallel_size) - self.num_micro_batches = self.current_global_batch_size // \ - self.micro_batch_times_data_parallel_size diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py deleted file mode 100644 index cb010e5fb6c318ae849ad647d8f6d4ee4e309931..0000000000000000000000000000000000000000 --- a/megatron/model/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm -from .rms_norm import RMSNorm - -from .bert_model import BertModel -from .gpt_model import GPTModel -from .t5_model import T5Model -from .language_model import get_language_model -from .module import Float16Module diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py deleted file mode 100644 index cd4bb35db725faebb2304305f207b1afd777d412..0000000000000000000000000000000000000000 --- a/megatron/model/bert_model.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""BERT model.""" - -import torch - -from megatron import get_args -from megatron.core import tensor_parallel -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import parallel_lm_logits -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_norm -from megatron.model.utils import openai_gelu, erf_gelu -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - - -def bert_extended_attention_mask(attention_mask): - # We create a 3D attention mask from a 2D tensor mask. - # [b, 1, s] - attention_mask_b1s = attention_mask.unsqueeze(1) - # [b, s, 1] - attention_mask_bs1 = attention_mask.unsqueeze(2) - # [b, s, s] - attention_mask_bss = attention_mask_b1s * attention_mask_bs1 - # [b, 1, s, s] - extended_attention_mask = attention_mask_bss.unsqueeze(1) - - # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) - - return extended_attention_mask - -def bert_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class BertLMHead(MegatronModule): - """Masked LM head for Bert - - Arguments: - config: TransformerConfig object - mpu_vocab_size: model parallel size of vocabulary. - parallel_output: whether output logits being distributed or not. - """ - - def __init__(self, mpu_vocab_size, config, parallel_output): - super().__init__(config=config) - - args = get_args() - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - self.parallel_output = parallel_output - - self.dense = get_linear_layer(config.hidden_size, config.hidden_size, config.init_method) - setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel) - setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel) - - self.norm = get_norm(config) - self.gelu = torch.nn.functional.gelu - if args.openai_gelu: - self.gelu = openai_gelu - elif args.onnx_safe: - self.gelu = erf_gelu - - def forward(self, hidden_states, word_embeddings_weight): - hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = self.norm(hidden_states) - output = parallel_lm_logits(hidden_states, - word_embeddings_weight, - self.parallel_output, - bias=self.bias) - return output - - def load_state_dict(self, state_dict, strict=True): - """Customize load.""" - - # Handle renaming layernorm -> norm in component names - state_dict_ = {} - for key in state_dict.keys(): - newkey = key.replace("layernorm", "norm") - state_dict_[newkey] = state_dict[key] - - super().load_state_dict(state_dict_, strict) - - -def post_language_model_processing(lm_output, pooled_output, - lm_head, binary_head, - lm_labels, - logit_weights, - fp16_lm_cross_entropy): - # Output. - lm_logits = lm_head( - lm_output, logit_weights) - - binary_logits = None - if binary_head is not None: - binary_logits = binary_head(pooled_output) - - if lm_labels is None: - # [s b h] => [b s h] - return lm_logits.transpose(0,1).contiguous(), binary_logits - else: - # [b s] => [s b] - lm_labels = lm_labels.transpose(0,1).contiguous() - # lm_logits : [s, b, h] and lm_labels: [s, b] - if fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) - # [s, b] => [b s] - lm_loss = lm_loss.transpose(0,1).contiguous() - return lm_loss, binary_logits - - -class BertModel(MegatronModule): - """Bert Language model.""" - - def __init__(self, - config, - num_tokentypes=2, - add_binary_head=True, - parallel_output=True, - pre_process=True, - post_process=True): - super().__init__(config=config) - args = get_args() - - # TODO this option is not yet implemented in BERT - assert args.untie_embeddings_and_output_weights is False - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.add_binary_head = add_binary_head - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - - self.return_embeddings = args.output_bert_embeddings - if self.return_embeddings: - assert self.post_process and self.add_binary_head - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=self.add_binary_head, - encoder_attn_mask_type=AttnMaskType.padding, - pre_process=self.pre_process, - post_process=self.post_process) - - self.initialize_word_embeddings() - if self.post_process: - self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config, parallel_output) - self._lm_head_key = 'lm_head' - self.binary_head = None - if self.add_binary_head: - self.binary_head = get_linear_layer(config.hidden_size, 2, - config.init_method) - self._binary_head_key = 'binary_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, bert_model_input, attention_mask, - tokentype_ids=None, lm_labels=None): - - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = bert_model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - - if self.post_process and self.add_binary_head: - lm_output, pooled_output = lm_output - - # Return pooled output (e.g., when computing Bert embeddings). - if self.return_embeddings: - - # Sum attention mask. - embeddings = torch.transpose(lm_output, 0, 1) - masks = torch.sum(attention_mask, dim=1) - - # Collect masked embeddings. - output = torch.zeros( - size=(embeddings.shape[0], embeddings.shape[2]), - dtype=torch.float32, - device=torch.cuda.current_device()) - for i, (embedding, mask) in enumerate(zip(embeddings, masks)): - output[i, :] = torch.mean(embedding[1: mask - 1], dim=0) - - return output - - else: - pooled_output = None - - if self.post_process: - return post_language_model_processing(lm_output, pooled_output, - self.lm_head, self.binary_head, - lm_labels, - self.shared_embedding_or_output_weight(), - self.fp16_lm_cross_entropy) - else: - return lm_output - - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process: - state_dict_[self._lm_head_key] \ - = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process and self.add_binary_head: - state_dict_[self._binary_head_key] \ - = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - self.lm_head.load_state_dict( - state_dict[self._lm_head_key], strict=strict) - if self.post_process and self.add_binary_head: - self.binary_head.load_state_dict( - state_dict[self._binary_head_key], strict=strict) - # Load word_embeddings. - if self.post_process and not self.pre_process: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py deleted file mode 100644 index c910879dc8dc48b96273683996f608453dc0729c..0000000000000000000000000000000000000000 --- a/megatron/model/biencoder_model.py +++ /dev/null @@ -1,328 +0,0 @@ -import os -import torch -import sys - -from megatron import get_args, print_rank_0, get_tokenizer -from megatron.core import mpu -from megatron.checkpointing import fix_query_key_value_ordering -from megatron.checkpointing import get_checkpoint_tracker_filename -from megatron.checkpointing import get_checkpoint_name -from megatron.model.bert_model import bert_position_ids -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - -def get_model_provider(only_query_model=False, only_context_model=False, - biencoder_shared_query_context_model=False): - - def model_provider(pre_process=True, post_process=True): - """Build the model.""" - - print_rank_0('building Bienoder model ...') - model = biencoder_model_provider(only_query_model=only_query_model, - only_context_model = only_context_model, - biencoder_shared_query_context_model = \ - biencoder_shared_query_context_model, - pre_process=pre_process, post_process=post_process) - - return model - - return model_provider - - -def biencoder_model_provider(only_query_model=False, - only_context_model=False, - biencoder_shared_query_context_model=False, - pre_process=True, - post_process=True): - """Build the model.""" - - assert mpu.get_tensor_model_parallel_world_size() == 1 and \ - mpu.get_pipeline_model_parallel_world_size() == 1, \ - "Model parallel size > 1 not supported for ICT" - - print_rank_0('building BiEncoderModel...') - - # simpler to just keep using 2 tokentypes since - # the LM we initialize with has 2 tokentypes - model = BiEncoderModel( - num_tokentypes=2, - parallel_output=False, - only_query_model=only_query_model, - only_context_model=only_context_model, - biencoder_shared_query_context_model=\ - biencoder_shared_query_context_model, - pre_process=pre_process, - post_process=post_process) - - return model - - -class BiEncoderModel(MegatronModule): - """Bert-based module for Biencoder model.""" - - def __init__(self, - num_tokentypes=1, - parallel_output=True, - only_query_model=False, - only_context_model=False, - biencoder_shared_query_context_model=False, - pre_process=True, - post_process=True): - super(BiEncoderModel, self).__init__() - args = get_args() - - bert_kwargs = dict( - num_tokentypes=num_tokentypes, - parallel_output=parallel_output, - pre_process=pre_process, - post_process=post_process) - - self.biencoder_shared_query_context_model = \ - biencoder_shared_query_context_model - assert not (only_context_model and only_query_model) - self.use_context_model = not only_query_model - self.use_query_model = not only_context_model - self.biencoder_projection_dim = args.biencoder_projection_dim - - if self.biencoder_shared_query_context_model: - self.model = PretrainedBertModel(**bert_kwargs) - self._model_key = 'shared_model' - self.query_model, self.context_model = self.model, self.model - else: - if self.use_query_model: - # this model embeds (pseudo-)queries - Embed_input in the paper - self.query_model = PretrainedBertModel(**bert_kwargs) - self._query_key = 'query_model' - - if self.use_context_model: - # this model embeds evidence blocks - Embed_doc in the paper - self.context_model = PretrainedBertModel(**bert_kwargs) - self._context_key = 'context_model' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - # this is just a placeholder and will be needed when model - # parallelism will be used - # self.language_model.set_input_tensor(input_tensor) - return - - def forward(self, query_tokens, query_attention_mask, query_types, - context_tokens, context_attention_mask, context_types): - """Run a forward pass for each of the models and - return the respective embeddings.""" - - if self.use_query_model: - query_logits = self.embed_text(self.query_model, - query_tokens, - query_attention_mask, - query_types) - else: - raise ValueError("Cannot embed query without the query model.") - if self.use_context_model: - context_logits = self.embed_text(self.context_model, - context_tokens, - context_attention_mask, - context_types) - else: - raise ValueError("Cannot embed block without the block model.") - return query_logits, context_logits - - @staticmethod - def embed_text(model, tokens, attention_mask, token_types): - """Embed a batch of tokens using the model""" - logits = model(tokens, - attention_mask, - token_types) - return logits - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """Save dict with state dicts of each of the models.""" - state_dict_ = {} - if self.biencoder_shared_query_context_model: - state_dict_[self._model_key] = \ - self.model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - else: - if self.use_query_model: - state_dict_[self._query_key] = \ - self.query_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - - if self.use_context_model: - state_dict_[self._context_key] = \ - self.context_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Load the state dicts of each of the models""" - if self.biencoder_shared_query_context_model: - print_rank_0("Loading shared query-context model") - self.model.load_state_dict(state_dict[self._model_key], \ - strict=strict) - else: - if self.use_query_model: - print_rank_0("Loading query model") - self.query_model.load_state_dict( \ - state_dict[self._query_key], strict=strict) - - if self.use_context_model: - print_rank_0("Loading context model") - self.context_model.load_state_dict( \ - state_dict[self._context_key], strict=strict) - - def init_state_dict_from_bert(self): - """Initialize the state from a pretrained BERT model - on iteration zero of ICT pretraining""" - args = get_args() - - if args.bert_load is None: - print_rank_0("bert-load argument is None") - return - - tracker_filename = get_checkpoint_tracker_filename(args.bert_load) - if not os.path.isfile(tracker_filename): - raise FileNotFoundError("Could not find BERT checkpoint") - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - assert iteration > 0 - - checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading BERT checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - # Load the checkpoint. - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except ModuleNotFoundError: - from megatron.fp16_deprecated import loss_scaler - # For backward compatibility. - print_rank_0(' > deserializing using the old code structure ...') - sys.modules['fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - state_dict = torch.load(checkpoint_name, map_location='cpu') - sys.modules.pop('fp16.loss_scaler', None) - sys.modules.pop('megatron.fp16.loss_scaler', None) - except BaseException: - print_rank_0('could not load the BERT checkpoint') - sys.exit() - - checkpoint_version = state_dict.get('checkpoint_version', 0) - - # load the LM state dict into each model - model_dict = state_dict['model']['language_model'] - - if self.biencoder_shared_query_context_model: - self.model.language_model.load_state_dict(model_dict) - fix_query_key_value_ordering(self.model, checkpoint_version) - else: - if self.use_query_model: - self.query_model.language_model.load_state_dict(model_dict) - # give each model the same ict_head to begin with as well - if self.biencoder_projection_dim > 0: - query_proj_state_dict = \ - self.state_dict_for_save_checkpoint()\ - [self._query_key]['projection_enc'] - fix_query_key_value_ordering(self.query_model, checkpoint_version) - - if self.use_context_model: - self.context_model.language_model.load_state_dict(model_dict) - if self.query_model is not None and \ - self.biencoder_projection_dim > 0: - self.context_model.projection_enc.load_state_dict\ - (query_proj_state_dict) - fix_query_key_value_ordering(self.context_model, checkpoint_version) - - -class PretrainedBertModel(MegatronModule): - """BERT-based encoder for queries or contexts used for - learned information retrieval.""" - - def __init__(self, num_tokentypes=2, - parallel_output=True, pre_process=True, post_process=True): - super(PretrainedBertModel, self).__init__() - - args = get_args() - tokenizer = get_tokenizer() - self.pad_id = tokenizer.pad - self.biencoder_projection_dim = args.biencoder_projection_dim - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal( - args.init_method_std, args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method, - pre_process=self.pre_process, - post_process=self.post_process) - - if args.biencoder_projection_dim > 0: - self.projection_enc = get_linear_layer(args.hidden_size, - args.biencoder_projection_dim, - init_method) - self._projection_enc_key = 'projection_enc' - - def forward(self, input_ids, attention_mask, tokentype_ids=None): - extended_attention_mask = attention_mask.unsqueeze(1) - #extended_attention_mask = bert_extended_attention_mask(attention_mask) - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model(input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids) - # This mask will be used in average-pooling and max-pooling - pool_mask = (input_ids == self.pad_id).unsqueeze(2) - - # Taking the representation of the [CLS] token of BERT - pooled_output = lm_output[0, :, :] - - # Converting to float16 dtype - pooled_output = pooled_output.to(lm_output.dtype) - - # Output. - if self.biencoder_projection_dim: - pooled_output = self.projection_enc(pooled_output) - - return pooled_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - - if self.biencoder_projection_dim > 0: - state_dict_[self._projection_enc_key] = \ - self.projection_enc.state_dict(prefix=prefix, - keep_vars=keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - print_rank_0("loading pretrained weights") - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - - if self.biencoder_projection_dim > 0: - print_rank_0("loading projection head weights") - self.projection_enc.load_state_dict( - state_dict[self._projection_enc_key], strict=strict) diff --git a/megatron/model/classification.py b/megatron/model/classification.py deleted file mode 100644 index bac50c54cdf9981057a4844a3a8a07d365590b2a..0000000000000000000000000000000000000000 --- a/megatron/model/classification.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Classification model.""" - -import torch - -from megatron import get_args, print_rank_last -from megatron.model.enums import AttnMaskType -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - - -class Classification(MegatronModule): - - def __init__(self, - config, - num_classes, - num_tokentypes=2, - pre_process=True, - post_process=True): - super().__init__(config=config, share_embeddings_and_output_weights=False) - args = get_args() - - self.num_classes = num_classes - self.pre_process = pre_process - self.post_process = post_process - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - pre_process=self.pre_process, - post_process=self.post_process) - - # Multi-choice head. - if self.post_process: - self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) - self.classification_head = get_linear_layer(args.hidden_size, - self.num_classes, - init_method) - self._classification_head_key = 'classification_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, model_input, attention_mask, tokentype_ids=None): - - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - - if self.post_process: - _, pooled_output = lm_output - classification_output = self.classification_dropout(pooled_output) - classification_logits = self.classification_head(classification_output) - - # Reshape back to separate choices. - classification_logits = classification_logits.view(-1, self.num_classes) - - return classification_logits - return lm_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process: - state_dict_[self._classification_head_key] \ - = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - if self._classification_head_key in state_dict: - self.classification_head.load_state_dict( - state_dict[self._classification_head_key], strict=strict) - else: - print_rank_last('***WARNING*** could not find {} in the checkpoint, ' - 'initializing to random'.format( - self._classification_head_key)) diff --git a/megatron/model/enums.py b/megatron/model/enums.py deleted file mode 100644 index bc4e4aa29a05856bcef01d9e0fb6bfda216c247b..0000000000000000000000000000000000000000 --- a/megatron/model/enums.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import enum - -class LayerType(enum.Enum): - encoder = 1 - decoder = 2 - retro_encoder = 3 - retro_decoder = 4 - retro_decoder_with_retriever = 5 - -class AttnType(enum.Enum): - self_attn = 1 - cross_attn = 2 - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - -# For backward compatibility with old model checkpoints -from megatron.core.enums import ModelType diff --git a/megatron/model/fused_bias_gelu.py b/megatron/model/fused_bias_gelu.py deleted file mode 100644 index 29222db024eb5c5e54c7f38f58be8edd45c49b39..0000000000000000000000000000000000000000 --- a/megatron/model/fused_bias_gelu.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import torch - - -###### BIAS GELU FUSION/ NO AUTOGRAD ################ -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) - -@torch.jit.script -def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, bias, y): - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(bias, input) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, bias, input) - return tmp, tmp - -bias_gelu_impl = GeLUFunction.apply diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py deleted file mode 100644 index c91a674e8cafe548820f6c838e607fb30b1e087c..0000000000000000000000000000000000000000 --- a/megatron/model/fused_layer_norm.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""This code is copied fron NVIDIA apex: - https://github.com/NVIDIA/apex - with some changes. """ - -import numbers -import torch -from torch.nn.parameter import Parameter -from torch.nn import init -import importlib - -from megatron.core.utils import make_viewless_tensor - -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNormFN - HAVE_PERSIST_LAYER_NORM = True -except: - HAVE_PERSIST_LAYER_NORM = False - -try: - from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction -except: - FusedLayerNormAffineFunction = None - -global fused_layer_norm_cuda -fused_layer_norm_cuda = None - - -class MixedFusedLayerNorm(torch.nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, - no_persist_layer_norm=True, - sequence_parallel=False, - apply_layernorm_1p=False): - super(MixedFusedLayerNorm, self).__init__() - - self.apply_layernorm_1p = apply_layernorm_1p - - global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - - # List of hiddens sizes supported in the persistent layer norm kernel - # If the hidden size is not supported, fall back to the non-persistent - # kernel. - persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, - 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536] - if normalized_shape not in persist_ln_hidden_sizes or \ - not HAVE_PERSIST_LAYER_NORM: - no_persist_layer_norm = True - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) - self.reset_parameters() - self.no_persist_layer_norm = no_persist_layer_norm - self.sequence_parallel = sequence_parallel - - # set sequence parallelism flag on weight and bias parameters - setattr(self.weight, 'sequence_parallel', self.sequence_parallel) - setattr(self.bias, 'sequence_parallel', self.sequence_parallel) - - - def reset_parameters(self): - - if self.apply_layernorm_1p: - init.zeros_(self.weight) - init.zeros_(self.bias) - else: - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, input): - - weight = self.weight + 1 if self.apply_layernorm_1p else self.weight - - if self.no_persist_layer_norm: - assert FusedLayerNormAffineFunction is not None, \ - "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex" - return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) - else: - output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) - - # Apex's fast layer norm function outputs a 'view' tensor (i.e., has - # a populated '_base' field). This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - output = make_viewless_tensor(inp = output, - requires_grad = input.requires_grad, - keep_graph = True) - - return output diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py deleted file mode 100644 index 9bacf337402c4b08f565e8ddfdfb8e4579f3e257..0000000000000000000000000000000000000000 --- a/megatron/model/fused_softmax.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - - -import torch -import torch.nn as nn -from megatron.model.enums import AttnMaskType - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( - inputs, scale_t[0] - ) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None - - -class ScaledSoftmax(torch.autograd.Function): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_softmax_cuda.forward( - inputs, scale_t[0] - ) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super(FusedScaleMaskSoftmax, self).__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - assert ( - self.scale is None or softmax_in_fp32 - ), "softmax should be in fp32 when scaled" - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and 16 < sk <= 16384 # sk must be 16 ~ 16384 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 16384: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - b, np, sq, sk = input.size() - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert sq == sk, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, sq, sk) - input = input.view(-1, sq, sk) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - return probs.view(b, np, sq, sk) - else: - # input is 4D tensor (b, np, sq, sk) - if mask is not None: - return ScaledMaskedSoftmax.apply(input, mask, scale) - else: - return ScaledSoftmax.apply(input, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - import scaled_masked_softmax_cuda - - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py deleted file mode 100644 index dd47188da4a14515bf595766014baa933d6bfef4..0000000000000000000000000000000000000000 --- a/megatron/model/gpt_model.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""GPT-2 model.""" - -import torch - -from megatron import get_args -from megatron.core import tensor_parallel -from .module import MegatronModule - -from .enums import AttnMaskType -from .language_model import parallel_lm_logits -from .language_model import get_language_model - - -def post_language_model_processing(lm_output, labels, logit_weights, - parallel_output, - fp16_lm_cross_entropy): - - # Output. Format [s b h] - output = parallel_lm_logits( - lm_output, - logit_weights, - parallel_output) - - if labels is None: - # [s b h] => [b s h] - return output.transpose(0,1).contiguous() - else: - # [b s] => [s b] - labels = labels.transpose(0,1).contiguous() - if fp16_lm_cross_entropy: - assert output.dtype == torch.half - loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) - else: - loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - - # [s b] => [b, s] - loss = loss.transpose(0,1).contiguous() - return loss - - -class GPTModel(MegatronModule): - """GPT-2 Language model.""" - - def __init__(self, - config, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True): - args = get_args() - super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) - - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, - pre_process=self.pre_process, - post_process=self.post_process) - - if not args.untie_embeddings_and_output_weights: - self.initialize_word_embeddings() - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, input_ids, position_ids, attention_mask, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - labels=None, tokentype_ids=None, inference_params=None): - - lm_output = self.language_model( - input_ids, - position_ids, - attention_mask, - retriever_input_ids=retriever_input_ids, - retriever_position_ids=retriever_position_ids, - retriever_attn_mask=retriever_attn_mask, - inference_params=inference_params) - - if self.post_process: - return post_language_model_processing( - lm_output, labels, - self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), - self.parallel_output, - self.fp16_lm_cross_entropy) - else: - return lm_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Load word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) - if self._language_model_key in state_dict: - state_dict = state_dict[self._language_model_key] - self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py deleted file mode 100644 index 69bfa2e8018ed30411764bf357111ba5f4ed4d95..0000000000000000000000000000000000000000 --- a/megatron/model/language_model.py +++ /dev/null @@ -1,626 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Transformer based language model.""" - -import torch -import torch.nn.functional as F - -from megatron import get_args -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding - -from .enums import AttnMaskType, LayerType -from .module import MegatronModule -from .transformer import ParallelTransformer -from .utils import get_linear_layer -from .utils import init_method_normal, scaled_init_method_normal - - -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, - bias=None): - """LM logits using word embedding weights.""" - args = get_args() - # Parallel logits. - if args.async_tensor_model_parallel_allreduce or\ - args.sequence_parallel: - input_parallel = input_ - model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 - async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ - model_parallel and not args.sequence_parallel - else: - input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) - async_grad_allreduce = False - - # Matrix multiply. - logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( - input=input_parallel, - weight=word_embeddings_weight, - bias=bias, - gradient_accumulation_fusion=args.gradient_accumulation_fusion, - async_grad_allreduce=async_grad_allreduce, - sequence_parallel=args.sequence_parallel) - # Gather if needed. - - if parallel_output: - return logits_parallel - - return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) - - -def get_language_model(config, num_tokentypes, add_pooler, - encoder_attn_mask_type, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - pre_process=True, post_process=True): - """Build language model and return along with the key to save.""" - args = get_args() - if config.init_method is None: - config.init_method = init_method_normal(config.init_method_std) - - if config.output_layer_init_method is None: - config.output_layer_init_method = scaled_init_method_normal(config.init_method_std, - config.num_layers) - - # Language model. - language_model = TransformerLanguageModel( - config, - encoder_attn_mask_type, - num_tokentypes=num_tokentypes, - add_encoder=add_encoder, - add_decoder=add_decoder, - decoder_attn_mask_type=decoder_attn_mask_type, - add_pooler=add_pooler, - pre_process=pre_process, - post_process=post_process - ) - # key used for checkpoints. - language_model_key = 'language_model' - - return language_model, language_model_key - - -class Pooler(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Arguments: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, hidden_size, init_method): - super(Pooler, self).__init__() - args = get_args() - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.sequence_parallel = args.sequence_parallel - - - def forward(self, hidden_states, sequence_index=0): - # hidden_states: [s, b, h] - # sequence_index: index of the token to pool. - - # gather data along sequence dimensions - # same pooler is run on all tensor parallel nodes - if self.sequence_parallel: - hidden_states = tensor_parallel.gather_from_sequence_parallel_region( - hidden_states, - tensor_parallel_output_grad=False) - - pooled = hidden_states[sequence_index, :, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled - - -class Embedding(MegatronModule): - """Language model embeddings. - - Arguments: - hidden_size: hidden size - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - init_method: weight initialization method - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - config, - num_tokentypes=0): - super(Embedding, self).__init__() - - self.hidden_size = hidden_size - self.init_method = config.init_method - self.num_tokentypes = num_tokentypes - - args = get_args() - - # Word embeddings (parallel). - self.params_dtype = args.params_dtype - self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - vocab_size, self.hidden_size, config=config, init_method=config.init_method) - self._word_embeddings_key = 'word_embeddings' - - # Position embedding (serial). - self.add_position_embedding = args.position_embedding_type == 'learned_absolute' - if self.add_position_embedding: - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - if args.perform_initialization: - self.init_method(self.position_embeddings.weight) - - # Token type embedding. - # Add this as an optional field that can be added through - # method call so we can load a pretrain model without - # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' - if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) - # Initialize the token-type embeddings. - if args.perform_initialization: - self.init_method(self.tokentype_embeddings.weight) - else: - self.tokentype_embeddings = None - - self.fp32_residual_connection = args.fp32_residual_connection - self.sequence_parallel = args.sequence_parallel - self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding - # Embeddings dropout - self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) - - def zero_parameters(self): - """Zero out all parameters in embedding.""" - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - if self.add_position_embedding: - self.position_embeddings.weight.data.fill_(0) - self.position_embeddings.weight.shared = True - if self.num_tokentypes > 0: - self.tokentype_embeddings.weight.data.fill_(0) - self.tokentype_embeddings.weight.shared = True - - def add_tokentype_embeddings(self, num_tokentypes): - """Add token-type embedding. This function is provided so we can add - token-type embeddings in case the pretrained model does not have it. - This allows us to load the model normally and then add this embedding. - """ - if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') - if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), - flush=True) - self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, - self.hidden_size) - # Initialize the token-type embeddings. - args = get_args() - self.init_method(self.tokentype_embeddings.weight) - - def forward(self, input_ids, position_ids, tokentype_ids=None): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - if self.add_position_embedding: - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings - else: - embeddings = words_embeddings - - if tokentype_ids is not None: - assert self.tokentype_embeddings is not None - embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) - else: - assert self.tokentype_embeddings is None - - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.sequence_parallel: - embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) - # `scatter_to_sequence_parallel_region` returns a view, which prevents - # the original tensor from being garbage collected. Clone to facilitate GC. - # Has a small runtime cost (~0.5%). - if self.clone_scatter_output_in_embedding: - embeddings = embeddings.clone() - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load.""" - - state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - if self.add_position_embedding: - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Word embedding. - if self._word_embeddings_key in state_dict: - state_dict_ = state_dict[self._word_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] - self.word_embeddings.load_state_dict(state_dict_, strict=strict) - - # Position embedding. - if self.add_position_embedding: - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) - - # Tokentype embedding. - if self.num_tokentypes > 0: - state_dict_ = {} - if self._tokentype_embeddings_key in state_dict: - state_dict_ = state_dict[self._tokentype_embeddings_key] - else: - # for backward compatibility. - for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] - if len(state_dict_.keys()) > 0: - self.tokentype_embeddings.load_state_dict(state_dict_, - strict=strict) - else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', flush=True) - - -class TransformerLanguageModel(MegatronModule): - """Transformer language model. - - Arguments: - transformer_hparams: transformer hyperparameters - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__(self, - config, - encoder_attn_mask_type, - num_tokentypes=0, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - add_pooler=False, - pre_process=True, - post_process=True): - args = get_args() - # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. - if args.untie_embeddings_and_output_weights: assert not add_decoder - super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) - - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = config.hidden_size - self.num_tokentypes = num_tokentypes - self.init_method = config.init_method - self.add_encoder = add_encoder - self.encoder_attn_mask_type = encoder_attn_mask_type - self.add_decoder = add_decoder - self.decoder_attn_mask_type = decoder_attn_mask_type - self.add_pooler = add_pooler - self.encoder_hidden_state = None - self.add_retriever = args.retro_add_retriever - self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - - # Embeddings. - if self.pre_process: - self.embedding = Embedding(self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - config, - self.num_tokentypes) - self._embedding_key = 'embedding' - - # Rotary positional embeddings - self.use_rotary_position_embeddings = \ - args.position_embedding_type == 'rope' - if self.use_rotary_position_embeddings: - self.seq_length = args.seq_length - rotary_dim = args.hidden_size // args.num_attention_heads \ - if args.kv_channels is None else args.kv_channels - - # partial rotary embeddings, which is better than full rotary - # Wang and Komatsuzaki et al - # https://github.com/kingoflolz/mesh-transformer-jax/ - self.rotary_pos_emb = RotaryEmbedding( - rotary_dim, - args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor - ) - - # Encoder (usually set to True, False if part of an encoder-decoder - # architecture and in encoder-only stage). - if self.add_encoder: - self.encoder = ParallelTransformer( - config, - model_type=args.model_type if not args.retro_add_retriever \ - else ModelType.retro_decoder, - self_attn_mask_type=self.encoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._encoder_key = 'encoder' - else: - self.encoder = None - - # Decoder (usually set to False, True if part of an encoder-decoder - # architecture and in decoder-only stage). - if self.add_decoder: - self.decoder = ParallelTransformer( - config, - model_type=args.model_type, - layer_type=LayerType.decoder, - self_attn_mask_type=self.decoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process) - self._decoder_key = 'decoder' - else: - self.decoder = None - - if self.post_process: - # Pooler. - if self.add_pooler: - self.pooler = Pooler(self.hidden_size, self.init_method) - self._pooler_key = 'pooler' - - if self.untie_embeddings_and_output_weights: - self.output_layer = tensor_parallel.ColumnParallelLinear( - args.hidden_size, - args.padded_vocab_size, - config=config, - init_method=self.init_method, - bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. - self._output_layer_key = 'output_layer' - - def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" - - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - if self.add_encoder and self.add_decoder: - assert len(input_tensor) == 1, \ - 'input_tensor should only be length 1 for stage with both encoder and decoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_encoder: - assert len(input_tensor) == 1, \ - 'input_tensor should only be length 1 for stage with only encoder' - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_decoder: - if len(input_tensor) == 2: - self.decoder.set_input_tensor(input_tensor[0]) - self.encoder_hidden_state = input_tensor[1] - elif len(input_tensor) == 1: - self.decoder.set_input_tensor(None) - self.encoder_hidden_state = input_tensor[0] - else: - raise Exception('input_tensor must have either length 1 or 2') - else: - raise Exception('Stage must have at least either encoder or decoder') - - def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, - dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, - retriever_input_ids=None, - retriever_position_ids=None, - retriever_attn_mask=None, - enc_dec_attn_mask=None, tokentype_ids=None, - inference_params=None, - pooling_sequence_index=0, - enc_hidden_states=None, output_enc_hidden=False): - - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding(enc_input_ids, enc_position_ids, - tokentype_ids=tokentype_ids) - else: - encoder_input = None - - # Retriever embedding. - if self.add_retriever and self.pre_process: - retriever_input = self.embedding(retriever_input_ids, - retriever_position_ids, - tokentype_ids=tokentype_ids) - else: - retriever_input = None - - # Rotary positional embeddings - rotary_pos_emb = None - if self.use_rotary_position_embeddings: - if inference_params is not None: - rotary_pos_emb = \ - self.rotary_pos_emb(inference_params.max_sequence_length) - else: - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - - # Run encoder. - if enc_hidden_states is None: - if self.encoder is not None: - encoder_output = self.encoder( - encoder_input, - enc_attn_mask, - retriever_input=retriever_input, - retriever_attn_mask=retriever_attn_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) - else: - encoder_output = self.encoder_hidden_state - else: - encoder_output = enc_hidden_states.to(encoder_input.dtype) - - if self.post_process: - if self.add_pooler: - pooled_output = self.pooler(encoder_output, - pooling_sequence_index) - - # output_enc_hidden refers to when we just need the encoder's - # output. For example, it is helpful to compute - # similarity between two sequences by average pooling - if not self.add_decoder or output_enc_hidden: - if self.add_pooler and self.post_process: - return encoder_output, pooled_output - else: - return encoder_output - - # Decoder embedding. - if self.pre_process: - decoder_input = self.embedding(dec_input_ids, - dec_position_ids) - else: - decoder_input = None - - # Run decoder. - decoder_output = self.decoder( - decoder_input, - dec_attn_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) - - if self.add_pooler and self.post_process: - return decoder_output, encoder_output, pooled_output - else: - return decoder_output, encoder_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load.""" - - state_dict_ = {} - if self.pre_process: - state_dict_[self._embedding_key] \ - = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.add_encoder: - state_dict_[self._encoder_key] \ - = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process: - if self.add_pooler: - state_dict_[self._pooler_key] \ - = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.untie_embeddings_and_output_weights: - state_dict_[self._output_layer_key] \ - = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) - - if self.add_decoder: - state_dict_[self._decoder_key] \ - = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - # Embedding. - if self.pre_process: - if self._embedding_key in state_dict: - state_dict_ = state_dict[self._embedding_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if '_embeddings' in key: - state_dict_[key] = state_dict[key] - self.embedding.load_state_dict(state_dict_, strict=strict) - - # Encoder. - if self.add_encoder: - if self._encoder_key in state_dict: - state_dict_ = state_dict[self._encoder_key] - # For backward compatibility. - elif 'transformer' in state_dict: - state_dict_ = state_dict['transformer'] - else: - # For backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] - - # For backward compatibility. - state_dict_self_attention = {} - for key in state_dict_.keys(): - if '.attention.' in key: - state_dict_self_attention[key.replace(".attention.", - ".self_attention.")] = state_dict_[key] - else: - state_dict_self_attention[key] = state_dict_[key] - state_dict_ = state_dict_self_attention - - self.encoder.load_state_dict(state_dict_, strict=strict) - - # Pooler. - if self.post_process: - if self.add_pooler: - assert 'pooler' in state_dict, \ - 'could not find data for pooler in the checkpoint' - self.pooler.load_state_dict(state_dict[self._pooler_key], - strict=strict) - if self.untie_embeddings_and_output_weights: - assert 'output_layer' in state_dict, \ - 'could not find data for output_layer in the checkpoint' - self.output_layer.load_state_dict(state_dict[self._output_layer_key], - strict=strict) - # Decoder. - if self.add_decoder: - assert 'decoder' in state_dict, \ - 'could not find data for pooler in the checkpoint' - self.decoder.load_state_dict(state_dict[self._decoder_key], - strict=strict) diff --git a/megatron/model/module.py b/megatron/model/module.py deleted file mode 100644 index c2887315a56e5ea575a6370de842d76f21f2937b..0000000000000000000000000000000000000000 --- a/megatron/model/module.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron Module""" - -import torch -from torch.autograd import Variable -from torch.nn.parameter import Parameter - -from megatron import get_args -from megatron.core import mpu, tensor_parallel - - -_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) -_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) -_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) - - - -def param_is_not_shared(param): - return not hasattr(param, 'shared') or not param.shared - - - -class MegatronModule(torch.nn.Module): - """Megatron specific extensions of torch Module with support - for pipelining.""" - - def __init__(self, config=None, share_embeddings_and_output_weights=True): - super(MegatronModule, self).__init__() - self.config = config - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """Use this function to override the state dict for - saving checkpoints.""" - return self.state_dict(prefix=prefix, keep_vars=keep_vars) - - - def shared_embedding_or_output_weight(self): - if self.pre_process: - return self.language_model.embedding.word_embeddings.weight - else: - if not self.share_embeddings_and_output_weights: - raise Exception('shared_embedding_or_output_weight() called for last ' - 'stage, but share_embeddings_and_output_weights is false') - return self.word_embeddings.weight - - - def initialize_word_embeddings(self): - args = get_args() - if not self.share_embeddings_and_output_weights: - raise Exception('initialize_word_embeddings() was called but ' - 'share_embeddings_and_output_weights is false') - - # This function just initializes the word embeddings in the final stage - # when we are using pipeline parallelism. Nothing to do if we aren't - # using pipeline parallelism. - if args.pipeline_model_parallel_size == 1: - return - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - if mpu.is_pipeline_last_stage() and not self.pre_process: - assert not mpu.is_pipeline_first_stage() - self._word_embeddings_for_head_key = 'word_embeddings_for_head' - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - args.padded_vocab_size, self.config.hidden_size, - config=self.config, init_method=self.config.init_method) - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - - # Zero out initial weights for decoder embedding. - # NOTE: We don't currently support T5 with the interleaved schedule. - if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ - self.pre_process: - self.language_model.embedding.zero_parameters() - - if not torch.distributed.is_initialized(): - if not getattr(MegatronModule, "embedding_warning_printed", False): - print("WARNING! Distributed processes aren't initialized, so " - "word embeddings in the last layer are not initialized. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong.") - MegatronModule.embedding_warning_printed = True - return - - # Ensure that first and last stages have the same initial parameter - # values. - if mpu.is_rank_in_embedding_group(): - torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data, - group=mpu.get_embedding_group()) - - # Ensure that encoder(first stage) and decoder(split stage) position - # embeddings have the same initial parameter values - # NOTE: We don't currently support T5 with the interleaved schedule. - if mpu.is_rank_in_position_embedding_group() and \ - args.pipeline_model_parallel_split_rank is not None: - # TODO: Support tokentype embedding. - self.language_model.embedding.cuda() - position_embeddings = self.language_model.embedding.position_embeddings - torch.distributed.all_reduce(position_embeddings.weight.data, - group=mpu.get_position_embedding_group()) - - -def conversion_helper(val, conversion): - """Apply conversion to val. Recursively apply conversion if `val` - #is a nested tuple/list structure.""" - if not isinstance(val, (tuple, list)): - return conversion(val) - rtn = [conversion_helper(v, conversion) for v in val] - if isinstance(val, tuple): - rtn = tuple(rtn) - return rtn - - -def fp32_to_float16(val, float16_convertor): - """Convert fp32 `val` to fp16/bf16""" - def half_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, _FLOAT_TYPES): - val = float16_convertor(val) - return val - return conversion_helper(val, half_conversion) - - -def float16_to_fp32(val): - """Convert fp16/bf16 `val` to fp32""" - def float_conversion(val): - val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): - val_typecheck = val.data - if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): - val = val.float() - return val - return conversion_helper(val, float_conversion) - - - -class Float16Module(MegatronModule): - - def __init__(self, module, args): - super(Float16Module, self).__init__() - - if args.fp16: - self.add_module('module', module.half()) - def float16_convertor(val): - return val.half() - elif args.bf16: - self.add_module('module', module.bfloat16()) - def float16_convertor(val): - return val.bfloat16() - else: - raise Exception('should not be here') - - self.float16_convertor = float16_convertor - - - def set_input_tensor(self, input_tensor): - return self.module.set_input_tensor(input_tensor) - - - def forward(self, *inputs, **kwargs): - if mpu.is_pipeline_first_stage(): - inputs = fp32_to_float16(inputs, self.float16_convertor) - outputs = self.module(*inputs, **kwargs) - if mpu.is_pipeline_last_stage(): - outputs = float16_to_fp32(outputs) - return outputs - - - def state_dict(self, prefix='', keep_vars=False): - return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) - - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - return self.module.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - - - def load_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py deleted file mode 100644 index 41f8bb49f6a075a484d1a03654f44d326898056c..0000000000000000000000000000000000000000 --- a/megatron/model/multiple_choice.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Multiple choice model.""" - -import torch - -from megatron import get_args, print_rank_last -from megatron.model.enums import AttnMaskType -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - - -class MultipleChoice(MegatronModule): - - def __init__(self, - config, - num_tokentypes=2, - pre_process=True, - post_process=True): - super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False) - args = get_args() - - self.pre_process = pre_process - self.post_process = post_process - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - pre_process=self.pre_process, - post_process=self.post_process) - - # Multi-choice head. - if self.post_process: - self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) - self.multichoice_head = get_linear_layer(args.hidden_size, 1, - init_method) - self._multichoice_head_key = 'multichoice_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, model_input, attention_mask, tokentype_ids=None): - - # [batch, choices, sequence] --> [batch * choices, sequence] --> - # transformer --> [batch, choices] --> softmax - - # Ensure the shape is [batch-size, choices, sequence] - assert len(attention_mask.shape) == 3 - num_choices = attention_mask.shape[1] - - # Reshape and treat choice dimension the same as batch. - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - extended_attention_mask = bert_extended_attention_mask(attention_mask) - - input_ids = model_input - # Do the same as attention_mask for input_ids, tokentype_ids - assert len(input_ids.shape) == 3 - assert len(tokentype_ids.shape) == 3 - input_ids = input_ids.view(-1, input_ids.size(-1)) - tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - if self.post_process: - _, pooled_output = lm_output - multichoice_output = self.multichoice_dropout(pooled_output) - multichoice_logits = self.multichoice_head(multichoice_output) - - # Reshape back to separate choices. - multichoice_logits = multichoice_logits.view(-1, num_choices) - - return multichoice_logits - return lm_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process: - state_dict_[self._multichoice_head_key] \ - = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - if self._multichoice_head_key in state_dict: - self.multichoice_head.load_state_dict( - state_dict[self._multichoice_head_key], strict=strict) - else: - print_rank_last('***WARNING*** could not find {} in the checkpoint, ' - 'initializing to random'.format( - self._multichoice_head_key)) diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py deleted file mode 100644 index 654f2992f62cdcb52c3ef7d40cc2dcb78fd776b6..0000000000000000000000000000000000000000 --- a/megatron/model/realm_model.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import torch - -from megatron import get_args, print_rank_0 -from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name -from megatron.model import BertModel -from .module import MegatronModule -from megatron.core import mpu -from megatron.model.enums import AttnMaskType -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.language_model import get_language_model -from megatron.model.utils import scaled_init_method_normal -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids - - -def general_ict_model_provider(only_query_model=False, only_block_model=False): - """Build the model.""" - args = get_args() - assert args.ict_head_size is not None, \ - "Need to specify --ict-head-size to provide an ICTBertModel" - assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ - "Model parallel size > 1 not supported for ICT" - - print_rank_0('building ICTBertModel...') - - # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes - model = ICTBertModel( - ict_head_size=args.ict_head_size, - num_tokentypes=2, - parallel_output=True, - only_query_model=only_query_model, - only_block_model=only_block_model) - - return model - - -class ICTBertModel(MegatronModule): - """Bert-based module for Inverse Cloze task.""" - def __init__(self, - ict_head_size, - num_tokentypes=1, - parallel_output=True, - only_query_model=False, - only_block_model=False): - super(ICTBertModel, self).__init__() - bert_kwargs = dict( - ict_head_size=ict_head_size, - num_tokentypes=num_tokentypes, - parallel_output=parallel_output - ) - assert not (only_block_model and only_query_model) - self.use_block_model = not only_query_model - self.use_query_model = not only_block_model - - if self.use_query_model: - # this model embeds (pseudo-)queries - Embed_input in the paper - self.query_model = IREncoderBertModel(**bert_kwargs) - self._query_key = 'question_model' - - if self.use_block_model: - # this model embeds evidence blocks - Embed_doc in the paper - self.block_model = IREncoderBertModel(**bert_kwargs) - self._block_key = 'context_model' - - def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): - """Run a forward pass for each of the models and return the respective embeddings.""" - query_logits = self.embed_query(query_tokens, query_attention_mask) - block_logits = self.embed_block(block_tokens, block_attention_mask) - return query_logits, block_logits - - def embed_query(self, query_tokens, query_attention_mask): - """Embed a batch of tokens using the query model""" - if self.use_query_model: - query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) - query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) - return query_ict_logits - else: - raise ValueError("Cannot embed query without query model.") - - def embed_block(self, block_tokens, block_attention_mask): - """Embed a batch of tokens using the block model""" - if self.use_block_model: - block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0) - block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) - return block_ict_logits - else: - raise ValueError("Cannot embed block without block model.") - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """Save dict with state dicts of each of the models.""" - state_dict_ = {} - if self.use_query_model: - state_dict_[self._query_key] \ - = self.query_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - - if self.use_block_model: - state_dict_[self._block_key] \ - = self.block_model.state_dict_for_save_checkpoint( - prefix=prefix, keep_vars=keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Load the state dicts of each of the models""" - if self.use_query_model: - print("Loading ICT query model", flush=True) - self.query_model.load_state_dict( - state_dict[self._query_key], strict=strict) - - if self.use_block_model: - print("Loading ICT block model", flush=True) - self.block_model.load_state_dict( - state_dict[self._block_key], strict=strict) - - def init_state_dict_from_bert(self): - """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" - args = get_args() - tracker_filename = get_checkpoint_tracker_filename(args.bert_load) - if not os.path.isfile(tracker_filename): - raise FileNotFoundError("Could not find BERT load for ICT") - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - assert iteration > 0 - - checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except BaseException: - raise ValueError("Could not load checkpoint") - - # load the LM state dict into each model - model_dict = state_dict['model']['language_model'] - self.query_model.language_model.load_state_dict(model_dict) - self.block_model.language_model.load_state_dict(model_dict) - - # give each model the same ict_head to begin with as well - query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] - self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) - - -class IREncoderBertModel(MegatronModule): - """BERT-based encoder for queries or blocks used for learned information retrieval.""" - def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): - super(IREncoderBertModel, self).__init__() - args = get_args() - - self.ict_head_size = ict_head_size - self.parallel_output = parallel_output - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, - args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method) - - self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method) - self._ict_head_key = 'ict_head' - - def forward(self, input_ids, attention_mask, tokentype_ids=None): - extended_attention_mask = bert_extended_attention_mask( - attention_mask, next(self.language_model.parameters()).dtype) - position_ids = bert_position_ids(input_ids) - - lm_output, pooled_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids) - - # Output. - ict_logits = self.ict_head(pooled_output) - return ict_logits, None - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - state_dict_[self._ict_head_key] \ - = self.ict_head.state_dict(prefix=prefix, - keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - self.ict_head.load_state_dict( - state_dict[self._ict_head_key], strict=strict) - - diff --git a/megatron/model/rms_norm.py b/megatron/model/rms_norm.py deleted file mode 100644 index d42e7df9a812b4212516c48af61bfe1796a48244..0000000000000000000000000000000000000000 --- a/megatron/model/rms_norm.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -import torch -from torch import nn - -class RMSNorm(torch.nn.Module): - - def __init__(self, - dim: int, - eps: float = 1e-6, - sequence_parallel: bool = False): - """RMS Normaliation module - - Arguments: - dim (int): The width of input, i.e. hidden size - eps (float): epsilon to use for the norm, default to 1e-6 - sequence_parallel (bool): Set to true if sequence parallelism is being used, - this marks the weights as needing to be allreduced. - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - setattr(self.weight, 'sequence_parallel', sequence_parallel) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py deleted file mode 100644 index f9fabd34010e7467546b0c32a79039ef049aec86..0000000000000000000000000000000000000000 --- a/megatron/model/t5_model.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""T5 model.""" - -import torch - -from megatron import get_args -from megatron.core import tensor_parallel -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import parallel_lm_logits, get_language_model -from megatron.model import LayerNorm -from megatron.model.utils import ( - openai_gelu, - get_linear_layer -) -from .module import MegatronModule - - -def t5_extended_attention_mask(attention_mask_list): - - def attn_mask_postprocess(attn_mask): - # [b, 1, s, s] - extended_attention_mask = attn_mask.unsqueeze(1) - return extended_attention_mask - - return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] - - -def t5_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class T5LMHead(MegatronModule): - """Masked LM head for T5 - - Arguments: - mpu_vocab_size: model parallel size of vocabulary. - parallel_output: wether output logits being distributed or not. - """ - - def __init__(self, mpu_vocab_size, parallel_output): - super(T5LMHead, self).__init__() - - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - self.bias.model_parallel = True - self.bias.partition_dim = 0 - self.bias.stride = 1 - self.parallel_output = parallel_output - - def forward(self, hidden_states, word_embeddings_weight): - output = parallel_lm_logits(hidden_states, - word_embeddings_weight, - self.parallel_output, - bias=self.bias) - return output - - -class T5Model(MegatronModule): - """T5 Language model.""" - - def __init__(self, - config, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True, - add_encoder=True, - add_decoder=True): - super().__init__(config=config) - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - self.add_encoder = add_encoder - self.add_decoder = add_decoder - - self.language_model, self._language_model_key = get_language_model( - config=config, - num_tokentypes=num_tokentypes, - add_pooler=False, - add_encoder=add_encoder, - add_decoder=add_decoder, - encoder_attn_mask_type=AttnMaskType.padding, - pre_process=self.pre_process, - post_process=self.post_process) - - self.initialize_word_embeddings() - - if self.post_process and self.add_decoder: - self.lm_head = T5LMHead( - self.shared_embedding_or_output_weight().size(0), - parallel_output) - self._lm_head_key = 'lm_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, - decoder_attn_mask, encoder_decoder_attn_mask, - tokentype_ids=None, lm_labels=None, enc_hidden_states=None): - - # Converting the attention masks to proper parameter settings - encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( - [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) - - encoder_position_ids = t5_position_ids(encoder_input_ids) - decoder_position_ids = t5_position_ids(decoder_input_ids) - - lm_output = self.language_model(encoder_input_ids, - encoder_position_ids, - encoder_attn_mask, - decoder_input_ids, - decoder_position_ids, - decoder_attn_mask, - encoder_decoder_attn_mask, - tokentype_ids=tokentype_ids, - enc_hidden_states=enc_hidden_states) - - if self.post_process and self.add_decoder: - decoder_output, encoder_output = lm_output - # Output. [s, b, h] - lm_logits = self.lm_head(decoder_output, - self.shared_embedding_or_output_weight()) - - if lm_labels is None: - # [s b h] => [b s h] - return lm_logits.transpose(0,1).contiguous() - else: - # [b s] => [s b] - lm_labels = lm_labels.transpose(0,1).contiguous() - if self.fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) - # [s b] => [b s] - lm_loss = lm_loss.transpose(0,1).contiguous() - return lm_loss - elif self.add_decoder and not self.add_encoder: - decoder_output, encoder_output = lm_output - return decoder_output - else: - encoder_output = lm_output - return encoder_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - if self.post_process and self.add_decoder: - state_dict_[self._lm_head_key] \ - = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix, - keep_vars=keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process and self.add_decoder: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process and self.add_decoder: - self.lm_head.load_state_dict(state_dict[self._lm_head_key], - strict=strict) - # Load word embeddings. - if self.post_process and not self.pre_process and self.add_decoder: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py deleted file mode 100644 index 9f1144c02bb875f67ec556f739855b40eed8ea6f..0000000000000000000000000000000000000000 --- a/megatron/model/transformer.py +++ /dev/null @@ -1,1793 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Transformer.""" -from contextlib import nullcontext -import math -import numpy as np -import torch -import torch.nn.functional as F -from typing import Optional - -from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches -from .module import MegatronModule -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.model.enums import AttnMaskType, LayerType, AttnType -from megatron.model.fused_softmax import FusedScaleMaskSoftmax -from megatron.model.fused_bias_gelu import bias_gelu_impl -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding, apply_rotary_pos_emb -from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm -from megatron.core.tensor_parallel import ( - gather_from_sequence_parallel_region_to_moe, - reduce_scatter_to_sequence_parallel_region_from_moe, - get_cuda_rng_tracker, - get_data_parallel_rng_tracker_name -) -from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_expert_parallel_group - -try: - from einops import rearrange -except ImportError: - rearrange = None - -try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func -except ImportError: - try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func - except ImportError: - flash_attn_unpadded_func = None - -""" We use the following notation throughout this file: - h: hidden size - n: number of attention heads - p: number of model parallel partitions - np: n/p - hp: h/p - hn: h/n - b: batch size - s: sequence length - l: number of layers - Transformer takes input of size [s, b, h] and returns a - tensor of the same size. We use the following arguments: - hyperparameters: transformer hyperparameters -""" - -class DropPath(MegatronModule): - """Drop paths (Stochastic Depth) per sample - (when applied in main path of residual blocks). - """ - - def __init__(self, drop_prob=0.): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_state): - if self.drop_prob == 0. or not self.training: - return hidden_state - keep_prob = 1 - self.drop_prob - # work with diff dim tensors, not just 2D ConvNets - # hidden_state: [s, b, h] - shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) - random_tensor = keep_prob + \ - torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) - random_tensor.floor_() # binarize - output = hidden_state.div(keep_prob) * random_tensor - return output - -class ParallelMLP(MegatronModule): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config, is_expert=False): - super(ParallelMLP, self).__init__() - args = get_args() - - self.add_bias = config.add_bias_linear - - ffn_hidden_size = config.ffn_hidden_size - if config.gated_linear_unit: - ffn_hidden_size *= 2 - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - ffn_hidden_size, - config=config, - init_method=config.init_method, - bias=self.add_bias, - gather_output=False, - skip_bias_add=True, - is_expert=is_expert, - ) - - self.bias_gelu_fusion = False - self.activation_func = None - self.swiglu = args.swiglu - - if args.openai_gelu: - self.activation_func = openai_gelu - elif args.onnx_safe: - self.activation_func = erf_gelu - elif args.swiglu: - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - self.activation_func = swiglu - elif args.squared_relu: - def squared_relu(x): - return torch.pow(F.relu(x), 2) - self.activation_func = squared_relu - else: - self.bias_gelu_fusion = args.bias_gelu_fusion - self.activation_func = F.gelu - - # Project back to h. - self.dense_4h_to_h = tensor_parallel.RowParallelLinear( - config.ffn_hidden_size, - config.hidden_size, - config=config, - init_method=config.output_layer_init_method, - bias=self.add_bias, - skip_bias_add=True, - input_is_parallel=True, - is_expert=is_expert, - ) - - def forward(self, hidden_states): - - # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - - if self.bias_gelu_fusion: - assert self.add_bias is True - assert self.activation_func == F.gelu - intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) - else: - if bias_parallel is not None: - intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - - # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) - return output, output_bias - -def sinkhorn(cost, tol=0.0001): - cost = torch.exp(cost) - d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) - d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) - - eps = 0.00000001 - error = 1e9 - d1_old = d1 - while error > tol: - d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps) - d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps) - error = torch.mean(torch.abs(d1_old-d1)) - d1_old = d1 - return d1*cost*d0.unsqueeze(1) - - -def get_router_linear_layer(config): - args = get_args() - router = torch.nn.Linear(args.hidden_size, args.num_experts, bias=False) - with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): - config.init_method(router.weight) - setattr(router.weight, 'sequence_parallel',config.sequence_parallel) - return router - - -class SwitchMLP(MegatronModule): - """ - Routes input to one of N MLP "experts" - """ - def __init__(self, config): - super(SwitchMLP, self).__init__() - args = get_args() - self.router = get_router_linear_layer(config) - self.expert_parallel_size = mpu.get_expert_model_parallel_world_size() - self.sequence_parallel = config.sequence_parallel - self.add_bias = config.add_bias_linear - - assert args.num_experts % self.expert_parallel_size == 0 - self.num_local_experts = args.num_experts // self.expert_parallel_size - local_expert_indices_offset = mpu.get_expert_model_parallel_rank() * self.num_local_experts - self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] - - self.local_experts = torch.nn.ModuleList() - for i in range(self.num_local_experts): - self.local_experts.append(ParallelMLP(config, is_expert=True)) - - def gather_indices(self, local_indices): - """ Gather tensors and concatinate along the first dimension.""" - group = get_tensor_and_expert_parallel_group() - world_size = torch.distributed.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return local_indices - - dim_size = list(local_indices.size()) - dim_size[0] = dim_size[0] * world_size - - # TODO pre allocate memory - output = torch.empty(dim_size, dtype=local_indices.dtype, - device=torch.cuda.current_device()) - torch.distributed._all_gather_base( - output, local_indices.contiguous(), group=group - ) - return output - - def forward(self, hidden_states): - # hidden_states: [b, s, h] - args = get_args() - s = hidden_states.size(0) - b = hidden_states.size(1) - h = hidden_states.size(2) - route = self.router(hidden_states).view(-1, args.num_experts) - - # TODO (rprenger) Right now we're just using the sinkhorn algorithm - # for load balancing. There should be an option to do no load balancing - # and the algorithm and parametets should be further tested - if self.training: - with torch.no_grad(): - sinkroute = sinkhorn(route.detach().to(dtype=torch.float32)) - _, max_ind = torch.max(sinkroute, dim=1) - route = torch.sigmoid(route) - max_prob = route[torch.arange(route.size(0)), max_ind] - else: - route = torch.sigmoid(route) - max_prob, max_ind = torch.max(route, dim=1) - - max_prob = torch.unsqueeze(max_prob, 1) - hidden_states = hidden_states.view(-1, hidden_states.size(2)) - - # TODO (rprenger) TODO this could be made easier to read - # Converting [s, b, h] to [s*b, h]. - # Each vector could be routed differently - if self.sequence_parallel or (self.expert_parallel_size > 1): - global_hidden_states = \ - gather_from_sequence_parallel_region_to_moe(hidden_states) - global_indices = self.gather_indices(max_ind) - else: - global_hidden_states = hidden_states - global_indices = max_ind - - output_total = torch.zeros_like(global_hidden_states) - if self.add_bias: - output_bias_total = torch.zeros_like(global_hidden_states) - - for expert_num, expert in enumerate(self.local_experts): - local_expert_index = self.local_expert_indices[expert_num] - local_indices = (global_indices == local_expert_index).nonzero() - hidden = global_hidden_states[local_indices, :] - output, output_bias = expert(hidden) - output_total[local_indices, :] = output - if self.add_bias: - output_bias = output_bias.expand_as(output) - output_bias_total[local_indices, :] = output_bias - - if self.sequence_parallel or (self.expert_parallel_size > 1): - output_total = \ - reduce_scatter_to_sequence_parallel_region_from_moe(output_total) - if self.add_bias: - output_bias_total = \ - reduce_scatter_to_sequence_parallel_region_from_moe(output_bias_total) - - # bias is duplicated across tensor parallelism ranks; - # reduce scatter reduces bias across tensor parallel_ranks - output_bias_total = \ - output_bias_total/mpu.get_tensor_model_parallel_world_size() - - output_total = output_total*max_prob - output_total = output_total.view(s, b, h) - if self.add_bias: - output_bias_total = output_bias_total*max_prob - output_bias_total = output_bias_total.view(s, b, h) - else: - output_bias_total = None - - return output_total, output_bias_total - - -class CoreAttention(MegatronModule): - - def __init__(self, layer_number, config, - attn_mask_type=AttnMaskType.padding): - super(CoreAttention, self).__init__() - self.fp16 = config.fp16 - self.bf16 = config.bf16 - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.sequence_parallel = config.sequence_parallel - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - world_size = mpu.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = core.utils.divide(projection_size, - world_size) - self.hidden_size_per_attention_head = core.utils.divide( - projection_size, config.num_attention_heads) - self.num_attention_heads_per_partition = core.utils.divide( - config.num_attention_heads, world_size) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.fp16, self.bf16, - self.attn_mask_type, - config.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - coeff) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, - value_layer, attention_mask): - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], - output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], - output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( - (output_size[0]*output_size[1], output_size[2], output_size[3]), - query_layer.dtype, "mpu") - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, - attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - if not self.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), - output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class FlashSelfAttention(torch.nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - device=None, dtype=None): - super().__init__() - assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' - 'e.g., with pip install flash-attn') - assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, q, k, v): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the query, key, and value. (B, S, H, D) - """ - - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) - assert all((i.is_cuda for i in (q,k,v))) - - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = k.shape[1] - - q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, - device=q.device) - - if self.training: - # during training q,k,v always have same seqlen - assert seqlen_k == seqlen_q - - is_causal = self.causal - cu_seqlens_k = cu_seqlens_q - dropout_p = self.dropout_p - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = seqlen_q == seqlen_k - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, - device=q.device) - dropout_p = 0 - - output = flash_attn_unpadded_func( - q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - dropout_p, - softmax_scale=self.softmax_scale, causal=is_causal - ) - - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) - return output - - -class ParallelAttention(MegatronModule): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config, layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding): - super(ParallelAttention, self).__init__() - args = get_args() - self.layer_number = max(1, layer_number) - self.attention_type = attention_type - self.attn_mask_type = attn_mask_type - self.params_dtype = config.params_dtype - self.sequence_parallel = config.sequence_parallel - - self.group_query_attention = args.group_query_attention - self.num_query_groups = args.num_query_groups - - query_projection_size = config.kv_channels * config.num_attention_heads - if self.group_query_attention: - kv_projection_size = args.kv_channels * args.num_query_groups - else: - kv_projection_size = args.kv_channels * args.num_attention_heads - - self.use_flash_attn = args.use_flash_attn \ - and attention_type == AttnType.self_attn \ - and self.attn_mask_type == AttnMaskType.causal - if self.use_flash_attn: - if flash_attn_unpadded_func is None: - raise ImportError('FlashAttention is not installed, please install with ' - 'pip install flash-attn') - assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' - 'self-attention for now') - assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' - 'supports causal mask for now') - if rearrange is None: - raise ImportError('einops is not installed, please install with pip install einops') - - # Per attention head and per partition values. - world_size = mpu.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = core.utils.divide( - query_projection_size, config.num_attention_heads) - self.num_attention_heads_per_partition = core.utils.divide( - config.num_attention_heads, world_size) - - if self.group_query_attention: - if args.num_query_groups % world_size != 0: - raise NotImplementedError('Currently the num_query_groups should be ' - 'a multiple of the tensor parallel size') - self.num_query_groups_per_partition = core.utils.divide( - args.num_query_groups, world_size) - else: - self.num_query_groups_per_partition = self.num_attention_heads_per_partition - - # Strided linear layer. - if attention_type == AttnType.self_attn: - self.query_key_value = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - query_projection_size + 2 * kv_projection_size, - config=config, - init_method=config.init_method, - bias=args.add_bias_linear, - gather_output=False) - else: - assert attention_type == AttnType.cross_attn - - if self.group_query_attention: - raise NotImplementedError("Grouped query attention not implemented for cross-attention.") - assert query_projection_size == kv_projection_size - - self.query = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - query_projection_size, - config=config, - init_method=config.init_method, - bias=config.add_bias_linear, - gather_output=False) - - self.key_value = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - 2 * kv_projection_size, - config=config, - init_method=config.init_method, - bias=config.add_bias_linear, - gather_output=False) - - self.core_attention = CoreAttention(self.layer_number, config, - self.attn_mask_type) - self.checkpoint_core_attention = config.recompute_granularity == 'selective' - - if self.use_flash_attn: - self.core_attention_flash = FlashSelfAttention( - causal=True, attention_dropout=config.attention_dropout - ) - - # Output. - self.dense = tensor_parallel.RowParallelLinear( - query_projection_size, - config.hidden_size, - config=config, - init_method=config.output_layer_init_method, - bias=args.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True) - - def _checkpointed_attention_forward(self, query_layer, key_layer, - value_layer, attention_mask, - rotary_pos_emb=None): - """Forward method with activation checkpointing.""" - def custom_forward(*inputs): - query_layer = inputs[0] - key_layer = inputs[1] - value_layer = inputs[2] - attention_mask = inputs[3] - output_ = self.core_attention(query_layer, key_layer, - value_layer, attention_mask) - return output_ - - q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ - else rotary_pos_emb - - hidden_states = tensor_parallel.checkpoint( - custom_forward, - False, query_layer, key_layer, value_layer, attention_mask, - q_pos_emb, k_pos_emb) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=self.params_dtype, - device=torch.cuda.current_device()) - - def forward(self, hidden_states, attention_mask, - encoder_output=None, inference_params=None, - rotary_pos_emb=None): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - is_first_step = False - if inference_params: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, - self.num_query_groups_per_partition) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, - self.num_query_groups_per_partition) - - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, inference_value_memory) - is_first_step = True - else: - inference_key_memory, inference_value_memory = \ - inference_params.key_value_memory_dict[self.layer_number] - - # ===================== - # Query, Key, and Value - # ===================== - if self.attention_type == AttnType.self_attn: - - # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_query_groups_per_partition, - ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head - ), - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query_layer, - key_layer, - value_layer) = torch.split( - mixed_x_layer, - [ - ( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head - ], - dim=3) - - # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - - query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) - else: - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, - value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query_layer, _ = self.query(hidden_states) - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - query_layer = query_layer.view(*new_tensor_shape) - - # ================================== - # Adjust key and value for inference - # ================================== - - # duplicate the pos_emb for self attention - if rotary_pos_emb is not None: - if isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = rotary_pos_emb - else: - rotary_pos_emb = ((rotary_pos_emb,) * 2) - - if inference_params: - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - # Copy key and values. - inference_key_memory[sequence_start:sequence_end, - batch_start:batch_end, ...] = key_layer - inference_value_memory[sequence_start:sequence_end, - batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[ - :sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[ - :sequence_end, batch_start:batch_end, ...] - - - # adjust the key rotary positional embedding - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # need to cross check this condition during inference - # if not set_inference_key_value_memory: - if not is_first_step: - # In inference, we compute one token at a time. - # Select the correct positional embedding - # (only the last token in the sequence) - q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] - else: - # In the first forward pass of inference, - # we use the entire provided prefix. - # q_pos_emb here has the rope embeddings of the entire - # prefix + to-be-generated output so - # we slice to just the prefix. - q_pos_emb = q_pos_emb[:sequence_end, :, :, :] - k_pos_emb = k_pos_emb[:sequence_end, :, :, :] - rotary_pos_emb = (q_pos_emb, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] - if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: - key_layer = key_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, - dim = 2 - ) - value_layer = value_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, - dim = 2 - ) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - if not self.use_flash_attn: - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) - else: - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) - else: - q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (query_layer, key_layer, value_layer)] - if not self.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(q, k, v) - else: - context_layer = self.core_attention_flash(q, k, v) - context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.dense(context_layer) - - return output, bias - - -def bias_dropout_add(x, bias, residual, prob, training): - # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor - if bias is not None: - x = x + bias - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out - return out - - -def get_bias_dropout_add(training): - def _bias_dropout_add(x, bias, residual, prob): - return bias_dropout_add(x, bias, residual, prob, training) - return _bias_dropout_add - - -@torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: Optional[torch.Tensor], - residual: torch.Tensor, - prob: float) -> torch.Tensor: - return bias_dropout_add(x, bias, residual, prob, True) - - -@torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: Optional[torch.Tensor], - residual: torch.Tensor, - prob: float) -> torch.Tensor: - return bias_dropout_add(x, bias, residual, prob, False) - - -class ParallelTransformerLayer(MegatronModule): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config, - layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - drop_path_rate=0.): - args = get_args() - - super(ParallelTransformerLayer, self).__init__() - self.layer_number = layer_number - self.layer_type = layer_type - - self.apply_residual_connection_post_norm \ - = config.apply_residual_connection_post_layernorm - - self.bf16 = config.bf16 - self.fp32_residual_connection = config.fp32_residual_connection - - # Normalize the input data. - self.input_norm = get_norm(config) - - # Self attention. - self.self_attention = ParallelAttention( - config, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) - self.hidden_dropout = config.hidden_dropout - self.bias_dropout_fusion = config.bias_dropout_fusion - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None - - # Normalize the attention output - self.post_attention_norm = get_norm(config) - - # Cross attention. - if self.layer_type in (LayerType.decoder, - LayerType.retro_decoder, - LayerType.retro_decoder_with_retriever, - LayerType.retro_encoder): - self.inter_attention = ParallelAttention( - config, - layer_number, - attention_type=AttnType.cross_attn) - # Normalize the attention output. - self.post_inter_attention_norm = get_norm(config) - - # MLP - if args.num_experts is not None: - self.mlp = SwitchMLP(config) - else: - self.mlp = ParallelMLP(config) - - # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = \ - nullcontext if use_nvfuser else torch.enable_grad - - if args.retro_add_retriever: - retro_args = get_retro_args() - self.retro_num_neighbors = args.retro_num_neighbors - self.retro_chunk_length = retro_args.retro_gpt_chunk_length - self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length - - # Retriever (bi-directional transformer with cross attention) - if layer_type == LayerType.retro_decoder_with_retriever: - self.retriever = ParallelTransformer( - config=config, - model_type=ModelType.retro_encoder, - self_attn_mask_type=AttnMaskType.padding, - pre_process=True, - post_process=False, - ) - self._retriever_key = 'retriever' - else: - self.retriever = None - - def default_decoder_cross_attention(self, - encoder_output, - enc_dec_attn_mask, - norm_input, - norm_output, - bias_dropout_add_func): - '''Cross attention for a standard encoder-decoder model.''' - - # Attention. - attention_output, attention_bias = \ - self.inter_attention(norm_output, - enc_dec_attn_mask, - encoder_output=encoder_output) - - # Residual connection. - if self.apply_residual_connection_post_norm: - residual = norm_output - else: - residual = norm_input - - if attention_bias is not None: - attention_bias = attention_bias.expand_as(residual) - - # Bias-dropout-add. - with self.bias_dropout_add_exec_handler(): - norm_input = bias_dropout_add_func( - attention_output, - attention_bias, - residual, - self.hidden_dropout) - - # Normalize. - norm_output = self.post_inter_attention_norm(norm_input) - - return norm_input, norm_output - - def retro_encoder_cross_attention(self, - retriever_output, - norm_input, - norm_output, - bias_dropout_add_func): - """Cross attention for Retro encoder. - - Notation: - ns : Sequence length. - bs : Batch size. - d : Hidden size. - l : Number of chunks per sample (i.e., seq_length/chunk_length). - k : Number of neighbors. - r : Number of retrieved tokens (neighbors + continuation). - """ - - ns, bs, d = norm_output.shape # [r, bs * l * k, d] - - # Divide sequence dimension into chunks. - chunked_outputs = norm_output.reshape(self.retro_retrieved_length, - -1, - self.retro_num_neighbors, - d) - chunked_outputs_before_norm = \ - norm_input.reshape(self.retro_retrieved_length, -1, - self.retro_num_neighbors, d) # [r, bs*l, k, d] - - # Per-chunk attention. - norm_inputs = [] - norm_outputs = [] - for k in range(self.retro_num_neighbors): - - # Attention. - chunked_output = chunked_outputs[:,:,k].contiguous() - attention_output, attention_bias = \ - self.inter_attention( - chunked_output, # Q (neighbor embedding) - None, - encoder_output=retriever_output) # K, V (hidden act) - - # Residual connection. - if self.apply_residual_connection_post_norm: - residual = chunked_output - else: - residual = chunked_outputs_before_norm[:,:,k] - - # Re-enable torch grad to enable fused optimization. - with torch.enable_grad(): - norm_input = bias_dropout_add_func( - attention_output, - None if attention_bias is None else attention_bias.expand_as(residual), - residual, - self.hidden_dropout) - norm_inputs.append(norm_input) - - # Layer norm. - norm_output = self.post_inter_attention_norm(norm_input) - norm_outputs.append(norm_output) - - # Concatenate layer norms. - # norm_input : [r, k * bs * l, d] - # norm_output : [r, k * bs * l, d] - norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d) - norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d) - - return norm_input, norm_output - - def retro_decoder_cross_attention(self, - retriever_input, - retriever_output, - retriever_attn_mask, - norm_input, - norm_output, - inference_params, - bias_dropout_add_func): - """Cross attention for Retro decoder. - - Notation: - ns : Sequence length. - bs : Batch size. - d : Hidden size. - l : Number of chunks per sample (i.e., seq_length/chunk_length). - m : Number of tokens per chunk. - k : Number of neighbors. - r : Number of retrieved tokens (neighbors + continuation). - """ - - ns, bs, d = norm_output.shape - l = int(np.ceil(ns / self.retro_chunk_length)) - - # Retrieve neighbors. - if self.layer_type == LayerType.retro_decoder_with_retriever: - first_ns = ns % self.retro_chunk_length - if first_ns > 0: - raise Exception("test this case.") - first_chunk, rest_chunk = \ - norm_output[:first_ns], norm_output[first_ns:] - first_chunk = torch.nn.functional.pad( - first_chunk, - (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), - 'constant', - 0) - chunked_output = \ - torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] - else: - chunked_output = norm_output # [l * m, bs, d] - chunked_output = chunked_output \ - .reshape(l, self.retro_chunk_length, bs, d) \ - .permute(1, 2, 0, 3) \ - .reshape(self.retro_chunk_length, bs * l, d) \ - .contiguous() - - # Get Encoder Output - retriever_output = self.retriever( - hidden_states=retriever_input, - attention_mask=retriever_attn_mask, - retriever_output=chunked_output, - retriever_attn_mask=retriever_attn_mask, - inference_params=inference_params) # [r, k * bs * l , d] - retriever_output = retriever_output.reshape( - self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] - - # Chunks. - pad = (ns - 1) % self.retro_chunk_length - attending_chunks = norm_output[pad:] - padded_chunks = torch.nn.functional.pad( - attending_chunks, - (0, 0, 0, 0, 0, self.retro_chunk_length - 1), - 'constant', 0) - padded_chunked_output = padded_chunks \ - .reshape(l, self.retro_chunk_length, bs, d) \ - .permute(1, 2, 0, 3) - padded_chunked_output = padded_chunked_output.reshape( - self.retro_chunk_length, bs * l, d).contiguous() - - # Encoder output. - attention_output, attention_bias = \ - self.inter_attention(padded_chunked_output, - None, - encoder_output=retriever_output) - - # Residual connection. - if self.apply_residual_connection_post_norm: - residual = norm_output - else: - residual = norm_input - - # Re-enable torch grad to enable fused optimization. - with torch.enable_grad(): - norm_input = bias_dropout_add_func( - attention_output, - None if attention_bias is None else attention_bias.expand_as(attention_output), - torch.zeros_like(attention_output), - self.hidden_dropout) - norm_input = norm_input \ - .reshape(self.retro_chunk_length, bs, l, d) \ - .permute(2, 0, 1, 3) # [l, m, bs, d] - norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d) - norm_input = torch.nn.functional.pad( - norm_input, - (0, 0, 0, 0, pad, 0), - 'constant', 0)[:ns] # [ns, b, d] - norm_input = norm_input + residual - - # Layer norm post the decoder attention - norm_output = self.post_inter_attention_norm(norm_input) - - return retriever_output, norm_input, norm_output - - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - norm_output = self.input_norm(hidden_states) - - # Self attention. - attention_output, attention_bias = \ - self.self_attention( - norm_output, - attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) - - # Residual connection. - if self.apply_residual_connection_post_norm: - residual = norm_output - else: - residual = hidden_states - - if self.drop_path is None: - # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two - # different nn.functional routines to account for varying - # dropout semantics during training and inference phases. - if self.bias_dropout_fusion: - if self.training: - bias_dropout_add_func = bias_dropout_add_fused_train - else: - bias_dropout_add_func = bias_dropout_add_fused_inference - else: - bias_dropout_add_func = get_bias_dropout_add(self.training) - - if attention_bias is not None: - attention_bias = attention_bias.expand_as(residual) - with self.bias_dropout_add_exec_handler(): - norm_input = bias_dropout_add_func( - attention_output, - attention_bias, - residual, - self.hidden_dropout) - else: - out = torch.nn.functional.dropout(attention_output + attention_bias, - p=self.hidden_dropout, - training=self.training) - norm_input = residual + self.drop_path(out) - - # Layer norm post the self attention. - norm_output = self.post_attention_norm(norm_input) - - # Cross attention. - if self.layer_type == LayerType.encoder: - pass - elif self.layer_type == LayerType.decoder: - norm_input, norm_output = \ - self.default_decoder_cross_attention( - encoder_output, - enc_dec_attn_mask, - norm_input, - norm_output, - bias_dropout_add_func) - elif self.layer_type == LayerType.retro_encoder: - norm_input, norm_output = \ - self.retro_encoder_cross_attention( - retriever_output, - norm_input, - norm_output, - bias_dropout_add_func) - elif self.layer_type in (LayerType.retro_decoder, - LayerType.retro_decoder_with_retriever): - retriever_output, norm_input, norm_output = \ - self.retro_decoder_cross_attention( - retriever_input, - retriever_output, - retriever_attn_mask, - norm_input, - norm_output, - inference_params, - bias_dropout_add_func) - else: - raise Exception("Unsupported layer type, '%s'." % - self.layer_type.name) - - # MLP. - mlp_output, mlp_bias = self.mlp(norm_output) - - # Second residual connection. - if self.apply_residual_connection_post_norm: - residual = norm_output - else: - residual = norm_input - - if self.drop_path is None: - if mlp_bias is not None: - mlp_bias = mlp_bias.expand_as(residual) - with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias, - residual, - self.hidden_dropout) - - # Jit compiled function creates 'view' tensor. This tensor - # potentially gets saved in the MPU checkpoint function context, - # which rejects view tensors. While making a viewless tensor here - # won't result in memory savings (like the data loader, or - # p2p_communication), it serves to document the origin of this - # 'view' tensor. - output = core.utils.make_viewless_tensor(inp = output, - requires_grad = output.requires_grad, - keep_graph = True) - - else: - if mlp_bias is not None: - mlp_output = mlp_output + mlp_bias - out = torch.nn.functional.dropout(mlp_output, - p=self.hidden_dropout, - training=self.training) - output = residual + self.drop_path(out) - - if self.layer_type == LayerType.retro_decoder_with_retriever: - return output, retriever_output - else: - return output - - -class NoopTransformerLayer(MegatronModule): - """A single 'no-op' transformer layer. - - The sole purpose of this layer is for when a standalone embedding layer - is used (i.e., args.standalone_embedding_stage == True). In this case, - zero transformer layers are assigned when pipeline rank == 0. Additionally, - when virtual pipeline rank >= 1, zero total model parameters are created - (virtual rank 0 contains the input embedding). This results in the model's - input and output tensors being the same, which causes an error when - performing certain memory optimiations on the output tensor (e.g., - deallocating it). Thus, this layer disconnects the input from the output - via a clone. Since ranks containing a no-op layer are generally under- - utilized (both compute and memory), there's no worry of any performance - degredation. - """ - - def __init__(self, layer_number): - super().__init__() - self.layer_number = layer_number - - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - inference_params=None): - return hidden_states.clone() - - -def _get_num_layers(args, model_type, is_decoder=False): - """Compute the number of transformer layers resident on the current rank.""" - is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) - if model_type == ModelType.retro_encoder: - num_layers = args.retro_encoder_layers - elif mpu.get_pipeline_model_parallel_world_size() > 1: - if is_encoder_and_decoder_model: - assert args.pipeline_model_parallel_split_rank is not None - - # When a standalone embedding stage is used, a rank is taken from - # the encoder's ranks, to be used for the encoder's embedding - # layer. This way, the rank referenced by the 'split rank' remains - # the same whether or not a standalone embedding stage is used. - num_ranks_in_encoder = ( - args.pipeline_model_parallel_split_rank - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_split_rank - ) - num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder - assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ - 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) - assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ - 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) - if mpu.is_pipeline_stage_before_split(): - num_layers = ( - 0 - if args.standalone_embedding_stage - and mpu.get_pipeline_model_parallel_rank() == 0 else - args.encoder_num_layers // num_ranks_in_encoder - ) - else: - num_layers = args.decoder_num_layers // num_ranks_in_decoder - else: - assert args.num_layers == args.encoder_num_layers - assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ - 'num_layers must be divisible by transformer_pipeline_model_parallel_size' - - # When a standalone embedding stage is used, all transformer layers - # are divided among pipeline rank >= 1, while on pipeline rank 0, - # ranks either contain the input embedding layer (virtual pp rank 0), - # or no layers at all (virtual pp rank >= 1). - num_layers = ( - 0 - if args.standalone_embedding_stage - and mpu.get_pipeline_model_parallel_rank() == 0 else - args.num_layers // args.transformer_pipeline_model_parallel_size - ) - else: - if not is_decoder: - num_layers = args.encoder_num_layers - else: - num_layers = args.decoder_num_layers - return num_layers - - -def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, - layer_number): - args = get_args() - if args.retro_add_retriever and layer_number in retro_layer_numbers: - if model_type == ModelType.retro_decoder: - return LayerType.retro_decoder_with_retriever \ - if layer_number == retro_layer_numbers[0] \ - else LayerType.retro_decoder - elif model_type == ModelType.retro_encoder: - return LayerType.retro_encoder - else: - raise Exception("Unsupported model type, '%s'." % model_type) - else: - return default_layer_type - - -class ParallelTransformer(MegatronModule): - """Transformer class.""" - - def __init__(self, config, - model_type, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - post_norm=True, - pre_process=True, - post_process=True, - drop_path_rate=0.0): - super(ParallelTransformer, self).__init__() - args = get_args() - - self.layer_type = layer_type - self.model_type = model_type - self.bf16 = config.bf16 - self.fp32_residual_connection = config.fp32_residual_connection - self.post_norm = post_norm - self.pre_process = pre_process - self.post_process = post_process - self.input_tensor = None - self.drop_path_rate = drop_path_rate - self.transformer_impl = args.transformer_impl - self.retro_add_retriever = args.retro_add_retriever - - # Store activation checkpoiting flag. - self.recompute_granularity = config.recompute_granularity - self.recompute_method = config.recompute_method - self.recompute_num_layers = config.recompute_num_layers - self.distribute_saved_activations = \ - config.distribute_saved_activations and not config.sequence_parallel - - self.sequence_parallel = config.sequence_parallel - - # Transformer Engine Init. - self.transformer_engine_v_0_10 = False - self.transformer_engine_v_0_11 = False - self.transformer_engine_v_0_8 = False - if self.transformer_impl == 'transformer_engine': - global transformer_engine - import transformer_engine - from importlib.metadata import version - from pkg_resources import packaging - - te_version = packaging.version.Version(version("transformer-engine")) - if te_version >= packaging.version.Version("0.8.0"): - self.transformer_engine_v_0_8 = True - if te_version >= packaging.version.Version("0.10.0"): - self.transformer_engine_v_0_10 = True - if te_version >= packaging.version.Version("0.11.0"): - self.transformer_engine_v_0_11 = True - - del version, packaging - - assert not args.squared_relu, "TransformerEngine does not support squared relu activation." - - self.use_fp8 = args.fp8 is not None - self.fp8_recipe = None - self.fp8_group = None - if self.use_fp8: - assert args.transformer_impl == 'transformer_engine', \ - 'transformer-engine required for fp8 training and inference' - self.fp8_group = mpu.get_amax_reduction_group() - if args.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif args.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") - self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - margin=args.fp8_margin, - interval=args.fp8_interval, - fp8_format=fp8_format, - amax_history_len=args.fp8_amax_history_len, - amax_compute_algo=args.fp8_amax_compute_algo, - override_linear_precision=(False, False, not args.fp8_wgrad), - ) - - self.num_microbatches_in_previous_step = -1 - self.microbatch_count = 0 - self.checkpoint_core_attention = config.recompute_granularity == 'selective' - - # Number of layers. - self.num_layers = _get_num_layers(args, model_type, - layer_type==LayerType.decoder) - - self.drop_path_rates = [ - rate.item() for rate in - torch.linspace(0, self.drop_path_rate, config.num_layers)] - - self.retro_layer_numbers = None - if model_type == ModelType.retro_decoder: - retro_layer_start = 6 if config.num_layers <= 15 else 9 - self.retro_layer_numbers = \ - np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() - if model_type == ModelType.retro_encoder: - self.retro_layer_numbers = [1] - - # Transformer layers. - if args.retro_add_retriever: - assert self.recompute_granularity != 'full', \ - "Full recompute not supported for Retro." - assert args.transformer_impl == 'local', \ - "Transformer engine does not support Retro layers." - def build_layer(layer_number): - if args.transformer_impl == 'local': - current_layer_type = _get_layer_type( - model_type, layer_type, self.retro_layer_numbers, - layer_number) - return ParallelTransformerLayer( - config, - layer_number, - layer_type=current_layer_type, - self_attn_mask_type=self_attn_mask_type, - drop_path_rate=self.drop_path_rates[layer_number - 1]) - else: - # This argument is only available from TE v0.10 onwards. - extra_transformer_engine_kwargs = {} - if self.transformer_engine_v_0_8: - extra_transformer_engine_kwargs["bias"] = args.add_bias_linear - if self.transformer_engine_v_0_10: - extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu" - if self.transformer_engine_v_0_11: - extra_transformer_engine_kwargs["normalization"] = args.normalization - return transformer_engine.pytorch.TransformerLayer( - config.hidden_size, - config.ffn_hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.layernorm_epsilon, - hidden_dropout=config.hidden_dropout, - attention_dropout=config.attention_dropout, - init_method=config.init_method, - output_layer_init_method=config.output_layer_init_method, - layer_number=layer_number, - kv_channels=config.kv_channels, - self_attn_mask_type=self_attn_mask_type.name, - tp_group=mpu.get_tensor_model_parallel_group(), - get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, - fuse_wgrad_accumulation=config.gradient_accumulation_fusion, - apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, - attention_softmax_in_fp32=config.attention_softmax_in_fp32, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - sequence_parallel=config.sequence_parallel, - params_dtype=config.params_dtype, - apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm, - output_layernorm=False, - layer_type="encoder", - drop_path_rate=self.drop_path_rates[layer_number - 1], - set_parallel_mode=True, - fuse_qkv_params=True, - **extra_transformer_engine_kwargs) - - if config.virtual_pipeline_model_parallel_size is not None: - assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ - 'num_layers_per_stage must be divisible by ' \ - 'virtual_pipeline_model_parallel_size' - assert args.model_type != ModelType.encoder_and_decoder - # Number of layers in each model chunk is the number of layers in the stage, - # divided by the number of model chunks in a stage. - self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size - # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0] [2] [4] [6] - # Stage 1: [1] [3] [5] [7] - # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0, 1] [4, 5] - # Stage 1: [2, 3] [6, 7] - offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( - config.num_layers // config.virtual_pipeline_model_parallel_size) + \ - (mpu.get_pipeline_model_parallel_rank() * self.num_layers) - else: - # Each stage gets a contiguous set of layers. - if args.model_type == ModelType.encoder_and_decoder and \ - mpu.get_pipeline_model_parallel_world_size() > 1: - pipeline_rank = mpu.get_pipeline_model_parallel_rank() - if layer_type == LayerType.encoder: - offset = pipeline_rank * self.num_layers - else: - num_ranks_in_enc = args.pipeline_model_parallel_split_rank - offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers - else: - offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers - - if self.num_layers == 0: - # When a standalone embedding stage is used (e.g., - # args.standalone_embedding_stage == True), virtual pipeline ranks - # on pipeline rank 0 will have zero transformer layers assigned to - # them. This results in the model's input and output tensors to be - # the same, which will cause failure for certain output tensor - # optimizations (e.g., pipeline output deallocation). To remedy - # this, we assign a 'no-op' layer on these ranks, which will - # disconnect the input tensor from the output tensor. - self.num_layers = 1 - self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) - else: - self.layers = torch.nn.ModuleList( - [build_layer(i + 1 + offset) for i in range(self.num_layers)]) - - # Update dropout rate for Retro encoder. - if model_type == ModelType.retro_encoder: - for layer in self.layers: - if layer.self_attention.use_flash_attn: - layer.self_attention.core_attention_flash.dropout_p = \ - torch.nn.Dropout(args.retro_encoder_attention_dropout) - else: - layer.self_attention.core_attention.attention_dropout.p =\ - args.retro_encoder_attention_dropout - layer.hidden_dropout = args.retro_encoder_hidden_dropout - - if self.post_process and self.post_norm: - # Final layer norm before output. - self.final_norm = get_norm(config) - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def _checkpointed_forward(self, hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - rotary_pos_emb, is_first_microbatch): - """Forward method with activation checkpointing.""" - def custom(start, end): - def custom_forward(*args, **kwargs): - x_, *args = args - for index in range(start, end): - layer = self._get_layer(index) - x_ = layer(x_, *args, **kwargs) - return x_ - return custom_forward - - te_forward_kwargs = {} - if self.transformer_impl == 'transformer_engine': - te_forward_kwargs['is_first_microbatch'] = is_first_microbatch - if self.transformer_engine_v_0_10: - te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - - if self.recompute_method == 'uniform': - # Uniformly divide the total number of Transformer layers and - # checkpoint the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - l = 0 - while l < self.num_layers: - if self.transformer_impl == 'transformer_engine': - hidden_states = transformer_engine.pytorch.checkpoint( - custom(l, l + self.recompute_num_layers), - self.distribute_saved_activations, - tensor_parallel.get_cuda_rng_tracker, - mpu.get_tensor_model_parallel_group(), - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) - else: - hidden_states = tensor_parallel.checkpoint( - custom(l, l + self.recompute_num_layers), - self.distribute_saved_activations, - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) - - l += self.recompute_num_layers - - elif self.recompute_method == 'block': - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - for l in range(self.num_layers): - if l < self.recompute_num_layers: - if self.transformer_impl == 'transformer_engine': - hidden_states = transformer_engine.pytorch.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - tensor_parallel.get_cuda_rng_tracker, - mpu.get_tensor_model_parallel_group(), - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) - else: - hidden_states = tensor_parallel.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) - else: - if self.transformer_impl == 'transformer_engine': - hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, **te_forward_kwargs) - else: - hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, - encoder_output, enc_dec_attn_mask, - None, None, None, None, rotary_pos_emb) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): - # hidden_states: [s, b, h] - - # Checks. - if inference_params: - assert self.recompute_granularity is None, \ - 'inference does not work with activation checkpointing' - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = core.utils.make_viewless_tensor( - hidden_states, - requires_grad=True, - keep_graph=True, - ) - - # RNG context. - if self.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # Forward layers. - with rng_context: - # The fp8_autocast context manager is a no-op when enabled=True - # The if...else serves to short circuit name resolution for fp8_autocast - with transformer_engine.pytorch.fp8_autocast( - enabled=self.use_fp8, - fp8_recipe=self.fp8_recipe, - fp8_group=self.fp8_group - ) if self.use_fp8 else nullcontext(): - # Determine if the current iteration is first microbatch - if self.num_microbatches_in_previous_step != get_num_microbatches(): - self.microbatch_count = 0 # Reset count on new batch size rampup interval - self.num_microbatches_in_previous_step = get_num_microbatches() - is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 - - # Forward pass. - if self.recompute_granularity == 'full': - hidden_states = self._checkpointed_forward(hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - rotary_pos_emb, - is_first_microbatch) - else: - forward_kwargs = { - 'encoder_output': encoder_output, - 'enc_dec_attn_mask': enc_dec_attn_mask, - 'inference_params': inference_params, - } - - if self.transformer_impl == 'transformer_engine': - forward_kwargs['is_first_microbatch'] = is_first_microbatch - forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention - if self.transformer_engine_v_0_10: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - else: - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - forward_kwargs['retriever_input'] = retriever_input - forward_kwargs['retriever_output'] = retriever_output - forward_kwargs['retriever_attn_mask'] = retriever_attn_mask - - for index in range(self.num_layers): - layer = self._get_layer(index) - - hidden_states = layer( - hidden_states, - attention_mask, - **forward_kwargs) - - # First Retro decoder layer returns both hidden_states - # and retriever_output. Make retriever_output available - # to subsequence Retro layers. - if isinstance(hidden_states, tuple): - assert len(hidden_states) == 2 - hidden_states, retriever_output = hidden_states - forward_kwargs["retriever_output"] = retriever_output - - # Skip counter update for eval and activation checkpointing - if torch.is_grad_enabled() and self.training: - self.microbatch_count += 1 - - # Final layer norm. - if self.post_process and self.post_norm: - hidden_states = self.final_norm(hidden_states) - - return hidden_states - - def load_state_dict(self, state_dict, strict=True): - """Customize load.""" - - # Handle renaming layernorm -> norm in component names - state_dict_ = {} - for key in state_dict.keys(): - newkey = key.replace("layernorm", "norm") - state_dict_[newkey] = state_dict[key] - - super().load_state_dict(state_dict_, strict) diff --git a/megatron/model/utils.py b/megatron/model/utils.py deleted file mode 100644 index 15fbe9ad9e215e5a33d03c8fcc58cbc281b5913f..0000000000000000000000000000000000000000 --- a/megatron/model/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Utilities for models.""" - -import math - -import torch - -from megatron import get_args -from megatron.model import LayerNorm, RMSNorm - -def init_method_normal(sigma): - """Init method based on N(0, sigma).""" - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma, num_layers): - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def get_linear_layer(rows, columns, init_method): - """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns) - if get_args().perform_initialization: - init_method(layer.weight) - with torch.no_grad(): - layer.bias.zero_() - return layer - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - - (1.0 + 0.044715 * x * x))) -def openai_gelu(x): - return gelu_impl(x) - - -#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter -@torch.jit.script -def erf_gelu(x): - return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) - - -def get_norm(config): - args = get_args() - if args.normalization == "LayerNorm": - return LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=not config.persist_layer_norm, - sequence_parallel=config.sequence_parallel, - apply_layernorm_1p=args.apply_layernorm_1p) - elif args.normalization == "RMSNorm": - if args.apply_layernorm_1p: - raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.') - - return RMSNorm(dim=config.hidden_size, - eps=config.layernorm_epsilon, - sequence_parallel=config.sequence_parallel) - else: - raise Exception(f"unsupported norm type '{args.normalization}'.") diff --git a/megatron/model/vision/__init__.py b/megatron/model/vision/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/model/vision/classification.py b/megatron/model/vision/classification.py deleted file mode 100644 index 3d5c823df47934f59319c4e8e1e869a62d425dc2..0000000000000000000000000000000000000000 --- a/megatron/model/vision/classification.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Vision Transformer(VIT) model.""" - -import torch -from torch.nn.init import trunc_normal_ -from megatron import get_args -from megatron.model.utils import get_linear_layer -from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead -from megatron.model.vision.mit_backbone import mit_b3_avg -from megatron.model.module import MegatronModule - -class VitClassificationModel(MegatronModule): - """Vision Transformer Model.""" - - def __init__(self, config, num_classes, finetune=False, - pre_process=True, post_process=True): - super(VitClassificationModel, self).__init__() - args = get_args() - self.config = config - - self.hidden_size = args.hidden_size - self.num_classes = num_classes - self.finetune = finetune - self.pre_process = pre_process - self.post_process = post_process - self.backbone = VitBackbone( - config=config, - pre_process=self.pre_process, - post_process=self.post_process, - single_token_output=True - ) - - if self.post_process: - if not self.finetune: - self.head = VitMlpHead(config, self.hidden_size, self.num_classes) - else: - self.head = get_linear_layer( - self.hidden_size, - self.num_classes, - torch.nn.init.zeros_ - ) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.backbone.set_input_tensor(input_tensor) - - def forward(self, input): - hidden_states = self.backbone(input) - - if self.post_process: - hidden_states = self.head(hidden_states) - - return hidden_states - - -class MitClassificationModel(MegatronModule): - """Mix vision Transformer Model.""" - - def __init__(self, num_classes, - pre_process=True, post_process=True): - super(MitClassificationModel, self).__init__() - args = get_args() - - self.hidden_size = args.hidden_size - self.num_classes = num_classes - - self.backbone = mit_b3_avg() - self.head = torch.nn.Linear(512, num_classes) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, torch.nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, torch.nn.Linear) and m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - pass - - def forward(self, input): - hidden_states = self.backbone(input) - hidden_states = self.head(hidden_states) - - return hidden_states diff --git a/megatron/model/vision/dino.py b/megatron/model/vision/dino.py deleted file mode 100644 index 151ec26647453cf139616ed2eeaf899c47cd280c..0000000000000000000000000000000000000000 --- a/megatron/model/vision/dino.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the Apache license found in the -# LICENSE file in the root directory of this source tree. - -# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py -# reworked/refactored some parts to make it run in Megatron. -import math -import apex -import einops -import torch -import numpy as np -import torch.nn.functional as F -from torch.nn.init import trunc_normal_ -from megatron import get_args, print_rank_0 -from megatron.model.utils import get_linear_layer -from megatron.model.vision.vit_backbone import VitBackbone -from megatron.model.module import MegatronModule -from megatron.model.vision.mit_backbone import mit_b5_avg -from megatron.model.vision.esvit_swin_backbone import get_swin - - -class DINOLoss(torch.nn.Module): - def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, - warmup_teacher_temp_epochs, nepochs, student_temp=0.1, - center_momentum=0.9): - super().__init__() - self.student_temp = student_temp - self.center_momentum = center_momentum - self.ncrops = ncrops - self.register_buffer("center", torch.zeros(1, out_dim)) - # we apply a warm up for the teacher temperature because - # a too high temperature makes the training instable at the beginning - self.teacher_temp_schedule = np.concatenate(( - np.linspace(warmup_teacher_temp, - teacher_temp, warmup_teacher_temp_epochs), - np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp - )) - self.teacher_temp = teacher_temp - - def forward(self, student_output, teacher_output, iteration): - """ - Cross-entropy between softmax outputs of the teacher - and student network. - """ - args = get_args() - student_out = student_output / self.student_temp - student_out = student_out.chunk(self.ncrops) - - epoch = iteration // args.iter_per_epoch - - # teacher centering and sharpening - temp = self.teacher_temp_schedule[epoch] - teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) - - teacher_out = teacher_out.detach().chunk(2) - - total_loss = 0 - n_loss_terms = 0 - for iq, q in enumerate(teacher_out): - for v in range(len(student_out)): - if v == iq: - # we skip cases where student and teacher operate on the same view - continue - loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) - total_loss += loss.mean() - n_loss_terms += 1 - total_loss /= n_loss_terms - self.update_center(teacher_output) - return total_loss - - @torch.no_grad() - def update_center(self, teacher_output): - """ - Update center used for teacher output. - """ - batch_center = torch.sum(teacher_output, dim=0, keepdim=True) - torch.distributed.all_reduce(batch_center) - batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size()) - self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) - -class DINOHead(torch.nn.Module): - def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3): - super().__init__() - args = get_args() - hidden_dim = args.dino_head_hidden_size - bottleneck_dim = args.dino_bottleneck_size - nlayers = max(nlayers, 1) - if nlayers == 1: - self.mlp = torch.nn.Linear(in_dim, bottleneck_dim) - else: - layers = [torch.nn.Linear(in_dim, hidden_dim)] - layers.append(torch.nn.GELU()) - for _ in range(nlayers - 2): - layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) - layers.append(torch.nn.GELU()) - layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim)) - self.mlp = torch.nn.Sequential(*layers) - self.apply(self._init_weights) - self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False)) - self.last_layer.weight_g.data.fill_(1) - if norm_last_layer: - self.last_layer.weight_g.requires_grad = False - - def _init_weights(self, m): - if isinstance(m, torch.nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, torch.nn.Linear) and m.bias is not None: - torch.nn.init.constant_(m.bias, 0) - - def forward(self, x): - x = self.mlp(x) - x = torch.nn.functional.normalize(x, dim=-1, p=2) - x = self.last_layer(x) - return x - - -class MultiCropWrapper(MegatronModule): - - """ - Perform forward pass separately on each resolution input. - The inputs corresponding to a single resolution are clubbed and single - forward is run on the same resolution inputs. Hence we do several - forward passes = number of different resolutions used. We then - concatenate all the output features and run the head forward on these - concatenated features. - """ - def __init__(self, backbone, head): - super(MultiCropWrapper, self).__init__() - # disable layers dedicated to ImageNet labels classification - #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity() - self.backbone = backbone - self.head = head - - def forward(self, x): - # convert to list - if not isinstance(x, list): - x = [x] - idx_crops = torch.cumsum(torch.unique_consecutive( - torch.tensor([inp.shape[-1] for inp in x]), - return_counts=True, - )[1], 0) - - start_idx = 0 - for end_idx in idx_crops: - _out = self.backbone(torch.cat(x[start_idx: end_idx])) - if start_idx == 0: - output = _out - else: - output = torch.cat((output, _out)) - start_idx = end_idx - # Run the head forward on the concatenated features. - if self.training: - return self.head(output) - else: - return output - - -def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, - warmup_epochs=0, start_warmup_value=0): - warmup_schedule = np.array([]) - warmup_iters = warmup_epochs * niter_per_ep - if warmup_epochs > 0: - warmup_schedule = \ - np.linspace(start_warmup_value, base_value, warmup_iters) - - iters = np.arange(epochs * niter_per_ep - warmup_iters) - schedule = final_value + 0.5 * (base_value - final_value) \ - * (1 + np.cos(np.pi * iters / len(iters))) - - schedule = np.concatenate((warmup_schedule, schedule)) - assert len(schedule) == epochs * niter_per_ep - return schedule - - -def get_student_backbone_and_num_features(config, pre_process=True, post_process=True): - args = get_args() - - if args.vision_backbone_type == 'vit': - student = VitBackbone(config, - pre_process=pre_process, - post_process=post_process, - drop_path_rate=0.1, - single_token_output=True) - num_features = args.hidden_size - elif args.vision_backbone_type == 'mit': - student = mit_b5_avg(drop_path_rate=0.1) - num_features = 512 - elif args.vision_backbone_type == 'swin': - student = get_swin() - num_features = student.num_features - else: - raise Exception('{} vision backbone is not supported.'.format( - args.vision_backbone_type)) - - return student, num_features - -def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True): - args = get_args() - - if args.vision_backbone_type == 'vit': - teacher = VitBackbone(config, - pre_process=pre_process, - post_process=post_process, - single_token_output=True) - num_features = args.hidden_size - elif args.vision_backbone_type == 'mit': - teacher = mit_b5_avg(drop_path_rate=0.0) - num_features = 512 - elif args.vision_backbone_type == 'swin': - teacher = get_swin(is_teacher=True) - num_features = teacher.num_features - else: - raise Exception('{} vision backbone is not supported.'.format( - args.vision_backbone_type)) - return teacher, num_features - - -class DINOPretrainModel(MegatronModule): - def __init__(self, config, pre_process=True, post_process=True): - super(DINOPretrainModel, self).__init__() - args = get_args() - self.config = config - self.out_dim = 65536 - - self.dino_loss = DINOLoss( - self.out_dim, - args.dino_local_crops_number + 2, - args.dino_warmup_teacher_temp, - args.dino_teacher_temp, - args.dino_warmup_teacher_temp_epochs, - 300, - ) - - self.pre_process = pre_process - self.post_process = post_process - self.momentum_teacher = 0.996 - - student_backbone, num_features = \ - get_student_backbone_and_num_features(config, pre_process, post_process) - - self.student = MultiCropWrapper( - student_backbone, - DINOHead(num_features, self.out_dim, - norm_last_layer=args.dino_norm_last_layer) - ) - - self.momentum_schedule = cosine_scheduler( - self.momentum_teacher, 1, - args.train_iters // args.iter_per_epoch, - args.iter_per_epoch - ) - - teacher_backbone, num_features = \ - get_teacher_backbone_and_num_features(config, pre_process, post_process) - self.teacher = MultiCropWrapper( - teacher_backbone, - DINOHead(num_features, self.out_dim) - ) - self.teacher.load_state_dict(self.student.state_dict()) - - for p in self.teacher.parameters(): - if hasattr(p, "requires_grad") and p.requires_grad is not None: - p.requires_grad = False - - def set_input_tensor(self, tensor): - pass - - def forward(self, input): - student_output = None - if self.training: - student_output = self.student(input) - teacher_output = self.teacher(input[:2]) - else: - teacher_output = self.teacher(input) - return student_output, teacher_output - - def cancel_gradients_last_layer(self, iteration): - args = get_args() - epoch = iteration // args.iter_per_epoch - if epoch < args.dino_freeze_last_layer: - for n, p in self.student.named_parameters(): - if "last_layer" in n: - p.grad = None - - def update_momentum(self, iteration): - with torch.no_grad(): - m = self.momentum_schedule[iteration] - for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): - param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) - diff --git a/megatron/model/vision/esvit_swin_backbone.py b/megatron/model/vision/esvit_swin_backbone.py deleted file mode 100644 index 70aee3db429bf63680b7c19b4d5acfe25ee2edf5..0000000000000000000000000000000000000000 --- a/megatron/model/vision/esvit_swin_backbone.py +++ /dev/null @@ -1,849 +0,0 @@ -# Copyright (c) 2021 Microsoft -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# Modified by Chunyuan Li (chunyl@microsoft.com) -# Swin Transformer -# -------------------------------------------------------- - -import os -import logging -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -import torch.distributed as dist -from torch.nn.init import trunc_normal_ -from megatron.model.transformer import DropPath -from megatron import get_args -from megatron.model import LayerNorm -import numpy as np -from math import sqrt - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, - out_features=None, act_layer=nn.GELU, drop=0.): - super(Mlp, self).__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r"""Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super(WindowAttention, self).__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type()) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn_out = attn - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x, attn_out - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - @staticmethod - def compute_macs(module, input, output): - B, N, C = input[0].shape - - module.__flops__ += module.flops(N) * B - - -class SwinTransformerBlock(nn.Module): - r"""Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.H = input_resolution[0] - self.W = input_resolution[1] - - self.attn_mask_dict = {} - - - def create_attn_mask(self, H, W): - # calculate attention mask for SW-MSA - - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - - def forward(self, x): - B, L, C = x.shape - H = int(sqrt(L)) - W = H - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # pad feature maps to multiples of window size - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - - if H in self.attn_mask_dict.keys(): - attn_mask = self.attn_mask_dict[H] - else: - self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device) - attn_mask = self.attn_mask_dict[H] - - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x, attn - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r"""Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ Forward function. - Args: - x: Input feature, tensor size (B, H*W, C). - H, W: Spatial resolution of the input feature. - """ - B, L, C = x.shape - H = int(sqrt(L)) - W = H - - x = x.view(B, H, W, C) - - # padding - pad_input = (H % 2 == 1) or (W % 2 == 1) - if pad_input: - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - x, _ = blk(x) - if self.downsample is not None: - x = self.downsample(x) - return x - - def forward_with_features(self, x): - fea = [] - for blk in self.blocks: - x, _ = blk(x) - fea.append(x) - if self.downsample is not None: - x = self.downsample(x) - return x, fea - - def forward_with_attention(self, x): - attns = [] - for blk in self.blocks: - x, attn = blk(x) - attns.append(attn) - if self.downsample is not None: - x = self.downsample(x) - return x, attns - - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - Args: - img_size (int | tuple(int)): Input image size. - patch_size (int | tuple(int)): Patch size. - in_chans (int): Number of input channels. - num_classes (int): Number of classes for classification head. - embed_dim (int): Embedding dimension. - depths (tuple(int)): Depth of Swin Transformer layers. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - drop_rate (float): Dropout rate. - attn_drop_rate (float): Attention dropout rate. - drop_path_rate (float): Stochastic depth rate. - norm_layer (nn.Module): normalization layer. - ape (bool): If True, add absolute position embedding to the patch embedding. - patch_norm (bool): If True, add normalization after patch embedding. - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs): - super().__init__() - - self.num_classes = num_classes - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) - self.layers.append(layer) - - self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - # todo: to be implemented - return {'relative_position_bias_table'} - - def forward(self, x): - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x_region = self.norm(x) # B L C - x = self.avgpool(x_region.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - - return x - - - def forward_feature_maps(self, x): - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x_grid = self.norm(x) # B L C - x = self.avgpool(x_grid.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) - - return x, x_grid - - - def forward_selfattention(self, x, n=1): - # n=1 return the last layer attn map; otherwise return attn maps in all layers - - - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - if n==1: - return self.forward_last_selfattention(x) - else: - return self.forward_all_selfattention(x) - - def forward_last_selfattention(self, x): - - for i, layer in enumerate(self.layers): - if i < len(self.layers) - 1: - x = layer(x) - else: - x, attns = layer.forward_with_attention(x) - return attns[-1] - - def forward_all_selfattention(self, x): - attn_out = [] - - for layer in self.layers: - x, attns = layer.forward_with_attention(x) - attn_out += attns - - return attn_out - - - def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]): - - num_blks = sum(depth) - start_idx = num_blks - n - - sum_cur = 0 - for i, d in enumerate(depth): - sum_cur_new = sum_cur + d - if start_idx >= sum_cur and start_idx < sum_cur_new: - start_stage = i - start_blk = start_idx - sum_cur - sum_cur = sum_cur_new - - - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - # we will return the averaged token features from the `n` last blocks - # note: there is no [CLS] token in Swin Transformer - output = [] - s = 0 - for i, layer in enumerate(self.layers): - x, fea = layer.forward_with_features(x) - - if i >= start_stage: - for x_ in fea[start_blk:]: - - if i == len(self.layers)-1: # use the norm in the last stage - x_ = self.norm(x_) - - x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C - # print(f'Stage {i}, x_avg {x_avg.shape}') - output.append(x_avg) - - start_blk = 0 - - return torch.cat(output, dim=-1) - - - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - if dist.get_rank() == 0: - print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}") - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): - if os.path.isfile(pretrained): - pretrained_dict = torch.load(pretrained, map_location='cpu') - logging.info(f'=> loading pretrained model {pretrained}') - model_dict = self.state_dict() - pretrained_dict = { - k: v for k, v in pretrained_dict.items() - if k in model_dict.keys() - } - need_init_state_dict = {} - for k, v in pretrained_dict.items(): - need_init = ( - k.split('.')[0] in pretrained_layers - or pretrained_layers[0] is '*' - or 'relative_position_index' not in k - or 'attn_mask' not in k - ) - - if need_init: - if verbose: - logging.info(f'=> init {k} from {pretrained}') - - if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): - relative_position_bias_table_pretrained = v - relative_position_bias_table_current = model_dict[k] - L1, nH1 = relative_position_bias_table_pretrained.size() - L2, nH2 = relative_position_bias_table_current.size() - if nH1 != nH2: - logging.info(f"Error in loading {k}, passing") - else: - if L1 != L2: - logging.info( - '=> load_pretrained: resized variant: {} to {}' - .format((L1, nH1), (L2, nH2)) - ) - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), - size=(S2, S2), - mode='bicubic') - v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) - - if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): - absolute_pos_embed_pretrained = v - absolute_pos_embed_current = model_dict[k] - _, L1, C1 = absolute_pos_embed_pretrained.size() - _, L2, C2 = absolute_pos_embed_current.size() - if C1 != C1: - logging.info(f"Error in loading {k}, passing") - else: - if L1 != L2: - logging.info( - '=> load_pretrained: resized variant: {} to {}' - .format((1, L1, C1), (1, L2, C2)) - ) - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) - absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( - absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') - v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) - - need_init_state_dict[k] = v - self.load_state_dict(need_init_state_dict, strict=False) - - def freeze_pretrained_layers(self, frozen_layers=[]): - for name, module in self.named_modules(): - if ( - name.split('.')[0] in frozen_layers - or '.'.join(name.split('.')[0:2]) in frozen_layers - or (len(frozen_layers) > 0 and frozen_layers[0] is '*') - ): - for _name, param in module.named_parameters(): - param.requires_grad = False - logging.info( - '=> set param {} requires grad to False' - .format(name) - ) - for name, param in self.named_parameters(): - if ( - name.split('.')[0] in frozen_layers - or (len(frozen_layers) > 0 and frozen_layers[0] is '*') - and param.requires_grad is True - ): - param.requires_grad = False - logging.info( - '=> set param {} requires grad to False' - .format(name) - ) - return self - - -def get_swin(is_teacher=False): - args = get_args() - - if args.swin_backbone_type == "tiny": - embed_dim = 96 - depths = [2, 2, 6, 2] - num_heads = [3, 6, 12, 24] - drop_path_rate = 0.1 - elif args.swin_backbone_type == 'h3': - embed_dim = 384 - depths = [2, 2, 18, 2] - num_heads = [6, 12, 24, 48] - drop_path_rate = 0.2 - else: - embed_dim = 128 - depths = [2, 2, 18, 2] - num_heads = [4, 8, 16, 32] - drop_path_rate = 0.2 - - swin = SwinTransformer( - img_size=224, - in_chans=3, - num_classes=1000, - patch_size=4, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - window_size=7, - mlp_ratio=4, - qkv_bias=True, - drop_rate=0, - attn_drop_rate=0, - drop_path_rate=(0.0 if is_teacher else drop_path_rate), - norm_layer=partial(LayerNorm, eps=1e-6), - ape=False, - patch_norm=True, - ) - - return swin - diff --git a/megatron/model/vision/inpainting.py b/megatron/model/vision/inpainting.py deleted file mode 100644 index 6aae9658bc86a110d20feb6f9026d0fb8c2b8f1e..0000000000000000000000000000000000000000 --- a/megatron/model/vision/inpainting.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -import apex -import einops -import torch -import torch.nn.functional as F -from megatron import get_args, print_rank_0 -from megatron.model.utils import get_linear_layer -from megatron.model.vision.vit_backbone import VitBackbone -from megatron.model.module import MegatronModule -from megatron.model.vision.mit_backbone import mit_b3 -from megatron.model.vision.utils import resize - - -class VitInpaintingModel(MegatronModule): - - def __init__(self, config, pre_process=True, post_process=True): - super(VitInpaintingModel, self).__init__() - args = get_args() - - self.config = config - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = config.hidden_size - self.backbone = VitBackbone( - config=config, - pre_process=self.pre_process, - post_process=self.post_process, - class_token=False, - ) - self.patch_dim = args.patch_dim - self.img_h = args.img_h - self.img_w = args.img_w - self.seq_length = args.seq_length - # full mask - - if self.post_process: - self.linear_decoder = get_linear_layer( - self.hidden_size, - self.backbone.flatten_dim, - torch.nn.init.zeros_ - ) - - def set_input_tensor(self, input_tensor): - self.backbone.set_input_tensor(input_tensor) - - def forward(self, input): - - hidden_states = self.backbone(input) - - if not self.post_process: - return hidden_states - decoded_output = self.linear_decoder(hidden_states) - output = einops.rearrange( - decoded_output, - "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", - p1=self.patch_dim, - p2=self.patch_dim, - h=self.img_h//self.patch_dim, - w=self.img_w//self.patch_dim, - ) - - return output - - -class MLP(torch.nn.Module): - """ - Linear Embedding - """ - def __init__(self, input_dim=2048, embed_dim=768): - super().__init__() - self.proj = torch.nn.Linear(input_dim, embed_dim) - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) - return x - - -class MitInpaintingModel(MegatronModule): - """Mix vision Transformer Model.""" - - def __init__(self, pre_process=True, post_process=True): - super(MitInpaintingModel, self).__init__() - self.pre_process = pre_process - self.post_process = post_process - - args = get_args() - self.patch_dim = args.patch_dim - self.img_h = args.img_h - self.img_w = args.img_w - self.flatten_dim = self.patch_dim * self.patch_dim * 3 - self.backbone = mit_b3() - - self.in_channels = [64, 128, 320, 512] - self.embedding_dim = 768 - - c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels - - self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) - self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) - self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) - self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) - - self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False) - self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim) - self.dropout = torch.nn.Dropout2d(0.1) - - self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - pass - - def forward(self, input): - c1, c2, c3, c4 = self.backbone(input) - - n, _, h, w = c4.shape - _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) - _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) - - _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) - _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) - - _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) - _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) - - _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) - - _c = torch.cat([_c4, _c3, _c2, _c1], dim=1) - _c = self.conv_fuse(_c) - - x = self.norm(_c) - x = F.relu(x, inplace=True) - x = self.dropout(x) - - x = self.linear_pred(x) - - output = einops.rearrange( - x, - "b (c p1 p2) h w -> b c (h p1) (w p2)", - p1=self.patch_dim, - p2=self.patch_dim, - h=self.img_h//self.patch_dim, - w=self.img_w//self.patch_dim, - ) - - return output diff --git a/megatron/model/vision/knn_monitor.py b/megatron/model/vision/knn_monitor.py deleted file mode 100644 index a7d79854eb51f24443fcdce9e029df5512503bf9..0000000000000000000000000000000000000000 --- a/megatron/model/vision/knn_monitor.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch.nn.functional as F -import torch -from megatron import print_rank_0, get_args -from megatron.core import mpu -from megatron.data.vit_dataset import ClassificationTransform -from megatron.data.image_folder import ImageFolder - -_FEATURE_BANK = None - - -def build_data_loader(dataset, drop_last=True, shuffle=False): - """Data loader. Note that batch-size is the local (per GPU) batch-size.""" - # Sampler. - args = get_args() - micro_batch_size = 16 - num_workers = args.num_workers - world_size = mpu.get_data_parallel_world_size() - rank = mpu.get_data_parallel_rank() - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=world_size, rank=rank, - drop_last=drop_last, shuffle=shuffle - ) - - # Data loader. Note that batch size is the per GPU batch size. - data_loader = torch.utils.data.DataLoader( - dataset, - batch_size=micro_batch_size, - sampler=sampler, - shuffle=False, - num_workers=num_workers, - drop_last=not drop_last, - pin_memory=True, - ) - return data_loader - - -def compute_feature_bank(model): - args = get_args() - global _FEATURE_BANK - feature_bank = [] - feature_label = [] - - train_ds = ImageFolder( - root=args.data_path[0], - transform=ClassificationTransform((args.img_h, args.img_w), train=False), - data_per_class_fraction=1.0 - ) - classes = len(train_ds.classes) - dataloader = build_data_loader(train_ds) - - for m in model: - m.eval() - - with torch.no_grad(): - for i, batch in enumerate(dataloader): - images = batch[0].cuda().contiguous() - labels = batch[1].cuda().contiguous() - student_feature, teacher_feature = model[0](images) - feature = F.normalize(teacher_feature.float(), dim=1) - feature_bank.append(feature) - feature_label.append(labels) - - for m in model: - m.train() - - # [N', D] - feature_bank = torch.cat(feature_bank, dim=0).contiguous() - feature_label = torch.cat(feature_label, dim=0).contiguous() - - feature_banks = [torch.zeros_like(feature_bank) - for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather(feature_banks, - feature_bank, - group=mpu.get_data_parallel_group()) - - assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], - feature_bank)) - - feature_labels = [torch.zeros_like(feature_label) - for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather(feature_labels, - feature_label, - group=mpu.get_data_parallel_group()) - - # [D, N] - feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() - # [N] - feature_labels = torch.cat(feature_labels, dim=0).contiguous() - print_rank_0("feature_banks size is {}".format(feature_banks.size())) - print_rank_0("feature labels size is {}".format(feature_labels.size())) - - _FEATURE_BANK = (feature_banks, feature_labels, classes) - - -def get_feature_bank(): - global _FEATURE_BANK - assert _FEATURE_BANK is not None - return _FEATURE_BANK - - -# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 -# implementation follows http://github.com/zhirongw/lemniscate.pytorch and -# https://github.com/leftthomas/SimCLR -def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): - # compute cos similarity between each feature vector and feature bank ---> [B, N] - sim_matrix = torch.mm(feature, feature_bank) - # [B, K] - sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) - # [B, K] - sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), - dim=-1, - index=sim_indices) - sim_weight = (sim_weight / knn_t).exp() - - # counts for each class - one_hot_label = torch.zeros(feature.size(0) * knn_k, - classes, - device=sim_labels.device) - # [B*K, C] - one_hot_label = one_hot_label.scatter(dim=-1, - index=sim_labels.view(-1, 1), - value=1.0) - # weighted score ---> [B, C] - pred_scores = torch.sum( - one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), - dim=1) - - pred_labels = pred_scores.argsort(dim=-1, descending=True) - return pred_labels diff --git a/megatron/model/vision/mit_backbone.py b/megatron/model/vision/mit_backbone.py deleted file mode 100644 index 6640b105dfce36169b51dd93351469000af4fdbf..0000000000000000000000000000000000000000 --- a/megatron/model/vision/mit_backbone.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright (c) 2023, NVIDIA Corporation. All rights reserved. - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from torch.nn.init import trunc_normal_ -from megatron.model.transformer import DropPath -from megatron.model import LayerNorm - - -class Mlp(nn.Module): - def __init__(self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.dwconv = DWConv(hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = self.fc1(x) - x = self.dwconv(x, H, W) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0., - sr_ratio=1): - super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." - - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = LayerNorm(dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - - return x - - -class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = LayerNorm(embed_dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - - return x, H, W - - -class MixVisionTransformer(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False): - super().__init__() - self.num_classes = num_classes - self.depths = depths - self.output_avg = output_avg - - # patch_embed - self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, - embed_dim=embed_dims[0]) - self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], - embed_dim=embed_dims[1]) - self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], - embed_dim=embed_dims[2]) - self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], - embed_dim=embed_dims[3]) - - # transformer encoder - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - cur = 0 - self.block1 = nn.ModuleList([Block( - dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[0]) - for i in range(depths[0])]) - self.norm1 = norm_layer(embed_dims[0]) - - cur += depths[0] - self.block2 = nn.ModuleList([Block( - dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[1]) - for i in range(depths[1])]) - self.norm2 = norm_layer(embed_dims[1]) - - cur += depths[1] - self.block3 = nn.ModuleList([Block( - dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[2]) - for i in range(depths[2])]) - self.norm3 = norm_layer(embed_dims[2]) - - cur += depths[2] - self.block4 = nn.ModuleList([Block( - dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[3]) - for i in range(depths[3])]) - self.norm4 = norm_layer(embed_dims[3]) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - fan_out //= m.groups - m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - m.bias.data.zero_() - - def reset_drop_path(self, drop_path_rate): - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] - cur = 0 - for i in range(self.depths[0]): - self.block1[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[0] - for i in range(self.depths[1]): - self.block2[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[1] - for i in range(self.depths[2]): - self.block3[i].drop_path.drop_prob = dpr[cur + i] - - cur += self.depths[2] - for i in range(self.depths[3]): - self.block4[i].drop_path.drop_prob = dpr[cur + i] - - def freeze_patch_emb(self): - self.patch_embed1.requires_grad = False - - def forward_features(self, x): - B = x.shape[0] - outs = [] - - # stage 1 - x, H, W = self.patch_embed1(x) - for i, blk in enumerate(self.block1): - x = blk(x, H, W) - x = self.norm1(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 2 - x, H, W = self.patch_embed2(x) - for i, blk in enumerate(self.block2): - x = blk(x, H, W) - x = self.norm2(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 3 - x, H, W = self.patch_embed3(x) - for i, blk in enumerate(self.block3): - x = blk(x, H, W) - x = self.norm3(x) - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - # stage 4 - x, H, W = self.patch_embed4(x) - for i, blk in enumerate(self.block4): - x = blk(x, H, W) - x = self.norm4(x) - if not self.output_avg: - x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() - outs.append(x) - - return outs - - def forward(self, x): - x = self.forward_features(x) - - if self.output_avg: - x = x[3].mean(dim=1) - - return x - - -class DWConv(nn.Module): - def __init__(self, dim=768): - super(DWConv, self).__init__() - self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - - def forward(self, x, H, W): - B, N, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) - x = self.dwconv(x) - x = x.flatten(2).transpose(1, 2) - - return x - -class mit_b0(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b0, self).__init__( - patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -class mit_b1(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b1, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -class mit_b2(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b2, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - - -class mit_b3(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b3, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - -class mit_b3_avg(MixVisionTransformer): - def __init__(self, drop_path_rate=0.1, **kwargs): - super(mit_b3_avg, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) - -class mit_b4(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b4, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - -class mit_b5(MixVisionTransformer): - def __init__(self, **kwargs): - super(mit_b5, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) - -class mit_b5_avg(MixVisionTransformer): - def __init__(self, drop_path_rate=0.1, **kwargs): - super(mit_b5_avg, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True) - diff --git a/megatron/model/vision/swin_backbone.py b/megatron/model/vision/swin_backbone.py deleted file mode 100644 index 9a622c7070f5a8335b56ce01398e3e464e3fef6a..0000000000000000000000000000000000000000 --- a/megatron/model/vision/swin_backbone.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright (c) 2021 Microsoft -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# Swin Transformer -# -------------------------------------------------------- - -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from math import sqrt - -from megatron import get_args -from functools import partial - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, - out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - self.H = input_resolution[0] - self.W = input_resolution[1] - - self.attn_mask_dict = {} - - def create_attn_mask(self, H, W): - # calculate attention mask for SW-MSA - - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - - def forward(self, x): - B, L, C = x.shape - H = int(sqrt(L)) - W = H - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - x_b4_ds = x - if self.downsample is not None: - x = self.downsample(x) - return x_b4_ds, x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - - -class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3, - norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True, - use_checkpoint=False, output_avg=False, **kwargs): - super().__init__() - - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - self.mlp_ratio = mlp_ratio - self.img_size = to_2tuple(img_size) - self.patch_size = to_2tuple(patch_size) - self.output_avg = output_avg - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) - self.layers.append(layer) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def forward(self, x): - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - h = self.img_size[0] // self.patch_size[0] - w = self.img_size[1] // self.patch_size[1] - outs = [] - - for i, layer in enumerate(self.layers): - px, x = layer(x) - b, n, c = px.shape - - if i != len(self.layers) - 1 or not self.output_avg: - px = px.permute(0, 2, 1).contiguous() - px = px.reshape(b, c, h, w) - # is this a fair assumption ?? i think it's baked into the architecture - h, w = h//2, w//2 - outs.append(px) - - if self.output_avg: - return outs[-1].mean(dim=1) - - return outs - - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) - flops += self.num_features * self.num_classes - return flops - - -def get_swin(drop_path_rate=0.3, output_avg=False): - args = get_args() - - window_size = 7 - embed_dim = 128 - depths = [2, 2, 18, 2] - num_heads = [4, 8, 16, 32] - swin = SwinTransformer( - img_size=(args.img_h, args.img_w,), - in_chans=3, - patch_size=args.patch_dim, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - window_size=window_size, - drop_path_rate=drop_path_rate, - output_avg=output_avg, - ) - - return swin - diff --git a/megatron/model/vision/utils.py b/megatron/model/vision/utils.py deleted file mode 100644 index b4068912c8bb234eff54d6b4feae499f7e8ab30c..0000000000000000000000000000000000000000 --- a/megatron/model/vision/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import warnings -import torch -import torch.nn.functional as F - - -def resize(input, - size=None, - scale_factor=None, - mode='nearest', - align_corners=None, - warning=True): - if warning: - if size is not None and align_corners: - input_h, input_w = tuple(int(x) for x in input.shape[2:]) - output_h, output_w = tuple(int(x) for x in size) - if output_h > input_h or output_w > output_h: - if ((output_h > 1 and output_w > 1 and input_h > 1 - and input_w > 1) and (output_h - 1) % (input_h - 1) - and (output_w - 1) % (input_w - 1)): - warnings.warn( - f'When align_corners={align_corners}, ' - 'the output would more aligned if ' - f'input size {(input_h, input_w)} is `x+1` and ' - f'out size {(output_h, output_w)} is `nx+1`') - if isinstance(size, torch.Size): - size = tuple(int(x) for x in size) - return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/megatron/model/vision/vit_backbone.py b/megatron/model/vision/vit_backbone.py deleted file mode 100644 index 15cf75affcd7cd111d5f2663d707a0026ff708a3..0000000000000000000000000000000000000000 --- a/megatron/model/vision/vit_backbone.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Vision Transformer(VIT) model.""" - -import math -import einops -import torch -import apex -import torch.nn.functional as F -from megatron import get_args -from megatron.model.transformer import ParallelTransformer -from megatron.model.utils import ( - get_linear_layer, - init_method_normal, - scaled_init_method_normal, -) -from megatron.model.module import MegatronModule - -CLASS_TOKEN_LENGTH = 8 - -class VitMlpHead(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Arguments: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, config, hidden_size, num_classes): - super(VitMlpHead, self).__init__() - self.config = config - self.dense_in = torch.nn.Linear(hidden_size, hidden_size) - self.relu = torch.nn.ReLU() - self.dense_out = torch.nn.Linear(hidden_size, num_classes) - torch.nn.init.constant_(self.dense_out.bias, -10) - - def forward(self, hidden_states): - # hidden_states: [b, 1, h] - # sequence_index: index of the token to pool. - dense_in_result = self.dense_in(hidden_states) - tanh_result = torch.tanh(dense_in_result) - dense_out_result = self.dense_out(tanh_result) - return dense_out_result - - -def isPerfectSquare(x): - if(x >= 0): - sr = math.sqrt(x) - return (int(sr) * int(sr) == x) - return False - - -def twod_interpolate_position_embeddings_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): - - args = get_args() - num_patches_per_dim_h = args.img_h // args.patch_dim - num_patches_per_dim_w = args.img_w // args.patch_dim - num_patches = num_patches_per_dim_h * num_patches_per_dim_w - hidden_size = args.hidden_size - - key = prefix + "weight" - - assert key in state_dict - if key in state_dict: - input_param = state_dict[key] - - input_seq_len = input_param.shape[0] - assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH)) - input_has_class_token = not isPerfectSquare(input_seq_len) - num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len - num_tok_output = num_patches - output_has_class_token = args.class_token_present - - # update input_param and load it to state_dict[key] - if input_has_class_token: - input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :] - input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :] - else: - input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size) - input_param_grid = input_param - - assert input_param.shape[1] == hidden_size - - if num_tok_input != num_tok_output: - - gs_input = int(math.sqrt(num_tok_input)) - gs_new = (num_patches_per_dim_h, num_patches_per_dim_w) - - input_param_grid = input_param_grid.transpose(0, 1).contiguous() - input_param_grid = input_param_grid.reshape( - (1, -1, gs_input, gs_input) - ) - input_param_grid = input_param_grid.float() - scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input) - - input_param_grid = F.interpolate( - input_param_grid, scale_factor=scale_factor, mode="bilinear" - ) - - input_param_grid = input_param_grid.half() - input_param_grid = input_param_grid.reshape((-1, num_tok_output)) - input_param_grid = input_param_grid.transpose(0, 1).contiguous() - - assert input_param_grid.shape[1] == hidden_size - - input_param = input_param_grid - assert ( - input_param.shape[0] == num_tok_output - and input_param.shape[1] == hidden_size - ) - - if output_has_class_token: - input_param = torch.cat((input_param_tok, input_param), dim=0) - - state_dict[key] = input_param - - -class VitBackbone(MegatronModule): - """Vision Transformer Model.""" - - def __init__(self, - config, - pre_process=True, - post_process=True, - class_token=True, - single_token_output=False, - post_layer_norm=True, - drop_path_rate=0.0): - super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False) - args = get_args() - self.config = config - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - - self.pre_process = pre_process - self.post_process = post_process - self.class_token = class_token - self.post_layer_norm = post_layer_norm - self.hidden_size = args.hidden_size - self.patch_dim = args.patch_dim - self.img_h = args.img_h - self.img_w = args.img_w - self.micro_batch_size = args.micro_batch_size - self.single_token_output = single_token_output - self.drop_path_rate = drop_path_rate - - assert self.img_h % self.patch_dim == 0 - assert self.img_w % self.patch_dim == 0 - self.num_patches_per_dim_h = self.img_h // self.patch_dim - self.num_patches_per_dim_w = self.img_w // self.patch_dim - self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w - self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) - self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels - self.input_tensor = None - self.position_ids = None - - if self.pre_process: - # cls_token - if self.class_token: - self.cls_token = torch.nn.Parameter( - torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size) - ) - torch.nn.init.zeros_(self.cls_token) - self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() - - # Linear encoder - self.linear_encoder = torch.nn.Linear( - self.flatten_dim, self.hidden_size - ) - - # embedding - self.position_embeddings = torch.nn.Embedding( - self.seq_length, self.hidden_size - ) - init_method_normal(args.init_method_std)( - self.position_embeddings.weight - ) - - args.class_token_present = self.class_token - self.position_embeddings._register_load_state_dict_pre_hook( - twod_interpolate_position_embeddings_hook - ) - - self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) - - # Transformer - self.transformer = ParallelTransformer( - config, - model_type=args.model_type, - pre_process=self.pre_process, - post_process=self.post_process, - post_layer_norm=self.post_layer_norm, - drop_path_rate=self.drop_path_rate - ) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.transformer.set_input_tensor(input_tensor) - - def forward(self, input): - - if self.pre_process: - rearranged_input = einops.rearrange( - input, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim, - p2=self.patch_dim, - ) - - assert rearranged_input.dtype == torch.half - encoder_output = self.linear_encoder(rearranged_input) - - concatenated_tokens = encoder_output - if self.class_token: - cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) - concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) - - token_embeddings = concatenated_tokens + \ - self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) - # [b, s, h] => [s, b, h] - token_embeddings = token_embeddings.transpose(0, 1).contiguous() - hidden_states = self.embedding_dropout(token_embeddings) - else: - hidden_states = input - - hidden_states = self.transformer(hidden_states, None) - - if self.post_process: - # [s b h] => [b s h] - if self.single_token_output: - hidden_states = hidden_states[0] - else: - hidden_states = hidden_states.transpose(0, 1).contiguous() - - return hidden_states - diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/mpu/tests/__init__.py b/megatron/mpu/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/megatron/mpu/tests/commons.py b/megatron/mpu/tests/commons.py deleted file mode 100644 index 611daf0f66692426ee5ad59824f3c421d7b94a90..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/commons.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -import argparse -import os -import random -import numpy -import torch - -import mpu - - -class IdentityLayer(torch.nn.Module): - def __init__(self, size, scale=1.0): - super(IdentityLayer, self).__init__() - self.weight = torch.nn.Parameter(scale * torch.randn(size)) - - def forward(self): - return self.weight - - -def set_random_seed(seed): - """Set random seed for reproducability.""" - random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) - mpu.model_parallel_cuda_manual_seed(seed) - - -def initialize_distributed(backend='nccl'): - """Initialize torch.distributed.""" - # Get local rank in case it is provided. - parser = argparse.ArgumentParser() - parser.add_argument('--local_rank', type=int, default=None, - help='local rank passed from distributed launcher') - args = parser.parse_args() - local_rank = args.local_rank - - # Get rank and world size. - rank = int(os.getenv('RANK', '0')) - world_size = int(os.getenv("WORLD_SIZE", '1')) - - print('> initializing torch.distributed with local rank: {}, ' - 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) - - # Set the device id. - device = rank % torch.cuda.device_count() - if local_rank is not None: - device = local_rank - torch.cuda.set_device(device) - - # Call the init process. - init_method = 'tcp://' - master_ip = os.getenv('MASTER_ADDR', 'localhost') - master_port = os.getenv('MASTER_PORT', '6000') - init_method += master_ip + ':' + master_port - torch.distributed.init_process_group( - backend=backend, - world_size=world_size, - rank=rank, - init_method=init_method) - - -def print_separator(message): - torch.distributed.barrier() - filler_len = (78 - len(message)) // 2 - filler = '-' * filler_len - string = '\n' + filler + ' {} '.format(message) + filler - if torch.distributed.get_rank() == 0: - print(string, flush=True) - torch.distributed.barrier() diff --git a/megatron/mpu/tests/test_cross_entropy.py b/megatron/mpu/tests/test_cross_entropy.py deleted file mode 100644 index 00ae42228a9259e12640034a911899b6386882bc..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/test_cross_entropy.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import set_random_seed -from commons import IdentityLayer -from commons import print_separator -from commons import initialize_distributed -from mpu.cross_entropy import vocab_parallel_cross_entropy -import mpu -import torch.nn.functional as F -import torch -import random -import sys -sys.path.append("../..") - - -def torch_cross_entropy(batch_size, seq_length, vocab_size, - logits_scale, seed): - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), - scale=logits_scale).cuda() - logits = identity() - target = torch.cuda.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size) - loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), - target.view(-1), - reduction='none').view_as(target).mean() - loss.backward() - return loss, identity.weight.grad - - -def mpu_cross_entropy(batch_size, seq_length, vocab_size, - logits_scale, seed): - set_random_seed(seed) - identity = IdentityLayer((batch_size, seq_length, vocab_size), - scale=logits_scale).cuda() - logits = identity() - logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) - target = torch.cuda.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size) - loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() - loss.backward() - return loss, identity.weight.grad - - -def test_cross_entropy(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing cross entropy with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - batch_size = 13 - seq_length = 17 - vocab_size_per_partition = 11 - logits_scale = 1000.0 - vocab_size = vocab_size_per_partition * tensor_model_parallel_size - seed = 1234 - - loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, - vocab_size, logits_scale, - seed) - loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, - vocab_size, logits_scale, - seed) - - error = loss_torch.sub_(loss_mpu).abs().max() - print(' max error in loss on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = grad_torch.sub_(grad_mpu).abs().max() - print(' max error in grad on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_tensor_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test cross entropy') - test_cross_entropy(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_data.py b/megatron/mpu/tests/test_data.py deleted file mode 100644 index c30bf4bb8d4dbb0c2d576d20b18b4ae640d00d2c..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/test_data.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -from mpu import data as data_utils -import mpu -import torch -import functools -import operator -import sys -sys.path.append("../..") - - -def test_broadcast_data(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing broadcast_data with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - torch.manual_seed(1234 + mpu.get_data_parallel_rank()) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - key_size_t = {'key1': [7, 11], - 'key2': [8, 2, 1], - 'key3': [13], - 'key4': [5, 1, 2], - 'key5': [5, 12]} - keys = list(key_size_t.keys()) - - data = {} - data_t = {} - for key in key_size_t: - data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) - data_t[key] = data[key].clone() - data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) - data_t['keyX'] = data['keyX'].clone() - if mpu.get_tensor_model_parallel_rank() != 0: - data = None - - data_utils._check_data_types(keys, data_t, torch.int64) - key_size, key_numel, \ - total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) - for key in keys: - assert key_size[key] == key_size_t[key] - total_numel_t = 0 - for key in keys: - target_size = functools.reduce(operator.mul, key_size_t[key], 1) - assert key_numel[key] == target_size - total_numel_t += target_size - assert total_numel == total_numel_t - - data_b = data_utils.broadcast_data(keys, data, torch.int64) - for key in keys: - tensor = data_t[key].cuda() - assert data_b[key].sub(tensor).abs().max() == 0 - - # Reset groups - mpu.destroy_tensor_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test test broadcast data') - test_broadcast_data(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_initialize.py b/megatron/mpu/tests/test_initialize.py deleted file mode 100644 index e5d2be37e269d8176a987b8a6ef5d7f47de98394..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/test_initialize.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -import mpu -import torch -import sys -sys.path.append("../..") - - -def test_initialize_model_parallel(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing initialize_model_parallel with size {} ...'.format( - tensor_model_parallel_size)) - tensor_model_parallel_size_ = min(tensor_model_parallel_size, - torch.distributed.get_world_size()) - assert not mpu.model_parallel_is_initialized() - mpu.initialize_model_parallel(tensor_model_parallel_size_) - assert mpu.model_parallel_is_initialized() - - # Checks. - def check(group, world_size, rank): - assert world_size == torch.distributed.get_world_size(group=group) - assert rank == torch.distributed.get_rank(group=group) - - # Model parallel. - world_size = tensor_model_parallel_size_ - rank = torch.distributed.get_rank() % tensor_model_parallel_size_ - assert world_size == mpu.get_tensor_model_parallel_world_size() - assert rank == mpu.get_tensor_model_parallel_rank() - check(mpu.get_tensor_model_parallel_group(), world_size, rank) - - # Data parallel. - world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ - rank = torch.distributed.get_rank() // tensor_model_parallel_size - assert world_size == mpu.get_data_parallel_world_size() - assert rank == mpu.get_data_parallel_rank() - check(mpu.get_data_parallel_group(), world_size, rank) - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): - - if torch.distributed.get_rank() == 0: - print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( - tensor_model_parallel_size_)) - tensor_model_parallel_size = min(tensor_model_parallel_size_, - torch.distributed.get_world_size()) - assert not mpu.model_parallel_is_initialized() - mpu.initialize_model_parallel(tensor_model_parallel_size) - assert mpu.model_parallel_is_initialized() - - # Checks - src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() - assert mpu.get_tensor_model_parallel_src_rank() == src_rank - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test initialize model parallel') - test_initialize_model_parallel(tensor_model_parallel_size) - print_separator('test model parallel source rank') - test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_layers.py b/megatron/mpu/tests/test_layers.py deleted file mode 100644 index 73ad4b9459502dc2f68a8e3d0cb66157895eda1d..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/test_layers.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from mpu import layers -from commons import set_random_seed -from commons import print_separator -from commons import initialize_distributed -import mpu -from torch.nn.parameter import Parameter -import torch.nn.init as init -import torch -import random -import sys -sys.path.append("../..") - - -def test_parallel_embedding(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing parallel embedding with model parallel size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - batch_size = 17 - seq_length = 23 - vocab_size = 48 - hidden_size = 16 - seed = 1236 - - set_random_seed(123) - input_data = torch.LongTensor( - size=(batch_size, seq_length)).random_(0, vocab_size).cuda() - loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() - - set_random_seed(seed) - embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() - - output = embedding_original(input_data) - loss_original = torch.mul(output, loss_weight).sum() - loss_original.backward() - - set_random_seed(seed) - embedding_parallel = layers.ParallelEmbedding( - vocab_size, hidden_size, init_method=init.normal_).cuda() - output = embedding_parallel(input_data) - loss_parallel = torch.mul(output, loss_weight).sum() - loss_parallel.backward() - - set_random_seed(seed) - embedding_vocab_parallel = layers.VocabParallelEmbedding( - vocab_size, hidden_size, init_method=init.normal_).cuda() - output = embedding_vocab_parallel(input_data) - loss_vocab_parallel = torch.mul(output, loss_weight).sum() - loss_vocab_parallel.backward() - - torch.distributed.barrier() - error = loss_parallel.sub(loss_original).abs() - print(' error in loss (parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - torch.distributed.barrier() - error = loss_vocab_parallel.sub(loss_original).abs() - print(' error in loss (vocab parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - weight_grad_orig = torch.split(embedding_original.weight.grad, - hidden_size // tensor_model_parallel_size, - 1)[mpu.get_tensor_model_parallel_rank()] - error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() - print(' error in grad (parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - weight_grad_orig = torch.split(embedding_original.weight.grad, - vocab_size // tensor_model_parallel_size, - 0)[mpu.get_tensor_model_parallel_rank()] - error = embedding_vocab_parallel.weight.grad.sub( - weight_grad_orig).abs().max() - print(' error in grad (vocab parallel) on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-12, 'error: {}'.format(error) - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_initialize_affine_weight(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing initialize_affine_weight with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - - # --------------- - # Column parallel - # --------------- - weight = torch.empty(output_size_coeff, input_size) - set_random_seed(seed) - layers._initialize_affine_weight(weight, output_size, input_size, - - output_size_coeff, 0, - torch.nn.init.normal_) - # Target. - set_random_seed(seed) - master_weight = torch.empty(output_size, input_size) - torch.nn.init.normal_(master_weight) - rank = mpu.get_tensor_model_parallel_rank() - my_weight = torch.split(master_weight, output_size_coeff, - dim=0)[rank].contiguous().clone() - - # Compare. - error = weight.sub(my_weight).abs().max() - torch.distributed.barrier() - print(' column parallel max error (should be zero) on global rank ' - '{}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # ------------ - # Row parallel - # ------------ - weight = torch.empty(output_size, input_size_coeff) - set_random_seed(seed) - mpu.layers._initialize_affine_weight(weight, output_size, input_size, - input_size_coeff, 1, - torch.nn.init.normal_) - # Target. - set_random_seed(seed) - master_weight = torch.empty(output_size, input_size) - torch.nn.init.normal_(master_weight) - rank = mpu.get_tensor_model_parallel_rank() - my_weight = torch.split(master_weight, input_size_coeff, - dim=1)[rank].contiguous().clone() - - # Compare. - error = weight.sub(my_weight).abs().max() - torch.distributed.barrier() - print(' row parallel max error (should be zero) on global rank ' - '{}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -class IdentityLayer2D(torch.nn.Module): - def __init__(self, m, n): - super(IdentityLayer2D, self).__init__() - self.weight = Parameter(torch.Tensor(m, n)) - torch.nn.init.xavier_normal_(self.weight) - - def forward(self): - return self.weight - - -def test_column_parallel_linear(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing ColumnParallelLinear with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - batch_size = 7 - - # Network - identity_layer = IdentityLayer2D(batch_size, input_size).cuda() - linear_layer = mpu.ColumnParallelLinear( - input_size, output_size, keep_master_weight_for_test=True).cuda() - loss_weight = torch.randn([batch_size, output_size]).cuda() - # Forward - input_ = identity_layer() - output = linear_layer(input_) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - # Values. - dLdY = loss_weight - X = identity_layer.weight - A = linear_layer.master_weight.cuda() - dLdA = torch.matmul(dLdY.t(), X) - dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) - dLdX = torch.matmul(dLdY, A) - - rank = mpu.get_tensor_model_parallel_rank() - my_dLdA = torch.split(dLdA, output_size_coeff, - dim=0)[rank].contiguous().clone() - error = my_dLdA.sub(linear_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdA on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - my_dLdb = torch.split(dLdb, output_size_coeff, - dim=0)[rank].contiguous().clone() - error = my_dLdb.sub(linear_layer.bias.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdb on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdX.sub(identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdX on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -def test_row_parallel_linear(tensor_model_parallel_size): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - if torch.distributed.get_rank() == 0: - print('> testing RowParallelLinear with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - input_size_coeff = 13 - input_size = input_size_coeff * tensor_model_parallel_size - output_size_coeff = 17 - output_size = output_size_coeff * tensor_model_parallel_size - batch_size = 7 - - # Network - identity_layer = IdentityLayer2D(batch_size, input_size).cuda() - linear_layer = mpu.RowParallelLinear( - input_size, output_size, keep_master_weight_for_test=True).cuda() - loss_weight = torch.randn([batch_size, output_size]).cuda() - # Forward - input_ = identity_layer() - output = linear_layer(input_) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - # Values. - dLdY = loss_weight - X = identity_layer.weight - A = linear_layer.master_weight.cuda() - dLdA = torch.matmul(dLdY.t(), X) - dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) - dLdX = torch.matmul(dLdY, A) - - rank = mpu.get_tensor_model_parallel_rank() - my_dLdA = torch.split(dLdA, input_size_coeff, - dim=1)[rank].contiguous().clone() - error = my_dLdA.sub(linear_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdA on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdb.sub(linear_layer.bias.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdb on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - error = dLdX.sub(identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' error in dLdX on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -class IdentityLayer3D(torch.nn.Module): - def __init__(self, m, n, k): - super(IdentityLayer3D, self).__init__() - self.weight = Parameter(torch.Tensor(m, n, k)) - torch.nn.init.xavier_normal_(self.weight) - - def forward(self): - return self.weight - - -def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, - sequence_length): - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - - num_att_heads = num_att_heads_per_partition * \ - torch.distributed.get_world_size() - hidden_size = hidden_size_per_att_head * num_att_heads - - # Network - identity_layer = IdentityLayer3D(batch_size, sequence_length, - hidden_size).cuda() - attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, - dropout_prob).cuda() - loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() - attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() - # Forward - input_ = identity_layer() - output = attention_layer(input_, attention_mask) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - rank = mpu.get_tensor_model_parallel_rank() - mpu.destroy_model_parallel() - return rank, hidden_size, tensor_model_parallel_size, loss, \ - attention_layer, identity_layer - - -def test_parallel_self_attention(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing ParallelSelfAttention with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - - num_att_heads_per_partition = 3 - hidden_size_per_att_head = 7 - dropout_prob = 0.0 # has to be zero - batch_size = 5 - sequence_length = 13 - - rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \ - attention_layer_1, identity_layer_1 = parallel_self_attention( - 1, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) - - rank, hidden_size, tensor_model_parallel_size, loss, \ - attention_layer, identity_layer = parallel_self_attention( - tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) - assert hideen_size_1 == hidden_size - - error = loss_1.sub(loss).abs().max() - torch.distributed.barrier() - print(' loss error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - my_lin_grad_list = torch.split( - attention_layer_1.query_key_value.weight.grad, - hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size] - my_lin_grad = torch.cat(my_lin_grad_list, dim=0) - error = my_lin_grad.sub( - attention_layer.query_key_value.weight.grad).abs().max() - torch.distributed.barrier() - print(' weight gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - error = identity_layer_1.weight.grad.sub( - identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' input gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-6 - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length): - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed = 12345 - set_random_seed(seed) - - num_att_heads = num_att_heads_per_partition * \ - torch.distributed.get_world_size() - hidden_size = hidden_size_per_att_head * num_att_heads - intermediate_size = 4 * hidden_size - - # Network - identity_layer = IdentityLayer3D(batch_size, sequence_length, - hidden_size).cuda() - transformer_layer = mpu.BertParallelTransformerLayer( - hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, - torch.nn.functional.relu, 1.0e-5).cuda() - - loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() - attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() - # Forward - input_ = identity_layer() - output = transformer_layer(input_, attention_mask) - loss = torch.mul(output, loss_weight).sum() - # Backward - loss.backward() - - rank = mpu.get_tensor_model_parallel_rank() - mpu.destroy_model_parallel() - return rank, hidden_size, tensor_model_parallel_size, loss, \ - transformer_layer, identity_layer - - -def test_parallel_transformer_layer(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing ParallelTransformerLayer with model parallel ' - 'size: {}'.format(tensor_model_parallel_size)) - - num_att_heads_per_partition = 3 - hidden_size_per_att_head = 7 - batch_size = 5 - sequence_length = 13 - - rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \ - transformer_layer_1, identity_layer_1 = parallel_transformer( - 1, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length) - - rank, hidden_size, tensor_model_parallel_size, loss, \ - transformer_layer, identity_layer = parallel_transformer( - tensor_model_parallel_size, num_att_heads_per_partition, - hidden_size_per_att_head, batch_size, sequence_length) - - error = loss_1.sub(loss).abs().max() - torch.distributed.barrier() - print(' loss error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-5, 'error: {}'.format(error) - - error = identity_layer_1.weight.grad.sub( - identity_layer.weight.grad).abs().max() - torch.distributed.barrier() - print(' input gradient error on global rank {}: {}'.format( - torch.distributed.get_rank(), error)) - assert error < 5.0e-5, 'error: {}'.format(error) - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(' >> passed the test :-)') - - -if __name__ == '__main__': - - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - print_separator('test initialize affine weight') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_initialize_affine_weight(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test parallel embedding') - test_parallel_embedding(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test column-parallel linear') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_column_parallel_linear(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test row-parallel linear') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_row_parallel_linear(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test parallel self-attention') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_parallel_self_attention(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - print_separator('test parallel transformer') - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - test_parallel_transformer_layer(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/mpu/tests/test_random.py b/megatron/mpu/tests/test_random.py deleted file mode 100644 index 8ee6942cf01fd7d9c93012c37f7b5e4b351f3c15..0000000000000000000000000000000000000000 --- a/megatron/mpu/tests/test_random.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from commons import print_separator -from commons import initialize_distributed -import mpu -import torch -import sys -sys.path.append("../..") - - -def test_set_cuda_rng_state(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing set_rng_state with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - size = 123 - seed = 1234 - torch.cuda.manual_seed(1234) - tensor = torch.cuda.FloatTensor(size) - - # Get the state - rng_state = torch.cuda.get_rng_state() - rng_state_copy = rng_state.clone() - - # Do some stuff. - for _ in range(5): - torch.randn(size, out=tensor) - result_1 = tensor.clone() - - assert rng_state.sub(rng_state_copy).max() == 0 - assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 - - # State should be different. - new_rng_state = torch.cuda.get_rng_state() - max_diff = new_rng_state.sub(rng_state).max() - print(' max diff in rng state (should be non-zero) on global rank {}: {}'. - format(torch.distributed.get_rank(), max_diff)) - assert max_diff > 0 - - # Reset the rng state and do the same stuff. - mpu.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - mpu.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - result_2 = tensor.clone() - - # Results should be the same - error = result_2.sub(result_1).abs().max() - print(' max error in generated tensors (should be zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Input state should have remained intact. - error = rng_state.sub(rng_state_copy).max() - print(' max error in rng state (should be zero) on global rank {}: {}'. - format(torch.distributed.get_rank(), error)) - assert error == 0 - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_cuda_rng_tracker(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing cuda rng tracker with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - seed_1 = 1234 - seed_2 = 4321 - size = [12, 21] - tensor = torch.cuda.FloatTensor(size) - - # Set to seed_1 and generate two tensors. - torch.cuda.manual_seed(seed_1) - torch.randn(size, out=tensor) - target_11 = tensor.clone() - torch.randn(size, out=tensor) - target_12 = tensor.clone() - - # Set to seed_2 and generate two tensors. - torch.cuda.manual_seed(seed_2) - torch.randn(size, out=tensor) - target_21 = tensor.clone() - torch.randn(size, out=tensor) - target_22 = tensor.clone() - - # Now if we interleave seed_1 and seed_2, - # we should still get the same tensors - torch.cuda.manual_seed(seed_1) - mpu.get_cuda_rng_tracker().add('test', seed_2) - - torch.randn(size, out=tensor) - result_11 = tensor.clone() - - with mpu.get_cuda_rng_tracker().fork('test'): - torch.randn(size, out=tensor) - result_21 = tensor.clone() - - torch.randn(size, out=tensor) - result_12 = tensor.clone() - - with mpu.get_cuda_rng_tracker().fork('test'): - torch.randn(size, out=tensor) - result_22 = tensor.clone() - - diff = result_11.sub(result_21).abs().max() - diff = min(diff, result_12.sub(result_22).abs().max()) - print(' max diff in generated tensors (should be non-zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) - assert diff > 1.0e-6 - error = max(result_11.sub(target_11).abs().max(), - result_12.sub(target_12).abs().max()) - error = max(error, result_21.sub(target_21).abs().max()) - error = max(error, result_22.sub(target_22).abs().max()) - print(' max error in generated tensors (should be zero) on ' - 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) - assert error < 1.0e-6 - - # Reset the tracker - mpu.get_cuda_rng_tracker().reset() - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): - - if torch.distributed.get_rank() == 0: - print('> testing model parallel cuda manual seed with size {} ...'. - format(tensor_model_parallel_size)) - - mpu.initialize_model_parallel(tensor_model_parallel_size) - tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() - - mpu.model_parallel_cuda_manual_seed(12345) - assert torch.cuda.initial_seed() == 12345 - with mpu.get_cuda_rng_tracker().fork(): - assert torch.cuda.initial_seed() == (12345 + 2718 + - mpu.get_tensor_model_parallel_rank()) - - # Reset the tracker - mpu.get_cuda_rng_tracker().reset() - - # Reset groups - mpu.destroy_model_parallel() - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') - - -if __name__ == '__main__': - - initialize_distributed() - world_size = torch.distributed.get_world_size() - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test set rng state') - test_set_cuda_rng_state(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test cuda rng tracker') - test_cuda_rng_tracker(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 - - tensor_model_parallel_size = 1 - while tensor_model_parallel_size <= world_size: - print_separator('test model parallel cuda manual seed') - test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) - tensor_model_parallel_size *= 2 diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py deleted file mode 100644 index 33744a2f3aded055c21d18700c32e3c240d440c0..0000000000000000000000000000000000000000 --- a/megatron/optimizer/__init__.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -from apex.optimizers import FusedAdam as Adam -from apex.optimizers import FusedSGD as SGD - -from megatron import get_args - -from .distrib_optimizer import DistributedOptimizer -from .grad_scaler import ConstantGradScaler, DynamicGradScaler -from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer - -def get_param_groups(modules, - no_weight_decay_cond, - scale_lr_cond, - lr_mult): - """creates param groups based on weight decay condition (regularized vs non regularized) - and learning rate scale condition (args.lr vs lr_mult * args.lr) - scale_lr_cond is used during finetuning where head of the network requires a scaled - version of the base learning rate. - """ - wd_no_scale_lr = [] - wd_scale_lr = [] - no_wd_no_scale_lr = [] - no_wd_scale_lr = [] - for module in modules: - for name, param in module.named_parameters(): - if not param.requires_grad: - continue - - if no_weight_decay_cond is not None: - no_wd = no_weight_decay_cond(name, param) - else: - # do not regularize biases nor Norm parameters - no_wd = name.endswith(".bias") or len(param.shape) == 1 - - if scale_lr_cond is not None: - scale_lr = scale_lr_cond(name, param) - else: - scale_lr = False - - if not no_wd and not scale_lr: - wd_no_scale_lr.append(param) - elif not no_wd and scale_lr: - wd_scale_lr.append(param) - elif no_wd and not scale_lr: - no_wd_no_scale_lr.append(param) - else: - no_wd_scale_lr.append(param) - - param_groups = [] - if len(wd_no_scale_lr): - param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0}) - if len(wd_scale_lr): - param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult}) - if len(no_wd_no_scale_lr): - param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0}) - if len(no_wd_scale_lr): - param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult}) - - return param_groups - -def get_megatron_optimizer(model, - no_weight_decay_cond=None, - scale_lr_cond=None, - lr_mult=1.0): - args = get_args() - - # Base optimizer. - param_groups = get_param_groups(model, - no_weight_decay_cond, - scale_lr_cond, - lr_mult) - - if args.optimizer == 'adam': - optimizer = Adam(param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps) - elif args.optimizer == 'sgd': - optimizer = SGD(param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - momentum=args.sgd_momentum) - else: - raise Exception('{} optimizer is not supported.'.format( - args.optimizer)) - - # Determine whether the params have main-grad field. - params_have_main_grad = True - - # Mixed precision optimizer. - # - Note: both the Float16Optimizer and the DistributedOptimizer inherit - # from the MixedPrecisionOptimizer, which manages any optimizer where - # the model params and main params are distinct. - if args.fp16 or args.bf16 or args.use_distributed_optimizer: - - # Grad scaler: - # if loss-scale is provided, instantiate the constant scaler. - # if we are using fp16 and loss-scale is not present, use a - # dynamic scaler. - # otherwise we are running in bf16 with no loss-scale so - # leave it as None. - grad_scaler = None - - # Constant loss scale. - if args.loss_scale: - grad_scaler = ConstantGradScaler(args.loss_scale) - - # Dynamic loss scale. - else: - if args.fp16: - grad_scaler = DynamicGradScaler( - initial_scale=args.initial_loss_scale, - min_scale=args.min_loss_scale, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=args.loss_scale_window, - hysteresis=args.hysteresis) - - # Megatron optimizer. - opt_ty = DistributedOptimizer \ - if args.use_distributed_optimizer else \ - Float16OptimizerWithFloat16Params - return opt_ty(optimizer, - args.clip_grad, - args.log_num_zeros_in_grad, - args.check_for_nan_in_loss_and_grad, - params_have_main_grad, - args.fp16, - args.bf16, - args.params_dtype, - grad_scaler, - model) - - # FP32. - return FP32Optimizer(optimizer, args.clip_grad, - args.log_num_zeros_in_grad, - args.check_for_nan_in_loss_and_grad, - params_have_main_grad, - model) diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py deleted file mode 100644 index d6e38afb58a01ce546e6939bdc6a381c943942e2..0000000000000000000000000000000000000000 --- a/megatron/optimizer/clip_grads.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Gradient clipping.""" - -import os - -import torch -from torch import inf - -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C - -from megatron.model.module import param_is_not_shared -from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate - - -def clip_grad_norm_fp32(parameters, grads_for_norm, - max_norm, check_for_nan_in_grad, - norm_type=2, model_parallel_group=None): - """Clips gradient norm of an iterable of parameters whose gradients - are in fp32. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single - Tensor that will be used for calculating the grad norm. - max_norm (float or int): max norm of the gradients. - check_for_nan_in_grad (bool): check if gradients have a NaN. - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - model_parallel_group (group): given the nature of the distributed - optimizer, this is passed as an argument. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - if isinstance(grads_for_norm, torch.Tensor): - grads_for_norm = [grads_for_norm] - - # Grads. - grads = [] - for param in parameters: - if param.grad is not None: - assert param.grad.type() == 'torch.cuda.FloatTensor' - grads.append(param.grad.detach()) - - # Norm parameters. - max_norm = float(max_norm) - norm_type = float(norm_type) - total_norm = 0.0 - - # Calculate norm. - if norm_type == inf: - total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - # Take max across all model-parallel GPUs. - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=model_parallel_group) - total_norm = total_norm_cuda[0].item() - - else: - if norm_type == 2.0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - if grads_for_norm: - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False # no per-parameter norm - ) - else: - grad_norm = torch.cuda.FloatTensor([0]) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm ** norm_type - - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type - - # Check individual rank grad norms are not NaN - # prior to model-parallel all-reduce. - if check_for_nan_in_grad: - global_rank = torch.distributed.get_rank() - assert not total_norm.isnan(), ( - f'Rank {global_rank}: found NaN in local grad norm in ' - f'backwards pass. Device: {torch.cuda.current_device()}, ' - f'node: {os.uname()[1]}' - ) - - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce(total_norm, - op=torch.distributed.ReduceOp.SUM, - group=model_parallel_group) - total_norm = total_norm.item() ** (1.0 / norm_type) - - # Scale. - clip_coeff = max_norm / (total_norm + 1.0e-6) - if clip_coeff < 1.0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads, grads], - clip_coeff) - - return total_norm - - -def count_zeros_fp32(parameters, model_parallel_group): - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - total_num_zeros = torch.cuda.FloatTensor([0.0]) - for param in parameters: - grad_not_none = param.grad is not None - is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) - if grad_not_none and is_not_shared and is_not_tp_duplicate: - grad = param.grad.detach() - num_zeros = grad.numel() - torch.count_nonzero(grad) - total_num_zeros = num_zeros + total_num_zeros - - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce(total_num_zeros, - op=torch.distributed.ReduceOp.SUM, - group=model_parallel_group) - - total_num_zeros = total_num_zeros.item() - - return total_num_zeros diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py deleted file mode 100644 index d58b1b08fcdfe81d5603cc5090d6362a3f48d2ac..0000000000000000000000000000000000000000 --- a/megatron/optimizer/distrib_optimizer.py +++ /dev/null @@ -1,1161 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Megatron distributed optimizer.""" - - -from apex.optimizers import FusedAdam as Adam -import math -import torch - -from megatron import get_args -from megatron import get_timers -from megatron import print_rank_0 -from megatron.core import mpu, tensor_parallel - -from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper -from .utils import shard_buffer - - - -class Range: - """ - A range represents a start and end points for indexing a shard - from a full tensor. - """ - def __init__(self, start, end): - self.start = start - self.end = end - self.size = end - start - def normalize(self, start = 0): - return Range(start, start + self.size) - def __str__(self): - return "%d,%d [%d]" % (self.start, self.end, self.size) - def __len__(self): - return self.end - self.start - - -class DistributedOptimizer(MixedPrecisionOptimizer): - """Distributed optimizer, for all data types (fp16, bf16, and fp32). - - Arguments: - optimizer: base optimizer such as Adam or SGD - clip_grad: clip gradeints with this global L2 norm. Note - that clipping is ignored if clip_grad == 0 - log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. - params_have_main_grad: flag indicating if parameters have - a `main_grad` field. If this is set, we are assuming - that the model parameters are store in the `main_grad` - field instead of the typical `grad` field. This happens - for the DDP cases where there is a continuous buffer - holding the gradients. For example for bfloat16, we want - to do gradient accumulation and all-reduces in float32 - and as a result we store those gradients in the main_grad. - Note that main grad is not necessarily in float32. - fp16: if true, the model is running in fp16. - bf16: if true, the model is running in bfloat16. - grad_scaler: used for scaling gradients. Note that this can be - None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constnat gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - models: list of models (i.e., the virtual pipelining models). This - is used by the distributed optimizer for mapping parameters. - """ - - @classmethod - def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range, bucket_offset): - """ - Build mapping from param reference to grad buffer shard ranges. - - This method builds a mapping from parameter references to grad - buffer shard ranges, specific to each data-parallel (DP) rank's - set of 'owned' parameters. Each grad buffer (padded to be an even - multiple of DP-world-size) is conceptually divided into DP-world-size - contiguous regions, where each DP rank 'owns' a contiguous regions. - Ownership in this sense means DP rank is responsible for reducing - the relevant subset of grads, and updating the relevant subset of - params. - - This conceptual partitioning of the grad buffer does NOT respect - parameter boundaries, and as such it is assumed that each created - range references a shard (or subset) of the full parameter. It is - easiest to think of each DP rank as operating (i.e., reducing, - gathering) purely on views into the grad buffer, for all model-to- - main & main-to-model operations. - - This method creates four ranges: - - The param's range within the entire grad buffer (i.e., world index). - - The param's range within the relevant grad bucket's buffer. - - The param's range within the DP rank's local view of the grad buffer. - - The param's range within itself (i.e., its shard). - """ - - # Param range map. - param_world_index_map = model.grad_buffer_param_index_map[dtype] - param_range_map = {} - for param, param_world_indexes in param_world_index_map.items(): - - # Param range. - param_world_start, param_world_end, _ = param_world_indexes - param_local_start = max( - 0, - param_world_start - gbuf_world_range.start) - param_local_end = min( - gbuf_world_range.size, - param_world_end - gbuf_world_range.start) - - # Add param, if within local gbuf range. - if param_local_end > param_local_start: - param_local_range = Range(param_local_start, param_local_end) - param_world_range = param_local_range.normalize( - param_local_start + gbuf_world_range.start) - param_world_range_in_bucket = Range(param_world_range.start-bucket_offset, - param_world_range.end-bucket_offset) - sub_param_start = max(0, gbuf_world_range.start-param_world_start) - sub_param_range = param_local_range.normalize(sub_param_start) - param_range_map[param] = { - "gbuf_world" : param_world_range, - "gbuf_world_in_bucket": param_world_range_in_bucket, - "gbuf_local" : param_local_range, - "param" : sub_param_range, - } - - return param_range_map - - - @classmethod - def build_model_gbuf_range(cls, model, dtype, bucket_index): - """ - Build mapping between params and their grad buffers. - - This method does the initial setup for the method above. This setup - includes determining the shard ranges into the DDP's grad buffer for - each data-parallel (DP) rank. Each DP rank keeps range info for - all other DP ranks, for the purpose of creating args for - reduce-scatter and all-gather. - """ - - data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True) - data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - - bucket = model.grad_buffers[dtype].buckets[bucket_index] - bucket_buffer = bucket.data - gbuf_size = bucket_buffer.numel() - assert gbuf_size % data_parallel_world_size == 0, \ - f"Each bucket's buffer size should be divisible by {data_parallel_world_size}" - max_gbuf_range_size = gbuf_size // data_parallel_world_size - - # All world ranges (i.e., across all data parallel ranks). - gbuf_world_all_ranges = [] - for r in range(data_parallel_world_size): - # Compute start of chunk in this bucket. - gbuf_world_start = r * max_gbuf_range_size - gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size) - # Add bucket's offset in grad buffer. - gbuf_world_range = Range(gbuf_world_start + bucket.offset, - gbuf_world_end + bucket.offset) - gbuf_world_all_ranges.append(gbuf_world_range) - - # Local DP's ranges. - gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank] - - # Get each param's ranges. - param_range_map = cls.build_model_gbuf_param_range_map(model, - dtype, - gbuf_world_range, - bucket.offset) - - # Group into dict. - data = { - "param_map" : param_range_map, - } - - return data - - - @classmethod - def build_model_gbuf_range_map(cls, model): - """ - Create param-to-grad-buffer mappings, for grad buffer data types - within a specific virtual model. - """ - # Iterate through all buckets to construct param ranges that this rank "owns" - # (the dp_rank'th shard of each bucket, where each shard is 1/dp_world_size - # of the bucket). - return { - dtype : [cls.build_model_gbuf_range(model, dtype, bucket_index) - for bucket_index in range(len(model.grad_buffers[dtype].buckets))] - for dtype in model.grad_buffers - } - - - @classmethod - def build_model_param_gbuf_map(cls, model_gbuf_ranges): - """ - Create a reverse of the model_gbuf_ranges, for referencing in - opposite direction. - """ - param_gbuf_map = {} - for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges): - for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items(): - for bucket_index, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): - for param, _ in gbuf_range_map["param_map"].items(): - assert param not in param_gbuf_map, \ - "Param should not be in param_gbuf_map; each param only belongs to a single bucket" - param_gbuf_map[param] = (model_index, dtype, bucket_index) - return param_gbuf_map - - - @classmethod - def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges): - """ - Create optimizer groups. - - Given the set of parameter shard ranges that are owned by the current - data-parallel (DP) rank, gather the set of parameters that will be - used (in the method below) to create the current DP's optimizer - groups. - """ - - num_groups = len(param_groups) - - # Param group map. - # World param group map. - # - Store a mapping of for all parameters - # across all DP ranks. This is necessary because it is our first - # cross reference between the DDP mappings and the optimizer group - # parameters. This mapping only for use in the next step of building - # the local mapping over this DP rank's parameters. - world_param_group_map = {} - for group_index, group in enumerate(param_groups): - for param in group["params"]: - assert param.requires_grad - world_param_group_map[param] = group_index - - # Optimizer group ranges & param-group mapping. - # - Build a mapping from groups to their contained parameters, and also - # from parameters to their containing group index and order within - # the group. The group index and order are particularly important for - # saving and loading checkpoints. - local_param_group_map = {} - group_ranges = [ {"params": []} for _ in param_groups ] - for model_gbuf_range_map in model_gbuf_ranges: - for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items(): - for gbuf_range_map in gbuf_range_map_for_all_buckets: - for param in gbuf_range_map["param_map"]: - group_index = world_param_group_map[param] - group_range = group_ranges[group_index] - group_range["params"].append(param) - local_param_group_map[param] = \ - (group_index, len(group_range["params"]) - 1) - - # Squeeze zero-size group ranges. - for group_index, group_range in enumerate(group_ranges): - group_range["orig_group"] = param_groups[group_index] - group_range["orig_group_idx"] = param_groups[group_index] - - return local_param_group_map, group_ranges - - - @classmethod - def build_model_and_main_param_groups(cls, - model_gbuf_ranges, - param_gbuf_map, - opt_group_ranges): - """ - Create main parameter groups needed for the optimizer step. - - These groups encompass both: 1) groups used by this class, for - reducing/gather, and 2) groups used by the inner optimizer for the - parameter update. Given that the conceptual grad buffer partitioning - (created in earlier method) doesn't respect parameter boundaries, - the optimizer operates on shards of the model parameters, rather than - the full parameters. - """ - - # Parameter groups: - # model_float16_groups: original float16 parameters - # model_fp32_groups: original fp32 parameters - # shard_float16_groups: shards of original float16 parameters - # shard_fp32_groups: shards of original fp32 parameters - # shard_fp32_from_float16_groups: fp32 copy of float16 parameters - model_float16_groups = [] - model_fp32_groups = [] - shard_float16_groups = [] - shard_fp32_groups = [] - shard_fp32_from_float16_groups = [] - - # Allocate (or slice) each group's param shard. - for group_index, group_range in enumerate(opt_group_ranges): - - # Params of this group. - model_float16_params_this_group = [] - model_fp32_params_this_group = [] - shard_float16_params_this_group = [] - shard_fp32_params_this_group = [] - shard_fp32_from_float16_params_this_group = [] - model_float16_groups.append(model_float16_params_this_group) - model_fp32_groups.append(model_fp32_params_this_group) - shard_float16_groups.append(shard_float16_params_this_group) - shard_fp32_groups.append(shard_fp32_params_this_group) - shard_fp32_from_float16_groups.append( - shard_fp32_from_float16_params_this_group) - - for model_param in group_range["params"]: - - assert model_param.requires_grad - - model_index, dtype, bucket_index = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] - param_range = gbuf_range["param_map"][model_param]["param"] - - # fp16, bf16 params. - if model_param.type() in ['torch.cuda.HalfTensor', - 'torch.cuda.BFloat16Tensor']: - - # Clone model -> main. - shard_model_param = model_param.detach().view(-1) \ - [param_range.start:param_range.end] - shard_main_param = shard_model_param.clone().float() - tensor_parallel.copy_tensor_model_parallel_attributes( - shard_model_param, model_param) - tensor_parallel.copy_tensor_model_parallel_attributes( - shard_main_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - shard_main_param.shared = model_param.shared - - # Add to group. - model_float16_params_this_group.append(model_param) - shard_float16_params_this_group.append(shard_model_param) - shard_fp32_from_float16_params_this_group.append(shard_main_param) - - # fp32 params. - elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1) \ - [param_range.start:param_range.end] - model_fp32_params_this_group.append(model_param) - shard_fp32_params_this_group.append(shard_model_param) - tensor_parallel.copy_tensor_model_parallel_attributes( - shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - - else: - raise TypeError('Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(model_param.type())) - - # Update optimizer's params. - group_range["orig_group"]["params"] = [ - *shard_fp32_params_this_group, - *shard_fp32_from_float16_params_this_group, - ] - - return ( - model_float16_groups, - model_fp32_groups, - shard_float16_groups, - shard_fp32_groups, - shard_fp32_from_float16_groups, - ) - - - def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, - bf16, params_dtype, grad_scaler, models): - """ - See top of class definition for argument descriptions. - - The steps in this method create the core mapping between DDP grad - buffers, parameters, and parameter shard ranges, that is needed for - converting between model param indexes and main parameter shard - indexes. This method also updates the optimizer parameter groups - with the newly created shards. - """ - - super().__init__( - optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, - fp16, bf16, params_dtype, grad_scaler, models) - - assert isinstance(optimizer, Adam), \ - "Only Adam currently supported, due to checkpointing requirements." - - # Model grad buffer ranges. - self.model_gbuf_ranges = [] - self.per_bucket_numel = [] - for _, model_chunk in enumerate(self.models): - self.per_bucket_numel.append( - {dtype: [bucket.data.numel() for bucket in model_chunk.grad_buffers[dtype].buckets] - for dtype in model_chunk.grad_buffers}) - self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model_chunk)) - self.model_param_gbuf_map = \ - self.build_model_param_gbuf_map(self.model_gbuf_ranges) - - # Optimizer ranges. - self.model_param_group_index_map, self.opt_group_ranges = \ - self.build_optimizer_group_ranges(self.optimizer.param_groups, - self.model_gbuf_ranges) - - # Allocate main param shards. - ( - self.model_float16_groups, - self.model_fp32_groups, - self.shard_float16_groups, - self.shard_fp32_groups, - self.shard_fp32_from_float16_groups, - ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, - self.model_param_gbuf_map, - self.opt_group_ranges) - - # Initialize param buffers. - # - These are views on the DDP model's grad buffers, that share - # storage & have their own dtype. This is safe because the param - # dtype size is always <= grad dtype size. - self.param_buffers = [] - for model_index, model in enumerate(self.models): - current_param_buffers = {} - for dtype, grad_buffer in model.grad_buffers.items(): - size_ratio = torch.finfo(dtype).bits // torch.finfo(params_dtype).bits - current_param_buffers[dtype] = [] - for bucket in grad_buffer.buckets: - - # Handle older/newer method for getting untyped storage. - try: - storage = bucket.data.storage()._untyped() - except: - storage = bucket.data.storage().untyped() - - # Typed param buffer. - param_buffer = torch.tensor( - storage, - dtype = params_dtype, - device = bucket.data.device) - - # .storage() ignores views / slices, so param_buffer now points to the start - # of the grad_buffer instead of to the start of each bucket. As a result, - # add bucket.offset to make sure param_buffers point to the right region of - # memory. - # Since we want the start of each bucket's param_buffer to coincide with the - # start of the same bucket's grad_buffer (this ensures that zeroing the grad - # buffer does not zero out params in the param_buffer before they are copied - # into the model_params), multiply the offset by the size ratio of grads and - # params. - offset = bucket.offset * size_ratio - param_buffer = param_buffer[offset:offset+bucket.data.numel()] - assert param_buffer.data_ptr() == bucket.data.data_ptr(), \ - "param_buffer and grad_buffer for same bucket should start at the same byte address" - assert param_buffer.numel() == bucket.data.numel(), \ - "param_buffer and grad_buffer for same bucket should have the same number of elements" - current_param_buffers[dtype].append(param_buffer) - self.param_buffers.append(current_param_buffers) - - # Now construct data structures to manage all-gather handles. - self.all_gather_handles = [] - self.all_gather_handle_index_to_bucket_index_map = [] - self.model_index_to_all_gather_handle_index_map = {} - self.param_to_all_gather_handle_index_map = {} - self.param_buffer_copied = [] - - self.pbuf_view_items = self.get_model_param_buffer_dp_views() - for (model_index, dtype, bucket_index, _, _) in self.pbuf_view_items: - self.all_gather_handle_index_to_bucket_index_map.append((model_index, dtype, bucket_index)) - all_gather_handle_index = len(self.all_gather_handle_index_to_bucket_index_map) - 1 - - # Store all all_gather_handle_indices relevant to a particular model chunk. - if model_index not in self.model_index_to_all_gather_handle_index_map: - self.model_index_to_all_gather_handle_index_map[model_index] = [] - self.model_index_to_all_gather_handle_index_map[model_index].append(all_gather_handle_index) - - for param in self.models[model_index].grad_buffers[dtype].buckets[bucket_index].params_list: - self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index - self.param_buffer_copied.append(False) - self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map) - - self.overlap_param_gather = get_args().overlap_param_gather - if self.overlap_param_gather: - self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook( - self._make_forward_pre_hook()) - else: - self.remove_pre_hook_handle = None - - self.update_successful = False - - # Update optimizer groups. - # - Also, leverage state_dict() and load_state_dict() to - # recast preexisting per-param state tensors. - self.optimizer.param_groups = \ - [ g["orig_group"] for g in self.opt_group_ranges ] - self.optimizer.load_state_dict(self.optimizer.state_dict()) - - - def get_model_param_range_map(self, param): - """ - Given a model param, get the index sub-range of the param that this - data-parallel rank owns. - """ - model_index, dtype, bucket_index = self.model_param_gbuf_map[param] - gbuf_range_map = self.model_gbuf_ranges[model_index][dtype][bucket_index] - param_range_map = gbuf_range_map["param_map"][param] - return param_range_map - - - def get_model_parallel_group(self): - """ - With the distributed optimizer, the model parallel group is the - entire world. - """ - return None - - - def state_dict(self): - """ - The state dict contains all non-DP-rank-dependent (i.e., non-parameter- - related) optimizer variables. The returned state dict can be stored in - the standard model/RNG checkpoint file. The parameter and dependent - optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate - checkpoint file by calling 'save_parameter_state()'. - """ - - state_dict = {} - - # Optimizer state (do not store parameter state here). - state_dict['optimizer'] = { - k : v - for k, v in self.optimizer.state_dict().items() - if k != "state" - } - for param_group in state_dict["optimizer"]["param_groups"]: - del param_group["params"] - - # Grad scaler state. - if self.grad_scaler: - state_dict['grad_scaler'] = self.grad_scaler.state_dict() - - return state_dict - - - def load_state_dict(self, state_dict): - """Load the state dict. - - As detailed in state_dict(), the state dict contains all non- - parameter-related variables. This method is notably longer than - state_dict(), because the Torch optimizers state has yet to be - allocated at this point, and so we must do a cross referencing between - the optimizers state (and the ordering it expects for parameter state) - and this DP rank's shards. The optimizer at this point does not contain - any tensor dimension information, so we must get these dimensions from - the DP shards mapped during DistributedOptimizer.__init__(). - - The tensor parameter state is loaded via load_parameter_state(), and - so this method also must populate the loaded state dict with dummy - tensor data (i.e., via torch.empty() below). This will be overwritten - during load_parameter_state(). - - ** Note: Torch optimizer's state structure. ** - The Torch optimizer stores its state in two levels. The top level is a - list of groups, where each group contains a list of integer indexes - (corresponding to parameters) that index into a master parameter list - that is shared by all groups. As such, three values are necessary for - maintaining this ordering: - - - group_index : The group to which a parameter belongs. - - group_order : The index of a parameter within its group. - - state_order : The index of a parameter within the shared parameter - list. - """ - - # Get the Torch optimizer's state dict. - # - This 'inner' optimizer at this point is unallocated, and only - # contains an integer odering of parameters within each group, and - # the ordering of parameters within its flattened parameter state - # list. - inner_state_dict = self.optimizer.state_dict() - state_dict_param_groups = [{ - **group, - "params" : list(inner_state_dict["param_groups"][idx]["params"]), - } for idx, group in enumerate(state_dict["optimizer"]["param_groups"])] - - # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below) - # - Real data is overwritten during load_parameter_state(). - state_dict_state = [] - for gbuf_range_maps in self.model_gbuf_ranges: - for gbuf_range_map_for_all_buckets in gbuf_range_maps.values(): - for gbuf_range_map in gbuf_range_map_for_all_buckets: - for model_param, param_range_map in \ - gbuf_range_map["param_map"].items(): - - # Get parameter ordering information (see method docstring - # for details). - group_index, group_order = \ - self.model_param_group_index_map[model_param] - state_order = inner_state_dict["param_groups"] \ - [group_index]["params"][group_order] - - # Allocate dummy tensors. - numel = len(param_range_map["gbuf_world"]) - init_shard = lambda : torch.empty( - (numel,), - dtype=torch.float32, - device=torch.cuda.current_device()) - - state_dict_state.append((state_order, { - "exp_avg" : init_shard(), - "exp_avg_sq" : init_shard(), - })) - - # Sort by state order (see method docstring for details). - state_dict_state.sort(key = lambda s : s[0]) - state_dict_state = {s[0]:s[1] for s in state_dict_state} - - # Optimizer. - self.optimizer.load_state_dict({ - "state" : state_dict_state, - "param_groups" : state_dict_param_groups, - }) - - # Grad scaler. - if 'grad_scaler' not in state_dict: - if self.fp16: - print_rank_0('***WARNING*** found an old checkpoint, will not ' - 'load grad scaler ...') - else: - if self.grad_scaler: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) - else: - print_rank_0('***WARNING*** fould the grad scaler in the ' - 'checkpoint but it is None in the class. ' - 'Skipping loading grad scaler ...') - - - def save_parameter_state(self, filename): - """Save parameter state (i.e., parameter & optimizer tensors). - - This method performs three steps: - - For each DP rank, copy param & optimizer shards to contiguous CPU - buffers. (e.g., one buffer each for main_param, exp_avg, and - exp_avg_sq). - - Gather contiguous buffers on DP rank 0 and concatenate to world - buffers. - - Save world buffers to disk (i.e., distrib_opt.pt). - """ - - # Data parallelism variables. - data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True) - data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True) - data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP) - - # Collect param states. - state = {"per_bucket_numel": self.per_bucket_numel} - for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): - - # Iterate grad buffers (by data type). - dtype_state = {} - assert len(gbuf_range_maps) == 1, "single dtype supported, for now." - for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): - world_tensors = {} - for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): - - # Compute local DP contiguous shard's size. - model = self.models[model_idx] - gbuf_world_numel = model.grad_buffers[dtype].buckets[bucket_idx].data.numel() - assert gbuf_world_numel % data_parallel_world_size == 0 - gbuf_local_numel = gbuf_world_numel // data_parallel_world_size - local_shards = {key: torch.empty((gbuf_local_numel,), - dtype=torch.float32, - device="cpu") - for key in ("param", "exp_avg", "exp_avg_sq")} - - # Build contiguous DP rank shards (for param + optim states). - for model_param, param_range_map in \ - gbuf_range_map["param_map"].items(): - - # Main param & optimizer states. - group_index, group_order = \ - self.model_param_group_index_map[model_param] - main_param = self.optimizer.param_groups \ - [group_index]["params"][group_order] - optim_state = self.optimizer.state[main_param] - - tensors = { - "param" : main_param, - **optim_state, - } - - # Copy states into contiguous shard. - gbuf_local_start = param_range_map["gbuf_local"].start - gbuf_local_end = param_range_map["gbuf_local"].end - for key in local_shards: - local_shards[key][gbuf_local_start:gbuf_local_end] \ - .data.copy_(tensors[key].detach().cpu()) - - # Gather contiguous shards on DP rank 0. - for key, send_tensor in local_shards.items(): - - # Gather tensor list. - if data_parallel_rank == 0: - recv_tensors = [torch.empty((gbuf_local_numel,), - dtype=torch.float32, - device="cpu") - for _ in range(data_parallel_world_size)] - else: - recv_tensors = None - - # Gather. - torch.distributed.gather( - send_tensor, - recv_tensors, - data_parallel_global_ranks[0], - data_parallel_group_gloo, - ) - - # Concatenate. - if data_parallel_rank == 0: - if key not in world_tensors: - world_tensors[key] = [] - world_tensors[key].append(torch.cat(recv_tensors)) - - # Collect world state. - dtype_state[dtype] = world_tensors - state[model_idx] = dtype_state - - # Save param state. - if data_parallel_rank == 0: - torch.save(state, filename) - - - def load_parameter_state(self, filename): - """Load parameter state (i.e., parameter & optimizer tensors). - - This method performs the reverse of save_parameter_state(): - - Load world buffers from disk (i.e., distrib_opt.pt). - - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP - rank receives its relevant subset of the world buffers). - - For each DP rank, copy param & optimizer shards from contiguous CPU - buffers. (e.g., one buffer each for main_param, exp_avg, and - exp_avg_sq). - """ - - # Data parallelism variables. - data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True) - data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True) - data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP) - - # Load on DP rank 0. - if data_parallel_rank == 0: - loaded_state = torch.load(filename) - if "per_bucket_numel" in loaded_state: - per_bucket_numel_in_checkpoint = loaded_state["per_bucket_numel"] - assert self.per_bucket_numel == per_bucket_numel_in_checkpoint, \ - (f"Number of elements in each bucket need to be the same in current run " - f"({self.per_bucket_numel}) and checkpoint ({per_bucket_numel_in_checkpoint})") - - # Scatter tensors to all DP ranks. - for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): - for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): - for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): - - # Compute local DP contiguous shard's size. - model = self.models[model_idx] - gbuf_world_numel = model.grad_buffers[dtype].buckets[bucket_idx].data.numel() - assert gbuf_world_numel % data_parallel_world_size == 0 - gbuf_local_numel = gbuf_world_numel // data_parallel_world_size - - # Contiguous local shards (received from DP rank 0). - local_shards = {key: torch.empty((gbuf_local_numel,), - dtype=torch.float32, - device="cpu") - for key in ("param", "exp_avg", "exp_avg_sq")} - - # Scatter local shards from DP rank 0. - for key, recv_tensor in local_shards.items(): - - # Scatter tensor list. - if data_parallel_rank == 0: - world_tensor_for_all_buckets = loaded_state[model_idx][dtype][key] - if not isinstance(world_tensor_for_all_buckets, list): - world_tensor_for_all_buckets = [world_tensor_for_all_buckets] - assert bucket_idx < len(world_tensor_for_all_buckets), \ - (f"Trying to load state for bucket_id {bucket_idx} (out of " - f"{len(gbuf_range_map_for_all_buckets)} buckets) from checkpoint; " - f"checkpoint only has {len(world_tensor_for_all_buckets)} bucket(s)") - world_tensor = world_tensor_for_all_buckets[bucket_idx] - gbuf_start_idxs = \ - list(range(0, gbuf_world_numel, gbuf_local_numel)) - send_tensors = [world_tensor[i:(i+gbuf_local_numel)] - for i in gbuf_start_idxs] - else: - send_tensors = None - - # Scatter. - torch.distributed.scatter( - recv_tensor, - send_tensors, - data_parallel_global_ranks[0], - data_parallel_group_gloo, - ) - - # Copy local contiguous shards to param/optim shards. - for model_param, param_range_map in \ - gbuf_range_map["param_map"].items(): - - # Main param & optimizer states. - group_index, group_order = \ - self.model_param_group_index_map[model_param] - main_param = self.optimizer.param_groups \ - [group_index]["params"][group_order] - optim_state = self.optimizer.state[main_param] - - tensors = { - "param" : main_param, - **optim_state, - } - - # Copy states into contiguous shard. - gbuf_local_start = param_range_map["gbuf_local"].start - gbuf_local_end = param_range_map["gbuf_local"].end - for key in local_shards: - tensors[key].data.copy_( - local_shards[key][gbuf_local_start:gbuf_local_end]) - - - def zero_grad(self, set_to_none=True): - """ - Zero grads. - - We only need to zero the model related parameters, i.e., - model_float16_groups & model_fp32_groups. We additionally zero - the remaining groups as a memory optimization to reduce - fragmentation; in the case of set_to_none==True, the space - used by this field can be safely deallocated at this point. - """ - for groups in ( - self.model_float16_groups, - self.model_fp32_groups, - self.shard_float16_groups, # grad empty/unused here? - self.shard_fp32_groups, # throws grad-access warning - self.shard_fp32_from_float16_groups): - for group in groups: - _zero_grad_group_helper(group, set_to_none) - - # If overlapping param all-gather with forward compute, launch all-gather - # for first accessed bucket here before forward compute is initiated. - # The all-gather for the next bucket will be launched in the forward - # pre-hook when this all-gather finishes (to ensure that the communication - # kernels don't head-of-line block the compute kernels since we run with - # CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism). - if self.overlap_param_gather: - self._dispatch_gather_model_params(all_gather_handle_index=0) - - - def get_model_param_buffer_dp_views(self): - """ - Get shard views of each of the param buffers. - - In this nested list, the top level is grouped by the virtual model - index and the buffer's data type. The sub-level is a list of - shards of that buffer, where each shard in the list represents - a contiguous view of the buffer, that is owned by a data-parallel - rank. The shard boundary does not respect parameter boundaries, and - so the elements of some parameters are split across data parallel - ranks. - - Additionally, return references to the entire buffers, for use - in _all_gather_base. - """ - - # Buffer views. - # Add in reverse order in each model chunk since buckets start from the end of the model but we want - # all-gathers to run first for the start of the model (same order as forward pass). - # We keep the view_items in model chunk order since we want to still first run all_gather and - # all_gather_handle.wait() for the first model chunk. - # In all cases, we want all_gather and all_gather_handle.wait() to be called in the same order, - # and all_gather_handle.wait() needs to be called just before the corresponding forward pass. - view_items = [] - for model_index, buffers in enumerate(self.param_buffers): - view_items_per_model_chunk = [] - for dtype, buf_for_all_buckets in buffers.items(): - for bucket_index, buf in enumerate(buf_for_all_buckets): - buf_views = shard_buffer(buf) - view_items_per_model_chunk.insert(0, (model_index, dtype, bucket_index, buf, buf_views)) - view_items.extend(view_items_per_model_chunk) - - return view_items - - - def _dispatch_gather_model_params(self, all_gather_handle_index): - """ - All-gather updated model params. - - The DDP's param buffer is used for the all-gather, and thus no - tensors are dynamically allocated. After the all-gather, the params - can be copied from the param buffer to the param. - """ - if self.update_successful: - data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True) - data_parallel_group = mpu.get_data_parallel_group(with_context_parallel=True) - - # All-gather updated main params. - # All param_buf views are guaranteed to have the same number of elements - # across all data-parallel ranks, due to padding (done in grad_buffer.py), - # and extended to the param_bufs. Thus, all sub-views will have consistent - # start / end indexes across data-parallel ranks. - (model_index, dtype, bucket_index, pbuf, pbuf_views) = self.pbuf_view_items[all_gather_handle_index] - assert all_gather_handle_index == len(self.all_gather_handles) - all_gather_handle = torch.distributed._all_gather_base( - pbuf, - pbuf_views[data_parallel_rank], - group = data_parallel_group, - async_op = self.overlap_param_gather - ) - self.all_gather_handles.append(all_gather_handle) - assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == \ - (model_index, dtype, bucket_index) - self.param_buffer_copied.append(False) - - if not self.overlap_param_gather: - self._copy_params_from_param_buffer(all_gather_handle_index) - - - - def _make_forward_pre_hook(self): - """ - Create a forward pre-hook to wait on all-gather handles when necessary (i.e., - when a module uses a parameter in a bucket with a still incomplete all-gather) - and then copy the results from the param_buffer into model_params. - """ - - def hook(module, *unused): - assert self.overlap_param_gather, "Should use pre-hook only when overlap_param_gather is True" - - # Make sure all parameters in this module have been all-gathered as necessary. - for param in module.parameters(recurse=False): - # Skip parameters that don't require grad. - if not param.requires_grad: - continue - - assert param in self.param_to_all_gather_handle_index_map - all_gather_handle_index = self.param_to_all_gather_handle_index_map[param] - self._finish_param_sync_helper(all_gather_handle_index) - - return hook - - - def finish_param_sync(self, model_index, *unused): - """ - Finishes all necessary param syncs for the model_index'th model chunk. - """ - all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index] - for all_gather_handle_index in all_gather_handle_indices: - self._finish_param_sync_helper(all_gather_handle_index) - - - def _finish_param_sync_helper(self, all_gather_handle_index): - """ - Waits on all_gather_handle if necessary, then copies params from param_buffer - into model_params if necessary. - """ - - # First check if there is an outstanding all-gather handle for this param. - # If so, wait on the handle to ensure the communication is finished. - if all_gather_handle_index >= len(self.all_gather_handles): - return - - all_gather_handle = self.all_gather_handles[all_gather_handle_index] - if all_gather_handle is not None: - all_gather_handle.wait() - self.all_gather_handles[all_gather_handle_index] = None - - # Launch the all-gather for the next bucket now. - # We can't pre-launch all-gathers for all buckets at once since we don't - # want to head-of-line block the compute kernels with communication kernels - # (since we run with CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence - # parallelism). - next_all_gather_handle_index = all_gather_handle_index + 1 - if next_all_gather_handle_index < self.num_all_gather_handles: - self._dispatch_gather_model_params(next_all_gather_handle_index) - - # Also check if we have already copied from the param buffer for this - # handle; if not, complete the copy and mark as such. - if not self.param_buffer_copied[all_gather_handle_index]: - self._copy_params_from_param_buffer(all_gather_handle_index) - self.param_buffer_copied[all_gather_handle_index] = True - - - def _copy_params_from_param_buffer(self, all_gather_handle_index): - """ - Copy params from param_buffer to model_params. - """ - (model_index, dtype, bucket_index) = self.all_gather_handle_index_to_bucket_index_map[ - all_gather_handle_index] - model = self.models[model_index] - if self.update_successful: - # Copy from param buffer to each param. - param_map = model.grad_buffer_param_index_map[dtype] - for param, (buf_start, buf_end, bucket_index_in_param_map) in param_map.items(): - if bucket_index == bucket_index_in_param_map: - bucket_offset = model.grad_buffers[dtype].buckets[bucket_index].offset - param_buf = self.param_buffers[model_index][dtype][bucket_index] - # buf_start and buf_end store position of this parameter in the full grad_buffer, - # so need to adjust these indices (by subtracting out bucket_offset) since we - # have independent param_bufs for each bucket. - param_buf_shard = param_buf[buf_start-bucket_offset:buf_end-bucket_offset] - assert param.data.nelement() == param_buf_shard.nelement() - param.view(-1).detach().copy_(param_buf_shard) - - # Zero out the grad buffer in preparation for next set of fwd / bwd passes after copy - # completes (since param_buffer and grad_buffer are shared for each bucket). - param_buf = self.param_buffers[model_index][dtype][bucket_index] - grad_buf = model.grad_buffers[dtype].buckets[bucket_index].data - assert param_buf.data_ptr() == grad_buf.data_ptr() - grad_buf.zero_() - - - def _collect_main_grad_data_for_unscaling(self): - """ - Note: this should be equivalent to the float-16 optimizer's method, - but writtent differently, so the two should be combined. - """ - return [ - param.grad.data - for group in self.optimizer.param_groups - for param in group["params"] - ] - - - def _get_model_and_main_params_data_float16(self): - """ - Get aligned list of model and main params. - """ - model_data = [] - main_data = [] - for model_group, main_group in zip(self.shard_float16_groups, - self.shard_fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - model_data.append(model_param.data) - main_data.append(main_param.data) - return model_data, main_data - - - def _copy_model_grads_to_main_grads(self): - """ - Copy model grads to main grads. - - Since this step follows a reduce-scatter through the DDP's grad - buffer, this method is responsible for copying the updated grads - from the grad buffer to the main shard's grad field. - """ - - # Utility method for copying group grads. - def copy_group_grads(model_groups, shard_main_groups): - for model_group, shard_main_group in zip(model_groups, - shard_main_groups): - for model_param, shard_main_param in zip(model_group, - shard_main_group): - - param_range_map = self.get_model_param_range_map(model_param) - param_range = param_range_map["param"] - assert param_range.size == shard_main_param.nelement() - - model_grad = model_param.main_grad - shard_model_grad = model_grad.view(-1) \ - [param_range.start:param_range.end] - shard_main_param.grad = shard_model_grad.float() - - # Copy model groups to shard groups. - copy_group_grads(self.model_float16_groups, - self.shard_fp32_from_float16_groups) - copy_group_grads(self.model_fp32_groups, - self.shard_fp32_groups) - - - def _copy_main_params_to_model_params(self): - """ - Copy main params to model params. - - Since this step is followed by an all-gather through the DDP's grad - buffer, this method is responsible for copying the updated params - from the main shards into the correct position in the grad buffer. - """ - - # Utility method for copying group params. - def copy_group_params(shard_main_groups, model_groups): - for shard_main_group, model_group in zip(shard_main_groups, - model_groups): - for shard_main_param, model_param in zip(shard_main_group, - model_group): - - param_range_map = self.get_model_param_range_map(model_param) - world_range = param_range_map["gbuf_world_in_bucket"] - - assert world_range.size == shard_main_param.nelement() - - model_id, dtype, bucket_id = self.model_param_gbuf_map[model_param] - model_param_buffer = self.param_buffers[model_id][dtype][bucket_id] - - shard_model_param = model_param_buffer.view(-1) \ - [world_range.start:world_range.end] - - shard_model_param.data.copy_(shard_main_param) - - # Copy shard groups to model groups. - copy_group_params(self.shard_fp32_from_float16_groups, - self.model_float16_groups) - copy_group_params(self.shard_fp32_groups, - self.model_fp32_groups) - - - def _copy_model_params_to_main_params(self): - """ - Copy model params to main params. - - During finetuning, this method is used to reload the main params from - the model params. This copy does not make use of the grad buffer as - an intermediary. - """ - - # Utility method for copying group params. - def copy_group_params(model_groups, shard_main_groups): - for model_group, shard_main_group in zip(model_groups, - shard_main_groups): - for model_param, shard_main_param in zip(model_group, - shard_main_group): - - param_range_map = self.get_model_param_range_map(model_param) - param_range = param_range_map["param"] - assert param_range.size == shard_main_param.nelement() - - shard_model_param = model_param.view(-1) \ - [param_range.start:param_range.end] - shard_main_param.data.copy_(shard_model_param) - - # Copy model groups to shard groups. - copy_group_params(self.model_float16_groups, - self.shard_fp32_from_float16_groups) - copy_group_params(self.model_fp32_groups, - self.shard_fp32_groups) - - - @torch.no_grad() - def step(self, args, timers): - self.update_successful, grad_norm, num_zeros_in_grad = super().step(args, timers) - - # Reset metadata needed to track results of all-gathers. - self.all_gather_handles = [] - self.param_buffer_copied = [] - - # If not overlapping all-gather for parameters, launch synchronous all-gather - # communication calls here. - if not self.overlap_param_gather: - timers('params-all-gather', log_level=1).start(barrier=args.barrier_with_L1_time) - for all_gather_handle_index in range(self.num_all_gather_handles): - self._dispatch_gather_model_params(all_gather_handle_index) - timers('params-all-gather').stop() - - return self.update_successful, grad_norm, num_zeros_in_grad diff --git a/megatron/optimizer/grad_scaler.py b/megatron/optimizer/grad_scaler.py deleted file mode 100644 index 66f7c907a41816de17e0d6a7bd5cb626722b33a6..0000000000000000000000000000000000000000 --- a/megatron/optimizer/grad_scaler.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron grad scaler.""" - -from abc import ABC -from abc import abstractmethod - -import torch - - -class MegatronGradScaler(ABC): - - def __init__(self, initial_scale): - """Initialize scale value with the input initial scale.""" - assert initial_scale > 0.0 - self._scale = torch.cuda.FloatTensor([initial_scale]) - - @property - def scale(self): - return self._scale - - @property - def inv_scale(self): - return self._scale.double().reciprocal().float() - - @abstractmethod - def update(self, found_inf): - pass - - @abstractmethod - def state_dict(self): - pass - - @abstractmethod - def load_state_dict(self, state_dict): - pass - - - -class ConstantGradScaler(MegatronGradScaler): - - def update(self, found_inf): - pass - - def state_dict(self): - return dict() - - def load_state_dict(self, state_dict): - pass - - - -class DynamicGradScaler(MegatronGradScaler): - - def __init__(self, initial_scale, min_scale, - growth_factor, backoff_factor, - growth_interval, hysteresis): - """"Grad scaler with dynamic scale that gets adjusted - during training.""" - super(DynamicGradScaler, self).__init__(initial_scale) - - # Lower bound on the scale. - assert min_scale > 0.0 - assert min_scale <= initial_scale - self.min_scale = torch.cuda.FloatTensor([min_scale]) - # Growth and backoff factors for the scale. - assert growth_factor > 1.0 - self.growth_factor = torch.cuda.FloatTensor([growth_factor]) - assert backoff_factor < 1.0 - assert backoff_factor > 0.0 - self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) - # Interval over which if we don't see any inf/nan, - # we will scale the grad scale by the growth factor. - assert growth_interval > 0 - self.growth_interval = growth_interval - # Number of inf/nans we should see before scaling down - # the grad scale by the backoff factor. - assert hysteresis > 0 - self.hysteresis = hysteresis - - # Trackers. - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - - - def update(self, found_inf): - - # If we have an inf/nan, growth tracker is set to 0 - # and hysterisis tracker is reduced by 1. - if found_inf: - self._growth_tracker = 0 - self._hysteresis_tracker -= 1 - # Now if we are out of hysteresis count, scale down the loss. - if self._hysteresis_tracker <= 0: - self._scale = torch.max(self._scale * self.backoff_factor, - self.min_scale) - else: - # If there is no nan/inf, increment the growth tracker. - self._growth_tracker += 1 - # If we have had enough consequitive intervals with no nan/inf: - if self._growth_tracker == self.growth_interval: - # Reset the tracker and hysteresis trackers, - self._growth_tracker = 0 - self._hysteresis_tracker = self.hysteresis - # and scale up the loss scale. - self._scale = self._scale * self.growth_factor - - - def state_dict(self): - state_dict = {} - state_dict['scale'] = self._scale - state_dict['growth_tracker'] = self._growth_tracker - state_dict['hysteresis_tracker'] = self._hysteresis_tracker - return state_dict - - - def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_tracker = state_dict['growth_tracker'] - self._hysteresis_tracker = state_dict['hysteresis_tracker'] diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py deleted file mode 100644 index 23749959b90086c7904962c221199e0cd8cb699d..0000000000000000000000000000000000000000 --- a/megatron/optimizer/optimizer.py +++ /dev/null @@ -1,644 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Megatron optimizer.""" - -from abc import ABC -from abc import abstractmethod -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C -import torch - -from megatron import get_timers -from megatron import print_rank_0 -from megatron.core import mpu, tensor_parallel -from megatron.model import Float16Module -from megatron.model.module import param_is_not_shared - -from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 - - -def _zero_grad_group_helper(group, set_to_none): - """Zero out the gradient for a group of parameters. - Note: copied from torch.optim.optimizer.""" - for param in group: - if param.grad is not None: - if set_to_none: - param.grad = None - else: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - -def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): - """Use multi-tensor-applier to copy values from one list to another. - We don't have a blfoat16 implementation so for now if the overflow_buf - is not provided, we default back to simple loop copy to be compatible - with bfloat16.""" - if overflow_buf: - overflow_buf.fill_(0) - # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(amp_C.multi_tensor_scale, - overflow_buf, - [this, that], - 1.0) - else: - for this_, that_ in zip(this, that): - that_.copy_(this_) - - - -class MegatronOptimizer(ABC): - - - def __init__(self, optimizer, clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, - models): - - """Input optimizer is the base optimizer for example Adam.""" - self.optimizer = optimizer - assert self.optimizer, 'no optimizer is provided.' - # Set gradient clipping and logging params. - self.clip_grad = clip_grad - self.log_num_zeros_in_grad = log_num_zeros_in_grad - self.check_for_nan_in_grad = check_for_nan_in_grad - self.params_have_main_grad = params_have_main_grad - - # 'models' are retained for access to the contiguous grad buffers. - # (see distributed optimizer) - self.models = models - - - def get_parameters(self): - params = [] - for param_group in self.optimizer.param_groups: - for param in param_group['params']: - params.append(param) - return params - - - def get_main_grads_for_grad_norm(self): - - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism - params = self.get_parameters() - grads_for_norm = [] - for param in params: - grad = param.grad - grad_not_none = grad is not None - is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param) - if grad_not_none and is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) - - return grads_for_norm - - - def get_model_parallel_group(self): - """Default returned here, but the distributed optimizer overrides this.""" - return mpu.get_model_parallel_group() - - - def clip_grad_norm(self, clip_grad, check_for_nan_in_grad): - params = self.get_parameters() - grads_for_norm = self.get_main_grads_for_grad_norm() - return clip_grad_norm_fp32( - params, grads_for_norm, clip_grad, - check_for_nan_in_grad, - model_parallel_group=self.get_model_parallel_group()) - - - def count_zeros(self): - params = self.get_parameters() - return count_zeros_fp32(params, - model_parallel_group=self.get_model_parallel_group()) - - - @abstractmethod - def zero_grad(self, set_to_none=True): - pass - - - @abstractmethod - def get_loss_scale(self): - """The output should be a cuda tensor of size 1.""" - pass - - - def scale_loss(self, loss): - """Simple scaling.""" - return self.get_loss_scale() * loss - - - @abstractmethod - def reload_model_params(self): - """Refreshes any internal state from the current model parameters. - Call whenever the parameters are changed outside of the optimizer. - For example, when we load a model from a checkpoint without loading - the optimizer, the model parameters are updated but for fp16 optimizer - with main parameters, the main parameters need to also be updated.""" - pass - - - @abstractmethod - def state_dict(self): - pass - - - @abstractmethod - def load_state_dict(self, state_dict): - pass - - - # Promote state so it can be retrieved or set via - # "optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - - # Promote param_groups so it can be retrieved or set via - # "optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - - @abstractmethod - def step(self, args, timers): - pass - - - -class MixedPrecisionOptimizer(MegatronOptimizer): - """Base class for both the float-16 and the distributed optimizer. - - Arguments: - optimizer: base optimizer such as Adam or SGD - clip_grad: clip gradeints with this global L2 norm. Note - that clipping is ignored if clip_grad == 0 - log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. - params_have_main_grad: flag indicating if parameters have - a `main_grad` field. If this is set, we are assuming - that the model parameters are store in the `main_grad` - field instead of the typical `grad` field. This happens - for the DDP cases where there is a continuous buffer - holding the gradients. For example for bfloat16, we want - to do gradient accumulation and all-reduces in float32 - and as a result we store those gradients in the main_grad. - Note that main grad is not necessarily in float32. - fp16: if true, the model is running in fp16. - bf16: if true, the model is running in bfloat16. - params_dtype: used by distributed optimizer. - grad_scaler: used for scaling gradients. Note that this can be - None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constnat gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - models: list of models (i.e., the virtual pipelining models). This - is used by the distributed optimizer for mapping parameters. - """ - - def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, - fp16, bf16, params_dtype, grad_scaler, models): - - super().__init__( - optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, - models) - - self.fp16 = fp16 - self.bf16 = bf16 - self.params_dtype = params_dtype - self.grad_scaler = grad_scaler - - # None grad scaler is only supported for bf16. - if self.grad_scaler is None: - assert not self.fp16, 'fp16 expects a grad scaler.' - - # Tensor used to determine if a nan/if has happend. - # Any non-zero value indicates inf/nan. - # Note that we keep this for the cases that grad scaler is none. - # We still record nan/inf if we have a bfloat16 with a grad scaler. - if self.grad_scaler: - self.found_inf = torch.cuda.FloatTensor([0.0]) - - # Dummy tensor needed for apex multi-apply tensor. - # For bfloat, we don't have multi-tensor apply and for now - # we set it to none so the multi-tensor apply gets ignored. - if bf16: - self._dummy_overflow_buf = None - else: - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - - # In case grad scaler is not passed, define the unity scale. - if self.grad_scaler is None: - self._scale_one = torch.cuda.FloatTensor([1.0]) - - - def get_loss_scale(self): - if self.grad_scaler is None: - return self._scale_one - return self.grad_scaler.scale - - - def reload_model_params(self): - self._copy_model_params_to_main_params() - - - def _unscale_main_grads_and_check_for_nan(self): - - # Collect main grads. - main_grads = self._collect_main_grad_data_for_unscaling() - - # Reset found inf. - self.found_inf.fill_(0.0) - - # Unscale and set found inf/nan - torch._amp_foreach_non_finite_check_and_unscale_( - main_grads, self.found_inf, self.grad_scaler.inv_scale) - - # Update across all model parallel instances. - torch.distributed.all_reduce(self.found_inf, - op=torch.distributed.ReduceOp.MAX, - group=self.get_model_parallel_group()) - - # Check for nan. - found_inf_flag = (self.found_inf.item() > 0) - - return found_inf_flag - - - @torch.no_grad() - def step(self, args, timers): - - # Copy gradients from model params to main params. - timers('optimizer-copy-to-main-grad', log_level=1).start( - barrier=args.barrier_with_L1_time) - self._copy_model_grads_to_main_grads() - timers('optimizer-copy-to-main-grad').stop() - - # Do unscale, check for inf, and update grad scaler only for - # the case that grad scaler is provided. - if self.grad_scaler: - - # Unscale and check for inf/nan. - timers('optimizer-unscale-and-check-inf', log_level=1).start( - barrier=args.barrier_with_L1_time) - found_inf_flag = self._unscale_main_grads_and_check_for_nan() - timers('optimizer-unscale-and-check-inf').stop() - - # We are done with scaling gradients - # so we can update the loss scale. - self.grad_scaler.update(found_inf_flag) - - # If we found inf/nan, skip the update. - if found_inf_flag: - return False, None, None - - # Clip the main gradients. - timers('optimizer-clip-main-grad', log_level=1).start( - barrier=args.barrier_with_L1_time) - grad_norm = None - if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad, - self.check_for_nan_in_grad) - timers('optimizer-clip-main-grad').stop() - - # Count the zeros in the grads. - timers('optimizer-count-zeros', log_level=1).start( - barrier=args.barrier_with_L1_time) - num_zeros_in_grad = self.count_zeros() if \ - self.log_num_zeros_in_grad else None - timers('optimizer-count-zeros').stop() - - # Step the optimizer. - timers('optimizer-inner-step', log_level=1).start( - barrier=args.barrier_with_L1_time) - self.optimizer.step() - timers('optimizer-inner-step').stop() - - # Update params from main params. - timers('optimizer-copy-main-to-model-params', log_level=1).start( - barrier=args.barrier_with_L1_time) - self._copy_main_params_to_model_params() - timers('optimizer-copy-main-to-model-params').stop() - - # Successful update. - return True, grad_norm, num_zeros_in_grad - - -class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): - """Float16 optimizer for fp16 and bf16 data types. - - Arguments: - optimizer: base optimizer such as Adam or SGD - clip_grad: clip gradeints with this global L2 norm. Note - that clipping is ignored if clip_grad == 0 - log_num_zeros_in_grad: return number of zeros in the gradients. - check_for_nan_in_grad: check if gradients have a NaN. - params_have_main_grad: flag indicating if parameters have - a `main_grad` field. If this is set, we are assuming - that the model parameters are store in the `main_grad` - field instead of the typical `grad` field. This happens - for the DDP cases where there is a continuous buffer - holding the gradients. For example for bfloat16, we want - to do gradient accumulation and all-reduces in float32 - and as a result we store those gradients in the main_grad. - Note that main grad is not necessarily in float32. - fp16: if true, the model is running in fp16. - bf16: if true, the model is running in bfloat16. - grad_scaler: used for scaling gradients. Note that this can be - None. This case happens when `bf16 = True` and we don't - use any loss scale. Note that for `bf16 = True`, we can have - a constnat gradient scaler. Also for `bf16 = False`, we - always require a grad scaler. - models: list of models (i.e., the virtual pipelining models). This - is used by the distributed optimizer for mapping parameters. - """ - - def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, fp16, bf16, - params_dtype, grad_scaler, models): - - super().__init__( - optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, - fp16, bf16, params_dtype, grad_scaler, models) - - # ====================== - # main parameter stuff - # ====================== - - # Three groups of parameters: - # float16_groups: original float16 parameters - # fp32_from_float16_groups: fp32 copy of float16 parameters - # fp32_from_fp32_groups: original fp32 parameters - self.float16_groups = [] - self.fp32_from_float16_groups = [] - self.fp32_from_fp32_groups = [] - - # For all the groups in the original optimizer: - for param_group in self.optimizer.param_groups: - float16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_float16_params_this_group = [] - # For all the parameters in this group: - for i, param in enumerate(param_group['params']): - if param.requires_grad: - - # float16 params: - if param.type() in ['torch.cuda.HalfTensor', - 'torch.cuda.BFloat16Tensor']: - float16_params_this_group.append(param) - # Create a copy - main_param = param.detach().clone().float() - # Copy tensor model parallel attributes. - tensor_parallel.copy_tensor_model_parallel_attributes(main_param, - param) - if hasattr(param, 'shared'): - main_param.shared = param.shared - # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = main_param - - fp32_from_float16_params_this_group.append(main_param) - # Reset existing state dict key to the new main param. - if param in self.optimizer.state: - self.optimizer.state[main_param] \ - = self.optimizer.state.pop(param) - # fp32 params. - elif param.type() == 'torch.cuda.FloatTensor': - fp32_params_this_group.append(param) - param_group['params'][i] = param - - else: - raise TypeError('Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(param.type())) - - self.float16_groups.append(float16_params_this_group) - self.fp32_from_float16_groups.append( - fp32_from_float16_params_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) - - - def zero_grad(self, set_to_none=True): - """We only need to zero the model related parameters, i.e., - float16_groups & fp32_from_fp32_groups. We additionally zero - fp32_from_float16_groups as a memory optimization to reduce - fragmentation; in the case of set_to_none==True, the space - used by this field can be safely deallocated at this point.""" - for group in self.float16_groups: - _zero_grad_group_helper(group, set_to_none) - for group in self.fp32_from_float16_groups: - _zero_grad_group_helper(group, set_to_none) - for group in self.fp32_from_fp32_groups: - _zero_grad_group_helper(group, set_to_none) - - - def _collect_main_grad_data_for_unscaling(self): - - main_grads = [] - - # fp32 params from float16 ones. - for main_group in self.fp32_from_float16_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - - # Append fp32 parameters. - for main_group in self.fp32_from_fp32_groups: - for main_param in main_group: - if main_param.grad is not None: - main_grads.append(main_param.grad.data) - - return main_grads - - - def _get_model_and_main_params_data_float16(self): - model_data = [] - main_data = [] - for model_group, main_group in zip(self.float16_groups, - self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - model_data.append(model_param.data) - main_data.append(main_param.data) - return model_data, main_data - - - def _copy_model_grads_to_main_grads(self): - # This only needs to be done for the float16 group. - for model_group, main_group in zip(self.float16_groups, - self.fp32_from_float16_groups): - for model_param, main_param in zip(model_group, main_group): - if self.params_have_main_grad and hasattr(model_param, 'main_grad'): - main_param.grad = model_param.main_grad.float() - else: - if model_param.grad is not None: - main_param.grad = model_param.grad.float() - - # Safe to deallocate model's grad/main_grad after copying. - # (If using contiguous buffers, main_grad's memory should - # persist and therefore should not be deallocated.) - model_param.grad = None - - # For fp32 grads, we need to reset the grads to main grad. - if self.params_have_main_grad: - for model_group in self.fp32_from_fp32_groups: - for model_param in model_group: - model_param.grad = model_param.main_grad - - - def _copy_main_params_to_model_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that(this=main_data, that=model_data, - overflow_buf=self._dummy_overflow_buf) - - - def _copy_model_params_to_main_params(self): - # Only needed for the float16 params. - model_data, main_data = self._get_model_and_main_params_data_float16() - _multi_tensor_copy_this_to_that(this=model_data, that=main_data, - overflow_buf=self._dummy_overflow_buf) - - - def state_dict(self): - state_dict = {} - state_dict['optimizer'] = self.optimizer.state_dict() - if self.grad_scaler: - state_dict['grad_scaler'] = self.grad_scaler.state_dict() - state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups - return state_dict - - - def load_state_dict(self, state_dict): - # Optimizer. - optimizer_key = 'optimizer' - if optimizer_key not in state_dict: - optimizer_key = 'optimizer_state_dict' - print_rank_0('***WARNING*** loading optimizer from ' - 'an old checkpoint ...') - self.optimizer.load_state_dict(state_dict[optimizer_key]) - - # Grad scaler. - if 'grad_scaler' not in state_dict: - if self.fp16: - print_rank_0('***WARNING*** found an old checkpoint, will not ' - 'load grad scaler ...') - else: - if self.grad_scaler: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) - else: - print_rank_0('***WARNING*** fould the grad scaler in the ' - 'checkpoint but it is None in the class. ' - 'Skipping loading grad scaler ...') - - # Copy data for the main params. - fp32_from_float16_params_key = 'fp32_from_fp16_params' - if fp32_from_float16_params_key not in state_dict: - fp32_from_float16_params_key = 'fp32_from_fp16' - for current_group, saved_group in zip( - self.fp32_from_float16_groups, - state_dict[fp32_from_float16_params_key]): - for current_param, saved_param in zip(current_group, saved_group): - current_param.data.copy_(saved_param.data) - - -class FP32Optimizer(MegatronOptimizer): - - def __init__(self, optimizer, clip_grad, - log_num_zeros_in_grad, - check_for_nan_in_grad, - params_have_main_grad, - models): - - super(FP32Optimizer, self).__init__( - optimizer, clip_grad, log_num_zeros_in_grad, - check_for_nan_in_grad, params_have_main_grad, - models) - - self._scale = torch.cuda.FloatTensor([1.0]) - - - def zero_grad(self, set_to_none=True): - """Copied from torch.optim.optimizer""" - for group in self.optimizer.param_groups: - _zero_grad_group_helper(group['params'], set_to_none) - - - def get_loss_scale(self): - """FP32 optimizer does not do any scaling.""" - return self._scale - - - @torch.no_grad() - def step(self, args, timers): - """Clip gradients (if needed) and step the base optimizer. - Always return successful since there is no overflow.""" - - # Copy main_grads to grads. - timers('optimizer-copy-to-main-grad', log_level=1).start( - barrier=args.barrier_with_L1_time) - if self.params_have_main_grad: - for param_group in self.optimizer.param_groups: - for param in param_group['params']: - param.grad = param.main_grad - - timers('optimizer-copy-to-main-grad').stop() - - # Clip gradients. - timers('optimizer-clip-main-grad', log_level=1).start( - barrier=args.barrier_with_L1_time) - grad_norm = None - if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad, - self.check_for_nan_in_grad) - timers('optimizer-clip-main-grad').stop() - - # count the zeros in the grads - timers('optimizer-count-zeros', log_level=1).start( - barrier=args.barrier_with_L1_time) - num_zeros_in_grad = self.count_zeros() if \ - self.log_num_zeros_in_grad else None - timers('optimizer-count-zeros').stop() - - # Update parameters. - timers('optimizer-inner-step', log_level=1).start( - barrier=args.barrier_with_L1_time) - self.optimizer.step() - timers('optimizer-inner-step').stop() - - # No overflow for FP32 optimizer. - return True, grad_norm, num_zeros_in_grad - - - def reload_model_params(self): - pass - - - def state_dict(self): - return self.optimizer.state_dict() - - - def load_state_dict(self, state_dict): - self.optimizer.load_state_dict(state_dict) diff --git a/megatron/optimizer/utils.py b/megatron/optimizer/utils.py deleted file mode 100644 index f4b7cbd634b4f67d3a1fbf5a950cf1cc95664b26..0000000000000000000000000000000000000000 --- a/megatron/optimizer/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Utility functions for Megatron optimizer.""" - - -from megatron.core import mpu - - -def shard_buffer(buffer): - """ - Shard buffer into dp_size chunks of equal size. - """ - data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - assert buffer.numel() % data_parallel_world_size == 0 - shard_size = buffer.numel() // data_parallel_world_size - sharded_buffer = [buffer[(r*shard_size):((r+1)*shard_size)] - for r in range(data_parallel_world_size)] - return sharded_buffer - diff --git a/megatron/optimizer_param_scheduler.py b/megatron/optimizer_param_scheduler.py deleted file mode 100644 index 0cf5fb1d8fcdb9bf212828464d1785e1f0284e61..0000000000000000000000000000000000000000 --- a/megatron/optimizer_param_scheduler.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Learning rate decay and weight decay incr functions.""" - -import math - -from megatron import print_rank_0 - -class OptimizerParamScheduler(object): - """Anneals learning rate and weight decay""" - - def __init__(self, optimizer, init_lr, max_lr, min_lr, - lr_warmup_steps, lr_decay_steps, lr_decay_style, - start_wd, end_wd, wd_incr_steps, wd_incr_style, - use_checkpoint_opt_param_scheduler=True, - override_opt_param_scheduler=False): - - # Class values. - self.optimizer = optimizer - - self.init_lr = init_lr - self.max_lr = float(max_lr) - self.min_lr = min_lr - assert self.min_lr >= 0.0 - assert self.max_lr >= self.min_lr - assert self.init_lr <= self.max_lr - - self.lr_warmup_steps = lr_warmup_steps - self.num_steps = 0 - self.lr_decay_steps = lr_decay_steps - assert self.lr_decay_steps > 0 - assert self.lr_warmup_steps < self.lr_decay_steps - - self.lr_decay_style = lr_decay_style - - self.start_wd = start_wd - self.end_wd = end_wd - assert self.start_wd >= 0.0 - assert self.end_wd >= self.start_wd - self.wd_incr_steps = wd_incr_steps - self.wd_incr_style = wd_incr_style - - self.override_opt_param_scheduler = override_opt_param_scheduler - self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler - if self.override_opt_param_scheduler: - assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\ - 'use-checkpoint are set.' - - # Set the learning rate - self.step(0) - print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style)) - - - def get_wd(self): - """ Weight decay incr functions""" - if self.num_steps > self.wd_incr_steps: - return self.end_wd - - if self.wd_incr_style == 'constant': - assert self.start_wd == self.end_wd - return self.end_wd - - incr_ratio = float(self.num_steps) / float(self.wd_incr_steps) - assert incr_ratio >= 0.0 - assert incr_ratio <= 1.0 - delta_wd = self.end_wd - self.start_wd - - if self.wd_incr_style == 'linear': - coeff = incr_ratio - elif self.wd_incr_style == 'cosine': - coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0) - else: - raise Exception('{} weight decay increment style is not supported.'.format( - self.wd_incr_style)) - - return self.start_wd + coeff * delta_wd - - - def get_lr(self): - """Learning rate decay functions from: - https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" - - # Use linear warmup for the initial part. - if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps: - return ( - self.init_lr - + ( - (self.max_lr - self.init_lr) - * float(self.num_steps) - / float(self.lr_warmup_steps) - ) - ) - - # If the learning rate is constant, just return the initial value. - if self.lr_decay_style == 'constant': - return self.max_lr - - # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`. - if self.num_steps > self.lr_decay_steps: - return self.min_lr - - # If we are done with the warmup period, use the decay style. - if self.lr_decay_style == 'inverse-square-root': - warmup_steps = max(self.lr_warmup_steps, 1) - num_steps = max(self.num_steps, 1) - lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5) - return max(self.min_lr, lr) - - num_steps_ = self.num_steps - self.lr_warmup_steps - decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps - decay_ratio = float(num_steps_) / float(decay_steps_) - assert decay_ratio >= 0.0 - assert decay_ratio <= 1.0 - delta_lr = self.max_lr - self.min_lr - - if self.lr_decay_style == 'linear': - coeff = (1.0 - decay_ratio) - elif self.lr_decay_style == 'cosine': - coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) - else: - raise Exception('{} decay style is not supported.'.format( - self.lr_decay_style)) - - return self.min_lr + coeff * delta_lr - - - def step(self, increment): - """Set lr for all parameters groups.""" - self.num_steps += increment - new_lr = self.get_lr() - new_wd = self.get_wd() - for group in self.optimizer.param_groups: - group['lr'] = new_lr * group.get('lr_mult', 1.0) - group['weight_decay'] = new_wd * group.get('wd_mult', 1.0) - - - def state_dict(self): - state_dict = { - 'max_lr': self.max_lr, - 'lr_warmup_steps': self.lr_warmup_steps, - 'num_steps': self.num_steps, - 'lr_decay_style': self.lr_decay_style, - 'lr_decay_steps': self.lr_decay_steps, - 'min_lr': self.min_lr, - 'start_wd': self.start_wd, - 'end_wd': self.end_wd, - 'wd_incr_style': self.wd_incr_style, - 'wd_incr_steps': self.wd_incr_steps - } - return state_dict - - - def _check_and_set(self, cls_value, sd_value, name): - """Auxiliary function for checking the values in the checkpoint and - setting them.""" - if self.override_opt_param_scheduler: - print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) - return cls_value - - if not self.use_checkpoint_opt_param_scheduler: - assert cls_value == sd_value, \ - f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \ - f'value {sd_value} for {name} do not match' - print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, - name)) - return sd_value - - - def load_state_dict(self, sd): - - if 'start_lr' in sd: - max_lr_ = sd['start_lr'] - else: - max_lr_ = sd['max_lr'] - self.max_lr = self._check_and_set(self.max_lr, max_lr_, - 'learning rate') - - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') - - if 'warmup_iter' in sd: - lr_warmup_steps_ = sd['warmup_iter'] - elif 'warmup_steps' in sd: - lr_warmup_steps_ = sd['warmup_steps'] - else: - lr_warmup_steps_ = sd['lr_warmup_steps'] - self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps, - lr_warmup_steps_, - 'warmup iterations') - - if 'end_iter' in sd: - lr_decay_steps_ = sd['end_iter'] - elif 'decay_steps' in sd: - lr_decay_steps_ = sd['decay_steps'] - else: - lr_decay_steps_ = sd['lr_decay_steps'] - self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_, - 'total number of iterations') - - if 'decay_style' in sd: - lr_decay_style_ = sd['decay_style'] - else: - lr_decay_style_ = sd['lr_decay_style'] - self.lr_decay_style = self._check_and_set(self.lr_decay_style, - lr_decay_style_, - 'learning rate decay style') - - if 'num_iters' in sd: - num_steps = sd['num_iters'] - else: - num_steps = sd['num_steps'] - self.step(increment=num_steps) - - - if 'start_wd' in sd: - self.start_wd = self._check_and_set(self.start_wd, - sd['start_wd'], - "start weight decay") - self.end_wd = self._check_and_set(self.end_wd, - sd['end_wd'], - "end weight decay") - self.wd_incr_steps = self._check_and_set(self.wd_incr_steps, - sd['wd_incr_steps'], - "total number of weight decay iterations") - self.wd_incr_style = self._check_and_set(self.wd_incr_style, - sd['wd_incr_style'], - "weight decay incr style") - - - - - - - - diff --git a/megatron/static/index.html b/megatron/static/index.html deleted file mode 100644 index 806287955bcc02e2d4148855af5ddb36ba94ae72..0000000000000000000000000000000000000000 --- a/megatron/static/index.html +++ /dev/null @@ -1,124 +0,0 @@ - - - - - - - -Megatron - - - -
-

Prompt Megatron

- - - - - -
-0 -/ 1000 -
- -
- - - - - diff --git a/megatron/text_generation/__init__.py b/megatron/text_generation/__init__.py deleted file mode 100644 index 77da7be30ae4d02bd7ab1e4bae86afc8923d4e23..0000000000000000000000000000000000000000 --- a/megatron/text_generation/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - - -from .api import ( - generate, - generate_and_post_process, - beam_search_and_post_process) diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py deleted file mode 100644 index 4557ff3c12e219daf3a8092390c5a74a7e56b4e0..0000000000000000000000000000000000000000 --- a/megatron/text_generation/api.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Inference API.""" - - -import torch - -from megatron.core import mpu -from .communication import broadcast_float_list -from .generation import ( - generate_tokens_probs_and_return_on_first_stage, - score_and_return_on_first_stage, - beam_search_and_return_on_first_stage) -from .tokenization import ( - tokenize_prompts, - detokenize_generations) - -def generate_and_post_process(model, - prompts=None, - tokens_to_generate=0, - return_output_log_probs=False, - top_k_sampling=0, - top_p_sampling=0.0, - top_p_decay=0.0, - top_p_bound=0.0, - temperature=1.0, - add_BOS=False, - use_eod_token_for_early_termination=True, - stop_on_double_eol=False, - stop_on_eol=False, - prevent_newline_after_colon=False, - random_seed=-1, - return_logits=False): - """Run inference and post-process outputs, i.e., detokenize, - move to cpu and convert to list.""" - - # Main inference. - tokens, lengths, output_log_probs, logits = generate( - model, - prompts=prompts, - tokens_to_generate=tokens_to_generate, - return_output_log_probs=return_output_log_probs, - top_k_sampling=top_k_sampling, - top_p_sampling=top_p_sampling, - top_p_decay=top_p_decay, - top_p_bound=top_p_bound, - temperature=temperature, - add_BOS=add_BOS, - use_eod_token_for_early_termination=use_eod_token_for_early_termination, - stop_on_double_eol=stop_on_double_eol, - stop_on_eol=stop_on_eol, - prevent_newline_after_colon=prevent_newline_after_colon, - random_seed=random_seed) - - # Only post-process on first stage. - if mpu.is_pipeline_first_stage(): - tokens, prompts_plus_generations, prompts_plus_generations_segments = \ - detokenize_generations(tokens, lengths, True) - - if return_output_log_probs: - output_log_probs = output_log_probs.cpu().numpy().tolist() - for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)): - output_log_probs[i] = prob[:len(seg)-1] - - if return_logits: - assert(tokens_to_generate == 0) - assert(mpu.get_pipeline_model_parallel_world_size() == 1) - return prompts_plus_generations, prompts_plus_generations_segments, \ - output_log_probs, tokens, logits - else: - return prompts_plus_generations, prompts_plus_generations_segments, \ - output_log_probs, tokens - - return None - -def generate(model, - prompts=None, - tokens_to_generate=0, - return_output_log_probs=False, - top_k_sampling=0, - top_p_sampling=0.0, - top_p_decay=0.0, - top_p_bound=0.0, - temperature=1.0, - add_BOS=False, - use_eod_token_for_early_termination=True, - stop_on_double_eol=False, - stop_on_eol=False, - prevent_newline_after_colon=False, - random_seed=-1): - """Given prompts and input parameters, run inference and return: - tokens: prompts plus the generated tokens. - lengths: length of the prompt + generations. Note that we can - discard tokens in the tokens tensor that are after the - corresponding length. - output_log_probs: log probs of the tokens. - """ - - # Make sure input params are avaialble to all ranks. - values = [tokens_to_generate, - return_output_log_probs, - top_k_sampling, top_p_sampling, top_p_decay, top_p_bound, - temperature, add_BOS, use_eod_token_for_early_termination, - stop_on_double_eol, - stop_on_eol, - prevent_newline_after_colon, - random_seed] - values_float_tensor = broadcast_float_list(len(values), float_list=values) - tokens_to_generate = int(values_float_tensor[0].item()) - return_output_log_probs = bool(values_float_tensor[1].item()) - top_k_sampling = int(values_float_tensor[2].item()) - top_p_sampling = values_float_tensor[3].item() - top_p_decay = values_float_tensor[4].item() - top_p_bound = values_float_tensor[5].item() - temperature = values_float_tensor[6].item() - add_BOS = bool(values_float_tensor[7].item()) - use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) - stop_on_double_eol = bool(values_float_tensor[9].item()) - stop_on_eol = bool(values_float_tensor[10].item()) - prevent_newline_after_colon = bool(values_float_tensor[11].item()) - random_seed = int(values_float_tensor[12].item()) - - if random_seed != -1: - torch.random.manual_seed(random_seed) - - # Tokenize prompts and get the batch. - # Note that these tensors are broadcaseted to all ranks. - if torch.distributed.get_rank() == 0: - assert prompts is not None - - context_tokens_tensor, context_length_tensor = tokenize_prompts( - prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) - - if tokens_to_generate == 0: - return score_and_return_on_first_stage( - model, context_tokens_tensor, context_length_tensor) - - # Main inference function. - # Note that the outputs are available on the first stage. - return generate_tokens_probs_and_return_on_first_stage( - model, context_tokens_tensor, context_length_tensor, - return_output_log_probs=return_output_log_probs, - top_k=top_k_sampling, - top_p=top_p_sampling, - top_p_decay=top_p_decay, - top_p_bound=top_p_bound, - temperature=temperature, - use_eod_token_for_early_termination=use_eod_token_for_early_termination, - stop_on_double_eol=stop_on_double_eol, - stop_on_eol=stop_on_eol, - prevent_newline_after_colon=prevent_newline_after_colon) - -def beam_search_and_post_process(model, - prompts=None, - tokens_to_generate=0, - beam_size=0, - add_BOS=False, - stop_token=50256, - num_return_gen=1, - length_penalty=1, - prevent_newline_after_colon=False): - """Run beam search and post-process outputs, i.e., detokenize, - move to cpu and convert to list.""" - - # Main inference. - tokens, scores = beam_search(model, - prompts=prompts, - tokens_to_generate=tokens_to_generate, - beam_size=beam_size, - add_BOS=add_BOS, - stop_token=stop_token, - num_return_gen=num_return_gen, - length_penalty=length_penalty, - prevent_newline_after_colon=prevent_newline_after_colon) - # Only post-process on first stage. - if mpu.is_pipeline_first_stage(): - lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device()) - tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True) - scores = scores.cpu().numpy().tolist() - return prompts_plus_generations, prompts_plus_generations_segments, scores - - return None - -def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False): - # Make sure input params are avaialble to all ranks. - values = [tokens_to_generate, - beam_size, - add_BOS, - stop_token, - num_return_gen, - length_penalty, - prevent_newline_after_colon] - values_float_tensor = broadcast_float_list(len(values), float_list=values) - tokens_to_generate = int(values_float_tensor[0].item()) - beam_size = int(values_float_tensor[1].item()) - add_BOS = bool(values_float_tensor[2].item()) - stop_token = int(values_float_tensor[3].item()) - num_return_gen = int(values_float_tensor[4].item()) - length_penalty = values_float_tensor[5].item() - prevent_newline_after_colon = values_float_tensor[6].item() - - context_tokens_tensor, context_length_tensor = tokenize_prompts( - prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) - - return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, - beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty, - prevent_newline_after_colon=prevent_newline_after_colon) diff --git a/megatron/text_generation/beam_utils.py b/megatron/text_generation/beam_utils.py deleted file mode 100644 index 911a64143a86c8521abd9741df22de528a82f692..0000000000000000000000000000000000000000 --- a/megatron/text_generation/beam_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -## from huggingface beam search -class BeamHypotheses(object): - def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): - """ - Initialize n-best list of hypotheses. - """ - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp, sum_logprobs, length): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / length ** self.length_penalty - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp)) - if len(self) > self.num_beams: - sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) - del self.beams[sorted_scores[0][1]] - self.worst_score = sorted_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs, cur_len): - """ - If there are enough hypotheses and that none of the hypotheses being generated - can become better than the worst one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - elif self.early_stopping: - return True - else: - cur_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= cur_score - return ret - diff --git a/megatron/text_generation/communication.py b/megatron/text_generation/communication.py deleted file mode 100644 index dee32077f34904f7585fab0f5180a5d014f7829f..0000000000000000000000000000000000000000 --- a/megatron/text_generation/communication.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Communications utilities.""" - - -import torch - -from megatron.core import mpu - - - -# TODO: use functions from megatron/p2p -def recv_from_prev_pipeline_rank_(recv_buffer=None): - """Receive from previous pipeline stage and update the - input buffer inplace.""" - if not mpu.is_pipeline_first_stage(): - assert recv_buffer is not None - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_buffer, - mpu.get_pipeline_model_parallel_prev_rank()) - reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - - -# TODO: use functions from megatron/p2p -def send_to_next_pipeline_rank(tensor=None): - """Send output to the next pipeline stage.""" - if not mpu.is_pipeline_last_stage(): - assert tensor is not None - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor, - mpu.get_pipeline_model_parallel_next_rank()) - reqs = torch.distributed.batch_isend_irecv([send_next_op]) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - - -def _is_cuda(tensor): - """Check if a tensor is not none and is cuda.""" - assert tensor is not None - assert tensor.is_cuda - - - -def _is_cuda_contiguous(tensor): - """Check if a tensor is not none, is cuda, and is contiguous.""" - _is_cuda(tensor) - assert tensor.is_contiguous() - - - -def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): - """Broadcast a tensor from last pipeline stage to all ranks.""" - - is_last_stage = mpu.is_pipeline_last_stage() - # If first stage and last state are the same, then there is no - # pipeline parallelism and no need to communicate. - if mpu.is_pipeline_first_stage() and is_last_stage: - return tensor - - if is_last_stage: - _is_cuda_contiguous(tensor) - else: - tensor = torch.empty(size, - dtype=dtype, - device=torch.cuda.current_device()) - # Get the group and corresponding source rank. - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_pipeline_model_parallel_group() - torch.distributed.broadcast(tensor, src, group) - - return tensor - - - -def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): - """Broadcast tensor values from last stage into the first stage.""" - - is_last_stage = mpu.is_pipeline_last_stage() - is_first_stage = mpu.is_pipeline_first_stage() - # If first stage and last state are the same, then there is no - # pipeline parallelism and no need to communicate. - if is_first_stage and is_last_stage: - return tensor - # Only first and last stage pipeline stages need to be involved. - if is_last_stage or is_first_stage: - if is_last_stage: - _is_cuda_contiguous(tensor) - else: - tensor = torch.empty(size, - dtype=dtype, - device=torch.cuda.current_device()) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - # Broadcast from last stage into the first stage. - torch.distributed.broadcast(tensor, src, group) - else: - tensor = None - - return tensor - - - -def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): - """Copy tensor values from last stage into the first stage. - Note that the input tensor is updated in place.""" - - is_last_stage = mpu.is_pipeline_last_stage() - is_first_stage = mpu.is_pipeline_first_stage() - # If first stage and last state are the same, then there is no - # pipeline parallelism and no need to communicate. - if is_first_stage and is_last_stage: - return - # Only first and last stage pipeline stages need to be involved. - if is_last_stage or is_first_stage: - _is_cuda(tensor) - is_contiguous = tensor.is_contiguous() - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - if is_contiguous: - tensor_ = tensor - else: - if is_last_stage: - tensor_ = tensor.contiguous() - else: - tensor_ = torch.empty(size, - dtype=dtype, - device=torch.cuda.current_device()) - # Broadcast from last stage into the first stage. - torch.distributed.broadcast(tensor_, src, group) - # Update the first stage tensor - if is_first_stage and not is_contiguous: - tensor[...] = tensor_ - - - -def broadcast_tensor(size, dtype, tensor=None, rank=0): - """ Given size and type of a tensor on all ranks and the tensor value - only on a specific rank, broadcast from that rank to all other ranks. - """ - - if torch.distributed.get_rank() == rank: - _is_cuda_contiguous(tensor) - else: - tensor = torch.empty(size, - dtype=dtype, - device=torch.cuda.current_device()) - - torch.distributed.broadcast(tensor, rank) - - return tensor - - - -def broadcast_list(size, dtype, list_values=None, rank=0): - """Broadcast a list of values with a given type.""" - - tensor = None - if torch.distributed.get_rank() == rank: - tensor = torch.tensor(list_values, dtype=dtype, - device=torch.cuda.current_device()) - - return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) - - - -def broadcast_int_list(size, int_list=None, rank=0): - """Broadcast a list of interger values.""" - - return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) - - - -def broadcast_float_list(size, float_list=None, rank=0): - """Broadcast a list of float values.""" - - return broadcast_list(size, torch.float32, list_values=float_list, - rank=rank) diff --git a/megatron/text_generation/forward_step.py b/megatron/text_generation/forward_step.py deleted file mode 100644 index 6a88709a521b29d40cf9f7760abe0f95756900f6..0000000000000000000000000000000000000000 --- a/megatron/text_generation/forward_step.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Forward step utilities.""" - -from collections.abc import Iterable - -import torch - -from megatron import get_args -from megatron.core import mpu, InferenceParams -from .communication import ( - send_to_next_pipeline_rank, - recv_from_prev_pipeline_rank_) - - -class ForwardStep: - """Forward step function with all the communications. - We use a class here to hide the inference parameters - from the outside caller.""" - - def __init__(self, model, max_batch_size, max_sequence_length): - """Set values so we don't need to do it multiple times.""" - # Make sure model is in eval mode. - assert not isinstance(model, Iterable), \ - 'interleaving schedule is not supported for inference' - model.eval() - self.model = model - # Initialize inference parameters. - self.inference_params = InferenceParams(max_batch_size, - max_sequence_length) - # Pipelining arguments. - args = get_args() - self.pipeline_size_larger_than_one = ( - args.pipeline_model_parallel_size > 1) - # Threshold of pipelining. - self.pipelining_batch_x_seqlen = \ - args.inference_batch_times_seqlen_threshold - - - def __call__(self, tokens, position_ids, attention_mask): - """Invocation of the forward methods. Note that self.inference_params - is being modified by the forward step.""" - # Pipelining case. - if self.pipeline_size_larger_than_one: - current_batch_x_seqlen = tokens.size(0) * tokens.size(1) - if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: - micro_batch_size = \ - max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) - return _with_pipelining_forward_step(self.model, - tokens, - position_ids, - attention_mask, - self.inference_params, - micro_batch_size) - - return _no_pipelining_forward_step(self.model, - tokens, - position_ids, - attention_mask, - self.inference_params) - - - -def _get_recv_buffer_dtype(args): - """Receive happens between the layers.""" - if args.fp32_residual_connection: - return torch.float - return args.params_dtype - - - -def _allocate_recv_buffer(batch_size, sequence_length): - """Receive happens between the layers with size [s, b, h].""" - if mpu.is_pipeline_first_stage(): - return None - args = get_args() - recv_size = (sequence_length, batch_size, args.hidden_size) - return torch.empty(recv_size, - dtype=_get_recv_buffer_dtype(args), - device=torch.cuda.current_device()) - - - -def _forward_step_helper(model, tokens, position_ids, attention_mask, - inference_params, recv_buffer=None): - """Single forward step. Update the allocate memory flag so - only the first time the memory is allocated.""" - batch_size = tokens.size(0) - sequence_length = tokens.size(1) - if recv_buffer is None: - recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) - - # Receive from previous stage. - recv_from_prev_pipeline_rank_(recv_buffer) - - # Forward pass through the model. - model.set_input_tensor(recv_buffer) - output_tensor = model(tokens, position_ids, attention_mask, - inference_params=inference_params) - - # Send output to the next stage. - send_to_next_pipeline_rank(output_tensor) - - return output_tensor - - - -def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, - inference_params, recv_buffer=None): - """If recv_buffer is none, we will allocate one on the fly.""" - # Run a simple forward pass. - output_tensor = _forward_step_helper(model, tokens, position_ids, - attention_mask, inference_params, - recv_buffer=recv_buffer) - # Update the sequence length offset. - inference_params.sequence_len_offset += tokens.size(1) - - logits = None - if mpu.is_pipeline_last_stage(): - logits = output_tensor - - return logits - - - -def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, - inference_params, micro_batch_size): - """No interleaving is supported.""" - sequence_length = tokens.size(1) - batch_size = tokens.size(0) - - # Divide the batch dimension into micro batches. - num_micro_batches, last_chunk = divmod(batch_size, - micro_batch_size) - if last_chunk > 0: - num_micro_batches += 1 - - # Preallocate memory for output logits. - logits = None - if mpu.is_pipeline_last_stage(): - args = get_args() - logits = torch.empty( - (batch_size, sequence_length, args.padded_vocab_size), - dtype=torch.float32, device=torch.cuda.current_device()) - - # Preallocate recv buffer. - recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) - - for micro_batch_index in range(num_micro_batches): - # Slice among the batch dimenion. - start = micro_batch_index * micro_batch_size - end = min(start + micro_batch_size, batch_size) - this_micro_batch_size = end - start - tokens2use = tokens[start:end, ...] - position_ids2use = position_ids[start:end, ...] - - # Run a simple forward pass. - if this_micro_batch_size != micro_batch_size: - recv_buffer = None - output = _forward_step_helper(model, tokens2use, position_ids2use, - attention_mask, inference_params, - recv_buffer=recv_buffer) - - # Adjust the batch size offset to account for the micro-batch. - inference_params.batch_size_offset += this_micro_batch_size - - # Copy logits. - if mpu.is_pipeline_last_stage(): - logits[start:end, ...] = output - - # Once we are done with all the micro-batches, we can - # adjust the sequence length offset. - inference_params.sequence_len_offset += sequence_length - # and reset the batch size offset - inference_params.batch_size_offset = 0 - - return logits diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py deleted file mode 100644 index 11dd9f436b5992739f4d0543b2108317edfd46b1..0000000000000000000000000000000000000000 --- a/megatron/text_generation/generation.py +++ /dev/null @@ -1,428 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Generation utilities.""" - -import torch -import torch.nn.functional as F - -from megatron import get_args, get_tokenizer -from megatron.core import mpu -from megatron.utils import get_ltor_masks_and_position_ids -from .communication import ( - copy_from_last_to_first_pipeline_stage, - broadcast_from_last_pipeline_stage, - broadcast_from_last_to_first_pipeline_stage) -from .forward_step import ForwardStep -from .sampling import sample -from .beam_utils import BeamHypotheses - -def score_and_return_on_first_stage(model, tokens, lengths): - """Function for just scoring. - Arguments: - model: no interleaving is supported. - tokens: prompt tokens extended to be of size [b, max_prompt_length] - lengths: original prompt length, size: [b] - Note: Outside of model, other parameters only need to be available on - rank 0. - Outputs: - output_log_probs: log probability of the selected tokens. size: [b, s] - """ - - args = get_args() - - batch_size = tokens.size(0) - max_prompt_length = lengths.max().item() - assert max_prompt_length == tokens.size(1) - - if max_prompt_length > args.max_position_embeddings: - raise ValueError("Length of prompt + tokens_to_generate longer than allowed") - - if max_prompt_length * batch_size > args.max_tokens_to_oom: - raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) - - # forward step. - forward_step = ForwardStep(model, batch_size, max_prompt_length) - - # =================== - # Pre-allocate memory - # =================== - - # Log probability of the sequence (prompt + generated tokens). - output_log_probs = None - output_log_probs_size = (batch_size, max_prompt_length - 1) - - if mpu.is_pipeline_last_stage(): - output_log_probs = torch.empty(output_log_probs_size, - dtype=torch.float32, - device=torch.cuda.current_device()) - - # ============= - # Run infernece - # ============= - with torch.no_grad(): - attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) - - # logits will be meanigful only in the last pipeline stage. - logits = forward_step(tokens, position_ids, attention_mask) - - if mpu.is_pipeline_last_stage(): - # Always the last stage should have an output. - assert logits is not None - log_probs = F.log_softmax(logits, dim=2) - - # Pick the tokens that we need to get the log - # probabilities for. Note that next input token is - # the token which we selected in the current logits, - # so shift by 1. - indices = torch.unsqueeze(tokens[:, 1:], 2) - output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2) - - # ====================================== - # Broadcast to the first pipeline stage. - # ====================================== - output_log_probs = broadcast_from_last_to_first_pipeline_stage( - output_log_probs_size, torch.float32, output_log_probs) - - return tokens, lengths, output_log_probs, logits - -def generate_tokens_probs_and_return_on_first_stage( - model, tokens, lengths, - return_output_log_probs=False, - top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0, - temperature=1.0, - use_eod_token_for_early_termination=True, - stop_on_double_eol=False, - stop_on_eol=False, - prevent_newline_after_colon=True - ): - """Main token generation function. - Arguments: - model: no interleaving is supported. - tokens: prompt tokens extended to be of size [b, max-sequence-length] - lengths: original prompt length, size: [b] - return_output_log_probs: flag to calculate the log probability of - the generated tokens. Note that the log probability is the one - from the original logit. - top_k, top_p: top-k and top-p sampling parameters. - Note that top-k = 1 is gready. Also, these paramters are - exclusive meaning that: - if top-k > 0 then we expect top-p=0. - if top-p > 0 then we check for top-k=0. - temperature: sampling temperature. - use_eod_token_for_early_termination: if True, do early termination if - all the sequences have reached this token. - prevent_newline_after_colon: if True, it will disable generating new line \n after : - Note: Outside of model, other parameters only need to be available on - rank 0. - Outputs: Note that is size is adjusted to a lower value than - max-sequence-length if generation is terminated early. - tokens: prompt and generated tokens. size: [b, :] - generated_sequence_lengths: total length (including prompt) of - the generated sequence. size: [b] - output_log_probs: log probability of the selected tokens. size: [b, s] - """ - - args = get_args() - tokenizer = get_tokenizer() - - batch_size = tokens.size(0) - min_prompt_length = lengths.min().item() - max_sequence_length = tokens.size(1) - - if max_sequence_length > args.max_position_embeddings: - raise ValueError("Length of prompt + tokens_to_generate longer than allowed") - - if max_sequence_length * batch_size > args.max_tokens_to_oom: - raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) - - # forward step. - forward_step = ForwardStep(model, batch_size, max_sequence_length) - - # Added termination_id to support the case that we want to terminate the - # generation once that id is generated. - if hasattr(args, 'eos_id'): - termination_id = args.eos_id - else: - termination_id = tokenizer.eod - - # =================== - # Pre-allocate memory - # =================== - - # Log probability of the sequence (prompt + generated tokens). - output_log_probs = None - output_log_probs_size = (batch_size, max_sequence_length - 1) - # Lengths of generated seuquence including including prompts. - generated_sequence_lengths = None - if mpu.is_pipeline_last_stage(): - if return_output_log_probs: - output_log_probs = torch.empty(output_log_probs_size, - dtype=torch.float32, - device=torch.cuda.current_device()) - generated_sequence_lengths = torch.ones( - batch_size, dtype=torch.int64, - device=torch.cuda.current_device()) * max_sequence_length - - # Whether we have reached a termination id. - is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, - device=torch.cuda.current_device()) - - # ============= - # Run infernece - # ============= - - with torch.no_grad(): - attention_mask, position_ids = _build_attention_mask_and_position_ids( - tokens) - prev_context_length = 0 - for context_length in range(min_prompt_length, max_sequence_length): - - # Pick the slice that we need to pass through the network. - tokens2use = tokens[:, prev_context_length:context_length] - positions2use = position_ids[:, prev_context_length:context_length] - attention_mask2use = attention_mask[ - ..., prev_context_length:context_length, :context_length] - - # logits will be meanigful only in the last pipeline stage. - logits = forward_step(tokens2use, positions2use, attention_mask2use) - - if mpu.is_pipeline_last_stage(): - if prevent_newline_after_colon: - logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" - # Always the last stage should have an output. - assert logits is not None - - # Sample. - last_token_logits = logits[:, -1, :] - new_sample = sample(last_token_logits, - top_k=top_k, - top_p=top_p, - temperature=temperature, - vocab_size=tokenizer.vocab_size) - if top_p > 0.0 and top_p_decay > 0.0: - top_p = top_p * top_p_decay - if top_p_bound > 0.0: - top_p = max(top_p, top_p_bound) - - # If a prompt length is smaller or equal th current context - # length, it means we have started generating tokens - started = lengths <= context_length - # Update the tokens. - tokens[started, context_length] = new_sample[started] - - # Calculate the log probabilities. - if return_output_log_probs: - log_probs = F.log_softmax(logits, dim=2) - if return_output_log_probs: - # Pick the tokens that we need to get the log - # probabilities for. Note that next input token is - # the token which we selected in the current logits, - # so shift by 1. - indices = torch.unsqueeze( - tokens[ - :, - (prev_context_length + 1):(context_length + 1)], - 2) - output_log_probs[:, - prev_context_length:context_length] = \ - torch.gather(log_probs, 2, indices).squeeze(2) - - # Update the tokens on the first stage so the next input to - # the network is correct. - copy_from_last_to_first_pipeline_stage(batch_size, torch.int64, - tokens[:, context_length]) - - # Update the context length for the next token generation. - prev_context_length = context_length - - # Check if all the sequences have hit the termination_id. - done = None - if mpu.is_pipeline_last_stage(): - # TODO(rprenger) These stopping methods are tokenizer dependent - # instead tokenization should be in the inference loop so stop sequences can be used - if stop_on_double_eol: - hit_double_eol = (new_sample == 628).byte() & started.byte() - hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() - done_token = hit_double_eol | hit_two_eols - elif stop_on_eol: - hit_double_eol = (new_sample == 628).byte() & started.byte() - hit_eol = (new_sample == 198).byte() & started.byte() - done_token = hit_double_eol | hit_eol - else: - done_token = (new_sample == termination_id).byte() & \ - started.byte() - - just_finished = (done_token & ~is_generation_done).bool() - generated_sequence_lengths[just_finished.view(-1)] = \ - context_length + 1 - is_generation_done = is_generation_done | done_token - done = torch.all(is_generation_done) - done = broadcast_from_last_pipeline_stage(1, torch.uint8, - tensor=done) - if use_eod_token_for_early_termination and done: - break - - # =================================================== - # Update the length of based on max generated length. - # =================================================== - - tokens = tokens[:, :(context_length + 1)] - if mpu.is_pipeline_last_stage(): - if return_output_log_probs: - output_log_probs = output_log_probs[:, :context_length] - - # ====================================== - # Broadcast to the first pipeline stage. - # ====================================== - - generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage( - batch_size, torch.int64, generated_sequence_lengths) - if return_output_log_probs: - output_log_probs_size = (batch_size, context_length) - output_log_probs = broadcast_from_last_to_first_pipeline_stage( - output_log_probs_size, torch.float32, output_log_probs) - - return tokens, generated_sequence_lengths, output_log_probs, None - -def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True): - args = get_args() - tokenizer = get_tokenizer() - - batch_size = tokens.size(0) - assert(batch_size == 1) - prompt_length = lengths.item() - final_sequence_length = tokens.size(1) - final_sequence_length = min(final_sequence_length, args.max_position_embeddings) - - # If the context is too big, this happens - if prompt_length >= final_sequence_length: - raise ValueError("context length + tokens_to_generate too large") - - # forward step. - forward_step = ForwardStep(model, beam_size, final_sequence_length) - - beam_hyp = BeamHypotheses(beam_size, length_penalty) - best_batches = None - done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device()) - scores = torch.zeros(beam_size, - dtype=torch.float32, - device=torch.cuda.current_device()).unsqueeze(1) - scores_size_tensor, tokens_size_tensor = None, None - # ============= - # Run infernece - # ============= - with torch.no_grad(): - tokens = tokens.repeat(beam_size, 1) - attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens) - prev_context_length = 0 - for context_length in range(prompt_length, final_sequence_length): - - # Pick the slice that we need to pass through the network. - tokens2use = tokens[:, prev_context_length:context_length] - positions2use = position_ids[:, prev_context_length:context_length] - attention_mask2use = attention_mask[ - ..., prev_context_length:context_length, :context_length] - - # logits will be meanigful only in the last pipeline stage. - logits = forward_step(tokens2use, positions2use, attention_mask2use) - - if mpu.is_pipeline_last_stage(): - if prevent_newline_after_colon: - logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" - vocab_size = logits.size(2) - log_probs = F.log_softmax(logits, dim=2) - new_scores = log_probs[:, -1, :] + scores - - if context_length == prompt_length: # if this is the first one - sorted_scores, indices = torch.sort(new_scores[0,:], descending=True) - else: - sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True) - - best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long() - best_words = indices[:2 * beam_size] % vocab_size - best_scores = sorted_scores[: 2 * beam_size] - - next_beams = [] - for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( - zip(best_words, best_scores, best_beam_ids) - ): - if token_id.item() == stop_token: - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size - if is_beam_token_worse_than_top_num_beams: - continue - beam_hyp.add( - tokens[beam_id].clone(), - beam_score, - context_length + 1 - prompt_length - ) - else: - # add next predicted token since it is not eos_token - next_beams.append((token_id, beam_score, beam_id)) - - if len(next_beams) == beam_size: - break - - if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): - done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device()) - - best_batches = tokens.new([item[2] for item in next_beams]) - tokens = tokens[best_batches,:] - tokens[:, context_length] = tokens.new([item[0] for item in next_beams]) - scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) - - # torch.distributed.barrier() - done = broadcast_from_last_pipeline_stage(1, torch.uint8, done) - if done: - break - - # Update the tokens on the first stage so the next input to - # the network is correct. - copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64, - tokens) - - # set inference key values to make it consistent with best beam index - best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches) - forward_step.inference_params.swap_key_value_dict(best_batches) - - # Update the context length for the next token generation. - prev_context_length = context_length - - if mpu.is_pipeline_last_stage(): - # if cannot find stop token, add open beams to hyps - if not done: - for beam_id in range(beam_size): - beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length) - - # rank based on scores - sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) - num_return_gen = min(num_return_gen, len(sorted_hyps)) - scores = [sorted_hyps[i][0] for i in range(num_return_gen)] - tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] - scores = torch.stack(scores, dim=0) - tokens = torch.stack(tokens, dim=0) - scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device()) - tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device()) - - scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor) - tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor) - - scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores) - tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens) - - return tokens, scores - - -def _build_attention_mask_and_position_ids(tokens): - """Build the attention mask and postition ids for the input tokens.""" - - # Since we are not interested in loss-mask and reset attention/position - # is also False, eod_token is not used so it is safe to set it to None. - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - data=tokens, - eod_token=None, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False) - - return attention_mask, position_ids diff --git a/megatron/text_generation/sampling.py b/megatron/text_generation/sampling.py deleted file mode 100644 index 370773a36c087d01e75731e38724cfb35d4acd74..0000000000000000000000000000000000000000 --- a/megatron/text_generation/sampling.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Sampling utilities. -Part of this code is inspired by: - - https://github.com/ari-holtzman/degen/blob/master/gen.py - - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html -""" - - -import torch - - - -def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf.""" - - filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(filter_, float('-Inf')) - - - -def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf.""" - - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - - # Filteration based on the cumulative sum. - filter_ = cumulative_probs > top_p - # This shift by 1 is weird and I cannot justify it. This existed - # in the original implementation: - # https://github.com/ari-holtzman/degen/blob/master/gen.py - # and I guess it is needed so keeping it for now. - filter_[:, 1:] = filter_[:, :-1].clone() - # Make sure we at least have one token to select from. - filter_[..., 0] = 0 - - # Fill in the filtered part - filter_ = filter_.scatter(1, sorted_indices, filter_) - logits.masked_fill_(filter_, float('-Inf')) - - - -def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): - """ Sample and generate a token. - Note: logits has the dimension [b, v] where b is the batch size - and v is the vocabulary size. - If vocab_size is provided, we will make sure the sample that is - generated is in [0, vocab-size). This will avoid out of vocabulary - generations due to padding. - """ - - # Check logits for consistency. - assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' - assert logits.type() == 'torch.cuda.FloatTensor', \ - 'input logits should be floats.' - - - # Greedy is just simple argmax. - if top_k == 1: - assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' - samples = torch.argmax(logits, dim=-1) - - # Top-k or top-p sampling. - else: - # Clone so we do not modify the inputs, - logits = logits.clone() - # Apply temperature in place. - if temperature != 1.0: - logits.div_(temperature) - - if top_k > 1: - assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' - assert top_k <= logits.size(1), 'top-k is larger than logit size.' - if vocab_size: - assert top_k < vocab_size, 'top-k is larger than vocab size.' - modify_logits_for_top_k_filtering(logits, top_k) - - elif top_p > 0.0: - assert top_p <= 1.0, 'top-p should be in (0, 1].' - modify_logits_for_top_p_filtering(logits, top_p) - - # After filtering, we need to recalculate the distribution. - probs = logits.softmax(dim=-1) - samples = torch.multinomial(probs, num_samples=1).view(-1) - - # If vocab size is provided, make sure the samples are in - # in the range [0, vocab-size). - if vocab_size: - samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) - - return samples diff --git a/megatron/text_generation/tokenization.py b/megatron/text_generation/tokenization.py deleted file mode 100644 index 4d4eb82e8049e5bdcd5aa84272f21941bef79e5a..0000000000000000000000000000000000000000 --- a/megatron/text_generation/tokenization.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Tokenization utilities.""" - - -import torch - - -from megatron import get_tokenizer, get_args -from .communication import broadcast_int_list, broadcast_tensor - - -def detokenize_generations(tokens_gpu_tensor, - lengths_gpu_tensor, - return_segments): - """Detokenize the generated tokens.""" - - tokenizer = get_tokenizer() - args = get_args() - prompts_plus_generations = [] - if return_segments: - prompts_plus_generations_segments = [] - - tokens = tokens_gpu_tensor.cpu().numpy().tolist() - lengths = lengths_gpu_tensor.cpu().numpy().tolist() - for sequence_tokens, length in zip(tokens, lengths): - sequence_tokens = sequence_tokens[:length] - prompts_plus_generations.append( - tokenizer.detokenize(sequence_tokens)) - if return_segments: - words = [] - for token in sequence_tokens: - if args.tokenizer_type in ['SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', - 'Llama2Tokenizer']: - word = tokenizer.decoder[token] - elif args.tokenizer_type == 'NullTokenizer': - word = str(token) - else: - word = tokenizer.tokenizer.decoder[token] - word = bytearray( - [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( - 'utf-8', errors='replace') - words.append(word) - prompts_plus_generations_segments.append(words) - - if return_segments: - return tokens, prompts_plus_generations, \ - prompts_plus_generations_segments - - return tokens, prompts_plus_generations - - -def tokenize_prompts(prompts=None, tokens_to_generate=None, - add_BOS=None, rank=0): - """Tokenize prompts and make them avaiable on all ranks.""" - - # On all ranks set to None so we can pass them to functions - sizes_list = None - prompts_tokens_cuda_long_tensor = None - prompts_length_cuda_long_tensor = None - - # On the specified rank, build the above. - if torch.distributed.get_rank() == rank: - assert prompts is not None - assert tokens_to_generate is not None - # Tensor of tokens padded and their unpadded length. - prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \ - _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS) - # We need the sizes of these tensors for the boradcast - sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size - prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght - - # First, broadcast the sizes. - sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank) - - # Now that we have the sizes, we can boradcast the tokens - # and length tensors. - sizes = sizes_tensor.tolist() - prompts_tokens_cuda_long_tensor = broadcast_tensor( - sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank) - prompts_length_cuda_long_tensor = broadcast_tensor( - sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor, - rank=rank) - - return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor - - -def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS): - """Given a set of prompts and number of tokens to generate: - - tokenize prompts - - set the sequence length to be the max of length of prompts - plus the number of tokens we would like to generate - - pad all the sequences to this length so we can convert them - into a 2D tensor. - """ - - # Tokenize all the prompts. - tokenizer = get_tokenizer() - if add_BOS: - prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) - for prompt in prompts] - else: - prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] - - # Now we have a list of list of tokens which each list has a different - # size. We want to extend this list to: - # - incorporate the tokens that need to be generated - # - make all the sequences equal length. - # Get the prompts length. - prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] - # Get the max prompts length. - max_prompt_len = max(prompts_length) - # Number of tokens in the each sample of the batch. - samples_length = max_prompt_len + tokens_to_generate - # Now update the list of list to be of the same size: samples_length. - for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): - padding_size = samples_length - prompt_length - prompt_tokens.extend([tokenizer.eod] * padding_size) - - # Now we are in a structured format, we can convert to tensors. - prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) - prompts_length_tensor = torch.cuda.LongTensor(prompts_length) - - return prompts_tokens_tensor, prompts_length_tensor diff --git a/megatron/text_generation_server.py b/megatron/text_generation_server.py deleted file mode 100644 index 8bd6c26fcc5bc0441bc50680d96e550030f0e964..0000000000000000000000000000000000000000 --- a/megatron/text_generation_server.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import datetime -import torch -import json -import threading -from flask import Flask, request, jsonify, current_app -from flask_restful import Resource, Api -from megatron import get_args -from megatron.text_generation import generate_and_post_process -from megatron.text_generation import beam_search_and_post_process - - -GENERATE_NUM = 0 -BEAM_NUM = 1 -lock = threading.Lock() - -class MegatronGenerate(Resource): - def __init__(self, model): - self.model = model - - @staticmethod - def send_do_generate(): - choice = torch.cuda.LongTensor([GENERATE_NUM]) - torch.distributed.broadcast(choice, 0) - - @staticmethod - def send_do_beam_search(): - choice = torch.cuda.LongTensor([BEAM_NUM]) - torch.distributed.broadcast(choice, 0) - - def put(self): - args = get_args() - - if not "prompts" in request.get_json(): - return "prompts argument required", 400 - - if "max_len" in request.get_json(): - return "max_len is no longer used. Replace with tokens_to_generate", 400 - - if "sentences" in request.get_json(): - return "sentences is no longer used. Replace with prompts", 400 - - prompts = request.get_json()["prompts"] - if not isinstance(prompts, list): - return "prompts is not a list of strings", 400 - - if len(prompts) == 0: - return "prompts is empty", 400 - - if len(prompts) > 128: - return "Maximum number of prompts is 128", 400 - - tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow - if "tokens_to_generate" in request.get_json(): - tokens_to_generate = request.get_json()["tokens_to_generate"] - if not isinstance(tokens_to_generate, int): - return "tokens_to_generate must be an integer greater than 0" - if tokens_to_generate < 0: - return "tokens_to_generate must be an integer greater than or equal to 0" - - logprobs = False - if "logprobs" in request.get_json(): - logprobs = request.get_json()["logprobs"] - if not isinstance(logprobs, bool): - return "logprobs must be a boolean value" - - if tokens_to_generate == 0 and not logprobs: - return "tokens_to_generate=0 implies logprobs should be True" - - temperature = 1.0 - if "temperature" in request.get_json(): - temperature = request.get_json()["temperature"] - if not (type(temperature) == int or type(temperature) == float): - return "temperature must be a positive number less than or equal to 100.0" - if not (0.0 < temperature <= 100.0): - return "temperature must be a positive number less than or equal to 100.0" - - top_k = 0.0 - if "top_k" in request.get_json(): - top_k = request.get_json()["top_k"] - if not (type(top_k) == int): - return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" - if not (0 <= top_k <= 1000): - return "top_k must be equal to or greater than 0 and less than or equal to 1000" - - top_p = 0.0 - if "top_p" in request.get_json(): - top_p = request.get_json()["top_p"] - if not (type(top_p) == float): - return "top_p must be a positive float less than or equal to 1.0" - if top_p > 0.0 and top_k > 0.0: - return "cannot set both top-k and top-p samplings." - if not (0 <= top_p <= 1.0): - return "top_p must be less than or equal to 1.0" - - top_p_decay = 0.0 - if "top_p_decay" in request.get_json(): - top_p_decay = request.get_json()["top_p_decay"] - if not (type(top_p_decay) == float): - return "top_p_decay must be a positive float less than or equal to 1.0" - if top_p == 0.0: - return "top_p_decay cannot be set without top_p" - if not (0 <= top_p_decay <= 1.0): - return "top_p_decay must be less than or equal to 1.0" - - top_p_bound = 0.0 - if "top_p_bound" in request.get_json(): - top_p_bound = request.get_json()["top_p_bound"] - if not (type(top_p_bound) == float): - return "top_p_bound must be a positive float less than or equal to top_p" - if top_p == 0.0: - return "top_p_bound cannot be set without top_p" - if not (0.0 < top_p_bound <= top_p): - return "top_p_bound must be greater than 0 and less than top_p" - - add_BOS = False - if "add_BOS" in request.get_json(): - add_BOS = request.get_json()["add_BOS"] - if not isinstance(add_BOS, bool): - return "add_BOS must be a boolean value" - - if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS: - return "Empty prompts require add_BOS=true" - - stop_on_double_eol = False - if "stop_on_double_eol" in request.get_json(): - stop_on_double_eol = request.get_json()["stop_on_double_eol"] - if not isinstance(stop_on_double_eol, bool): - return "stop_on_double_eol must be a boolean value" - - stop_on_eol = False - if "stop_on_eol" in request.get_json(): - stop_on_eol = request.get_json()["stop_on_eol"] - if not isinstance(stop_on_eol, bool): - return "stop_on_eol must be a boolean value" - - prevent_newline_after_colon = False - if "prevent_newline_after_colon" in request.get_json(): - prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"] - if not isinstance(prevent_newline_after_colon, bool): - return "prevent_newline_after_colon must be a boolean value" - - random_seed = -1 - if "random_seed" in request.get_json(): - random_seed = request.get_json()["random_seed"] - if not isinstance(random_seed, int): - return "random_seed must be integer" - if random_seed < 0: - return "random_seed must be a positive integer" - - no_log = False - if "no_log" in request.get_json(): - no_log = request.get_json()["no_log"] - if not isinstance(no_log, bool): - return "no_log must be a boolean value" - - beam_width = None - if "beam_width" in request.get_json(): - beam_width = request.get_json()["beam_width"] - if not isinstance(beam_width, int): - return "beam_width must be integer" - if beam_width < 1: - return "beam_width must be an integer > 1" - if len(prompts) > 1: - return "When doing beam_search, batch size must be 1" - - stop_token=50256 - if "stop_token" in request.get_json(): - stop_token = request.get_json()["stop_token"] - if not isinstance(stop_token, int): - return "stop_token must be an integer" - - length_penalty = 1 - if "length_penalty" in request.get_json(): - length_penalty = request.get_json()["length_penalty"] - if not isinstance(length_penalty, float): - return "length_penalty must be a float" - - with lock: # Need to get lock to keep multiple threads from hitting code - - if not no_log: - print("request IP: " + str(request.remote_addr)) - print(json.dumps(request.get_json()),flush=True) - print("start time: ", datetime.datetime.now()) - - try: - if beam_width is not None: - MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search - response, response_seg, response_scores = \ - beam_search_and_post_process( - self.model, - prompts=prompts, - tokens_to_generate=tokens_to_generate, - beam_size = beam_width, - add_BOS=add_BOS, - stop_token=stop_token, - num_return_gen=beam_width, # Returning whole beam - length_penalty=length_penalty, - prevent_newline_after_colon=prevent_newline_after_colon - ) - - return jsonify({"text": response, - "segments": response_seg, - "scores": response_scores}) - else: - MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate - response, response_seg, response_logprobs, _ = \ - generate_and_post_process( - self.model, - prompts=prompts, - tokens_to_generate=tokens_to_generate, - return_output_log_probs=logprobs, - top_k_sampling=top_k, - top_p_sampling=top_p, - top_p_decay=top_p_decay, - top_p_bound=top_p_bound, - temperature=temperature, - add_BOS=add_BOS, - use_eod_token_for_early_termination=True, - stop_on_double_eol=stop_on_double_eol, - stop_on_eol=stop_on_eol, - prevent_newline_after_colon=prevent_newline_after_colon, - random_seed=random_seed) - - return jsonify({"text": response, - "segments": response_seg, - "logprobs": response_logprobs}) - - except ValueError as ve: - return ve.args[0] - print("end time: ", datetime.datetime.now()) - - -class MegatronServer(object): - def __init__(self, model): - self.app = Flask(__name__, static_url_path='') - api = Api(self.app) - api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) - - def run(self, url, port): - self.app.run(url, threaded=True, debug=False, port=port) diff --git a/megatron/theoretical_memory_usage.py b/megatron/theoretical_memory_usage.py deleted file mode 100644 index 1a6fb6b5b313dc572ed241cfa6db157bc6784d54..0000000000000000000000000000000000000000 --- a/megatron/theoretical_memory_usage.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Computes theoretical memory footprint for model training.""" - - -import math - - -NUM_BYTES_IN_MEGABYTE = 1024 * 1024 - - -def compute_weight_and_optimizer_memory(args, verbose=False): - if not args.group_query_attention: - args.num_query_groups = args.num_attention_heads - num_parameters_in_transformer_layers = ( - 10 - * args.num_layers - * args.hidden_size - * args.hidden_size - * ( - 1 - + (args.num_query_groups / (5.0 * args.num_attention_heads)) - + (2 / (5 * args.hidden_size)) - + (1 / (5 * args.num_layers * args.hidden_size)) - ) - ) - embedding_size = args.hidden_size * args.padded_vocab_size - if args.untie_embeddings_and_output_weights: - num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + ( - 2 * embedding_size - ) - else: - num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + embedding_size - if verbose: - print( - f"Number of parameters in billions: {num_total_parameters_with_embeddings / 10**9:.2f}" - ) - - # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size. - num_parameters_on_most_loaded_model_shard = ( - (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size - ) / args.tensor_model_parallel_size - if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1: - num_parameters_on_most_loaded_model_shard += ( - embedding_size / args.tensor_model_parallel_size - ) - if verbose: - print( - f"Number of parameters in most loaded shard in billions: {num_parameters_on_most_loaded_model_shard / 10**9:.4f}" - ) - - if args.pipeline_model_parallel_size > 1: - # Other shards just have (1/pp_size transformer layers) / tp_size. - num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / ( - args.pipeline_model_parallel_size * args.tensor_model_parallel_size - ) - if verbose: - print( - f"Number of parameters in other shards in billions: {num_parameters_on_other_model_shards / 10**9:.4f}" - ) - - num_bytes_per_parameter = ( - 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size) - ) - weight_and_optimizer_memory = ( - num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter - ) - - return weight_and_optimizer_memory - - -def compute_activation_memory(args, num_microbatches, verbose=False): - # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf. - # We are trying to compute the maximum activation footprint, so all calculations in this function - # are for the first pipeline stage. - - # Memory footprint from transformer layer (self-attention and MLP). - activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * 34 - if verbose: - print( - f"Activation memory footprint per transformer layer: " - f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB" - ) - activation_memory *= args.num_layers - - # Now add activation memory required for input embeddings, last LayerNorm and output layer. - - # Input to embedding (pp_size microbatches in flight). - activation_memory += ( - 8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size - ) - # Dropout in embedding layer (pp_size microbatches in flight). - activation_memory += ( - args.seq_length - * args.micro_batch_size - * args.hidden_size - * args.pipeline_model_parallel_size - ) - - # Multiply by interleaved PP memory factor. - if args.virtual_pipeline_model_parallel_size is not None: - interleaved_schedule_memory_penalty = 1 + ( - (args.pipeline_model_parallel_size - 1) - / (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size) - ) - in_flight_microbatches = math.ceil( - interleaved_schedule_memory_penalty * args.pipeline_model_parallel_size - ) - if verbose: - print( - f"Memory penalty from interleaved schedule: {interleaved_schedule_memory_penalty:.2f}" - ) - print(f"Number of in-flight microbatches: {in_flight_microbatches}") - activation_memory *= interleaved_schedule_memory_penalty - - # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size, - # so discount accordingly. - if args.virtual_pipeline_model_parallel_size is None and args.pipeline_model_parallel_size > 1: - if num_microbatches is not None: - activation_memory *= min(1, num_microbatches / args.pipeline_model_parallel_size) - in_flight_microbatches = min(num_microbatches, args.pipeline_model_parallel_size) - else: - in_flight_microbatches = args.pipeline_model_parallel_size - if verbose: - print(f"Number of in-flight microbatches: {in_flight_microbatches}") - - if args.pipeline_model_parallel_size == 1: - # Inputs to output layer and CE loss. - activation_memory += ( - args.seq_length - * args.micro_batch_size - * args.hidden_size - * 4 - * (1 + (args.padded_vocab_size / args.hidden_size)) - ) - - # Activation memory is partitioned by TP size due to tensor and sequence model parallelism. - return activation_memory / args.tensor_model_parallel_size - - -def report_theoretical_memory(args, num_microbatches=None, verbose=False): - # Formulae here assume sequence parallelism and selective activation recomputation. - if not args.sequence_parallel or args.recompute_granularity != 'selective': - return - - weight_and_optimizer_memory = ( - compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE - ) - activation_memory = ( - compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose) - / NUM_BYTES_IN_MEGABYTE - ) - total_memory = weight_and_optimizer_memory + activation_memory - - print( - f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, " - f"activation={activation_memory:.2f} MB, " - f"total={total_memory:.2f} MB\n" - ) diff --git a/megatron/timers.py b/megatron/timers.py deleted file mode 100644 index a9478fa014b3a01dd514f74005a4b86294328dc2..0000000000000000000000000000000000000000 --- a/megatron/timers.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Megatron timers.""" - -from abc import ABC -from abc import abstractmethod -import time - -import torch - - - -class TimerBase(ABC): - - def __init__(self, name): - self.name = name - - @abstractmethod - def start(self, barrier=False): - pass - - @abstractmethod - def stop(self, barrier=False): - pass - - @abstractmethod - def reset(self): - pass - - @abstractmethod - def elapsed(self, reset=True, barrier=False): - pass - - - -class DummyTimer(TimerBase): - - def __init__(self): - super().__init__('dummy timer') - - def start(self, barrier=False): - return - - def stop(self, barrier=False): - return - - def reset(self): - return - - def elapsed(self, reset=True, barrier=False): - raise Exception('dummy timer should not be used to ' - 'calculate elapsed time') - - - -class Timer(TimerBase): - """ - Comment on using `barrier`: If this flag is passed, then all - the caller processes will wait till all reach the timing routine. - It is up to the user to make sure all the ranks in `barrier_group` - call it otherwise, it will result in a hang. - Comment on `barrier_group`: By default it is set to None which - in torch distributed land, it will result in the global communicator. - """ - - def __init__(self, name): - super().__init__(name) - self._elapsed = 0.0 - self._started = False - # Note that None will default to the global process group - self._barrier_group = None - self._start_time = time.time() - - - def set_barrier_group(self, barrier_group): - self._barrier_group = barrier_group - - - def start(self, barrier=False): - """Start the timer.""" - assert not self._started, 'timer has already been started' - if barrier: - torch.distributed.barrier(group=self._barrier_group) - torch.cuda.synchronize() - self._start_time = time.time() - self._started = True - - - def stop(self, barrier=False): - """Stop the timer.""" - assert self._started, 'timer is not started' - if barrier: - torch.distributed.barrier(group=self._barrier_group) - torch.cuda.synchronize() - self._elapsed += (time.time() - self._start_time) - self._started = False - - - def reset(self): - """Reset timer.""" - self._elapsed = 0.0 - self._started = False - - - def elapsed(self, reset=True, barrier=False): - """Calculate the elapsed time.""" - _started = self._started - # If the timing in progress, end it first. - if self._started: - self.stop(barrier=barrier) - # Get the elapsed time. - _elapsed = self._elapsed - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if _started: - self.start(barrier=barrier) - return _elapsed - - - -class Timers: - """Group of timers.""" - - def __init__(self, log_level, log_option): - self._log_level = log_level - self._log_option = log_option - self._timers = {} - self._log_levels = {} - self._dummy_timer = DummyTimer() - self._max_log_level = 2 - - - def __call__(self, name, log_level=None): - # If the timer has already been set, then check if the log-level - # is provided, it matches the one that the timer was created with. - if name in self._timers: - if log_level is not None: - assert log_level == self._log_levels[name], \ - 'input log level {} does not match already existing '\ - 'log level {} for {} timer'.format( - log_level, self._log_levels[name], name) - return self._timers[name] - # If timer does not exist and no log level is provided, - # set it to the max log level which is 2. - if log_level is None: - log_level = self._max_log_level - assert log_level <= self._max_log_level, \ - 'log level {} is larger than max supported log level {}'.format( - log_level, self._max_log_level) - # Now if the input log level is larger than the one set for - # the timers class, just ignore it and return a dummy timer. - if log_level > self._log_level: - return self._dummy_timer - # Otherwise, initalize the timer and set the level. - self._timers[name] = Timer(name) - self._log_levels[name] = log_level - return self._timers[name] - - - def _get_elapsed_time_all_ranks(self, names, reset, barrier): - """ - Assumptions: - - All the ranks call this function. - - `names` are identical on all ranks. - If the above assumptions are not met, calling this function will - result in hang. - Arguments: - - names: list of timer names - - reset: reset the timer after recording the elapsed time - - barrier: if set, do a global barrier before time measurments - """ - - # First make sure all the callers are in sync. - if barrier: - torch.distributed.barrier() - - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - # Here we can use gather on the rank we want to print the - # timing, however, there is no gather_base support in - # pytorch yet. It is simpler to deal with a single tensor - # and since we are only gathering a small amount of data, - # it should be ok to use all-gather instead of gather. - rank_name_to_time = torch.zeros((world_size, len(names)), - dtype=torch.float, - device=torch.cuda.current_device()) - for i, name in enumerate(names): - if name in self._timers: - # Here we don't need to pass the barrier flag as all - # the processes are already in sync. This avoids the - # issue of different timers having different barrier - # groups inside their class. - rank_name_to_time[rank, i] = self._timers[name].elapsed( - reset=reset) - - # See the note above for why we are not using gather. - torch.distributed._all_gather_base(rank_name_to_time.view(-1), - rank_name_to_time[rank, :].view(-1)) - - return rank_name_to_time - - - def _get_global_min_max_time(self, names, reset, barrier, normalizer): - """Report only min and max times across all ranks.""" - - rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, - barrier) - name_to_min_max_time = {} - for i, name in enumerate(names): - rank_to_time = rank_name_to_time[:, i] - # filter out the ones we did not have any timings for - rank_to_time = rank_to_time[rank_to_time > 0.0] - # If the timer exists: - if rank_to_time.numel() > 0: - name_to_min_max_time[name] = ( - rank_to_time.min().item() / normalizer, - rank_to_time.max().item() / normalizer) - return name_to_min_max_time - - - def _get_global_min_max_time_string(self, names, reset, barrier, - normalizer, max_only): - name_to_min_max_time = self._get_global_min_max_time( - names, reset, barrier, normalizer) - if not name_to_min_max_time: - return None - output_string = '(min, max) time across ranks (ms):' - for name in name_to_min_max_time: - min_time, max_time = name_to_min_max_time[name] - if max_only: - output_string += '\n {}: {:.2f}'.format( - (name+' ').ljust(48, '.'), max_time) - else: - output_string += '\n {}: ({:.2f}, {:.2f})'.format( - (name+' ').ljust(48, '.'), min_time, max_time) - return output_string - - - def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): - """Report times across all ranks.""" - rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, - barrier) - - output_string = 'times across ranks (ms):' - no_reported_timing = True - for i, name in enumerate(names): - not_yet_found = True - for rank in range(torch.distributed.get_world_size()): - if rank_name_to_time[rank, i] > 0: - no_reported_timing = False - if not_yet_found: - not_yet_found = False - output_string += '\n {}:'.format(name) - output_string += '\n rank {:2d}: {:.2f}'.format( - rank, rank_name_to_time[rank, i] / normalizer) - if no_reported_timing: - return None - return output_string - - - def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): - """Log a group of timers.""" - - # Print. - assert normalizer > 0.0 - if self._log_option in ['max', 'minmax']: - max_only = False - if self._log_option == 'max': - max_only = True - output_string = self._get_global_min_max_time_string( - names, reset, barrier, normalizer/1000.0, max_only) - elif self._log_option == 'all': - output_string = self._get_all_ranks_time_string(names, - reset, barrier, - normalizer/1000.0) - else: - raise Exception('unknown timing log option {}'.format( - self._log_option)) - - # If no input rank is provided, log on last rank. - if rank is None: - rank = torch.distributed.get_world_size() - 1 - if rank == torch.distributed.get_rank() and output_string is not None: - print(output_string, flush=True) - - - def write(self, names, writer, iteration, normalizer=1.0, - reset=False, barrier=False): - """Write timers to a tensorboard writer - Note that we only report maximum time across ranks to tensorboard. - """ - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - name_to_min_max_time = self._get_global_min_max_time( - names, reset, barrier, normalizer) - if writer is not None: - for name in name_to_min_max_time: - _, max_time = name_to_min_max_time[name] - writer.add_scalar(name + '-time', max_time, iteration) diff --git a/megatron/tokenizer/__init__.py b/megatron/tokenizer/__init__.py deleted file mode 100644 index 59ceb33865a384b4d5f12efee2d31944c79ff109..0000000000000000000000000000000000000000 --- a/megatron/tokenizer/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - - -from .tokenizer import build_tokenizer diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py deleted file mode 100644 index 642041e778e81a0ddb8bba755ce93116b296a9dd..0000000000000000000000000000000000000000 --- a/megatron/tokenizer/bert_tokenization.py +++ /dev/null @@ -1,431 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tokenization classes.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import re -import unicodedata -import six - - -def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): - """Checks whether the casing config is consistent with the checkpoint name.""" - - # The casing has to be passed in by the user and there is no explicit check - # as to whether it matches the checkpoint. The casing information probably - # should have been stored in the bert_config.json file, but it's not, so - # we have to heuristically detect it to validate. - - if not init_checkpoint: - return - - m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) - if m is None: - return - - model_name = m.group(1) - - lower_models = [ - "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", - "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" - ] - - cased_models = [ - "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", - "multi_cased_L-12_H-768_A-12" - ] - - is_bad_config = False - if model_name in lower_models and not do_lower_case: - is_bad_config = True - actual_flag = "False" - case_name = "lowercased" - opposite_flag = "True" - - if model_name in cased_models and do_lower_case: - is_bad_config = True - actual_flag = "True" - case_name = "cased" - opposite_flag = "False" - - if is_bad_config: - raise ValueError( - "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " - "However, `%s` seems to be a %s model, so you " - "should pass in `--do_lower_case=%s` so that the fine-tuning matches " - "how the model was pre-training. If this error is wrong, please " - "just comment out this check." % (actual_flag, init_checkpoint, - model_name, case_name, opposite_flag)) - - -def convert_to_unicode(text): - """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): - return text - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") - - -def printable_text(text): - """Returns text encoded in a way suitable for print or `tf.logging`.""" - - # These functions want `str` for both Python2 and Python3, but in one case - # it's a Unicode string and in the other it's a byte string. - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text - elif isinstance(text, unicode): - return text.encode("utf-8") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - else: - raise ValueError("Not running on Python2 or Python 3?") - - -def load_vocab(vocab_file): - """Loads a vocabulary file into a dictionary.""" - vocab = collections.OrderedDict() - index = 0 - with open(vocab_file, "r", encoding = "utf-8") as reader: - while True: - token = convert_to_unicode(reader.readline()) - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 - return vocab - - -def convert_by_vocab(vocab, items): - """Converts a sequence of [tokens|ids] using the vocab.""" - output = [] - for item in items: - output.append(vocab[item]) - return output - - -def convert_tokens_to_ids(vocab, tokens): - return convert_by_vocab(vocab, tokens) - - -def convert_ids_to_tokens(inv_vocab, ids): - return convert_by_vocab(inv_vocab, ids) - - -def whitespace_tokenize(text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - tokens = text.split() - return tokens - - -class FullTokenizer(object): - """Runs end-to-end tokenziation.""" - - def __init__(self, vocab_file, do_lower_case=True): - self.vocab = load_vocab(vocab_file) - self.inv_vocab = {v: k for k, v in self.vocab.items()} - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - - def tokenize(self, text): - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) - - return split_tokens - - def convert_tokens_to_ids(self, tokens): - return convert_by_vocab(self.vocab, tokens) - - def convert_ids_to_tokens(self, ids): - return convert_by_vocab(self.inv_vocab, ids) - - @staticmethod - def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): - """ Converts a sequence of tokens (string) in a single string. """ - - def clean_up_tokenization(out_string): - """ Clean up a list of simple English tokenization artifacts - like spaces before punctuations and abreviated forms. - """ - out_string = ( - out_string.replace(" .", ".") - .replace(" ?", "?") - .replace(" !", "!") - .replace(" ,", ",") - .replace(" ' ", "'") - .replace(" n't", "n't") - .replace(" 'm", "'m") - .replace(" 's", "'s") - .replace(" 've", "'ve") - .replace(" 're", "'re") - ) - return out_string - - text = ' '.join(tokens).replace(' ##', '').strip() - if clean_up_tokenization_spaces: - clean_text = clean_up_tokenization(text) - return clean_text - else: - return text - - def vocab_size(self): - return len(self.vocab) - - -class BasicTokenizer(object): - """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - - def __init__(self, do_lower_case=True): - """Constructs a BasicTokenizer. - - Args: - do_lower_case: Whether to lower case the input. - """ - self.do_lower_case = do_lower_case - - def tokenize(self, text): - """Tokenizes a piece of text.""" - text = convert_to_unicode(text) - text = self._clean_text(text) - - # This was added on November 1st, 2018 for the multilingual and Chinese - # models. This is also applied to the English models now, but it doesn't - # matter since the English models were not trained on any Chinese data - # and generally don't have any Chinese data in them (there are Chinese - # characters in the vocabulary because Wikipedia does have some Chinese - # words in the English Wikipedia.). - text = self._tokenize_chinese_chars(text) - - orig_tokens = whitespace_tokenize(text) - split_tokens = [] - for token in orig_tokens: - if self.do_lower_case: - token = token.lower() - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" - output = [] - for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - -class WordpieceTokenizer(object): - """Runs WordPiece tokenziation.""" - - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): - self.vocab = vocab - self.unk_token = unk_token - self.max_input_chars_per_word = max_input_chars_per_word - - def tokenize(self, text): - """Tokenizes a piece of text into its word pieces. - - This uses a greedy longest-match-first algorithm to perform tokenization - using the given vocabulary. - - For example: - input = "unaffable" - output = ["un", "##aff", "##able"] - - Args: - text: A single token or whitespace separated tokens. This should have - already been passed through `BasicTokenizer. - - Returns: - A list of wordpiece tokens. - """ - - text = convert_to_unicode(text) - - output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) - continue - - is_bad = False - start = 0 - sub_tokens = [] - while start < len(chars): - end = len(chars) - cur_substr = None - while start < end: - substr = "".join(chars[start:end]) - if start > 0: - substr = "##" + substr - if substr in self.vocab: - cur_substr = substr - break - end -= 1 - if cur_substr is None: - is_bad = True - break - sub_tokens.append(cur_substr) - start = end - - if is_bad: - output_tokens.append(self.unk_token) - else: - output_tokens.extend(sub_tokens) - return output_tokens - - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py deleted file mode 100644 index 3f37e449089b8f779c93ac7f034085fd7607bfb0..0000000000000000000000000000000000000000 --- a/megatron/tokenizer/gpt2_tokenization.py +++ /dev/null @@ -1,321 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tokenization classes for OpenAI GPT.""" - -from __future__ import (absolute_import, division, print_function, - unicode_literals) - -import sys -import json -import logging -import os -import regex as re -from io import open - -try: - from functools import lru_cache -except ImportError: - # Just a dummy decorator to get the checks to run on python2 - # because honestly I don't want to support a byte-level unicode BPE - # tokenizer on python 2 right now. - def lru_cache(): - return lambda func: func - - -logger = logging.getLogger(__name__) - -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", -} -PRETRAINED_MERGES_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", -} -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { - 'gpt2': 1024, -} -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - _chr = unichr if sys.version_info[0] == 2 else chr - bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \ - list(range(ord("®"), ord("ÿ") + 1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8 + n) - n += 1 - cs = [_chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -class GPT2Tokenizer(object): - """ - GPT-2 BPE tokenizer. Peculiarities: - - Byte-level BPE - """ - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] - special_tokens_file = None - else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) - if not os.path.exists(special_tokens_file): - special_tokens_file = None - else: - logger.info("loading special tokens file {}".format(special_tokens_file)) - # redirect to the cache, if necessary - try: - from .file_utils import cached_path - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) - except EnvironmentError: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file, merges_file)) - return None - if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - logger.info("loading merges file {}".format(merges_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - else: - special_tokens = kwargs.pop('special_tokens', []) - tokenizer = cls( - resolved_vocab_file, - resolved_merges_file, - special_tokens=special_tokens, - *inputs, - **kwargs) - return tokenizer - - def __init__(self, vocab_file, merges_file, errors='replace', - special_tokens=None, max_len=None): - self.max_len = max_len if max_len is not None else int(1e12) - self.encoder = json.load(open(vocab_file)) - self.decoder = {v: k for k, v in self.encoder.items()} - self.errors = errors # how to handle errors in decoding - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] - bpe_merges = [tuple(merge.split()) for merge in bpe_data] - self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) - self.cache = {} - - # Should haved added re.IGNORECASE so BPE merges can happen for - # capitalized versions of contractions - self.pat = re.compile( - r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") - - self.special_tokens = {} - self.special_tokens_decoder = {} - self.set_special_tokens(special_tokens) - - def __len__(self): - return len(self.encoder) + len(self.special_tokens) - - def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - self.special_tokens = dict((tok, len(self.encoder) + i) - for i, tok in enumerate(special_tokens)) - self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} - logger.info("Special tokens {}".format(self.special_tokens)) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token) - pairs = get_pairs(word) - - if not pairs: - return token - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except BaseException: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def tokenize(self, text): - """ Tokenize a string. """ - bpe_tokens = [] - for token in re.findall(self.pat, text): - if sys.version_info[0] == 2: - token = ''.join(self.byte_encoder[ord(b)] for b in token) - else: - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def convert_tokens_to_ids(self, tokens): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - if tokens in self.special_tokens: - return self.special_tokens[tokens] - else: - return self.encoder.get(tokens, 0) - for token in tokens: - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - ids.append(self.encoder.get(token, 0)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this OpenAI GPT model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format( - len(ids), self.max_len) - ) - return ids - - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - """Converts a sequence of ids in BPE tokens using the vocab.""" - tokens = [] - for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - tokens.append(self.decoder[i]) - return tokens - - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) - return text - - def save_vocabulary(self, vocab_path): - """Save the tokenizer vocabulary and merge files to a directory.""" - if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) - return - vocab_file = os.path.join(vocab_path, VOCAB_NAME) - merge_file = os.path.join(vocab_path, MERGES_NAME) - special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) - - with open(vocab_file, 'w', encoding='utf-8') as f: - f.write(json.dumps(self.encoder, ensure_ascii=False)) - - index = 0 - with open(merge_file, "w", encoding="utf-8") as writer: - writer.write(u'#version: 0.2\n') - for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(merge_file)) - index = token_index - writer.write(' '.join(bpe_tokens) + u'\n') - index += 1 - - index = len(self.encoder) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) - index = token_index - writer.write(token + u'\n') - index += 1 - - return vocab_file, merge_file, special_tokens_file diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py deleted file mode 100644 index 98643343c5c7621f8351b75e22ea9203ae61ad7b..0000000000000000000000000000000000000000 --- a/megatron/tokenizer/tokenizer.py +++ /dev/null @@ -1,588 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Megatron tokenizers.""" - -from abc import ABC -from abc import abstractmethod - -from .bert_tokenization import FullTokenizer as FullBertTokenizer -from .gpt2_tokenization import GPT2Tokenizer - -def build_tokenizer(args): - """Initialize tokenizer.""" - if args.rank == 0: - print('> building {} tokenizer ...'.format(args.tokenizer_type), - flush=True) - - # Select and instantiate the tokenizer. - if args.tokenizer_type == 'BertWordPieceLowerCase': - assert args.vocab_file is not None - tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, - lower_case=True, - vocab_extra_ids=args.vocab_extra_ids) - elif args.tokenizer_type == 'BertWordPieceCase': - assert args.vocab_file is not None - tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, - lower_case=False, - vocab_extra_ids=args.vocab_extra_ids) - elif args.tokenizer_type == 'GPT2BPETokenizer': - assert args.vocab_file is not None - assert args.merge_file is not None - tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) - elif args.tokenizer_type == 'SentencePieceTokenizer': - assert args.tokenizer_model is not None - tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids) - elif args.tokenizer_type == 'GPTSentencePieceTokenizer': - assert args.tokenizer_model is not None - tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) - elif args.tokenizer_type == 'Llama2Tokenizer': - assert args.tokenizer_model is not None - tokenizer = _Llama2Tokenizer(args.tokenizer_model) - elif args.tokenizer_type == 'NullTokenizer': - assert args.vocab_size is not None - tokenizer = _NullTokenizer(args.vocab_size) - else: - raise NotImplementedError('{} tokenizer is not ' - 'implemented.'.format(args.tokenizer_type)) - - # Add vocab size (if not already set from a checkpoint). - if getattr(args, "padded_vocab_size", None) is None: - args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, - args) - - return tokenizer - - -def _vocab_size_with_padding(orig_vocab_size, args): - """Pad vocab size so it is divisible by model parallel size and - still having GPU friendly size.""" - - after = orig_vocab_size - multiple = args.make_vocab_size_divisible_by * \ - args.tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 - if args.rank == 0: - print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format( - orig_vocab_size, after - orig_vocab_size, after), flush=True) - return after - - -class AbstractTokenizer(ABC): - """Abstract class for tokenizer.""" - - def __init__(self, name): - self.name = name - super().__init__() - - @property - @abstractmethod - def vocab_size(self): - pass - - @property - @abstractmethod - def vocab(self): - """Dictionary from vocab text token to id token.""" - pass - - @property - @abstractmethod - def inv_vocab(self): - """Dictionary from vocab id token to text token.""" - pass - - @abstractmethod - def tokenize(self, text): - pass - - def detokenize(self, token_ids): - raise NotImplementedError('detokenizer is not implemented for {} ' - 'tokenizer'.format(self.name)) - - @property - def cls(self): - raise NotImplementedError('CLS is not provided for {} ' - 'tokenizer'.format(self.name)) - - @property - def sep(self): - raise NotImplementedError('SEP is not provided for {} ' - 'tokenizer'.format(self.name)) - - @property - def pad(self): - raise NotImplementedError('PAD is not provided for {} ' - 'tokenizer'.format(self.name)) - - @property - def eod(self): - raise NotImplementedError('EOD is not provided for {} ' - 'tokenizer'.format(self.name)) - - @property - def mask(self): - raise NotImplementedError('MASK is not provided for {} ' - 'tokenizer'.format(self.name)) - - -class _BertWordPieceTokenizer(AbstractTokenizer): - """Original BERT wordpiece tokenizer.""" - - def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): - if lower_case: - name = 'BERT Lower Case' - else: - name = 'BERT Upper Case' - super().__init__(name) - self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) - self.cls_id = self.tokenizer.vocab['[CLS]'] - self.sep_id = self.tokenizer.vocab['[SEP]'] - self.pad_id = self.tokenizer.vocab['[PAD]'] - self.mask_id = self.tokenizer.vocab['[MASK]'] - self._additional_special_tokens = [] - - # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', - 'bos_token': '[BOS]'} - self._bos_token = '[BOS]' - self.add_token(self._bos_token) - self._bos_token_id = self.vocab.get(self._bos_token) - - self._eos_token = '[EOS]' - self.add_token(self._eos_token) - self._eos_token_id = self.vocab.get(self._eos_token) - - # (dsachan) Add additional special tokens - # These can be used as sentinel tokens in T5 model inputs - additional_special_tokens = [] - additional_special_tokens.extend( - ["".format(i) for i in range(vocab_extra_ids)]) - self.add_additional_special_tokens(additional_special_tokens) - - def add_token(self, token): - if token not in self.vocab: - self.inv_vocab[self.vocab_size] = token - # self.vocab_size comes from len(vocab) - # and it will increase as we add elements - self.vocab[token] = self.vocab_size - - def add_additional_special_tokens(self, tokens_list): - setattr(self, "additional_special_tokens", tokens_list) - for value in tokens_list: - self.add_token(value) - - @property - def vocab_size(self): - return self.tokenizer.vocab_size() - - @property - def vocab(self): - return self.tokenizer.vocab - - @property - def inv_vocab(self): - return self.tokenizer.inv_vocab - - def tokenize(self, text): - text_tokens = self.tokenizer.tokenize(text) - return self.tokenizer.convert_tokens_to_ids(text_tokens) - - def decode(self, ids): - tokens = self.tokenizer.convert_ids_to_tokens(ids) - return self.tokenizer.convert_tokens_to_string(tokens) - - def decode_token_ids(self, token_ids): - tokens = self.tokenizer.convert_ids_to_tokens(token_ids) - exclude_list = ['[PAD]', '[CLS]'] - non_pads = [t for t in tokens if t not in exclude_list] - - result = "" - for s in non_pads: - if s.startswith("##"): - result += s[2:] - else: - result += " " + s - - return result - - @property - def cls(self): - return self.cls_id - - @property - def sep(self): - return self.sep_id - - @property - def pad(self): - return self.pad_id - - @property - def mask(self): - return self.mask_id - - @property - def bos_token(self): - """ Beginning of sentence token id """ - return self._bos_token - - @property - def eos_token(self): - """ End of sentence token id """ - return self._eos_token - - @property - def additional_special_tokens(self): - """ All the additional special tokens you may want to use (list of strings).""" - return self._additional_special_tokens - - @property - def bos_token_id(self): - """ Id of the beginning of sentence token in the vocabulary.""" - return self._bos_token_id - - @property - def eos_token_id(self): - """ Id of the end of sentence token in the vocabulary.""" - return self._eos_token_id - - @property - def additional_special_tokens_ids(self): - """ Ids of all the additional special tokens in the vocabulary (list of integers).""" - return [self.vocab.get(token) for token in self._additional_special_tokens] - - @additional_special_tokens.setter - def additional_special_tokens(self, value): - self._additional_special_tokens = value - - -class _GPT2BPETokenizer(AbstractTokenizer): - """Original GPT2 BPE tokenizer.""" - - def __init__(self, vocab_file, merge_file): - name = 'GPT2 BPE' - super().__init__(name) - - self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', - special_tokens=[], max_len=None) - self.eod_id = self.tokenizer.encoder['<|endoftext|>'] - - @property - def vocab_size(self): - return len(self.tokenizer.encoder) - - @property - def vocab(self): - return self.tokenizer.encoder - - @property - def inv_vocab(self): - return self.tokenizer.decoder - - def tokenize(self, text): - return self.tokenizer.encode(text) - - def detokenize(self, token_ids): - return self.tokenizer.decode(token_ids) - - @property - def eod(self): - return self.eod_id - - -class _SentencePieceTokenizer(AbstractTokenizer): - """SentencePieceTokenizer-Megatron wrapper""" - - def __init__(self, model_file, vocab_extra_ids=0): - name = 'SentencePieceTokenizer' - super().__init__(name) - - import sentencepiece - self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file) - self._initalize(vocab_extra_ids) - - def _populate_vocab(self): - self._vocab = {} - self._inv_vocab = {} - - for i in range(len(self.tokenizer)): - t = self.tokenizer.id_to_piece(i) - self._inv_vocab[i] = t - self._vocab[t] = i - - def _initalize(self, vocab_extra_ids): - self._populate_vocab() - self._special_tokens = {} - self._inv_special_tokens = {} - - self._t5_tokens = [] - - def _add_special_token(t): - if t not in self._vocab: - next_id = len(self._vocab) - self._vocab[t] = next_id - self._inv_vocab[next_id] = t - self._special_tokens[t] = self._vocab[t] - self._inv_special_tokens[self._vocab[t]] = t - - _add_special_token('') - self._cls_id = self._vocab[''] - _add_special_token('') - self._sep_id = self._vocab[''] - _add_special_token('') - self._eod_id = self._vocab[''] - _add_special_token('') - self._mask_id = self._vocab[''] - - pad_id = self.tokenizer.pad_id() - try: - pad_token = self.tokenizer.id_to_piece(pad_id) - except IndexError: - pad_token = '' - _add_special_token(pad_token) - self._pad_id = self._vocab[pad_token] - - bos_id = self.tokenizer.bos_id() - try: - bos_token = self.tokenizer.id_to_piece(bos_id) - except IndexError: - bos_token = '' - _add_special_token(bos_token) - self._bos_id = self._vocab[bos_token] - - eos_id = self.tokenizer.eos_id() - try: - eos_token = self.tokenizer.id_to_piece(eos_id) - except IndexError: - eos_token = '' - _add_special_token(eos_token) - self._eos_id = self._vocab[eos_token] - - for i in range(vocab_extra_ids): - t = "".format(i) - _add_special_token(t) - self._t5_tokens += [t] - - @property - def vocab_size(self): - return len(self._vocab) - - @property - def vocab(self): - return self._vocab - - @property - def inv_vocab(self): - return self._inv_vocab - - @property - def decoder(self): - return self._inv_vocab - - @property - def encoder(self): - return self._vocab - - # From: - # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89 - def tokenize(self, text): - ids = [] - idx = 0 - - while 1: - indices = {} - for token in self._special_tokens: - try: - indices[token] = text[idx:].index(token) - except ValueError: - continue - if len(indices) == 0: - break - - next_token = min(indices, key=indices.get) - next_idx = idx + indices[next_token] - - ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx])) - ids.append(self._special_tokens[next_token]) - idx = next_idx + len(next_token) - - ids.extend(self.tokenizer.encode_as_ids(text[idx:])) - return ids - - # From: - # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125 - def detokenize(self, ids): - text = "" - last_i = 0 - - for i, id in enumerate(ids): - if id in self._inv_special_tokens: - text += self.tokenizer.decode_ids(ids[last_i:i]) + " " - text += self._inv_special_tokens[id] + " " - last_i = i + 1 - - text += self.tokenizer.decode_ids(ids[last_i:]) - return text - - @property - def cls(self): - return self._cls_id - - @property - def sep(self): - return self._sep_id - - @property - def pad(self): - return self._pad_id - - @property - def bos_token_id(self): - return self._bos_id - - @property - def bos(self): - return self._bos_id - - @property - def eod(self): - return self._eod_id - - @property - def eos_token_id(self): - return self._eos_id - - @property - def eos(self): - return self._eos_id - - @property - def mask(self): - return self._mask_id - - @property - def additional_special_tokens_ids(self): - return [self.vocab[k] for k in self._t5_tokens] - -class _GPTSentencePieceTokenizer(_SentencePieceTokenizer): - """SentencePieceTokenizer-Megatron wrapper""" - - def __init__(self, model_file,): - super().__init__(model_file, vocab_extra_ids=0) - - def _initalize(self, vocab_extra_ids): - self._populate_vocab() - - self._pad_id = self.tokenizer.pad_id() - self._bos_id = self.tokenizer.bos_id() - self._eos_id = self.tokenizer.eos_id() - - def tokenize(self, text): - return self.tokenizer.encode_as_ids(text) - - def detokenize(self, ids): - return self.tokenizer.decode_ids(ids) - - @property - def cls(self): - return -1 - - @property - def sep(self): - return -1 - - @property - def mask(self): - return -1 - - @property - def eod(self): - return self._eos_id - - @property - def additional_special_tokens_ids(self): - return None - -class _Llama2Tokenizer(_SentencePieceTokenizer): - """SentencePieceTokenizer-Megatron wrapper""" - - def __init__(self, model_file,): - super().__init__(model_file, vocab_extra_ids=0) - - def _initalize(self, vocab_extra_ids): - self._populate_vocab() - - # BOS / EOS token IDs - self.n_words: int = self.tokenizer.vocab_size() - self.bos_id: int = self.tokenizer.bos_id() - self.eos_id: int = self.tokenizer.eos_id() - self.pad_id: int = self.tokenizer.pad_id() - assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size() - - def tokenize(self, s: str, bos=True, eos=False): - '''Default args for text completion, not chat/dialog.''' - assert type(s) is str - t = self.tokenizer.encode(s) - if bos: - t = [self.bos_id] + t - if eos: - t = t + [self.eos_id] - return t - - def detokenize(self, ids): - return self.tokenizer.decode_ids(ids) - - @property - def cls(self): - return -1 - - @property - def sep(self): - return -1 - - @property - def mask(self): - return -1 - - @property - def eod(self): - return self.eos_id - - @property - def additional_special_tokens_ids(self): - return None - -class _NullTokenizer: - def __init__(self, vocab_size): - vocab_size = int(vocab_size) - self._eos_id = vocab_size - self.vocab_size = vocab_size+1 - - def tokenize(self, text): - return [int(x) for x in text.split(' ')] - - def detokenize(self, ids): - text = [str(x) for x in ids] - return ' '.join(text) - - @property - def cls(self): - return -1 - - @property - def sep(self): - return -1 - - @property - def mask(self): - return -1 - - @property - def eod(self): - return self._eos_id - - @property - def additional_special_tokens_ids(self): - return None diff --git a/megatron/training.py b/megatron/training.py deleted file mode 100644 index 823402b6ef6beabe0dab507da9349bfafda20451..0000000000000000000000000000000000000000 --- a/megatron/training.py +++ /dev/null @@ -1,1167 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Pretrain utilities.""" - -import gc -from datetime import datetime -import math -import logging -import sys -from .log_handler import CustomHandler -# Make default logging level INFO, but filter out all log messages not from MCore. -logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) -from .theoretical_memory_usage import report_theoretical_memory -import time -# The earliest we can measure the start time. -_TRAIN_START_TIME = time.time() -import torch - -from megatron import get_args -from megatron import get_signal_handler -from megatron import get_timers -from megatron import get_tensorboard_writer -from megatron import get_wandb_writer -from megatron import get_current_global_batch_size -from megatron import get_num_microbatches -from megatron import is_last_rank -from megatron import update_num_microbatches -from megatron.core import mpu, tensor_parallel -from megatron.core.utils import get_model_config -from megatron import print_rank_0 -from megatron import print_rank_last -from megatron.checkpointing import load_checkpoint -from megatron.checkpointing import save_checkpoint -from megatron.model import Float16Module -from megatron.model import GPTModel -from megatron.core.distributed import DistributedDataParallel as DDP -from megatron.core.distributed import finalize_model_grads -from megatron.core.enums import ModelType -from megatron.optimizer import get_megatron_optimizer -from megatron.initialize import initialize_megatron -from megatron.initialize import write_args_to_tensorboard -from megatron.initialize import set_jit_fusion_options -from megatron.optimizer_param_scheduler import OptimizerParamScheduler -from megatron.utils import check_adlr_autoresume_termination -from megatron.utils import unwrap_model -from megatron.data.data_samplers import build_pretraining_data_loader -from megatron.utils import calc_params_l2_norm -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.utils import report_memory -from megatron.model.vision.knn_monitor import compute_feature_bank - - -def print_datetime(string): - """Note that this call will sync across all ranks.""" - torch.distributed.barrier() - time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - print_rank_0('[' + string + '] datetime: {} '.format(time_str)) - - -def num_floating_point_operations(args, batch_size): - if not args.group_query_attention: - args.num_query_groups = args.num_attention_heads - return ( - 60 - * batch_size - * args.seq_length - * args.num_layers - * args.hidden_size - * args.hidden_size - * ( - 1 - + (args.num_query_groups / (5 * args.num_attention_heads)) - + (args.seq_length / (5 * args.hidden_size)) - + (args.padded_vocab_size / (10 * args.num_layers * args.hidden_size)) - ) - ) - - -def pretrain(train_valid_test_dataset_provider, - model_provider, - model_type, - forward_step_func, - process_non_loss_data_func=None, - extra_args_provider=None, - args_defaults={}): - """Main training program. - - This function will run the followings in the order provided: - 1) initialize Megatron. - 2) setup model, optimizer and lr schedule using the model_provider. - 3) call train_val_test_data_provider to get train/val/test datasets. - 4) train the modle using the forward_step_func. - - Arguments: - train_valid_test_dataset_provider: a function that takes the size of - train/valid/test dataset and returns `train, valid, test` datasets. - model_provider: a function that returns a vanilla version of the - model. By vanilla we mean a simple model on cpu with no fp16 or ddp. - model_type: an enum that specifies the type of model being trained. - forward_step_func: a function that takes a `data iterator` and `model`, - and returns a `loss` scalar with a dictionary with key:values being - the info we would like to monitor during training, for example - `lm-loss: value`. We also require that this function add - `batch generator` to the timers class. - process_non_loss_data_func: a function to post process outputs of the - network. It can be used for dumping output tensors (e.g images) to - tensorboard. It takes `collected data`(list of tensors), - `current iteration index` and `tensorboard writer` as arguments. - extra_args_provider: a function that takes a parser and adds arguments - to it. It is used for programs to add their own arguments. - args_defaults: a dictionary from argument-name to argument-value. It - to set already parse arguments. - """ - - # Initalize and get arguments, timers, and Tensorboard writer. - initialize_megatron(extra_args_provider=extra_args_provider, - args_defaults=args_defaults) - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options() - - # Adjust the startup time so it reflects the largest value. - # This will be closer to what scheduler will see (outside of - # image ... launches. - global _TRAIN_START_TIME - start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME]) - torch.distributed.all_reduce(start_time_tensor, - op=torch.distributed.ReduceOp.MIN) - _TRAIN_START_TIME = start_time_tensor.item() - print_rank_0('time to initialize megatron (seconds): {:.3f}'.format( - time.time() - _TRAIN_START_TIME)) - print_datetime('after megatron is initialized') - - args = get_args() - timers = get_timers() - - # Model, optimizer, and learning rate. - timers('model-and-optimizer-setup', log_level=0).start(barrier=True) - model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, model_type) - timers('model-and-optimizer-setup').stop() - print_datetime('after model, optimizer, and learning rate ' - 'scheduler are built') - config = get_model_config(model[0]) - - # Data stuff. - timers('train/valid/test-data-iterators-setup', log_level=0).start( - barrier=True) - if args.virtual_pipeline_model_parallel_size is not None: - train_data_iterator = [] - valid_data_iterator = [] - test_data_iterator = [] - for i in range(len(model)): - mpu.set_virtual_pipeline_model_parallel_rank(i) - iterators = build_train_valid_test_data_iterators( - train_valid_test_dataset_provider) - train_data_iterator.append(iterators[0]) - valid_data_iterator.append(iterators[1]) - test_data_iterator.append(iterators[2]) - else: - train_data_iterator, valid_data_iterator, test_data_iterator \ - = build_train_valid_test_data_iterators( - train_valid_test_dataset_provider) - timers('train/valid/test-data-iterators-setup').stop() - print_datetime('after dataloaders are built') - - # Print setup timing. - print_rank_0('done with setup ...') - timers.log(['model-and-optimizer-setup', - 'train/valid/test-data-iterators-setup'], barrier=True) - - if not args.skip_train: - print_rank_0('training ...') - - if args.dataloader_type == 'cyclic' and args.retro_add_retriever: - args.train_iters = args.retro_cyclic_train_iters - print_rank_0("retro cyclic train iters : %d" % args.train_iters) - - iteration = 0 - if args.do_train and args.train_iters > 0: - iteration = train(forward_step_func, - model, optimizer, opt_param_scheduler, - train_data_iterator, valid_data_iterator, - process_non_loss_data_func, config) - - print_datetime('after training is done') - - if args.save and iteration != 0: - save_checkpoint(iteration, model, optimizer, opt_param_scheduler) - else: - print_rank_0('skipping training (--skip-train is on) ...') - - iteration = args.iteration - - if args.do_valid: - prefix = f'iteration {iteration} on validation set' - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, config, - verbose=True, write_to_tensorboard=not args.skip_train) - - if args.do_test: - prefix = f'iteration {iteration} on test set' - evaluate_and_print_results(prefix, forward_step_func, - test_data_iterator, model, - iteration, process_non_loss_data_func, config, - verbose=True, write_to_tensorboard=not args.skip_train) - - -def update_train_iters(args): - - # For iteration-based training, we don't need to do anything - if args.train_iters: - return - - # Constant batch size with sample-based training. - if args.rampup_batch_size is None: - args.train_iters = args.train_samples // args.global_batch_size - - else: - # Sample based training with rampup batch size. - iterations = 0 - consumed_samples = 0 - # Rampup phase. - while consumed_samples <= int(args.rampup_batch_size[2]): - update_num_microbatches(consumed_samples, consistency_check=False) - consumed_samples += get_current_global_batch_size() - iterations += 1 - # Reset - update_num_microbatches(0, consistency_check=False) - # Constant phase - # Note that we throw away any partial last batch. - iterations += (args.train_samples - consumed_samples) // \ - args.global_batch_size - args.train_iters = iterations - - print_rank_0('setting training iterations to {}'.format(args.train_iters)) - - -def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): - """Build the model.""" - args = get_args() - args.model_type = model_type - - # Build model. - if mpu.get_pipeline_model_parallel_world_size() > 1 and \ - args.virtual_pipeline_model_parallel_size is not None: - assert model_type != ModelType.encoder_and_decoder, \ - "Interleaved schedule not supported for model with both encoder and decoder" - model = [] - for i in range(args.virtual_pipeline_model_parallel_size): - mpu.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - this_model = model_provider_func( - pre_process=pre_process, - post_process=post_process - ) - this_model.model_type = model_type - model.append(this_model) - else: - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - add_encoder = True - add_decoder = True - if model_type == ModelType.encoder_and_decoder: - if mpu.get_pipeline_model_parallel_world_size() > 1: - assert args.pipeline_model_parallel_split_rank is not None, \ - "Split rank needs to be specified for model with both encoder and decoder" - rank = mpu.get_pipeline_model_parallel_rank() - split_rank = args.pipeline_model_parallel_split_rank - world_size = mpu.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == split_rank - post_process = (rank == (split_rank - 1)) or ( - rank == (world_size - 1)) - add_encoder = mpu.is_pipeline_stage_before_split() - add_decoder = mpu.is_pipeline_stage_after_split() - model = model_provider_func( - pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder) - else: - model = model_provider_func( - pre_process=pre_process, - post_process=post_process - ) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Disallow training and inference with Transformer Engine - # for non-GPT models - args.allow_transformer_engine = all([type(m) == GPTModel for m in model]) - # assert args.allow_transformer_engine or args.transformer_impl == 'local', \ - # 'Transformer Engine is only approved for GPT models' - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - # Print number of parameters. - if mpu.get_data_parallel_rank() == 0: - print(' > number of parameters on (tensor, pipeline) ' - 'model parallel rank ({}, {}): {}'.format( - mpu.get_tensor_model_parallel_rank(), - mpu.get_pipeline_model_parallel_rank(), - sum([sum([p.nelement() for p in model_module.parameters()]) - for model_module in model])), flush=True) - - # GPU allocation. - for model_module in model: - model_module.cuda(torch.cuda.current_device()) - - # Fp16 conversion. - if args.fp16 or args.bf16: - model = [Float16Module(model_module, args) for model_module in model] - - if wrap_with_ddp: - config = get_model_config(model[0]) - model = [DDP(config, - model_chunk, - data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True), - accumulate_allreduce_grads_in_fp32=args.accumulate_allreduce_grads_in_fp32, - overlap_grad_reduce=args.overlap_grad_reduce, - use_distributed_optimizer=args.use_distributed_optimizer, - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0)) - for (model_chunk_idx, model_chunk) in enumerate(model)] - - # Broadcast params from data parallel src rank to other data parallel ranks. - if args.data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() - - return model - - -def get_optimizer_param_scheduler(optimizer): - """Build the learning rate scheduler.""" - args = get_args() - - # Iteration-based training. - if args.train_iters: - if args.lr_decay_iters is None: - args.lr_decay_iters = args.train_iters - lr_decay_steps = args.lr_decay_iters * args.global_batch_size - wd_incr_steps = args.train_iters * args.global_batch_size - if args.lr_warmup_fraction is not None: - lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps - else: - lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size - # Sample-based training. - elif args.train_samples: - # We need to set training iters for later use. Technically - # we need to adjust the training samples too (due to last - # batch being incomplete) but we leave it as is for now. - update_train_iters(args) - if args.lr_decay_samples is None: - args.lr_decay_samples = args.train_samples - lr_decay_steps = args.lr_decay_samples - wd_incr_steps = args.train_samples - if args.lr_warmup_fraction is not None: - lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps - else: - lr_warmup_steps = args.lr_warmup_samples - else: - raise Exception( - 'either train-iters or train-samples should be provided.') - - opt_param_scheduler = OptimizerParamScheduler( - optimizer, - init_lr=args.lr_warmup_init, - max_lr=args.lr, - min_lr=args.min_lr, - lr_warmup_steps=lr_warmup_steps, - lr_decay_steps=lr_decay_steps, - lr_decay_style=args.lr_decay_style, - start_wd=args.start_weight_decay, - end_wd=args.end_weight_decay, - wd_incr_steps=wd_incr_steps, - wd_incr_style=args.weight_decay_incr_style, - use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler, - override_opt_param_scheduler=args.override_opt_param_scheduler) - - return opt_param_scheduler - - -def setup_model_and_optimizer(model_provider_func, - model_type, - no_wd_decay_cond=None, - scale_lr_cond=None, - lr_mult=1.0): - """Setup model and optimizer.""" - args = get_args() - - model = get_model(model_provider_func, model_type) - unwrapped_model = unwrap_model(model) - - optimizer = get_megatron_optimizer(model, no_wd_decay_cond, - scale_lr_cond, lr_mult) - opt_param_scheduler = get_optimizer_param_scheduler(optimizer) - - if args.load is not None: - timers = get_timers() - timers('load-checkpoint', log_level=0).start(barrier=True) - args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) - timers('load-checkpoint').stop(barrier=True) - timers.log(['load-checkpoint']) - else: - args.iteration = 0 - - # get model without FP16 and/or DDP wrappers - if args.iteration == 0 and len(unwrapped_model) == 1 \ - and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): - print_rank_0("Initializing ICT from pretrained BERT model") - unwrapped_model[0].init_state_dict_from_bert() - if args.fp16: - optimizer.reload_model_params() - - return model, optimizer, opt_param_scheduler - - - -def train_step(forward_step_func, data_iterator, - model, optimizer, opt_param_scheduler, config): - """Single training step.""" - args = get_args() - timers = get_timers() - - # Set grad to zero. - for model_chunk in model: - # If using distributed optimizer, don't zero buffer here; zeroing of buffer is - # handled automatically by the optimizer after all-gathers finish. - # Otherwise, zero the buffer. - model_chunk.zero_grad_buffer(zero_buffer=(not args.use_distributed_optimizer)) - optimizer.zero_grad() - - # Forward pass. - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=False) - - # Empty unused memory. - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - # Vision gradients. - if args.vision_pretraining and args.vision_pretraining_type == "dino": - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) - - # Update parameters. - timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) - update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) - timers('optimizer').stop() - - # Vision momentum. - if args.vision_pretraining and args.vision_pretraining_type == "dino": - unwrapped_model = unwrap_model(model[0]) - unwrapped_model.update_momentum(args.curr_iteration) - - # Update learning rate. - if update_successful: - increment = get_num_microbatches() * \ - args.micro_batch_size * \ - args.data_parallel_size - opt_param_scheduler.step(increment=increment) - skipped_iter = 0 - else: - skipped_iter = 1 - - # Empty unused memory. - if args.empty_unused_memory_level >= 2: - torch.cuda.empty_cache() - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Average loss across microbatches. - loss_reduced = {} - for key in losses_reduced[0]: - losses_reduced_for_key = [x[key] for x in losses_reduced] - loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) - return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad - return {}, skipped_iter, grad_norm, num_zeros_in_grad - - -def training_log(loss_dict, total_loss_dict, learning_rate, iteration, - loss_scale, report_memory_flag, skipped_iter, - grad_norm, params_norm, num_zeros_in_grad): - """Log training information such as losses, timing, ....""" - args = get_args() - timers = get_timers() - writer = get_tensorboard_writer() - wandb_writer = get_wandb_writer() - - # Advanced, skipped, and Nan iterations. - advanced_iters_key = 'advanced iterations' - skipped_iters_key = 'skipped iterations' - nan_iters_key = 'nan iterations' - # Advanced iterations. - if not skipped_iter: - total_loss_dict[advanced_iters_key] = total_loss_dict.get( - advanced_iters_key, 0) + 1 - else: - if advanced_iters_key not in total_loss_dict: - total_loss_dict[advanced_iters_key] = 0 - # Skipped iterations. - total_loss_dict[skipped_iters_key] = total_loss_dict.get( - skipped_iters_key, 0) + skipped_iter - # Update losses and set nan iterations - got_nan = False - for key in loss_dict: - if not skipped_iter: - total_loss_dict[key] = total_loss_dict.get( - key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] - else: - value = loss_dict[key].float().sum().item() - is_nan = value == float('inf') or \ - value == -float('inf') or \ - value != value - got_nan = got_nan or is_nan - total_loss_dict[nan_iters_key] = total_loss_dict.get( - nan_iters_key, 0) + int(got_nan) - - # Logging. - timers_to_log = [ - 'forward-backward', - 'forward-compute', - 'backward-compute', - 'batch-generator', - 'forward-recv', - 'forward-send', - 'backward-recv', - 'backward-send', - 'forward-send-forward-recv', - 'forward-send-backward-recv', - 'backward-send-forward-recv', - 'backward-send-backward-recv', - 'forward-backward-send-forward-backward-recv', - 'layernorm-grads-all-reduce', - 'embedding-grads-all-reduce', - 'all-grads-sync', - 'params-all-gather', - 'optimizer-copy-to-main-grad', - 'optimizer-unscale-and-check-inf', - 'optimizer-clip-main-grad', - 'optimizer-count-zeros', - 'optimizer-inner-step', - 'optimizer-copy-main-to-model-params', - 'optimizer'] - - # Calculate batch size. - batch_size = args.micro_batch_size * args.data_parallel_size * \ - get_num_microbatches() - - total_iterations = total_loss_dict[advanced_iters_key] + \ - total_loss_dict[skipped_iters_key] - - # Tensorboard values. - # Timer requires all the ranks to call. - if args.log_timers_to_tensorboard and \ - (iteration % args.tensorboard_log_interval == 0): - timers.write(timers_to_log, writer, iteration, - normalizer=total_iterations) - if writer and (iteration % args.tensorboard_log_interval == 0): - if wandb_writer: - wandb_writer.log({'samples vs steps': args.consumed_train_samples}, - iteration) - if args.log_learning_rate_to_tensorboard: - writer.add_scalar('learning-rate', learning_rate, iteration) - writer.add_scalar('learning-rate vs samples', learning_rate, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'learning-rate': learning_rate}, iteration) - if args.log_batch_size_to_tensorboard: - writer.add_scalar('batch-size', batch_size, iteration) - writer.add_scalar('batch-size vs samples', batch_size, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'batch-size': batch_size}, iteration) - for key in loss_dict: - writer.add_scalar(key , loss_dict[key], iteration) - writer.add_scalar(key + ' vs samples', loss_dict[key], - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({key: loss_dict[key]}, iteration) - if args.log_loss_scale_to_tensorboard: - writer.add_scalar('loss-scale', loss_scale, iteration) - writer.add_scalar('loss-scale vs samples', loss_scale, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'loss-scale': loss_scale}, iteration) - if args.log_world_size_to_tensorboard: - writer.add_scalar('world-size', args.world_size, iteration) - writer.add_scalar('world-size vs samples', args.world_size, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'world-size': args.world_size}, iteration) - if grad_norm is not None: - writer.add_scalar('grad-norm', grad_norm, iteration) - writer.add_scalar('grad-norm vs samples', grad_norm, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'grad-norm': grad_norm}, iteration) - if num_zeros_in_grad is not None: - writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) - writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) - if params_norm is not None: - writer.add_scalar('params-norm', params_norm, iteration) - writer.add_scalar('params-norm vs samples', params_norm, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'params-norm': params_norm}, iteration) - if args.log_memory_to_tensorboard: - mem_stats = torch.cuda.memory_stats() - writer.add_scalar( - "mem-reserved-bytes", - mem_stats["reserved_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-bytes", - mem_stats["allocated_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-count", - mem_stats["allocation.all.current"], - iteration, - ) - - if iteration % args.log_interval == 0: - elapsed_time = timers('interval-time').elapsed(barrier=True) - elapsed_time_per_iteration = elapsed_time / total_iterations - throughput = num_floating_point_operations(args, batch_size) / ( - elapsed_time_per_iteration * 10**12 * args.world_size) - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('iteration-time', - elapsed_time_per_iteration, iteration) - if wandb_writer: - wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, - iteration) - log_string = ' iteration {:8d}/{:8d} |'.format( - iteration, args.train_iters) - log_string += ' consumed samples: {:12d} |'.format( - args.consumed_train_samples) - log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( - elapsed_time_per_iteration * 1000.0) - if args.log_throughput: - log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('throughput', throughput, iteration) - if wandb_writer: - wandb_writer.log({'throughput': throughput}, iteration) - log_string += ' learning rate: {:.3E} |'.format(learning_rate) - log_string += ' global batch size: {:5d} |'.format(batch_size) - for key in total_loss_dict: - if key not in [advanced_iters_key, skipped_iters_key, - nan_iters_key]: - avg = total_loss_dict[key].item() / \ - float(max(1, total_loss_dict[advanced_iters_key])) - if avg > 0.0: - log_string += ' {}: {:.6E} |'.format(key, avg) - total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) - log_string += ' loss scale: {:.1f} |'.format(loss_scale) - if grad_norm is not None: - log_string += ' grad norm: {:.3f} |'.format(grad_norm) - if num_zeros_in_grad is not None: - log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) - if params_norm is not None: - log_string += ' params norm: {:.3f} |'.format(params_norm) - log_string += ' number of skipped iterations: {:3d} |'.format( - total_loss_dict[skipped_iters_key]) - log_string += ' number of nan iterations: {:3d} |'.format( - total_loss_dict[nan_iters_key]) - total_loss_dict[advanced_iters_key] = 0 - total_loss_dict[skipped_iters_key] = 0 - total_loss_dict[nan_iters_key] = 0 - print_rank_last(log_string) - if report_memory_flag and learning_rate > 0.: - # Report memory after optimizer state has been initialized. - if torch.distributed.get_rank() == 0: - num_microbatches = get_num_microbatches() - report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) - report_memory('(after {} iterations)'.format(iteration)) - report_memory_flag = False - timers.log(timers_to_log, normalizer=args.log_interval) - - return report_memory_flag - - -def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): - timers = get_timers() - # Extra barrier is added to make sure - # all ranks report the max time. - timers('save-checkpoint', log_level=0).start(barrier=True) - save_checkpoint(iteration, model, optimizer, opt_param_scheduler) - timers('save-checkpoint').stop(barrier=True) - timers.log(['save-checkpoint']) - - -def train(forward_step_func, model, optimizer, opt_param_scheduler, - train_data_iterator, valid_data_iterator, - process_non_loss_data_func, config): - """Train the model function.""" - args = get_args() - timers = get_timers() - - # Write args to tensorboard - write_args_to_tensorboard() - - # Turn on training mode which enables dropout. - for model_module in model: - model_module.train() - - # Tracking loss. - total_loss_dict = {} - - # Iterations. - iteration = args.iteration - - # Setup some training config params - config.grad_scale_func = optimizer.scale_loss - config.timers = timers - if isinstance(model[0], DDP) and args.overlap_grad_reduce: - assert config.no_sync_func is None, \ - ('When overlap_grad_reduce is True, config.no_sync_func must be None; ' - 'a custom no_sync_func is not supported when overlapping grad-reduce') - config.no_sync_func = [model_chunk.no_sync for model_chunk in model] - if len(model) == 1: - config.no_sync_func = config.no_sync_func[0] - if args.delay_grad_reduce: - config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] - if len(model) == 1: - config.grad_sync_func = config.grad_sync_func[0] - if args.overlap_param_gather and args.delay_param_gather: - config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x) - for model_index in range(len(model))] - if len(model) == 1: - config.param_sync_func = config.param_sync_func[0] - config.finalize_model_grads_func = finalize_model_grads - - timers('interval-time', log_level=0).start(barrier=True) - print_datetime('before the start of training step') - report_memory_flag = True - exit = False - - if args.manual_gc: - # Disable the default garbage collector and perform the collection manually. - # This is to align the timing of garbage collection across ranks. - assert args.manual_gc_interval >= 0, \ - 'Manual garbage collection interval should be laerger than or equal to 0.' - gc.disable() - gc.collect() - - while iteration < args.train_iters: - if args.profile and \ - iteration == args.profile_step_start and \ - torch.distributed.get_rank() in args.profile_ranks: - torch.cuda.cudart().cudaProfilerStart() - torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() - - update_num_microbatches(args.consumed_train_samples) - args.curr_iteration = iteration - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_step_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) - iteration += 1 - args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ - args.micro_batch_size * \ - get_num_microbatches() - - # Logging. - loss_scale = optimizer.get_loss_scale().item() - params_norm = None - if args.log_params_norm: - params_norm = calc_params_l2_norm(model) - report_memory_flag = training_log(loss_dict, total_loss_dict, - optimizer.param_groups[0]['lr'], - iteration, loss_scale, - report_memory_flag, skipped_iter, - grad_norm, params_norm, num_zeros_in_grad) - - # Autoresume - if args.adlr_autoresume and \ - (iteration % args.adlr_autoresume_interval == 0): - check_adlr_autoresume_termination(iteration, model, optimizer, - opt_param_scheduler) - - # Evaluation - if args.eval_interval and iteration % args.eval_interval == 0 and \ - args.do_valid: - timers('interval-time').stop() - if args.manual_gc and args.manual_gc_eval: - # Collect all objects. - gc.collect() - prefix = 'iteration {}'.format(iteration) - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, - config, False) - if args.manual_gc and args.manual_gc_eval: - # Collect only the objects created and used in evaluation. - gc.collect(generation=0) - timers('interval-time', log_level=0).start(barrier=True) - - # Checkpointing - saved_checkpoint = False - if args.exit_signal_handler: - signal_handler = get_signal_handler() - if any(signal_handler.signals_received()): - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after receiving SIGTERM.') - exit = True - break - - if args.save and args.save_interval and \ - iteration % args.save_interval == 0: - timers('interval-time').stop() - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - saved_checkpoint = True - timers('interval-time', log_level=0).start(barrier=True) - - # Exiting based on duration - if args.exit_duration_in_mins: - train_time = (time.time() - _TRAIN_START_TIME) / 60.0 - done_cuda = torch.cuda.IntTensor( - [train_time > args.exit_duration_in_mins]) - torch.distributed.all_reduce( - done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - if not saved_checkpoint: - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - print_datetime('exiting program after {} minutes'.format(train_time)) - exit = True - break - - # Exiting based on iterations - if args.exit_interval and iteration % args.exit_interval == 0: - if args.save and not saved_checkpoint: - save_checkpoint_and_time(iteration, model, optimizer, - opt_param_scheduler) - torch.distributed.barrier() - print_datetime('exiting program at iteration {}'.format(iteration)) - exit = True - break - - if args.profile and \ - iteration == args.profile_step_end and \ - torch.distributed.get_rank() in args.profile_ranks: - torch.cuda.cudart().cudaProfilerStop() - - if args.manual_gc: - if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0: - gc.collect() - - # Flush TensorBoard and WandB writers. - writer = get_tensorboard_writer() - if writer: - writer.flush() - wandb_writer = get_wandb_writer() - if wandb_writer: - wandb_writer.finish() - - # If any exit conditions (signal handler, duration, iterations) have been reached, exit. - if exit: - sys.exit() - - return iteration - - -def evaluate(forward_step_func, - data_iterator, - model, - process_non_loss_data_func, - config, - verbose=False): - """Evaluation.""" - args = get_args() - timers = get_timers() - - timers('evaluate', log_level=0).start(barrier=True) - - if args.vision_pretraining and args.vision_pretraining_type == "dino": - compute_feature_bank(model) - - # Turn on evaluation mode which disables dropout. - for model_module in model: - model_module.eval() - - total_loss_dict = {} - - # make validation batch size independent from training batch size - eval_batch_size = args.global_batch_size - eval_num_microbatches = eval_batch_size // \ - (args.micro_batch_size * args.data_parallel_size) - - with torch.no_grad(): - iteration = 0 - if verbose: - print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples') - while iteration < args.eval_iters: - iteration += 1 - if verbose: - print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}') - - forward_backward_func = get_forward_backward_func() - # Don't care about timing during evaluation - config.timers = None - loss_dicts = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=eval_num_microbatches, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=True) - config.timers = get_timers() - - # Empty unused memory - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Reduce across processes. - for loss_dict in loss_dicts: - for key in loss_dict: - total_loss_dict[key] = total_loss_dict.get( - key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] - - args.consumed_valid_samples += eval_batch_size - - if args.exit_duration_in_mins: - train_time = (time.time() - _TRAIN_START_TIME) / 60.0 - done_cuda = torch.cuda.IntTensor( - [train_time > args.exit_duration_in_mins]) - torch.distributed.all_reduce( - done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - print_rank_0('Exiting during evaluation, timelimit reached') - return None, None, True - - collected_non_loss_data = None - if process_non_loss_data_func is not None and is_last_rank(): - collected_non_loss_data = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=True, - collect_non_loss_data=True) - - # Move model back to the train mode. - for model_module in model: - model_module.train() - - for key in total_loss_dict: - total_loss_dict[key] /= args.eval_iters * eval_num_microbatches - - timers('evaluate').stop() - timers.log(['evaluate']) - - return total_loss_dict, collected_non_loss_data, False - -def evaluate_and_print_results(prefix, forward_step_func, - data_iterator, model, - iteration, process_non_loss_data_func, config, - verbose=False, write_to_tensorboard=True): - """Helper function to evaluate and dump results on screen.""" - args = get_args() - if write_to_tensorboard: - writer = get_tensorboard_writer() - else: - writer = None - - wandb_writer = get_wandb_writer() - - total_loss_dict, collected_non_loss_data, timelimit = evaluate( - forward_step_func, data_iterator, model, - process_non_loss_data_func, config, verbose) - # Timelimit hit during evaluation - if timelimit: - return - string = ' validation loss at {} | '.format(prefix) - for key in total_loss_dict: - string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) - ppl = math.exp(min(20, total_loss_dict[key].item())) - string += '{} PPL: {:.6E} | '.format(key, ppl) - if writer: - writer.add_scalar('{} validation'.format(key), - total_loss_dict[key].item(), - iteration) - writer.add_scalar('{} validation vs samples'.format(key), - total_loss_dict[key].item(), - args.consumed_train_samples) - if args.log_validation_ppl_to_tensorboard: - writer.add_scalar('{} validation ppl'.format(key), ppl, - iteration) - writer.add_scalar('{} validation ppl vs samples'.format(key), - ppl, args.consumed_train_samples) - if wandb_writer and is_last_rank(): - wandb_writer.log({ - '{} validation'.format(key): total_loss_dict[key].item()}, - iteration) - - if process_non_loss_data_func is not None and writer and is_last_rank(): - process_non_loss_data_func(collected_non_loss_data, iteration, writer) - - length = len(string) + 1 - print_rank_last('-' * length) - print_rank_last(string) - print_rank_last('-' * length) - - -def cyclic_iter(iter): - while True: - for x in iter: - yield x - - -def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): - """Build pretraining datasets.""" - - args = get_args() - - # Number of train/valid/test samples. - if args.train_samples: - train_samples = args.train_samples - else: - train_samples = args.train_iters * args.global_batch_size - eval_iters = (args.train_iters // args.eval_interval + 1) * \ - args.eval_iters - test_iters = args.eval_iters - train_val_test_num_samples = [train_samples, - eval_iters * args.global_batch_size, - test_iters * args.global_batch_size] - print_rank_0(' > datasets target sizes (minimum size):') - print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) - print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) - print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) - - # Build the datasets. - return build_train_valid_test_datasets_provider(train_val_test_num_samples) - - -def build_train_valid_test_data_loaders( - build_train_valid_test_datasets_provider): - """Build pretraining data loaders.""" - - args = get_args() - - (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) - - print_rank_0('> building train, validation, and test datasets ...') - - # Backward compatibility, assume fixed batch size. - if args.iteration > 0 and args.consumed_train_samples == 0: - assert args.train_samples is None, \ - 'only backward compatiblity support for iteration-based training' - args.consumed_train_samples = args.iteration * args.global_batch_size - if args.iteration > 0 and args.consumed_valid_samples == 0: - if args.train_samples is None: - args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ - args.eval_iters * args.global_batch_size - - # Rely on distributed-aware core datasets, temporary - is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False) - - # Construct the data pipeline - if is_distributed or mpu.get_tensor_model_parallel_rank() == 0: - - # Build datasets. - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - build_train_valid_test_datasets_provider) - # Build dataloders. - train_dataloader = build_pretraining_data_loader( - train_ds, args.consumed_train_samples) - if args.skip_train: - valid_dataloader = build_pretraining_data_loader(valid_ds, 0) - else: - valid_dataloader = build_pretraining_data_loader( - valid_ds, args.consumed_valid_samples) - test_dataloader = build_pretraining_data_loader(test_ds, 0) - - # Flags to know if we need to do training/validation/testing. - do_train = train_dataloader is not None and args.train_iters > 0 - do_valid = valid_dataloader is not None and args.eval_iters > 0 - do_test = test_dataloader is not None and args.eval_iters > 0 - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) - else: - flags = torch.cuda.LongTensor([0, 0, 0]) - - torch.distributed.broadcast(flags, 0) - - args.do_train = getattr(args, "do_train", False) or flags[0].item() - args.do_valid = getattr(args, "do_valid", False) or flags[1].item() - args.do_test = getattr(args, "do_test", False) or flags[2].item() - - return train_dataloader, valid_dataloader, test_dataloader - - -def build_train_valid_test_data_iterators( - build_train_valid_test_datasets_provider): - """Build pretraining data iterators.""" - - args = get_args() - - # Build loaders. - train_dataloader, valid_dataloader, test_dataloader = \ - build_train_valid_test_data_loaders( - build_train_valid_test_datasets_provider) - - # Build iterators. - dl_type = args.dataloader_type - assert dl_type in ['single', 'cyclic'] - - if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(train_dataloader)) - else: - train_data_iterator = None - - if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(valid_dataloader)) - else: - valid_data_iterator = None - - if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(test_dataloader)) - else: - test_data_iterator = None - - return train_data_iterator, valid_data_iterator, test_data_iterator diff --git a/megatron/utils.py b/megatron/utils.py deleted file mode 100644 index af9b4a07e08735dcb6b362a6cad4ee7b2a45a8db..0000000000000000000000000000000000000000 --- a/megatron/utils.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""General utilities.""" - -import sys - -import torch - -try: - from apex.multi_tensor_apply import multi_tensor_applier -except ImportError: - multi_tensor_applier = None - -try: - import amp_C -except ImportError: - amp_C = None - -from megatron import ( - get_args, - get_adlr_autoresume, -) -from megatron.core import DistributedDataParallel as DDP -from megatron.core import mpu -from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate -from megatron.model import Float16Module -from megatron.model.module import param_is_not_shared - - -ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) - - -def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): - return_list = True - if not isinstance(model, list): - model = [model] - return_list = False - unwrapped_model = [] - for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module - unwrapped_model.append(model_module) - if not return_list: - return unwrapped_model[0] - return unwrapped_model - - -def calc_params_l2_norm(model): - """Calculate l2 norm of parameters """ - args = get_args() - if not isinstance(model, list): - model = [model] - # Remove duplicate params. - params_data = [] - for model_ in model: - for param in model_.parameters(): - is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) - if mpu.get_expert_model_parallel_rank() > 0: - if not getattr(param, 'allreduce', True) and is_not_tp_duplicate: - assert param_is_not_shared(param) - params_data.append(param.data.float() if args.bf16 else param.data) - else: - is_not_shared = param_is_not_shared(param) - if is_not_shared and is_not_tp_duplicate: - params_data.append(param.data.float() if args.bf16 else param.data) - - # Check the availability of apex - assert multi_tensor_applier is not None and amp_C is not None, \ - "apex is not available, please install it from https://github.com/NVIDIA/apex" - - # Calculate norm - dummy_overflow_buf = torch.cuda.IntTensor([0]) - norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [params_data], - False # no per-parameter norm - ) - norm_2 = norm * norm - if mpu.get_expert_model_parallel_world_size() == 1: - # Sum across all model-parallel GPUs(tensor + pipeline). - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_model_parallel_group()) - else: - # Sum across tensor, pipeline and expert model-parallel GPUs. - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_tensor_and_expert_parallel_group()) - torch.distributed.all_reduce(norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_pipeline_model_parallel_group()) - return norm_2.item() ** 0.5 - - -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat( - [loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, - group=mpu.get_data_parallel_group()) - averaged_losses = averaged_losses / \ - torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) - - return averaged_losses - - -def report_memory(name): - """Simple GPU memory report.""" - mega_bytes = 1024.0 * 1024.0 - string = name + ' memory (MB)' - string += ' | allocated: {}'.format( - torch.cuda.memory_allocated() / mega_bytes) - string += ' | max allocated: {}'.format( - torch.cuda.max_memory_allocated() / mega_bytes) - string += ' | reserved: {}'.format( - torch.cuda.memory_reserved() / mega_bytes) - string += ' | max reserved: {}'.format( - torch.cuda.max_memory_reserved() / mega_bytes) - if mpu.get_data_parallel_rank() == 0: - print("[Rank {}] {}".format(torch.distributed.get_rank(), string), - flush=True) - - -def print_params_min_max_norm(optimizer, iteration): - """Print min, max, and norm of all parameters.""" - index = 0 - rank = torch.distributed.get_rank() - string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n' - optimizer_ = optimizer.optimizer - for param_group in optimizer_.param_groups: - for param in param_group['params']: - index += 1 - min_ = param.data.min() - max_ = param.data.max() - norm = torch.linalg.norm(param.data) - string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( - iteration, rank, index, int(param.tensor_model_parallel)) - string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) - print(string, flush=True) - - -def check_adlr_autoresume_termination(iteration, model, - optimizer, opt_param_scheduler): - """Check for autoresume signal and exit if it is received.""" - from megatron.checkpointing import save_checkpoint - - args = get_args() - autoresume = get_adlr_autoresume() - # Add barrier to ensure consistnecy. - torch.distributed.barrier() - if autoresume.termination_requested(): - if args.save: - save_checkpoint(iteration, model, optimizer, opt_param_scheduler) - print_rank_0(">>> autoresume termination request found!") - if torch.distributed.get_rank() == 0: - autoresume.request_resume() - print_rank_0(">>> training terminated. Returning") - sys.exit(0) - - -def get_ltor_masks_and_position_ids(data, - eod_token, - reset_position_ids, - reset_attention_mask, - eod_mask_loss): - """Build masks and position id for left to right model.""" - - # Extract batch size and sequence length. - micro_batch_size, seq_length = data.size() - - # Attention mask (lower triangular). - if reset_attention_mask: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 - attention_mask = torch.tril(torch.ones( - (att_mask_batch, seq_length, seq_length), device=data.device)).view( - att_mask_batch, 1, seq_length, seq_length) - - # Loss mask. - loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) - if eod_mask_loss: - loss_mask[data == eod_token] = 0.0 - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, - device=data.device) - position_ids = position_ids.unsqueeze(0).expand_as(data) - # We need to clone as the ids will be modifed based on batch index. - if reset_position_ids: - position_ids = position_ids.clone() - - if reset_position_ids or reset_attention_mask: - # Loop through the batches: - for b in range(micro_batch_size): - - # Find indecies where EOD token is. - eod_index = position_ids[b, data[b] == eod_token] - # Detach indecies from positions if going to modify positions. - if reset_position_ids: - eod_index = eod_index.clone() - - # Loop through EOD indecies: - prev_index = 0 - for j in range(eod_index.size()[0]): - i = eod_index[j] - # Mask attention loss. - if reset_attention_mask: - attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 - # Reset positions. - if reset_position_ids: - position_ids[b, (i + 1):] -= (i + 1 - prev_index) - prev_index = i + 1 - - # Convert attention mask to binary: - attention_mask = (attention_mask < 0.5) - - return attention_mask, loss_mask, position_ids - - -def get_batch_on_this_cp_rank(batch): - """ Slice batch input along sequence dimension into multiple chunks, - which are parallelized across GPUs in a context parallel group. - """ - - # With causal masking, each token only attends to its prior tokens. Simply split - # sequence into CP chunks can result in severe load imbalance. That's to say, chunks - # at the end of sequence have bigger workload than others. To address this issue, - # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 - # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so - # that we can get balanced workload among GPUs in a context parallel group. - args = get_args() - cp_size = args.context_parallel_size - if cp_size > 1: - cp_rank = mpu.get_context_parallel_rank() - for key, val in batch.items(): - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - - return batch - - -def print_rank_0(message): - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - -def is_last_rank(): - return torch.distributed.get_rank() == ( - torch.distributed.get_world_size() - 1) - -def print_rank_last(message): - """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True)