diff --git a/README_en.md b/README_en.md
index a14128071bf352403b4423370f11a7840b9a8d5a..102afd4658fbf4f72c2c546a766e7782d3116f54 100644
--- a/README_en.md
+++ b/README_en.md
@@ -237,9 +237,8 @@ For the supported models listed above, we provide training scripts and readme in
| torch | 2.1.0 |
| torch_npu | release v5.0.0 |
-The current repository uses Megatron commitId [bcce6f54](https://github.com/NVIDIA/Megatron-LM/tree/bcce6f54e075e3c3374ea67adefe54f3f2da2b07)
-【Based on the latest version, the performance statistics from our testing are as follows】
+【Based on the current version of megatron, the performance statistics from our testing are as follows】
diff --git a/examples/aquila/README.md b/examples/aquila/README.md
index d26df0a6a442a2c2c1b11f16261ea785bed3aae3..8575529334854a4262fc54d0c386b1000afd5c5f 100644
--- a/examples/aquila/README.md
+++ b/examples/aquila/README.md
@@ -28,6 +28,11 @@ Aquila-7B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -48,15 +53,17 @@ pip install torch_npu-2.1.0.postxxxx-cp38-cp38-xxxx_aarch64.whl
# 通过PTA上提供的安装包,以whl文件方式安装apex
pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl
-# 拉取AscendSpeed源代码,进入AscendSpeed目录,source相应的CANN环境文件,然后源码方式安装ascendspeed加速包
-git clone https://gitee.com/ascend/AscendSpeed.git
-cd AscendSpeed/
# 请按照您的真实环境修改 set_env.sh 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
-pip install -e ./
+# 拉取AscendSpeed源代码,进入AscendSpeed目录,然后源码方式安装ascendspeed加速包
+git clone https://gitee.com/ascend/AscendSpeed.git
+cd AscendSpeed/
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
+pip install -r requirements.txt
+pip install -e .
+cd ..
-# 进入ModelLink主目录,安装其余依赖包
-cd ModelLink/
+# 安装其余依赖包
pip install -r requirements.txt
```
diff --git a/examples/aquila/README_en.md b/examples/aquila/README_en.md
index 26bba5d923c0352bc919bf0c9a464535a0c87ed2..3fd5c58b9b685272b13b9f7fd83a52f7c1d559e1 100644
--- a/examples/aquila/README_en.md
+++ b/examples/aquila/README_en.md
@@ -28,6 +28,11 @@ Here's a hardware summary of pre-training Aquila-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -45,14 +50,18 @@ conda activate test
pip install torch==2.1.0
pip install torch_npu-2.1.0.postxxxx-cp38-cp38-xxxx_aarch64.whl
pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl
-# use git to clone the AscendSpeed source code, enter the directory, source the set_env.sh file based on your host settings(you may need to change the path), then install ascendspeed package by source code
+
+# source the set_env.sh file based on your host settings(you may need to change the path)
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+# use git to clone the AscendSpeed source code, enter the directory, then install ascendspeed package by source code
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed/
-source /usr/local/Ascend/ascend-toolkit/set_env.sh
-pip install -e ./
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
+pip install -r requirements.txt
+pip install -e .
+cd ..
-# enter the ModelLink/ directory and install other packages
-cd ModelLink/
+# install other packages
pip install -r requirements.txt
```
diff --git a/examples/baichuan/README.md b/examples/baichuan/README.md
index 3533bed6ee98f6d1d98f072a49b86b6510832aa1..535a9b8c87a1c903aca367476d0eb1d54784f12d 100644
--- a/examples/baichuan/README.md
+++ b/examples/baichuan/README.md
@@ -41,6 +41,11 @@ Baichuan-7B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -66,6 +71,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -262,6 +268,11 @@ Baichuan-13B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -287,6 +298,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/baichuan/README_en.md b/examples/baichuan/README_en.md
index fcb128f86ac9d2eca3363c642c2700583ce76ecf..b3d493c328022bb5f66f728e21c8c448f5938c1d 100644
--- a/examples/baichuan/README_en.md
+++ b/examples/baichuan/README_en.md
@@ -42,6 +42,11 @@ Here's a hardware summary of pre-training Baichuan-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -67,6 +72,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -264,6 +270,11 @@ Here's a hardware summary of pre-training Baichuan-13B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -289,6 +300,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
#install Ascendspeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/baichuan2/README.md b/examples/baichuan2/README.md
index 3cef67bd90bbf525dba3d3a49ba982a33d33b940..0636a1857bb501d1921da63b3cfb29bd1d43971d 100644
--- a/examples/baichuan2/README.md
+++ b/examples/baichuan2/README.md
@@ -39,6 +39,11 @@ Baichuan2-7B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -64,6 +69,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -263,6 +269,11 @@ Baichuan2-13B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -288,6 +299,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/baichuan2/README_en.md b/examples/baichuan2/README_en.md
index a47668e13f80046dcf3ea7a3b56e4331e25723a7..c433bf770b89b3d6fd777a52cd1f9a473c0c52a4 100644
--- a/examples/baichuan2/README_en.md
+++ b/examples/baichuan2/README_en.md
@@ -39,6 +39,11 @@ Here's a hardware summary of pre-training Baichuan2-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -64,6 +69,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -263,6 +269,11 @@ Here's a hardware summary of pre-training Baichuan2-13B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -288,6 +299,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/bloom/README.md b/examples/bloom/README.md
index 38fb639c5ead7a43614e806fed6b9c4db01f1c02..27d49a897c6d258a0b44aa6de05aa2b30647f840 100644
--- a/examples/bloom/README.md
+++ b/examples/bloom/README.md
@@ -24,6 +24,11 @@ Bloom-7B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -49,6 +54,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -231,6 +237,11 @@ Bloom-176B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -256,6 +267,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/bloom/README_en.md b/examples/bloom/README_en.md
index e6789e6ca484b6c539052aeb097d473ec1bbd853..64e10251f3d3d6d4a739e7b543d7afe1d5598de1 100644
--- a/examples/bloom/README_en.md
+++ b/examples/bloom/README_en.md
@@ -24,6 +24,11 @@ Here's a hardware summary of pre-training Bloom-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -49,6 +54,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -233,6 +239,11 @@ Here's a hardware summary of pre-training Bloom-176B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -258,6 +269,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/intern/README.md b/examples/intern/README.md
index 91cc5af84dcea33a762e56cd0edac9489df2693d..9b99ac6b6d0a761d16613aacfd09384964ee0a05 100644
--- a/examples/intern/README.md
+++ b/examples/intern/README.md
@@ -37,6 +37,11 @@ InternLM-7B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -62,6 +67,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -222,6 +228,11 @@ InternLM-65B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -247,6 +258,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/intern/README_en.md b/examples/intern/README_en.md
index 085797dea72ddb58f11f7ad8ae2b58886166ba6c..56d77e91c0f2d2dda31da560e269d1c7f2f5c0da 100644
--- a/examples/intern/README_en.md
+++ b/examples/intern/README_en.md
@@ -38,6 +38,11 @@ Here's a hardware summary of pre-training InternLM-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -63,6 +68,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -222,6 +228,11 @@ Here's a hardware summary of pre-training InternLM-65B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -247,6 +258,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/llama/README.md b/examples/llama/README.md
index 6d65e3a5db350c8788e40aeed50e00526b555531..fa0b323a6ea333867238ebde306b4140f8414b58 100644
--- a/examples/llama/README.md
+++ b/examples/llama/README.md
@@ -40,6 +40,11 @@ LLaMA-7B/13B 训练的硬件配置如下:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -60,6 +65,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -470,6 +476,11 @@ LLaMA-33B/65B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -494,6 +505,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/llama/README_en.md b/examples/llama/README_en.md
index 7cdefac288a98f966ce2a17b8df87542f547c800..2cd333875b12d5e8af6d468e58ef1f4edd4b1ac7 100644
--- a/examples/llama/README_en.md
+++ b/examples/llama/README_en.md
@@ -39,6 +39,11 @@ Here's a hardware summary of pre-training LLaMA-7B/13B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -60,6 +65,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install ascendspeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -452,6 +458,11 @@ The model was trained using alpaca datasets.
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -475,6 +486,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/llama2/README.md b/examples/llama2/README.md
index ace911a0968a393576cdfd31daf7acdafda4b0a1..991b5eddb72145061e30241912083c37c2f0f43f 100755
--- a/examples/llama2/README.md
+++ b/examples/llama2/README.md
@@ -47,6 +47,11 @@ LLAMA2-7B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -71,6 +76,7 @@ LLAMA2-7B 训练的硬件配置:
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -394,6 +400,11 @@ LLaMA2-13B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -418,6 +429,7 @@ LLaMA2-13B 训练的硬件配置:
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -681,6 +693,11 @@ LLaMA2-34B/70B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -705,6 +722,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/llama2/README_en.md b/examples/llama2/README_en.md
index 1626bb4ec73df93c360896db1a14fc5e4a416772..e2bdd42597185fa689b97a6f0411cfcbb922fcb8 100644
--- a/examples/llama2/README_en.md
+++ b/examples/llama2/README_en.md
@@ -46,6 +46,11 @@ Here's a hardware summary of pre-training LLAMA2-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -70,6 +75,7 @@ Here's a hardware summary of pre-training LLAMA2-7B:
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -413,6 +419,11 @@ Here's a hardware summary of pre-training LLaMA2-13B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -437,6 +448,7 @@ Here's a hardware summary of pre-training LLaMA2-13B:
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
@@ -696,6 +708,11 @@ Here's a hardware summary of pre-training LLaMA2-34B/70B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -721,6 +738,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/mixtral/README.md b/examples/mixtral/README.md
index 84e626163f5ba979cde41c010b773809d6af295b..e7b218aee08cbd16db122f17974807c8e3eb5a10 100644
--- a/examples/mixtral/README.md
+++ b/examples/mixtral/README.md
@@ -40,6 +40,11 @@
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -65,6 +70,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/mixtral/README_en.md b/examples/mixtral/README_en.md
index dc25da17d6787c6aa8682b1603bf145827fe6421..f462c6abcc674d3aac206bf1e8116456979192be 100644
--- a/examples/mixtral/README_en.md
+++ b/examples/mixtral/README_en.md
@@ -40,6 +40,11 @@ Recommended hardware configuration for inference:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+git clone https://github.com/NVIDIA/Megatron-LM.git
+cd Megatron-LM
+git checkout -f bcce6f
+cp -r megatron ../ModelLink/
+cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -65,6 +70,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip3 install -e .
cd ..
diff --git a/examples/qwen/README.md b/examples/qwen/README.md
index 64013256aeacf902163dfcc99260c9c275a016ea..d1004d3916ee8a564d4e84d2826d6b63ca2978b7 100644
--- a/examples/qwen/README.md
+++ b/examples/qwen/README.md
@@ -47,6 +47,11 @@ Qwen-7B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -68,6 +73,7 @@ Qwen-7B 训练的硬件配置:
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
@@ -273,6 +279,11 @@ Qwen-14B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -294,6 +305,7 @@ Qwen-14B 训练的硬件配置:
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
@@ -505,6 +517,11 @@ Qwen-72B 训练的硬件配置:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -526,6 +543,7 @@ Qwen-72B 训练的硬件配置:
# 安装加速库
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
diff --git a/examples/qwen/README_en.md b/examples/qwen/README_en.md
index 6f96b1a0131db4a5d812a39973ecf66f121a3a62..c98c2912b333ed4a417e3f4c6fa167e8d5422242 100644
--- a/examples/qwen/README_en.md
+++ b/examples/qwen/README_en.md
@@ -46,6 +46,11 @@ Here's a hardware summary of pre-training Qwen-7B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -67,6 +72,7 @@ Here's a hardware summary of pre-training Qwen-7B:
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
@@ -275,6 +281,11 @@ Here's a hardware summary of pre-training Qwen-14B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -296,6 +307,7 @@ Here's a hardware summary of pre-training Qwen-14B:
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
@@ -510,6 +522,11 @@ Here's a hardware summary of pre-training Qwen-72B:
```shell
git clone https://gitee.com/ascend/ModelLink.git
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout -f bcce6f
+ cp -r megatron ../ModelLink/
+ cd ..
cd ModelLink
mkdir logs
mkdir model_from_hf
@@ -531,6 +548,7 @@ Here's a hardware summary of pre-training Qwen-72B:
# install AscendSpeed
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
+ git checkout 224ae35e8fc96778f957029d1371ddb623452a50
pip install -r requirements.txt
pip install -e .
cd ..
diff --git a/megatron/__init__.py b/megatron/__init__.py
deleted file mode 100644
index c35de282a27e754544b1961d42bb5be27b83af71..0000000000000000000000000000000000000000
--- a/megatron/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-from .global_vars import get_args, get_retro_args
-from .global_vars import get_current_global_batch_size
-from .global_vars import get_num_microbatches
-from .global_vars import get_signal_handler
-from .global_vars import update_num_microbatches
-from .global_vars import get_tokenizer
-from .global_vars import get_tensorboard_writer
-from .global_vars import get_wandb_writer
-from .global_vars import get_adlr_autoresume
-from .global_vars import get_timers
-from .initialize import initialize_megatron
-
-from .utils import (print_rank_0,
- is_last_rank,
- print_rank_last)
diff --git a/megatron/arguments.py b/megatron/arguments.py
deleted file mode 100644
index d4f1cd5a324b5d23928a690ea7cf7bda30f4432f..0000000000000000000000000000000000000000
--- a/megatron/arguments.py
+++ /dev/null
@@ -1,1403 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron arguments."""
-
-import argparse
-import dataclasses
-import json
-import os
-import torch
-import types
-
-import torch.nn.functional as F
-from megatron.global_vars import set_retro_args, get_retro_args
-from tools.retro.utils import get_args_path as get_retro_args_path
-
-from megatron.core.models.retro import RetroConfig
-from megatron.core.transformer import TransformerConfig
-
-
-def parse_args(extra_args_provider=None, ignore_unknown_args=False):
- """Parse all arguments."""
- parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
- allow_abbrev=False)
-
- # Standard arguments.
- parser = _add_network_size_args(parser)
- parser = _add_regularization_args(parser)
- parser = _add_training_args(parser)
- parser = _add_initialization_args(parser)
- parser = _add_learning_rate_args(parser)
- parser = _add_checkpointing_args(parser)
- parser = _add_mixed_precision_args(parser)
- parser = _add_distributed_args(parser)
- parser = _add_validation_args(parser)
- parser = _add_data_args(parser)
- parser = _add_autoresume_args(parser)
- parser = _add_biencoder_args(parser)
- parser = _add_vision_args(parser)
- parser = _add_logging_args(parser)
- parser = _add_inference_args(parser)
- parser = _add_transformer_engine_args(parser)
- parser = _add_retro_args(parser)
- parser = _add_experimental_args(parser)
-
- # Custom arguments.
- if extra_args_provider is not None:
- parser = extra_args_provider(parser)
-
- # Parse.
- if ignore_unknown_args:
- args, _ = parser.parse_known_args()
- else:
- args = parser.parse_args()
-
- # Args from environment
- args.rank = int(os.getenv('RANK', '0'))
- args.world_size = int(os.getenv("WORLD_SIZE", '1'))
-
- return args
-
-def validate_args(args, defaults={}):
- # Tensor model parallel size.
- args.tensor_model_parallel_size = min(
- args.tensor_model_parallel_size, args.world_size)
- assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
- ' ({}) is not divisible by tensor model parallel size ({})'.format(
- args.world_size, args.tensor_model_parallel_size)
- # Pipeline model parallel size.
- args.pipeline_model_parallel_size = min(
- args.pipeline_model_parallel_size,
- (args.world_size // args.tensor_model_parallel_size))
- args.transformer_pipeline_model_parallel_size = (
- args.pipeline_model_parallel_size - 1
- if args.standalone_embedding_stage else
- args.pipeline_model_parallel_size
- )
- # Checks.
- model_parallel_size = args.pipeline_model_parallel_size * \
- args.tensor_model_parallel_size
- assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \
- 'world size ({}) is not divisible by tensor parallel size ({}) times ' \
- 'pipeline parallel size ({}) times context parallel size ({})'.format(
- args.world_size, args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size, args.context_parallel_size)
- args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
- if args.rank == 0:
- print('using world size: {}, data-parallel size: {}, '
- 'context-parallel size: {} '
- 'tensor-model-parallel size: {}, '
- 'pipeline-model-parallel size: {} '.format(
- args.world_size, args.data_parallel_size,
- args.context_parallel_size,
- args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size), flush=True)
- if args.pipeline_model_parallel_size > 1:
- if args.pipeline_model_parallel_split_rank is not None:
- assert args.pipeline_model_parallel_split_rank < \
- args.pipeline_model_parallel_size, 'split rank needs'\
- ' to be less than pipeline model parallel size ({})'.format(
- args.pipeline_model_parallel_size)
-
- if args.tp_comm_overlap:
- assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled'
-
-
- # Deprecated arguments
- assert args.batch_size is None, '--batch-size argument is no longer ' \
- 'valid, use --micro-batch-size instead'
- del args.batch_size
- assert args.warmup is None, '--warmup argument is no longer valid, use ' \
- '--lr-warmup-fraction instead'
- del args.warmup
- assert args.model_parallel_size is None, '--model-parallel-size is no ' \
- 'longer valid, use --tensor-model-parallel-size instead'
- del args.model_parallel_size
-
- if args.checkpoint_activations:
- if args.rank == 0:
- print('--checkpoint-activations is no longer valid, use --recompute-activations, '
- 'or, for more control, --recompute-granularity and --recompute-method.')
- exit()
- del args.checkpoint_activations
-
- if args.recompute_activations:
- args.recompute_granularity = 'selective'
- del args.recompute_activations
-
- # Set input defaults.
- for key in defaults:
- # For default to be valid, it should not be provided in the
- # arguments that are passed to the program. We check this by
- # ensuring the arg is set to None.
- if getattr(args, key, None) is not None:
- if args.rank == 0:
- print('WARNING: overriding default arguments for {key}:{v} \
- with {key}:{v2}'.format(key=key, v=defaults[key],
- v2=getattr(args, key)),
- flush=True)
- else:
- setattr(args, key, defaults[key])
-
- # Batch size.
- assert args.micro_batch_size is not None
- assert args.micro_batch_size > 0
- if args.global_batch_size is None:
- args.global_batch_size = args.micro_batch_size * args.data_parallel_size
- if args.rank == 0:
- print('setting global batch size to {}'.format(
- args.global_batch_size), flush=True)
- assert args.global_batch_size > 0
- if args.num_layers_per_virtual_pipeline_stage is not None:
- assert args.pipeline_model_parallel_size > 2, \
- 'pipeline-model-parallel size should be greater than 2 with ' \
- 'interleaved schedule'
- assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
- 'number of layers should be divisible by the pipeline parallel size'
- num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
- assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
- 'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
- args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
- args.num_layers_per_virtual_pipeline_stage
- else:
- args.virtual_pipeline_model_parallel_size = None
- # Overlap P2P communication is disabled if not using the interleaved schedule.
- args.overlap_p2p_comm = False
- if args.rank == 0:
- print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved '
- 'schedule does not support overlapping p2p communication')
-
- if args.overlap_param_gather:
- assert args.use_distributed_optimizer, \
- '--overlap-param-gather only supported with distributed optimizer'
-
- # Parameters dtype.
- args.params_dtype = torch.float
- if args.fp16:
- assert not args.bf16
- args.params_dtype = torch.half
- if args.bf16:
- assert not args.fp16
- args.params_dtype = torch.bfloat16
- # bfloat16 requires gradient accumulation and all-reduce to
- # be done in fp32.
- if not args.accumulate_allreduce_grads_in_fp32:
- args.accumulate_allreduce_grads_in_fp32 = True
- if args.rank == 0:
- print('accumulate and all-reduce gradients in fp32 for '
- 'bfloat16 data type.', flush=True)
-
- if args.rank == 0:
- print('using {} for parameters ...'.format(args.params_dtype),
- flush=True)
-
- if args.dataloader_type is None:
- args.dataloader_type = 'single'
-
- # Consumed tokens.
- args.consumed_train_samples = 0
- args.consumed_valid_samples = 0
-
- # Support for variable sequence lengths across batches/microbatches.
- # set it if the dataloader supports generation of variable sequence lengths
- # across batches/microbatches. Due to additional communication overhead
- # during pipeline parallelism, it should not be set if sequence length
- # is constant during training.
- args.variable_seq_lengths = False
-
- # Iteration-based training.
- if args.train_iters:
- # If we use iteration-based training, make sure the
- # sample-based options are off.
- assert args.train_samples is None, \
- 'expected iteration-based training'
- assert args.lr_decay_samples is None, \
- 'expected iteration-based learning rate decay'
- assert args.lr_warmup_samples == 0, \
- 'expected iteration-based learning rate warmup'
- assert args.rampup_batch_size is None, \
- 'expected no batch-size rampup for iteration-based training'
- if args.lr_warmup_fraction is not None:
- assert args.lr_warmup_iters == 0, \
- 'can only specify one of lr-warmup-fraction and lr-warmup-iters'
-
- # Sample-based training.
- if args.train_samples:
- # If we use sample-based training, make sure the
- # iteration-based options are off.
- assert args.train_iters is None, \
- 'expected sample-based training'
- assert args.lr_decay_iters is None, \
- 'expected sample-based learning rate decay'
- assert args.lr_warmup_iters == 0, \
- 'expected sample-based learnig rate warmup'
- if args.lr_warmup_fraction is not None:
- assert args.lr_warmup_samples == 0, \
- 'can only specify one of lr-warmup-fraction ' \
- 'and lr-warmup-samples'
-
- if args.num_layers is not None:
- assert args.encoder_num_layers is None, \
- 'cannot have both num-layers and encoder-num-layers specified'
- args.encoder_num_layers = args.num_layers
- else:
- assert args.encoder_num_layers is not None, \
- 'either num-layers or encoder-num-layers should be specified'
- args.num_layers = args.encoder_num_layers
-
- # Check required arguments.
- required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
- 'max_position_embeddings']
- for req_arg in required_args:
- _check_arg_is_not_none(args, req_arg)
-
- # Checks.
- if args.ffn_hidden_size is None:
- if args.swiglu:
- # reduce the dimnesion for MLP since projections happens on
- # two linear layers. this keeps the number of paramters in
- # the same ballpark as the counterpart with 4*h size
- # we keep it a multiple of 64, which means the actual tensor size
- # will be a multiple of 64 / tp_size
- args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
- else:
- args.ffn_hidden_size = 4 * args.hidden_size
-
- if args.kv_channels is None:
- assert args.hidden_size % args.num_attention_heads == 0
- args.kv_channels = args.hidden_size // args.num_attention_heads
-
- if args.seq_length is not None:
- assert args.encoder_seq_length is None
- args.encoder_seq_length = args.seq_length
- else:
- assert args.encoder_seq_length is not None
- args.seq_length = args.encoder_seq_length
-
- if args.seq_length is not None:
- assert args.max_position_embeddings >= args.seq_length
- if args.decoder_seq_length is not None:
- assert args.max_position_embeddings >= args.decoder_seq_length
- if args.lr is not None:
- assert args.min_lr <= args.lr
- if args.save is not None:
- assert args.save_interval is not None
- # Mixed precision checks.
- if args.fp16_lm_cross_entropy:
- assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
- if args.fp32_residual_connection:
- assert args.fp16 or args.bf16, \
- 'residual connection in fp32 only supported when using fp16 or bf16.'
-
- if args.weight_decay_incr_style == 'constant':
- assert args.start_weight_decay is None
- assert args.end_weight_decay is None
- args.start_weight_decay = args.weight_decay
- args.end_weight_decay = args.weight_decay
- else:
- assert args.start_weight_decay is not None
- assert args.end_weight_decay is not None
-
- TORCH_MAJOR = int(torch.__version__.split('.')[0])
- TORCH_MINOR = int(torch.__version__.split('.')[1])
- # Persistent fused layer norm.
- if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
- args.no_persist_layer_norm = True
- if args.rank == 0:
- print('Persistent fused layer norm kernel is supported from '
- 'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
- 'Defaulting to no_persist_layer_norm=True')
-
- # Activation recomputing.
- if args.distribute_saved_activations:
- assert args.tensor_model_parallel_size > 1, 'can distribute ' \
- 'recomputed activations only across tensor model ' \
- 'parallel groups'
- assert args.recompute_granularity == 'full', \
- 'distributed recompute activations is only '\
- 'application to full recompute granularity'
- assert args.recompute_method is not None, \
- 'for distributed recompute activations to work you '\
- 'need to use a recompute method '
- assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \
- 'distributed recompute activations are supported for pytorch ' \
- 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
- 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
-
- if args.recompute_granularity == 'selective':
- assert args.recompute_method is None, \
- 'recompute method is not yet supported for ' \
- 'selective recomputing granularity'
-
- # disable sequence parallelism when tp=1
- # to avoid change in numerics when
- # sequence_parallelism is enabled.
- if args.tensor_model_parallel_size == 1:
- args.sequence_parallel = False
-
- # disable async_tensor_model_parallel_allreduce when
- # model parallel memory optimization is enabled
- if args.sequence_parallel:
- args.async_tensor_model_parallel_allreduce = False
-
- if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
- if args.sequence_parallel:
- raise RuntimeError(
- "Using sequence parallelism requires setting the environment variable "
- "CUDA_DEVICE_MAX_CONNECTIONS to 1")
- if args.async_tensor_model_parallel_allreduce:
- raise RuntimeError(
- "Using async gradient all reduce requires setting the environment "
- "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
-
- # Disable bias gelu fusion if we are disabling bias altogether
- if not args.add_bias_linear:
- args.bias_gelu_fusion = False
-
- # Retro checks.
- if args.retro_add_retriever:
-
- # Sequence parallelism unsupported.
- assert not args.sequence_parallel, \
- "retro currently does not support sequence parallelism."
-
- # Pipeline parallelism unsupported.
- assert args.pipeline_model_parallel_size == 1, \
- "retro currently does not support pipeline parallelism."
-
- # Load retro args.
- retro_args_path = get_retro_args_path(args.retro_workdir)
- assert os.path.exists(retro_args_path), "retro workdir missing args.json"
- with open(retro_args_path) as f:
- retro_args = types.SimpleNamespace(**json.load(f))
- retro_args.retro_return_doc_ids = args.retro_return_doc_ids
- retro_args.retro_gpt_retrieved_length = \
- args.retro_num_retrieved_chunks * \
- retro_args.retro_gpt_chunk_length
- set_retro_args(retro_args)
-
- # Legacy RoPE arguments
- if args.use_rotary_position_embeddings:
- args.position_embedding_type = 'rope'
-
- # Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
- # don't allow it to keep things simple
- if not args.add_position_embedding and args.position_embedding_type != 'rope':
- raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
-
- # MoE Spec check
- if args.num_experts is not None:
- assert args.spec is None, "Model Spec must be None when using MoEs"
-
- # Expert parallelism check
- if args.expert_model_parallel_size > 1:
- assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism"
- assert args.num_experts % args.expert_model_parallel_size == 0, \
- "Number of experts should be a multiple of expert model parallel_size."
- assert not args.use_distributed_optimizer, \
- "Expert parallelism is not suppored with distributed optimizer."
- assert not args.fp16, \
- "Expert parallelism is not supported with fp16 training."
- if args.tensor_model_parallel_size > 1:
- assert args.sequence_parallel, \
- "When using expert parallelism and tensor parallelism, sequence parallelism must be used."
-
- # Print arguments.
- _print_args("arguments", args)
- retro_args = get_retro_args()
- if retro_args and args != retro_args:
- _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
-
- return args
-
-
-def _print_args(title, args):
- """Print arguments."""
- if args.rank == 0:
- print(f'------------------------ {title} ------------------------',
- flush=True)
- str_list = []
- for arg in vars(args):
- dots = '.' * (48 - len(arg))
- str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
- for arg in sorted(str_list, key=lambda x: x.lower()):
- print(arg, flush=True)
- print(f'-------------------- end of {title} ---------------------',
- flush=True)
-
-
-def _check_arg_is_not_none(args, arg):
- assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
-
-def core_transformer_config_from_args(args):
-
- # Translate args to core transformer configuration
- kw_args = {}
- for f in dataclasses.fields(TransformerConfig):
- if hasattr(args, f.name):
- kw_args[f.name] = getattr(args, f.name)
- kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
- kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
- kw_args['layernorm_epsilon'] = args.norm_epsilon
- kw_args['deallocate_pipeline_outputs'] = True
- kw_args['pipeline_dtype'] = args.params_dtype
- kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
- kw_args['num_moe_experts'] = args.num_experts
- if args.swiglu:
- kw_args['activation_func'] = F.silu
- kw_args['gated_linear_unit'] = True
- kw_args['bias_gelu_fusion'] = False
- if args.squared_relu:
- assert not args.swiglu
- def squared_relu(x):
- return torch.pow(F.relu(x), 2)
- kw_args['activation_func'] = squared_relu
- if args.init_method_xavier_uniform:
- kw_args['init_method'] = torch.nn.init.xavier_uniform_
- kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
- if args.group_query_attention:
- kw_args['num_query_groups'] = args.num_query_groups
- else:
- kw_args['num_query_groups'] = None
-
- # If using Retro, return Retro config.
- retro_args = get_retro_args()
- if retro_args:
- kw_args['retro_preprocess'] = retro_args
- return RetroConfig(**kw_args)
-
- # Return Transformer config.
- return TransformerConfig(**kw_args)
-
-
-def _add_transformer_engine_args(parser):
- group = parser.add_argument_group(title='Transformer-Engine')
-
- group.add_argument('--fp8-format', default=None,
- choices=['e4m3', 'hybrid'],
- help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass',
- dest='fp8')
- group.add_argument('--fp8-margin', type=int, default=0,
- help='Scaling margin for fp8',
- dest='fp8_margin')
- group.add_argument('--fp8-interval', type=int, default=1,
- help='Scaling update interval for fp8',
- dest='fp8_interval')
- group.add_argument('--fp8-amax-history-len', type=int, default=1,
- help='Number of steps for which amax history is recorded per tensor',
- dest='fp8_amax_history_len')
- group.add_argument('--fp8-amax-compute-algo', default='most_recent',
- choices=['most_recent', 'max'],
- help='Algorithm for computing amax from history',
- dest='fp8_amax_compute_algo')
- group.add_argument('--no-fp8-wgrad', action='store_false',
- help='Execute wgrad in higher precision even for FP8 runs',
- dest='fp8_wgrad')
- group.add_argument('--transformer-impl', default='local',
- choices=['local', 'transformer_engine'],
- help='Which Transformer implementation to use.')
-
- return parser
-
-def _add_inference_args(parser):
- group = parser.add_argument_group(title='inference')
-
- group.add_argument('--inference-batch-times-seqlen-threshold',
- type=int, default=512,
- help='During inference, if batch-size times '
- 'sequence-length is smaller than this threshold '
- 'then we will not use pipelining, otherwise we will.')
- group.add_argument('--max-tokens-to-oom',
- type=int, default=12000,
- help='Maximum number of tokens during inference'
- 'tokens here is # in prompt + # to generate'
- 'Allows us to throw an error before OOM crashes server')
- group.add_argument('--output-bert-embeddings', action='store_true',
- help='Output Bert embeddings (via mean pooling) from '
- 'model, rather than its binary head output or entire '
- 'hidden batch.')
- group.add_argument('--bert-embedder-type', default="megatron",
- choices=["megatron", "huggingface"],
- help='Select either Megatron or Huggingface as the '
- 'Bert embedder.')
-
- return parser
-
-
-def _add_retro_args(parser):
- group = parser.add_argument_group(title='retro')
-
- group.add_argument('--retro-workdir', default=None,
- help='Retro working directory, which contains the '
- 'preprocessed data for for pretraining. This directory '
- 'is built during preprocessing (see '
- 'tools/retro/README.md), and contains subdirectories '
- 'for the chunk database and pretraining neighbors.')
- group.add_argument('--retro-add-retriever',
- action='store_true', default=False,
- help='Add a retriever to the transformer, for use in '
- 'pretraining a Retro model.')
- group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
- help='Set number of training iterations for cyclic '
- 'Retro training.')
- group.add_argument('--retro-encoder-layers', type=int, default=2,
- help='Number of layers to use for the retrieval '
- 'encoder.')
- group.add_argument('--retro-encoder-hidden-dropout',
- type=float, default=0.1, help='Hidden dropout for '
- 'retrieval encoder.')
- group.add_argument('--retro-encoder-attention-dropout',
- type=float, default=0.1, help='Attention dropout for '
- 'retrieval encoder.')
- group.add_argument("--retro-num-neighbors", type=int, default=2,
- help='Number of neighbors to retrieve during '
- 'pretraining.')
- group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
- help='Number of chunks to retrieve from the retrieval '
- 'database.')
- group.add_argument("--retro-return-doc-ids", action="store_true",
- help="Turn this on when preprocessing retro data.")
- group.add_argument("--retro-no-verify-neighbor-count", action="store_false",
- dest="retro_verify_neighbor_count",
- help="Skip verifying that len(GPT dataset) == len(saved "
- "neighbors).")
-
- # Enforce argument naming convention.
- for action in group._group_actions:
- prefix = action.dest.split("_")[0]
- assert prefix == "retro", \
- "Retro args must be prefixed with '--retro-*', for consistent " \
- "styling. Please fix '%s'." % ", ".join(action.option_strings)
-
- return parser
-
-
-def _add_network_size_args(parser):
- group = parser.add_argument_group(title='network size')
-
- group.add_argument('--num-layers', type=int, default=None,
- help='Number of transformer layers.')
- group.add_argument('--encoder-num-layers', type=int, default=None,
- help='Number of encoder transformer layers.')
- group.add_argument('--decoder-num-layers', type=int, default=None,
- help='Number of decoder transformer layers.')
- group.add_argument('--hidden-size', type=int, default=None,
- help='Tansformer hidden size.')
- group.add_argument('--ffn-hidden-size', type=int, default=None,
- help='Transformer Feed-Forward Network hidden size. '
- 'This is set to 4*hidden-size if not provided')
- group.add_argument('--num-attention-heads', type=int, default=None,
- help='Number of transformer attention heads.')
- group.add_argument('--kv-channels', type=int, default=None,
- help='Projection weights dimension in multi-head '
- 'attention. This is set to '
- ' args.hidden_size // args.num_attention_heads '
- 'if not provided.')
- group.add_argument('--group-query-attention', action='store_true',
- help='Use group-query attention.')
- group.add_argument('--num-query-groups', type=int, default=1)
-
- group.add_argument('--max-position-embeddings', type=int, default=None,
- help='Maximum number of position embeddings to use. '
- 'This is the size of position embedding.')
- group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
- choices=['learned_absolute', 'rope'],
- help='Position embedding type.')
- group.add_argument('--use-rotary-position-embeddings', action='store_true',
- help='Use rotary positional embeddings or not. '
- 'Deprecated: use --position-embedding-type')
- group.add_argument('--rotary-percent', type=float, default=1.0,
- help='Percent of rotary dimension to use, default 100%%')
- group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
- help='Sequence length interpolation factor for rotary embeddings.')
- group.add_argument('--no-position-embedding',
- action='store_false',
- help='Disable position embedding. Deprecated: use --position-embedding-type',
- dest='add_position_embedding')
- group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
- help='Pad the vocab size to be divisible by this value.'
- 'This is added for computational efficieny reasons.')
- group.add_argument('--normalization', default='LayerNorm',
- choices=['LayerNorm', 'RMSNorm'],
- help='Which normalization technique to use.')
- group.add_argument('--norm-epsilon', type=float, default=1e-5,
- help='Epsilon for layer norm and RMS norm.')
- group.add_argument('--apply-layernorm-1p', action='store_true',
- help='Adjust LayerNorm weights such that they are centered '
- 'around zero. This improves numerical stability.')
- group.add_argument('--apply-residual-connection-post-layernorm',
- action='store_true',
- help='If set, use original BERT residula connection '
- 'ordering.')
- group.add_argument('--openai-gelu', action='store_true',
- help='Use OpenAIs GeLU implementation. This option'
- 'should not be used unless for backward compatibility'
- 'reasons.')
- group.add_argument('--squared-relu', action='store_true',
- help='Use squared relu activation instead of default gelu')
- group.add_argument('--swiglu', action='store_true',
- help='Use gated linear units and SiLU activation instead of default gelu')
- group.add_argument('--onnx-safe', type=bool, required=False,
- help='Use workarounds for known problems with '
- 'Torch ONNX exporter')
- group.add_argument('--bert-no-binary-head', action='store_false',
- help='Disable BERT binary head.',
- dest='bert_binary_head')
- group.add_argument('--num-experts', type=int, default=None,
- help='Number of Experts in Switch Transformer (None means no Switch)')
- group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
- help='Untie embeddings and output weights.'),
- return parser
-
-
-def _add_logging_args(parser):
- group = parser.add_argument_group(title='logging')
-
- group.add_argument('--log-params-norm', action='store_true',
- help='If set, calculate and log parameters norm.')
- group.add_argument('--log-num-zeros-in-grad', action='store_true',
- help='If set, calculate and log the number of zeros in gradient.')
- group.add_argument('--log-throughput', action='store_true',
- help='If set, calculate and log throughput per GPU.')
- group.add_argument('--timing-log-level', type=int,
- default=0, choices=range(0,3),
- help='Granularity level to measure and report timing. '
- ' 0: report only iteration time and make sure timing '
- ' does not introduce extra overhead.'
- ' 1: report timing for operations that are executed '
- ' very limited times (basically once) during '
- ' each iteration (such as gradient all-reduce) '
- ' 2: report timing for operations that migh be '
- ' executed numerous times during each iteration. '
- 'Note that setting the level to 1 or 2 might '
- 'cause increase in iteration time.')
- group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
- help='If not set, use barrier with level 1 time '
- 'measurements. Note that this is up to the user '
- 'to make sure calling barrier with their timers '
- 'will not result in hangs. This can happen if for '
- 'example the user adds a level 1 timer that is not '
- 'called by all ranks.',
- dest='barrier_with_L1_time')
- group.add_argument('--timing-log-option', type=str, default='minmax',
- choices=['max', 'minmax', 'all'],
- help='Options for logging timing:'
- ' max: report the max timing across all ranks'
- ' minmax: report min and max timings across all ranks'
- ' all: report timings of all ranks.')
- group.add_argument('--tensorboard-log-interval', type=int, default=1,
- help='Report to tensorboard interval.')
- group.add_argument('--tensorboard-queue-size', type=int, default=1000,
- help='Size of the tensorboard queue for pending events '
- 'and summaries before one of the ‘add’ calls forces a '
- 'flush to disk.')
- group.add_argument('--log-timers-to-tensorboard', action='store_true',
- help='If set, write timers to tensorboard.')
- group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
- help='If set, write batch-size to tensorboard.')
- group.add_argument('--no-log-learnig-rate-to-tensorboard',
- action='store_false',
- help='Disable learning rate logging to tensorboard.',
- dest='log_learning_rate_to_tensorboard')
- group.add_argument('--no-log-loss-scale-to-tensorboard',
- action='store_false',
- help='Disable loss-scale logging to tensorboard.',
- dest='log_loss_scale_to_tensorboard')
- group.add_argument('--log-validation-ppl-to-tensorboard',
- action='store_true',
- help='If set, write validation perplexity to '
- 'tensorboard.')
- group.add_argument('--log-memory-to-tensorboard',
- action='store_true',
- help='Enable memory logging to tensorboard.')
- group.add_argument('--log-world-size-to-tensorboard',
- action='store_true',
- help='Enable world size logging to tensorboard.')
- group.add_argument('--wandb-project', type=str, default='',
- help='The wandb project name. Ignore wandb by default.')
- group.add_argument('--wandb-exp-name', type=str, default='',
- help='The wandb experiment name.')
- group.add_argument('--wandb-save-dir', type=str, default='',
- help='Path to save the wandb results locally.')
- return parser
-
-
-def _add_regularization_args(parser):
- group = parser.add_argument_group(title='regularization')
-
- group.add_argument('--attention-dropout', type=float, default=0.1,
- help='Post attention dropout probability.')
- group.add_argument('--hidden-dropout', type=float, default=0.1,
- help='Dropout probability for hidden state transformer.')
- group.add_argument('--weight-decay', type=float, default=0.01,
- help='Weight decay coefficient for L2 regularization.')
- group.add_argument('--start-weight-decay', type=float,
- help='Initial weight decay coefficient for L2 regularization.')
- group.add_argument('--end-weight-decay', type=float,
- help='End of run weight decay coefficient for L2 regularization.')
- group.add_argument('--weight-decay-incr-style', type=str, default='constant',
- choices=['constant', 'linear', 'cosine'],
- help='Weight decay increment function.')
- group.add_argument('--clip-grad', type=float, default=1.0,
- help='Gradient clipping based on global L2 norm.')
- group.add_argument('--adam-beta1', type=float, default=0.9,
- help='First coefficient for computing running averages '
- 'of gradient and its square')
- group.add_argument('--adam-beta2', type=float, default=0.999,
- help='Second coefficient for computing running averages '
- 'of gradient and its square')
- group.add_argument('--adam-eps', type=float, default=1e-08,
- help='Term added to the denominator to improve'
- 'numerical stability')
- group.add_argument('--sgd-momentum', type=float, default=0.9,
- help='Momentum factor for sgd')
- return parser
-
-
-def _add_training_args(parser):
- group = parser.add_argument_group(title='training')
-
- group.add_argument('--micro-batch-size', type=int, default=None,
- help='Batch size per model instance (local batch size). '
- 'Global batch size is local batch size times data '
- 'parallel size times number of micro batches.')
- group.add_argument('--batch-size', type=int, default=None,
- help='Old batch size parameter, do not use. '
- 'Use --micro-batch-size instead')
- group.add_argument('--global-batch-size', type=int, default=None,
- help='Training batch size. If set, it should be a '
- 'multiple of micro-batch-size times data-parallel-size. '
- 'If this value is None, then '
- 'use micro-batch-size * data-parallel-size as the '
- 'global batch size. This choice will result in 1 for '
- 'number of micro-batches.')
- group.add_argument('--rampup-batch-size', nargs='*', default=None,
- help='Batch size ramp up with the following values:'
- ' --rampup-batch-size '
- ' '
- ' '
- 'For example:'
- ' --rampup-batch-size 16 8 300000 \ '
- ' --global-batch-size 1024'
- 'will start with global batch size 16 and over '
- ' (1024 - 16) / 8 = 126 intervals will increase'
- 'the batch size linearly to 1024. In each interval'
- 'we will use approximately 300000 / 126 = 2380 samples.')
- group.add_argument('--recompute-activations', action='store_true',
- help='recompute activation to allow for training '
- 'with larger models, sequences, and batch sizes.')
- group.add_argument('--recompute-granularity', type=str, default=None,
- choices=['full', 'selective'],
- help='Checkpoint activations to allow for training '
- 'with larger models, sequences, and batch sizes. '
- 'It is supported at two granularities 1) full: '
- 'whole transformer layer is recomputed, '
- '2) selective: core attention part of the transformer '
- 'layer is recomputed.')
- group.add_argument('--no-check-for-nan-in-loss-and-grad', action='store_false',
- help='Check for NaNs in loss and grad',
- dest='check_for_nan_in_loss_and_grad')
- group.add_argument('--distribute-saved-activations',
- action='store_true',
- help='If set, distribute recomputed activations '
- 'across model parallel group.')
- group.add_argument('--recompute-method', type=str, default=None,
- choices=['uniform', 'block'],
- help='1) uniform: uniformly divide the total number of '
- 'Transformer layers and recompute the input activation of '
- 'each divided chunk at specified granularity, '
- '2) recompute the input activations of only a set number of '
- 'individual Transformer layers per pipeline stage and do the '
- 'rest without any recomputing at specified granularity'
- 'default) do not apply activations recompute to any layers')
- group.add_argument('--recompute-num-layers', type=int, default=None,
- help='1) uniform: the number of Transformer layers in each '
- 'uniformly divided recompute unit, '
- '2) block: the number of individual Transformer layers '
- 'to recompute within each pipeline stage.')
- group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
- help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
- dest='clone_scatter_output_in_embedding')
- group.add_argument('--profile', action='store_true',
- help='Enable nsys profiling. When using this option, nsys '
- 'options should be specified in commandline. An example '
- 'nsys commandline is `nsys profile -s none -t nvtx,cuda '
- '-o --force-overwrite true '
- '--capture-range=cudaProfilerApi '
- '--capture-range-end=stop`.')
- group.add_argument('--profile-step-start', type=int, default=10,
- help='Global step to start profiling.')
- group.add_argument('--profile-step-end', type=int, default=12,
- help='Global step to stop profiling.')
- group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
- help='Global ranks to profile.')
- group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the '
- ' overlap of Tensor parallel communication and GEMM kernels.')
- group.add_argument('--tp-comm-overlap-cfg', type=str, default=None,
- help = 'Config file when tp_comm_overlap is enabled.')
- group.add_argument('--disable-tp-comm-split-ag', action='store_false',
- help = 'Disables the All-Gather overlap with fprop GEMM.',
- dest='tp_comm_split_ag')
- group.add_argument('--disable-tp-comm-split-rs', action='store_false',
- help = 'Disables the Reduce-Scatter overlap with fprop GEMM.',
- dest='tp_comm_split_rs')
- group.add_argument('--disable-tp-comm-bulk-dgrad', action='store_false',
- help = 'Disables the All-Gather overlap with bprop activation gradient GEMM.',
- dest='tp_comm_bulk_dgrad')
- group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false',
- help = 'Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.',
- dest='tp_comm_bulk_wgrad')
-
-
- # deprecated
- group.add_argument('--checkpoint-activations', action='store_true',
- help='Checkpoint activation to allow for training '
- 'with larger models, sequences, and batch sizes.')
- group.add_argument('--train-iters', type=int, default=None,
- help='Total number of iterations to train over all '
- 'training runs. Note that either train-iters or '
- 'train-samples should be provided.')
- group.add_argument('--train-samples', type=int, default=None,
- help='Total number of samples to train over all '
- 'training runs. Note that either train-iters or '
- 'train-samples should be provided.')
- group.add_argument('--log-interval', type=int, default=100,
- help='Report loss and timing interval.')
- group.add_argument('--exit-interval', type=int, default=None,
- help='Exit the program after the iteration is divisible '
- 'by this value.')
- group.add_argument('--exit-duration-in-mins', type=int, default=None,
- help='Exit the program after this many minutes.')
- group.add_argument('--exit-signal-handler', action='store_true',
- help='Dynamically save the checkpoint and shutdown the '
- 'training if SIGTERM is received')
- group.add_argument('--tensorboard-dir', type=str, default=None,
- help='Write TensorBoard logs to this directory.')
- group.add_argument('--no-masked-softmax-fusion',
- action='store_false',
- help='Disable fusion of query_key_value scaling, '
- 'masking, and softmax.',
- dest='masked_softmax_fusion')
- group.add_argument('--no-bias-gelu-fusion', action='store_false',
- help='Disable bias and gelu fusion.',
- dest='bias_gelu_fusion')
- group.add_argument('--no-bias-dropout-fusion', action='store_false',
- help='Disable bias and dropout fusion.',
- dest='bias_dropout_fusion')
- group.add_argument('--use-flash-attn', action='store_true',
- help='use FlashAttention implementation of attention. '
- 'https://arxiv.org/abs/2205.14135')
- group.add_argument('--disable-bias-linear', action='store_false',
- help='Disable bias in the linear layers',
- dest='add_bias_linear')
- group.add_argument('--optimizer', type=str, default='adam',
- choices=['adam', 'sgd'],
- help='Optimizer function')
- group.add_argument('--dataloader-type', type=str, default=None,
- choices=['single', 'cyclic'],
- help='Single pass vs multiple pass data loader')
- group.add_argument('--no-async-tensor-model-parallel-allreduce',
- action='store_false',
- help='Disable asynchronous execution of '
- 'tensor-model-parallel all-reduce with weight '
- 'gradient compuation of a column-linear layer.',
- dest='async_tensor_model_parallel_allreduce')
- group.add_argument('--no-persist-layer-norm', action='store_true',
- help='Disable using persistent fused layer norm kernel. '
- 'This kernel supports only a set of hidden sizes. Please '
- 'check persist_ln_hidden_sizes if your hidden '
- 'size is supported.')
- group.add_argument('--sequence-parallel', action='store_true',
- help='Enable sequence parallel optimization.')
- group.add_argument('--no-gradient-accumulation-fusion',
- action='store_false',
- help='Disable fusing gradient accumulation to weight '
- 'gradient computation of linear layers',
- dest='gradient_accumulation_fusion')
- group.add_argument('--use-mcore-models', action='store_true',
- help='Use the implementation from megatron core')
- group.add_argument('--manual-gc', action='store_true',
- help='Disable the threshold-based default garbage '
- 'collector and trigger the garbage collection manually. '
- 'Manual garbage collection helps to align the timing of '
- 'the collection across ranks which mitigates the impact '
- 'of CPU-associated jitters. When the manual gc is enabled, '
- 'garbage collection is performed only at the start and the '
- 'end of the validation routine by default.')
- group.add_argument('--manual-gc-interval', type=int, default=0,
- help='Training step interval to trigger manual garbage '
- 'collection. When the value is set to 0, garbage '
- 'collection is not triggered between training steps.')
- group.add_argument('--no-manual-gc-eval', action='store_false',
- help='When using manual garbage collection, disable '
- 'garbage collection at the start and the end of each '
- 'evaluation run.', dest='manual_gc_eval')
-
- return parser
-
-
-def _add_initialization_args(parser):
- group = parser.add_argument_group(title='initialization')
-
- group.add_argument('--seed', type=int, default=1234,
- help='Random seed used for python, numpy, '
- 'pytorch, and cuda.')
- group.add_argument('--data-parallel-random-init', action='store_true',
- help='Enable random initialization of params '
- 'across data parallel ranks')
- group.add_argument('--init-method-std', type=float, default=0.02,
- help='Standard deviation of the zero mean normal '
- 'distribution used for weight initialization.')
- group.add_argument('--init-method-xavier-uniform', action='store_true',
- help='Enable Xavier uniform parameter initialization')
-
- return parser
-
-
-def _add_learning_rate_args(parser):
- group = parser.add_argument_group(title='learning rate')
-
- group.add_argument('--lr', type=float, default=None,
- help='Initial learning rate. Depending on decay style '
- 'and initial warmup, the learing rate at each '
- 'iteration would be different.')
- group.add_argument('--lr-decay-style', type=str, default='linear',
- choices=['constant', 'linear', 'cosine', 'inverse-square-root'],
- help='Learning rate decay function.')
- group.add_argument('--lr-decay-iters', type=int, default=None,
- help='number of iterations to decay learning rate over,'
- ' If None defaults to `--train-iters`')
- group.add_argument('--lr-decay-samples', type=int, default=None,
- help='number of samples to decay learning rate over,'
- ' If None defaults to `--train-samples`')
- group.add_argument('--lr-warmup-fraction', type=float, default=None,
- help='fraction of lr-warmup-(iters/samples) to use '
- 'for warmup (as a float)')
- group.add_argument('--lr-warmup-iters', type=int, default=0,
- help='number of iterations to linearly warmup '
- 'learning rate over.')
- group.add_argument('--lr-warmup-samples', type=int, default=0,
- help='number of samples to linearly warmup '
- 'learning rate over.')
- group.add_argument('--lr-warmup-init', type=float, default=0.0,
- help='Initial value for learning rate warmup. The '
- 'scheduler starts warmup from this value.')
- group.add_argument('--warmup', type=int, default=None,
- help='Old lr warmup argument, do not use. Use one of the'
- '--lr-warmup-* arguments above')
- group.add_argument('--min-lr', type=float, default=0.0,
- help='Minumum value for learning rate. The scheduler'
- 'clip values below this threshold.')
- group.add_argument('--override-opt_param-scheduler', action='store_true',
- help='Reset the values of the scheduler (learning rate,'
- 'warmup iterations, minimum learning rate, maximum '
- 'number of iterations, and decay style from input '
- 'arguments and ignore values from checkpoints. Note'
- 'that all the above values will be reset.')
- group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
- help='Use checkpoint to set the values of the scheduler '
- '(learning rate, warmup iterations, minimum learning '
- 'rate, maximum number of iterations, and decay style '
- 'from checkpoint and ignore input arguments.')
-
- return parser
-
-
-def _add_checkpointing_args(parser):
- group = parser.add_argument_group(title='checkpointing')
-
- group.add_argument('--save', type=str, default=None,
- help='Output directory to save checkpoints to.')
- group.add_argument('--save-interval', type=int, default=None,
- help='Number of iterations between checkpoint saves.')
- group.add_argument('--no-save-optim', action='store_true', default=None,
- help='Do not save current optimizer.')
- group.add_argument('--no-save-rng', action='store_true', default=None,
- help='Do not save current rng state.')
- group.add_argument('--load', type=str, default=None,
- help='Directory containing a model checkpoint.')
- group.add_argument('--no-load-optim', action='store_true', default=None,
- help='Do not load optimizer when loading checkpoint.')
- group.add_argument('--no-load-rng', action='store_true', default=None,
- help='Do not load rng state when loading checkpoint.')
- group.add_argument('--finetune', action='store_true',
- help='Load model for finetuning. Do not load optimizer '
- 'or rng state from checkpoint and set iteration to 0. '
- 'Assumed when loading a release checkpoint.')
- group.add_argument('--no-initialization', action='store_false',
- help='Do not perform initialization when building model, '
- 'can reduce startup time when definitely loading from a '
- 'checkpoint',
- dest='perform_initialization')
- group.add_argument('--use-checkpoint-args', action='store_true',
- help='Override any command line arguments with arguments '
- 'from the checkpoint')
- group.add_argument('--exit-on-missing-checkpoint', action='store_true',
- help="If '--load' is set, but checkpoint is not found "
- "(e.g., path typo), then exit instead of random "
- "initialization.")
-
- return parser
-
-
-def _add_mixed_precision_args(parser):
- group = parser.add_argument_group(title='mixed precision')
-
- group.add_argument('--fp16', action='store_true',
- help='Run model in fp16 mode.')
- group.add_argument('--bf16', action='store_true',
- help='Run model in bfloat16 mode.')
- group.add_argument('--loss-scale', type=float, default=None,
- help='Static loss scaling, positive power of 2 '
- 'values can improve fp16 convergence. If None, dynamic'
- 'loss scaling is used.')
- group.add_argument('--initial-loss-scale', type=float, default=2**32,
- help='Initial loss-scale for dynamic loss scaling.')
- group.add_argument('--min-loss-scale', type=float, default=1.0,
- help='Minimum loss scale for dynamic loss scale.')
- group.add_argument('--loss-scale-window', type=float, default=1000,
- help='Window over which to raise/lower dynamic scale.')
- group.add_argument('--hysteresis', type=int, default=2,
- help='hysteresis for dynamic loss scaling')
- group.add_argument('--fp32-residual-connection', action='store_true',
- help='Move residual connections to fp32.')
- group.add_argument('--apply-query-key-layer-scaling', action='store_true',
- help='Scale Q * K^T by 1 / layer-number. '
- 'Useful for fp16 training.')
- group.add_argument('--attention-softmax-in-fp32', action='store_true',
- help='Run attention masking and softmax in fp32. '
- 'This flag is ignored unless '
- '--no-query-key-layer-scaling is specified.')
- group.add_argument('--accumulate-allreduce-grads-in-fp32',
- action='store_true',
- help='Gradient accumulation and all-reduce in fp32.')
- group.add_argument('--fp16-lm-cross-entropy', action='store_true',
- help='Move the cross entropy unreduced loss calculation'
- 'for lm head to fp16.')
-
- return parser
-
-
-def _add_distributed_args(parser):
- group = parser.add_argument_group(title='distributed')
-
- group.add_argument('--tensor-model-parallel-size', type=int, default=1,
- help='Degree of tensor model parallelism.')
- group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
- help='Degree of pipeline model parallelism.')
- group.add_argument('--pipeline-model-parallel-split-rank',
- type=int, default=None,
- help='Rank where encoder and decoder should be split.')
- group.add_argument('--model-parallel-size', type=int, default=None,
- help='Old model parallel argument, do not use. Use '
- '--tensor-model-parallel-size instead.')
- group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
- help='Number of layers per virtual pipeline stage')
- group.add_argument('--no-overlap-p2p-communication', action='store_false',
- help='overlap pipeline parallel communication with forward and backward chunks',
- dest='overlap_p2p_comm')
- group.add_argument('--distributed-backend', default='nccl',
- choices=['nccl', 'gloo'],
- help='Which backend to use for distributed training.')
- group.add_argument('--distributed-timeout-minutes', type=int, default=10,
- help='Timeout minutes for torch.distributed.')
- group.add_argument('--overlap-grad-reduce', action='store_true',
- default=False, help='If set, overlap DDP grad reduce.')
- group.add_argument('--no-delay-grad-reduce', action='store_false',
- help='If not set, delay / synchronize grad reductions in all but first PP stage.',
- dest='delay_grad_reduce')
- group.add_argument('--overlap-param-gather', action='store_true',
- default=False, help='If set, overlap param all-gather in distributed optimizer.')
- group.add_argument('--delay-param-gather', action='store_true',
- default=False, help='If set, delay / synchronize param all-gathers in all but first PP stage.')
- group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
- help='If not set, use scatter/gather to optimize communication of tensors in pipeline.',
- dest='scatter_gather_tensors_in_pipeline')
- group.add_argument('--use-ring-exchange-p2p', action='store_true',
- default=False, help='If set, use custom-built ring exchange '
- 'for p2p communications. Note that this option will require '
- 'a custom built image that support ring-exchange p2p.')
- group.add_argument('--local_rank', type=int, default=None,
- help='local rank passed from distributed launcher.')
- group.add_argument('--lazy-mpu-init', type=bool, required=False,
- help='If set to True, initialize_megatron() '
- 'skips DDP initialization and returns function to '
- 'complete it instead.Also turns on '
- '--use-cpu-initialization flag. This is for '
- 'external DDP manager.' )
- group.add_argument('--use-cpu-initialization', action='store_true',
- default=None, help='If set, affine parallel weights '
- 'initialization uses CPU' )
- group.add_argument('--empty-unused-memory-level', default=0, type=int,
- choices=[0, 1, 2],
- help='Call torch.cuda.empty_cache() each iteration '
- '(training and eval), to reduce fragmentation.'
- '0=off, 1=moderate, 2=aggressive.')
- group.add_argument('--standalone-embedding-stage', action='store_true',
- default=False, help='If set, *input* embedding layer '
- 'is placed on its own pipeline stage, without any '
- 'transformer layers. (For T5, this flag currently only '
- 'affects the encoder embedding.)')
- group.add_argument('--use-distributed-optimizer', action='store_true',
- help='Use distributed optimizer.')
- group.add_argument('--expert-model-parallel-size', type=int, default=1,
- help='Degree of expert model parallelism.')
- group.add_argument('--context-parallel-size', type=int, default=1,
- help='Degree of context parallelism.')
- group.add_argument('--nccl-communicator-config-path', type=str, default=None,
- help='Path to the yaml file with NCCL communicator '
- 'configurations. The number of min/max thread groups and thread '
- 'group cluster size of each communicator can be configured by '
- 'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.')
- return parser
-
-
-def _add_validation_args(parser):
- group = parser.add_argument_group(title='validation')
-
- group.add_argument('--eval-iters', type=int, default=100,
- help='Number of iterations to run for evaluation'
- 'validation/test for.')
- group.add_argument('--eval-interval', type=int, default=1000,
- help='Interval between running evaluation on '
- 'validation set.')
- group.add_argument('--skip-train', action='store_true',
- default=False, help='If set, bypass the training loop, '
- 'optionally do evaluation for validation/test, and exit.')
-
- return parser
-
-
-def _add_data_args(parser):
- group = parser.add_argument_group(title='data and dataloader')
-
- group.add_argument('--data-path', nargs='*', default=None,
- help='Path to the training dataset. Accepted format:'
- '1) a single data path, 2) multiple datasets in the'
- 'form: dataset1-weight dataset1-path dataset2-weight '
- 'dataset2-path ... It is used with --split when a '
- 'single dataset used for all three: train, valid '
- 'and test. It is exclusive to the other '
- '--*-data-path args')
- group.add_argument('--split', type=str, default='969, 30, 1',
- help='Comma-separated list of proportions for training,'
- ' validation, and test split. For example the split '
- '`90,5,5` will use 90%% of data for training, 5%% for '
- 'validation and 5%% for test.')
- group.add_argument('--train-data-path', nargs='*', default=None,
- help='Path to the training dataset. Accepted format:'
- '1) a single data path, 2) multiple datasets in the'
- 'form: dataset1-weight dataset1-path dataset2-weight '
- 'dataset2-path ...')
- group.add_argument('--valid-data-path', nargs='*', default=None,
- help='Path to the validation dataset. Accepted format:'
- '1) a single data path, 2) multiple datasets in the'
- 'form: dataset1-weight dataset1-path dataset2-weight '
- 'dataset2-path ...')
- group.add_argument('--test-data-path', nargs='*', default=None,
- help='Path to the test dataset. Accepted format:'
- '1) a single data path, 2) multiple datasets in the'
- 'form: dataset1-weight dataset1-path dataset2-weight '
- 'dataset2-path ...')
- group.add_argument('--data-cache-path', default=None,
- help='Path to a directory to hold cached index files.')
-
- group.add_argument('--vocab-size', type=int, default=None,
- help='Size of vocab before EOD or padding.')
- group.add_argument('--vocab-file', type=str, default=None,
- help='Path to the vocab file.')
- group.add_argument('--merge-file', type=str, default=None,
- help='Path to the BPE merge file.')
- group.add_argument('--vocab-extra-ids', type=int, default=0,
- help='Number of additional vocabulary tokens. '
- 'They are used for span masking in the T5 model')
- group.add_argument('--seq-length', type=int, default=None,
- help='Maximum sequence length to process.')
- group.add_argument('--encoder-seq-length', type=int, default=None,
- help='Maximum encoder sequence length to process.'
- 'This should be exclusive of --seq-length')
- group.add_argument('--decoder-seq-length', type=int, default=None,
- help="Maximum decoder sequence length to process.")
- group.add_argument('--retriever-seq-length', type=int, default=256,
- help='Maximum sequence length for the biencoder model '
- 'for retriever')
- group.add_argument('--sample-rate', type=float, default=1.0,
- help='sample rate for training data. Supposed to be 0 '
- ' < sample_rate < 1')
- group.add_argument('--mask-prob', type=float, default=0.15,
- help='Probability of replacing a token with mask.')
- group.add_argument('--short-seq-prob', type=float, default=0.1,
- help='Probability of producing a short sequence.')
- group.add_argument('--num-workers', type=int, default=2,
- help="Dataloader number of workers.")
- group.add_argument('--tokenizer-type', type=str,
- default=None,
- choices=['BertWordPieceLowerCase',
- 'BertWordPieceCase',
- 'GPT2BPETokenizer',
- 'SentencePieceTokenizer',
- 'GPTSentencePieceTokenizer',
- 'Llama2Tokenizer',
- 'NullTokenizer'],
- help='What type of tokenizer to use.')
- group.add_argument('--tokenizer-model', type=str, default=None,
- help='Sentencepiece tokenizer model.')
- group.add_argument('--reset-position-ids', action='store_true',
- help='Reset posistion ids after end-of-document token.')
- group.add_argument('--reset-attention-mask', action='store_true',
- help='Reset self attention maske after '
- 'end-of-document token.')
- group.add_argument('--eod-mask-loss', action='store_true',
- help='Mask loss for the end of document tokens.')
-
- return parser
-
-
-def _add_autoresume_args(parser):
- group = parser.add_argument_group(title='autoresume')
-
- group.add_argument('--adlr-autoresume', action='store_true',
- help='Enable autoresume on adlr cluster.')
- group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
- help='Intervals over which check for autoresume'
- 'termination signal')
-
- return parser
-
-
-def _add_biencoder_args(parser):
- group = parser.add_argument_group(title='biencoder')
-
- # network size
- group.add_argument('--ict-head-size', type=int, default=None,
- help='Size of block embeddings to be used in ICT and '
- 'REALM (paper default: 128)')
- group.add_argument('--biencoder-projection-dim', type=int, default=0,
- help='Size of projection head used in biencoder (paper'
- ' default: 128)')
- group.add_argument('--biencoder-shared-query-context-model', action='store_true',
- help='Whether to share the parameters of the query '
- 'and context models or not')
-
- # checkpointing
- group.add_argument('--ict-load', type=str, default=None,
- help='Directory containing an ICTBertModel checkpoint')
- group.add_argument('--bert-load', type=str, default=None,
- help='Directory containing an BertModel checkpoint '
- '(needed to start ICT and REALM)')
-
- # data
- group.add_argument('--titles-data-path', type=str, default=None,
- help='Path to titles dataset used for ICT')
- group.add_argument('--query-in-block-prob', type=float, default=0.1,
- help='Probability of keeping query in block for '
- 'ICT dataset')
- group.add_argument('--use-one-sent-docs', action='store_true',
- help='Whether to use one sentence documents in ICT')
- group.add_argument('--evidence-data-path', type=str, default=None,
- help='Path to Wikipedia Evidence frm DPR paper')
-
- # training
- group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
- default=[], help="Which top-k accuracies to report "
- "(e.g. '1 5 20')")
- group.add_argument('--retriever-score-scaling', action='store_true',
- help='Whether to scale retriever scores by inverse '
- 'square root of hidden size')
-
- # faiss index
- group.add_argument('--block-data-path', type=str, default=None,
- help='Where to save/load BlockData to/from')
- group.add_argument('--embedding-path', type=str, default=None,
- help='Where to save/load Open-Retrieval Embedding'
- ' data to/from')
-
- # indexer
- group.add_argument('--indexer-batch-size', type=int, default=128,
- help='How large of batches to use when doing indexing '
- 'jobs')
- group.add_argument('--indexer-log-interval', type=int, default=1000,
- help='After how many batches should the indexer '
- 'report progress')
- return parser
-
-
-def _add_vision_args(parser):
- group = parser.add_argument_group(title="vision")
-
- # general vision arguements
- group.add_argument('--num-classes', type=int, default=1000,
- help='num of classes in vision classificaiton task')
- group.add_argument('--img-h', type=int, default=224,
- help='Image height for vision classification task')
- group.add_argument('--img-w', type=int, default=224,
- help='Image height for vision classification task')
- group.add_argument('--num-channels', type=int, default=3,
- help='Number of channels in input image data')
- group.add_argument('--patch-dim', type=int, default=16,
- help='patch dimension')
- group.add_argument('--classes-fraction', type=float, default=1.0,
- help='training with fraction of classes.')
- group.add_argument('--data-per-class-fraction', type=float, default=1.0,
- help='training with fraction of data per class.')
- group.add_argument('--no-data-sharding', action='store_false',
- help='Disable data sharding.',
- dest='data_sharding')
- group.add_argument('--head-lr-mult', type=float, default=1.0,
- help='learning rate multiplier for head during finetuning')
-
- # pretraining type and backbone selection`
- group.add_argument('--vision-pretraining', action='store_true',
- help='flag to indicate vision pretraining')
- group.add_argument('--vision-pretraining-type', type=str, default='classify',
- choices=['classify', 'inpaint', 'dino'],
- help='pretraining objectives')
- group.add_argument('--vision-backbone-type', type=str, default='vit',
- choices=['vit', 'mit', 'swin'],
- help='backbone types types')
- group.add_argument('--swin-backbone-type', type=str, default='tiny',
- choices=['tiny', 'base', 'h3'],
- help='pretraining objectives')
-
- # inpainting arguments
- group.add_argument('--mask-type', type=str, default='random',
- choices=['random', 'row'],
- help='mask types')
- group.add_argument('--mask-factor', type=float, default=1.0,
- help='mask size scaling parameter')
-
- # dino arguments
- group.add_argument('--iter-per-epoch', type=int, default=1250,
- help='iterations per epoch')
- group.add_argument('--dino-local-img-size', type=int, default=96,
- help='Image size for vision classification task')
- group.add_argument('--dino-local-crops-number', type=int, default=10,
- help='Number of local crops')
- group.add_argument('--dino-head-hidden-size', type=int, default=2048,
- help='Hidden dimension size in dino head')
- group.add_argument('--dino-bottleneck-size', type=int, default=256,
- help='Bottle neck dimension in dino head ')
- group.add_argument('--dino-freeze-last-layer', type=float, default=1,
- help='Freezing last layer weights')
- group.add_argument('--dino-norm-last-layer', action='store_true',
- help='Disable Norm in last layer.')
- group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
- help='warump teacher temperature')
- group.add_argument('--dino-teacher-temp', type=float, default=0.07,
- help='teacher temperature')
- group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
- help='warmup teacher temperaure epochs')
-
- return parser
-
-def _add_experimental_args(parser):
- group = parser.add_argument_group(title='experimental')
-
- group.add_argument('--spec', type=str, default=None, nargs=2,
- help='Specify the pair '
- 'that returns a spec to customize a model, transformer '
- 'block, or transformer layer, depending on the use case. '
- 'For more details, see the model class, '
- '`transformer_block.py`, or `transformer_layer.py`')
-
- return parser
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
deleted file mode 100644
index 2be766e384578cbdff64d199571f0f753c013c6d..0000000000000000000000000000000000000000
--- a/megatron/checkpointing.py
+++ /dev/null
@@ -1,714 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Input/output checkpointing."""
-
-import os
-import random
-import sys
-import numpy as np
-
-import torch
-
-from megatron import update_num_microbatches
-from megatron.core import mpu, tensor_parallel
-from .global_vars import get_args
-from .utils import (unwrap_model,
- print_rank_0)
-
-
-_CHECKPOINT_VERSION = None
-
-
-def set_checkpoint_version(value):
- global _CHECKPOINT_VERSION
- if _CHECKPOINT_VERSION is not None:
- assert _CHECKPOINT_VERSION == value, \
- "checkpoint versions do not match"
- _CHECKPOINT_VERSION = value
-
-
-def get_checkpoint_version():
- global _CHECKPOINT_VERSION
- return _CHECKPOINT_VERSION
-
-
-def check_checkpoint_args(checkpoint_args):
- """Ensure fixed arguments for a model are the same for the input
- arguments and the one retrieved from checkpoint."""
- args = get_args()
-
- def _compare(arg_name, old_arg_name=None, default=None):
- if old_arg_name is not None:
- ckpt_arg_name = old_arg_name
- else:
- ckpt_arg_name = arg_name
- if default is not None:
- checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default)
- else:
- checkpoint_value = getattr(checkpoint_args, ckpt_arg_name)
- args_value = getattr(args, arg_name)
- error_message = '{} value from checkpoint ({}) is not equal to the ' \
- 'input argument value ({}).'.format(
- arg_name, checkpoint_value, args_value)
- assert checkpoint_value == args_value, error_message
-
- _compare('num_layers')
- _compare('hidden_size')
- _compare('num_attention_heads')
- _compare('add_position_embedding', default=True)
- if args.vocab_file:
- _compare('max_position_embeddings')
- _compare('make_vocab_size_divisible_by')
- _compare('padded_vocab_size')
- _compare('tokenizer_type')
- if args.data_parallel_random_init:
- _compare('data_parallel_random_init')
- if get_checkpoint_version() < 3.0:
- _compare('tensor_model_parallel_size',
- old_arg_name='model_parallel_size')
- if get_checkpoint_version() >= 3.0:
- _compare('tensor_model_parallel_size')
- _compare('pipeline_model_parallel_size')
-
-
-def ensure_directory_exists(filename):
- """Build filename's path if it does not already exists."""
- dirname = os.path.dirname(filename)
- os.makedirs(dirname, exist_ok = True)
-
-
-def get_checkpoint_name(checkpoints_path, iteration, release=False,
- pipeline_parallel=None,
- tensor_rank=None, pipeline_rank=None,
- expert_parallel=None, expert_rank=None):
- """Determine the directory name for this rank's checkpoint."""
- if release:
- directory = 'release'
- else:
- directory = 'iter_{:07d}'.format(iteration)
-
- # Use both the tensor and pipeline MP rank.
- if pipeline_parallel is None:
- pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
- if tensor_rank is None:
- tensor_rank = mpu.get_tensor_model_parallel_rank()
- if pipeline_rank is None:
- pipeline_rank = mpu.get_pipeline_model_parallel_rank()
- if expert_parallel is None:
- expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1)
- if expert_rank is None:
- expert_rank = mpu.get_expert_model_parallel_rank()
-
- # Use both the tensor and pipeline MP rank. If using the distributed
- # optimizer, then the optimizer's path must additionally include the
- # data parallel rank.
- if not pipeline_parallel:
- common_path = os.path.join(checkpoints_path, directory,
- f'mp_rank_{tensor_rank:02d}')
- else:
- common_path = os.path.join(checkpoints_path, directory,
- f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
-
- if expert_parallel:
- common_path = common_path + f'_{expert_rank:03d}'
-
- return os.path.join(common_path, "model_optim_rng.pt")
-
-
-def get_distributed_optimizer_checkpoint_name(model_checkpoint_name):
- return os.path.join(os.path.dirname(model_checkpoint_name),
- "distrib_optim.pt")
-
-
-def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
- """Finds the checkpoint for rank 0 without knowing if we are using
- pipeline parallelism/expert parallelism or not.
-
- Since the checkpoint naming scheme changes if pipeline or expert
- parallelism is present, we need to look for both naming schemes if
- we don't know if the checkpoint has pipeline or expert parallelism.
- """
-
- # Look for checkpoint with no pipelining and no expert parallelism
- filename = get_checkpoint_name(checkpoints_path, iteration, release,
- pipeline_parallel=False,
- tensor_rank=0, pipeline_rank=0,
- expert_parallel=False, expert_rank=0)
- if os.path.isfile(filename):
- return filename
-
- # Look for checkpoint with no pipelining and expert parallelism
- filename = get_checkpoint_name(checkpoints_path, iteration, release,
- pipeline_parallel=False,
- tensor_rank=0, pipeline_rank=0,
- expert_parallel=True, expert_rank=0)
- if os.path.isfile(filename):
- return filename
-
- # Look for checkpoint with pipelining and no expert parallelism
- filename = get_checkpoint_name(checkpoints_path, iteration, release,
- pipeline_parallel=True,
- tensor_rank=0, pipeline_rank=0,
- expert_parallel=False, expert_rank=0)
- if os.path.isfile(filename):
- return filename
-
- # Look for checkpoint with pipelining and expert parallelism
- filename = get_checkpoint_name(checkpoints_path, iteration, release,
- pipeline_parallel=True,
- tensor_rank=0, pipeline_rank=0,
- expert_parallel=True, expert_rank=0)
- if os.path.isfile(filename):
- return filename
-
- return None, None
-
-
-def get_checkpoint_tracker_filename(checkpoints_path):
-
- """Tracker file rescords the latest chckpoint during
- training to restart from."""
- return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
-
-
-def read_metadata(tracker_filename):
- # Read the tracker file and either set the iteration or
- # mark it as a release checkpoint.
- iteration = 0
- release = False
- with open(tracker_filename, 'r') as f:
- metastring = f.read().strip()
- try:
- iteration = int(metastring)
- except ValueError:
- release = metastring == 'release'
- if not release:
- print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
- tracker_filename))
- sys.exit()
- assert iteration > 0 or release, 'error parsing metadata file {}'.format(
- tracker_filename)
-
- # Get the max iteration retrieved across the ranks.
- if torch.distributed.is_initialized():
- iters_cuda = torch.cuda.LongTensor([iteration])
- torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
- max_iter = iters_cuda[0].item()
-
- # We should now have all the same iteration.
- # If not, print a warning and chose the maximum
- # iteration across all ranks.
- if iteration != max_iter:
- rank = torch.distributed.get_rank()
- print('WARNING: on rank {} found iteration {} in the '
- 'metadata while max iteration across the ranks '
- 'is {}, replacing it with max iteration.'.format(
- rank, iteration, max_iter), flush=True)
- else:
- # When loading a checkpoint outside of training (for example,
- # when editing it), we might not have torch distributed
- # initialized, in this case, just assume we have the latest
- max_iter = iteration
- return max_iter, release
-
-
-def get_rng_state():
- """ collect rng state across data parallel ranks """
- args = get_args()
- rng_state = {
- 'random_rng_state': random.getstate(),
- 'np_rng_state': np.random.get_state(),
- 'torch_rng_state': torch.get_rng_state(),
- 'cuda_rng_state': torch.cuda.get_rng_state(),
- 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
-
- rng_state_list = None
- if torch.distributed.is_initialized() and \
- mpu.get_data_parallel_world_size() > 1 and \
- args.data_parallel_random_init:
- rng_state_list = \
- [None for i in range(mpu.get_data_parallel_world_size())]
- torch.distributed.all_gather_object(
- rng_state_list,
- rng_state,
- group=mpu.get_data_parallel_group())
- else:
- rng_state_list = [rng_state]
-
- return rng_state_list
-
-
-def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
- """Save a model checkpoint."""
- args = get_args()
-
- # Only rank zero of the data parallel writes to the disk.
- model = unwrap_model(model)
-
- print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
- iteration, args.save))
-
- # Collect rng state across data parallel ranks.
- rng_state = get_rng_state()
-
- # Checkpoint name.
- checkpoint_name = get_checkpoint_name(args.save, iteration)
-
- # Save distributed optimizer's custom parameter state.
- if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None:
- optim_checkpoint_name = \
- get_distributed_optimizer_checkpoint_name(checkpoint_name)
- ensure_directory_exists(optim_checkpoint_name)
- optimizer.save_parameter_state(optim_checkpoint_name)
-
- # Collect args, model, RNG.
- if not torch.distributed.is_initialized() \
- or mpu.get_data_modulo_expert_parallel_rank() == 0:
-
- # Arguments, iteration, and model.
- state_dict = {}
- state_dict['args'] = args
- state_dict['checkpoint_version'] = 3.0
- state_dict['iteration'] = iteration
- if len(model) == 1:
- state_dict['model'] = model[0].state_dict_for_save_checkpoint()
- else:
- for i in range(len(model)):
- mpu.set_virtual_pipeline_model_parallel_rank(i)
- state_dict['model%d' % i] = \
- model[i].state_dict_for_save_checkpoint()
-
- # Optimizer stuff.
- if not args.no_save_optim:
- if optimizer is not None:
- state_dict['optimizer'] = optimizer.state_dict()
- if opt_param_scheduler is not None:
- state_dict['opt_param_scheduler'] = \
- opt_param_scheduler.state_dict()
-
- # RNG states.
- if not args.no_save_rng:
- state_dict["rng_state"] = rng_state
-
- # Save.
- ensure_directory_exists(checkpoint_name)
- torch.save(state_dict, checkpoint_name)
-
- # Wait so everyone is done (necessary)
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
-
- print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
- .format(iteration, args.save))
-
- # And update the latest iteration
- if not torch.distributed.is_initialized() \
- or torch.distributed.get_rank() == 0:
- tracker_filename = get_checkpoint_tracker_filename(args.save)
- with open(tracker_filename, 'w') as f:
- f.write(str(iteration))
-
- # Wait so everyone is done (not necessary)
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
-
-
-def _transpose_first_dim(t, num_splits, num_splits_first, model):
- input_shape = t.size()
- # We use a self_attention module but the values extracted aren't
- # specific to self attention so should work for cross attention as well
- while hasattr(model, 'module'):
- model = model.module
- attention_module = model.language_model.encoder.layers[0].self_attention
- hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
- num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
- if num_splits_first:
- """[num_splits * np * hn, h]
- -->(view) [num_splits, np, hn, h]
- -->(tranpose) [np, num_splits, hn, h]
- -->(view) [np * num_splits * hn, h] """
-
- intermediate_shape = \
- (num_splits, num_attention_heads_per_partition,
- hidden_size_per_attention_head) + input_shape[1:]
-
- t = t.view(*intermediate_shape)
- t = t.transpose(0, 1).contiguous()
- else:
- """[np * hn * num_splits, h]
- -->(view) [np, hn, num_splits, h]
- -->(tranpose) [np, num_splits, hn, h]
- -->(view) [np * num_splits * hn, h] """
-
- intermediate_shape = \
- (num_attention_heads_per_partition,
- hidden_size_per_attention_head, num_splits) +\
- input_shape[1:]
-
- t = t.view(*intermediate_shape)
- t = t.transpose(1, 2).contiguous()
- t = t.view(*input_shape)
-
- return t
-
-
-def fix_query_key_value_ordering(model, checkpoint_version):
- """Fix up query/key/value matrix ordering if checkpoint
- version is smaller than 2.0
- """
- if checkpoint_version < 2.0:
- if isinstance(model, list):
- assert len(model)==1
- model = model[0]
- for name, param in model.named_parameters():
- if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
- if checkpoint_version == 0:
- fixed_param = _transpose_first_dim(param.data, 3, True, model)
- elif checkpoint_version == 1.0:
- fixed_param = _transpose_first_dim(param.data, 3, False, model)
- else:
- print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
- sys.exit()
- param.data.copy_(fixed_param)
- if name.endswith(('.key_value.weight', '.key_value.bias')):
- if checkpoint_version == 0:
- fixed_param = _transpose_first_dim(param.data, 2, True, model)
- elif checkpoint_version == 1.0:
- fixed_param = _transpose_first_dim(param.data, 2, False, model)
- else:
- print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
- sys.exit()
- param.data.copy_(fixed_param)
- print_rank_0(" succesfully fixed query-key-values ordering for"
- " checkpoint version {}".format(checkpoint_version))
-
-
-def _load_base_checkpoint(load_dir, rank0=False):
- """ Load the base state_dict from the given directory
-
- If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
- """
-
- # Read the tracker file and set the iteration.
- tracker_filename = get_checkpoint_tracker_filename(load_dir)
-
- # If no tracker file, return nothing
- if not os.path.isfile(tracker_filename):
- if not rank0:
- print_rank_0('WARNING: could not find the metadata file {} '.format(
- tracker_filename))
- print_rank_0(' will not load any checkpoints and will start from '
- 'random')
- return None, "", False
-
- # Otherwise, read the tracker file and either set the iteration or
- # mark it as a release checkpoint.
- iteration, release = read_metadata(tracker_filename)
-
- # Checkpoint.
- if rank0:
- checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
- else:
- checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
- if release:
- print_rank_0(f' loading release checkpoint from {load_dir}')
- else:
- print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
-
- # Load the checkpoint.
- try:
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- except ModuleNotFoundError:
- from megatron.fp16_deprecated import loss_scaler
- # For backward compatibility.
- if not rank0:
- print_rank_0(' > deserializing using the old code structure ...')
- sys.modules['fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- sys.modules.pop('fp16.loss_scaler', None)
- sys.modules.pop('megatron.fp16.loss_scaler', None)
- except BaseException as e:
- print_rank_0('could not load the checkpoint')
- print_rank_0(e)
- sys.exit()
-
- return state_dict, checkpoint_name, release
-
-
-def load_args_from_checkpoint(args, load_arg='load'):
- """Set required arguments from the checkpoint specified in the
- arguments.
-
- Will overwrite arguments that have a non-None default value, but
- will leave any arguments that default to None as set.
-
- Returns the same args NameSpace with the new values added/updated.
-
- If no checkpoint is specified in args, or if the checkpoint is
- there but invalid, the arguments will not be modified
-
- """
- load_dir = getattr(args, load_arg)
-
- if load_dir is None:
- print_rank_0('No load directory specified, using provided arguments.')
- return args
-
- state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True)
-
- # Args.
- if not state_dict:
- print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
- return args
-
- if 'args' not in state_dict:
- print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
- return args
-
- checkpoint_args = state_dict['args']
- checkpoint_version = state_dict.get('checkpoint_version', 0)
- args.iteration = state_dict['iteration']
-
- # One-off conversion for foundation models
- if hasattr(checkpoint_args, 'disable_bias_linear'):
- setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear'))
-
- def _set_arg(arg_name, old_arg_name=None, force=False):
- if not force and getattr(args, arg_name, None) is not None:
- return
-
- if old_arg_name is not None:
- checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
- else:
- checkpoint_value = getattr(checkpoint_args, arg_name, None)
-
- if checkpoint_value is not None:
- print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
- setattr(args, arg_name, checkpoint_value)
- else:
- print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
-
- _set_arg('num_layers')
- _set_arg('hidden_size')
- _set_arg('ffn_hidden_size')
- _set_arg('seq_length')
- _set_arg('num_attention_heads')
- _set_arg('num_query_groups', force=True)
- _set_arg('group_query_attention', force=True)
- _set_arg('kv_channels')
- _set_arg('max_position_embeddings')
- _set_arg('position_embedding_type', force=True)
- _set_arg('add_position_embedding', force=True)
- _set_arg('use_rotary_position_embeddings', force=True)
- _set_arg('rotary_percent', force=True)
- _set_arg('add_bias_linear', force=True)
- _set_arg('swiglu', force=True)
- _set_arg('untie_embeddings_and_output_weights', force=True)
- _set_arg('apply_layernorm_1p', force=True)
- _set_arg('normalization', force=True)
- _set_arg('tokenizer_type')
- _set_arg('padded_vocab_size')
- if checkpoint_version < 3.0:
- _set_arg('tensor_model_parallel_size',
- 'model_parallel_size')
- else:
- _set_arg('tensor_model_parallel_size', force=True)
- _set_arg('pipeline_model_parallel_size', force=True)
- _set_arg('virtual_pipeline_model_parallel_size', force=True)
- _set_arg('num_layers_per_virtual_pipeline_stage')
- return args, checkpoint_args
-
-
-def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
- """Load a model checkpoint and return the iteration.
- strict (bool): whether to strictly enforce that the keys in
- :attr:`state_dict` of the checkpoint match the names of
- parameters and buffers in model.
- """
- args = get_args()
- load_dir = getattr(args, load_arg)
-
- model = unwrap_model(model)
-
- state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False)
-
- # Checkpoint not loaded.
- if state_dict is None:
-
- # Conditionally exit at this point.
- if args.exit_on_missing_checkpoint:
- print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
- torch.distributed.barrier()
- sys.exit()
-
- # Iteration defaults to 0.
- return 0
-
- # Set checkpoint version.
- set_checkpoint_version(state_dict.get('checkpoint_version', 0))
-
- # Set iteration.
- if args.finetune or release:
- iteration = 0
- else:
- try:
- iteration = state_dict['iteration']
- except KeyError:
- try: # Backward compatible with older checkpoints
- iteration = state_dict['total_iters']
- except KeyError:
- print_rank_0('A metadata file exists but unable to load '
- 'iteration from checkpoint {}, exiting'.format(
- checkpoint_name))
- sys.exit()
-
- # Check arguments.
- assert args.consumed_train_samples == 0
- assert args.consumed_valid_samples == 0
- if 'args' in state_dict and not args.finetune:
- checkpoint_args = state_dict['args']
- check_checkpoint_args(checkpoint_args)
- args.consumed_train_samples = getattr(checkpoint_args,
- 'consumed_train_samples', 0)
- update_num_microbatches(consumed_samples=args.consumed_train_samples)
- args.consumed_valid_samples = getattr(checkpoint_args,
- 'consumed_valid_samples', 0)
- else:
- print_rank_0('could not find arguments in the checkpoint ...')
-
- # Model.
- if len(model) == 1:
- model[0].load_state_dict(state_dict['model'], strict=strict)
- else:
- for i in range(len(model)):
- mpu.set_virtual_pipeline_model_parallel_rank(i)
- model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
-
- # Fix up query/key/value matrix ordering if needed.
- checkpoint_version = get_checkpoint_version()
- print_rank_0(f' checkpoint version {checkpoint_version}')
- fix_query_key_value_ordering(model, checkpoint_version)
-
- # Optimizer.
- if not release and not args.finetune and not args.no_load_optim:
- try:
- # Load state dict.
- if optimizer is not None:
- optimizer.load_state_dict(state_dict['optimizer'])
-
- # Load distributed optimizer's custom parameter state.
- if args.use_distributed_optimizer:
- tracker_filename = get_checkpoint_tracker_filename(load_dir)
- iteration, release = read_metadata(tracker_filename)
- model_checkpoint_name = \
- get_checkpoint_name(load_dir, iteration, release)
- optim_checkpoint_name = \
- get_distributed_optimizer_checkpoint_name(
- model_checkpoint_name)
- optimizer.load_parameter_state(optim_checkpoint_name)
-
- # Load scheduler.
- if opt_param_scheduler is not None:
- if 'lr_scheduler' in state_dict: # backward compatbility
- opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
- else:
- opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
- except KeyError:
- print_rank_0('Unable to load optimizer from checkpoint {}. '
- 'Specify --no-load-optim or --finetune to prevent '
- 'attempting to load the optimizer state, '
- 'exiting ...'.format(checkpoint_name))
- sys.exit()
- else:
- if (args.fp16 or args.bf16) and optimizer is not None:
- optimizer.reload_model_params()
-
- # rng states.
- if not release and not args.finetune and not args.no_load_rng:
- try:
- if 'rng_state' in state_dict:
- # access rng_state for data parallel rank
- if args.data_parallel_random_init:
- rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
- else:
- rng_state = state_dict['rng_state'][0]
- random.setstate(rng_state['random_rng_state'])
- np.random.set_state(rng_state['np_rng_state'])
- torch.set_rng_state(rng_state['torch_rng_state'])
- torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
- # Check for empty states array
- if not rng_state['rng_tracker_states']:
- raise KeyError
- tensor_parallel.get_cuda_rng_tracker().set_states(
- rng_state['rng_tracker_states'])
- else: # backward compatability
- random.setstate(state_dict['random_rng_state'])
- np.random.set_state(state_dict['np_rng_state'])
- torch.set_rng_state(state_dict['torch_rng_state'])
- torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
- # Check for empty states array
- if not state_dict['rng_tracker_states']:
- raise KeyError
- tensor_parallel.get_cuda_rng_tracker().set_states(
- state_dict['rng_tracker_states'])
- except KeyError:
- print_rank_0('Unable to load rng state from checkpoint {}. '
- 'Specify --no-load-rng or --finetune to prevent '
- 'attempting to load the rng state, '
- 'exiting ...'.format(checkpoint_name))
- sys.exit()
-
- # Some utilities want to load a checkpoint without distributed being initialized
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
-
- print_rank_0(f' successfully loaded checkpoint from {args.load} '
- f'at iteration {iteration}')
-
- return iteration
-
-
-def load_biencoder_checkpoint(model, only_query_model=False,
- only_context_model=False, custom_load_path=None):
- """
- selectively load retrieval models for indexing/retrieving
- from saved checkpoints
- """
-
- args = get_args()
-
- model = unwrap_model(model)
-
- load_path = custom_load_path if custom_load_path is not None else args.load
-
- tracker_filename = get_checkpoint_tracker_filename(load_path)
- with open(tracker_filename, 'r') as f:
- iteration = int(f.read().strip())
-
- checkpoint_name = get_checkpoint_name(load_path, iteration,
- args.use_distributed_optimizer,
- release=False)
-
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
-
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- ret_state_dict = state_dict['model']
-
- if only_query_model:
- ret_state_dict.pop('context_model')
- if only_context_model:
- ret_state_dict.pop('query_model')
-
- assert len(model) == 1
- model[0].load_state_dict(ret_state_dict)
- torch.distributed.barrier()
-
- if mpu.get_data_parallel_rank() == 0:
- print(' successfully loaded {}'.format(checkpoint_name))
-
- return model
diff --git a/megatron/core/README.md b/megatron/core/README.md
deleted file mode 100644
index 0c8c61738da2f6526d065eb600e33d23187bcd0c..0000000000000000000000000000000000000000
--- a/megatron/core/README.md
+++ /dev/null
@@ -1 +0,0 @@
-Megatron Core is a library for efficient and scalable training of transformer based models.
diff --git a/megatron/core/__init__.py b/megatron/core/__init__.py
deleted file mode 100644
index 2858dc692dff393ba9e01dc34cf75e6c74834346..0000000000000000000000000000000000000000
--- a/megatron/core/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import megatron.core.tensor_parallel
-import megatron.core.utils
-from megatron.core import parallel_state
-from megatron.core.distributed import DistributedDataParallel
-from megatron.core.inference_params import InferenceParams
-from megatron.core.model_parallel_config import ModelParallelConfig
-
-# Alias parallel_state as mpu, its legacy name
-mpu = parallel_state
-
-__all__ = [
- "parallel_state",
- "tensor_parallel",
- "utils",
- "DistributedDataParallel",
- "InferenceParams",
- "ModelParallelConfig",
-]
diff --git a/megatron/core/datasets/Makefile b/megatron/core/datasets/Makefile
deleted file mode 100644
index 8f9db7686696fbea6c94b998db4b40ef426c748d..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/Makefile
+++ /dev/null
@@ -1,9 +0,0 @@
-CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
-CPPFLAGS += $(shell python3 -m pybind11 --includes)
-LIBNAME = helpers
-LIBEXT = $(shell python3-config --extension-suffix)
-
-default: $(LIBNAME)$(LIBEXT)
-
-%$(LIBEXT): %.cpp
- $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
diff --git a/megatron/core/datasets/__init__.py b/megatron/core/datasets/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/datasets/blended_dataset.py b/megatron/core/datasets/blended_dataset.py
deleted file mode 100644
index 89f3bbc9e5e25b301ef9559f4ff1dc1898faf61c..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/blended_dataset.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import hashlib
-import json
-import logging
-import os
-import time
-from collections import OrderedDict
-from typing import Dict, List, Tuple, Union
-
-import numpy
-import torch
-
-from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
-from megatron.core.datasets.megatron_dataset import MegatronDataset
-from megatron.core.datasets.utils import log_single_rank, normalize
-
-logger = logging.getLogger(__name__)
-
-_VERBOSE = False
-
-
-class BlendedDataset(torch.utils.data.Dataset):
- """Conjugating class for a set of MegatronDataset instances
-
- Args:
- datasets (List[MegatronDataset]): The MegatronDataset instances to blend
-
- weights (List[float]): The weights which determines the dataset blend ratios
-
- size (int): The number of samples to draw from the blend
-
- config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
-
- Raises:
- RuntimeError: When the dataset has fewer or more samples than 'size' post-initialization
- """
-
- def __init__(
- self,
- datasets: List[MegatronDataset],
- weights: List[float],
- size: int,
- config: BlendedMegatronDatasetConfig,
- ) -> None:
- assert len(datasets) < 32767
- assert len(datasets) == len(weights)
- assert numpy.isclose(sum(weights), 1.0)
- assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
-
- # Alert user to unnecessary blending
- if len(datasets) == 1:
- log_single_rank(
- logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
- )
-
- # Redundant normalization for bitwise identical comparison with Megatron-LM
- weights = normalize(weights)
-
- self.datasets = datasets
- self.weights = weights
- self.size = size
- self.config = config
-
- unique_identifiers = OrderedDict()
- unique_identifiers["class"] = type(self).__name__
- unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
- unique_identifiers["weights"] = self.weights
- unique_identifiers["size"] = self.size
-
- self.unique_description = json.dumps(unique_identifiers, indent=4)
- self.unique_description_hash = hashlib.md5(
- self.unique_description.encode("utf-8")
- ).hexdigest()
-
- self.dataset_index, self.dataset_sample_index = self._build_indices()
-
- # Check size
- _ = self[self.size - 1]
- try:
- _ = self[self.size]
- raise RuntimeError(f"{type(self).__name__} size is improperly bounded")
- except IndexError:
- log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}")
-
- def __len__(self) -> int:
- return self.size
-
- def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
- dataset_id = self.dataset_index[idx]
- dataset_sample_id = self.dataset_sample_index[idx]
- return {
- "dataset_id": dataset_id,
- **self.datasets[dataset_id][dataset_sample_id],
- }
-
- def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
- """Build and optionally cache the dataset index and the dataset sample index
-
- The dataset index is a 1-D mapping which determines the dataset to query. The dataset
- sample index is a 1-D mapping which determines the sample to request from the queried
- dataset.
-
- Returns:
- Tuple[numpy.ndarray, numpy.ndarray]: The dataset index and the dataset sample index
- """
- path_to_cache = getattr(self.config, "path_to_cache")
-
- if path_to_cache:
- get_path_to = lambda suffix: os.path.join(
- path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
- )
- path_to_description = get_path_to("description.txt")
- path_to_dataset_index = get_path_to("dataset_index.npy")
- path_to_dataset_sample_index = get_path_to("dataset_sample_index.npy")
- cache_hit = all(
- map(
- os.path.isfile,
- [path_to_description, path_to_dataset_index, path_to_dataset_sample_index],
- )
- )
- else:
- cache_hit = False
-
- if not path_to_cache or (not cache_hit and torch.distributed.get_rank() == 0):
- log_single_rank(
- logger, logging.INFO, f"Build and save the {type(self).__name__} indices",
- )
-
- # Build the dataset and dataset sample indexes
- log_single_rank(
- logger, logging.INFO, f"\tBuild and save the dataset and dataset sample indexes"
- )
- t_beg = time.time()
- from megatron.core.datasets import helpers
-
- dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
- dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
- helpers.build_blending_indices(
- dataset_index,
- dataset_sample_index,
- self.weights,
- len(self.datasets),
- self.size,
- _VERBOSE,
- )
-
- if path_to_cache:
- os.makedirs(path_to_cache, exist_ok=True)
- # Write the description
- with open(path_to_description, "wt") as writer:
- writer.write(self.unique_description)
- # Save the indexes
- numpy.save(path_to_dataset_index, dataset_index, allow_pickle=True)
- numpy.save(path_to_dataset_sample_index, dataset_sample_index, allow_pickle=True)
- else:
- log_single_rank(
- logger,
- logging.WARNING,
- "Unable to save the indexes because path_to_cache is None",
- )
-
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- return dataset_index, dataset_sample_index
-
- log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} indices")
-
- log_single_rank(
- logger, logging.INFO, f"\tLoad the dataset index from {path_to_dataset_index}"
- )
- t_beg = time.time()
- dataset_index = numpy.load(path_to_dataset_index, allow_pickle=True, mmap_mode='r')
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(
- logger,
- logging.INFO,
- f"\tLoad the dataset sample index from {path_to_dataset_sample_index}",
- )
- t_beg = time.time()
- dataset_sample_index = numpy.load(
- path_to_dataset_sample_index, allow_pickle=True, mmap_mode='r'
- )
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- return dataset_index, dataset_sample_index
diff --git a/megatron/core/datasets/blended_megatron_dataset_builder.py b/megatron/core/datasets/blended_megatron_dataset_builder.py
deleted file mode 100644
index 3dee4e469616efe867b72c3b95e8f4ef4cf18720..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/blended_megatron_dataset_builder.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-import math
-from typing import Any, List, Optional, Tuple, Type, Union
-
-import numpy
-import torch
-
-from megatron.core.datasets.blended_dataset import BlendedDataset
-from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
-from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
-from megatron.core.datasets.megatron_dataset import MegatronDataset
-from megatron.core.datasets.utils import Split, normalize
-
-logger = logging.getLogger(__name__)
-
-DistributedDataset = Union[BlendedDataset, MegatronDataset, MMapIndexedDataset]
-
-
-class BlendedMegatronDatasetBuilder(object):
- """Builder class for the BlendedDataset and MegatronDataset classes
-
- Args:
- cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset
-
- sizes (List[int]): The minimum number of total samples to draw from each split, varies
- with blend
-
- config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
- """
-
- def __init__(
- self, cls: Type[MegatronDataset], sizes: List[int], config: BlendedMegatronDatasetConfig,
- ):
- self.cls = cls
- self.sizes = sizes
- self.config = config
-
- def build(self) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]:
- """Build all dataset splits according to the provided blend(s)
-
- This method is distributed-aware and must be called on all ranks.
-
- The dataset splits returned can vary according to the config. Supply config.blend and
- config.split to build BlendedDataset and/or MegatronDataset splits from the same
- distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset
- splits from separate distributions.
-
- Returns:
- List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either
- MegatronDataset or BlendedDataset (or None) per split
- """
- return self._build_blended_dataset_splits()
-
- def _build_blended_dataset_splits(
- self,
- ) -> List[Optional[Union[BlendedDataset, MegatronDataset]]]:
- """Build all dataset splits according to the provided blend(s)
-
- See the BlendedMegatronDatasetBuilder.build alias for more information.
-
- Returns:
- List[Optional[Union[BlendedDataset, MegatronDataset]]]: A list of either
- MegatronDataset or BlendedDataset (or None) per split
- """
-
- if getattr(self.config, "blend"):
- blend = getattr(self.config, "blend")
- split = getattr(self.config, "split_vector")
-
- # Blend consists of a single prefix
- if len(blend) == 1:
- return self._build_megatron_dataset_splits(blend[0], split, self.sizes)
-
- # Blend consists of multiple weights and prefixes
- (
- prefix_per_dataset,
- weight_per_dataset,
- sizes_per_dataset,
- ) = _get_prefixes_weights_and_sizes_for_blend(blend, self.sizes)
-
- megatron_datasets = [[] for _ in range(len(Split))]
-
- for i in range(len(prefix_per_dataset)):
- megatron_datasets_split = self._build_megatron_dataset_splits(
- prefix_per_dataset[i], split, sizes_per_dataset[i]
- )
- for j in range(len(megatron_datasets_split)):
- megatron_datasets[j].append(megatron_datasets_split[j])
-
- # Sum over all contributing datasets, per split
- size_per_split = list(map(sum, zip(*sizes_per_dataset)))
-
- blended_datasets = []
-
- for i in range(len(megatron_datasets)):
- is_none = map(lambda _: _ is None, megatron_datasets[i])
-
- if split[i] == 0.0:
- assert all(is_none)
- blended_datasets.append(None)
- else:
- assert all(is_none) or not any(is_none)
- blended_datasets.append(
- self._build_generic_dataset(
- BlendedDataset,
- megatron_datasets[i],
- weight_per_dataset,
- size_per_split[i],
- self.config,
- )
- )
-
- return blended_datasets
-
- else:
- blended_datasets = []
- for i in range(len(Split)):
- blend = getattr(self.config, "blend_per_split")[i]
-
- # Blend is not provided
- if not blend:
- blended_datasets.append(None)
- continue
-
- split_spoof = [0.0] * len(Split)
- split_spoof[i] = 1.0
- sizes_spoof = [0] * len(Split)
- sizes_spoof[i] = self.sizes[i]
-
- # Blend consists of a sigle prefix
- if len(blend) == 1:
- blended_datasets.append(
- self._build_megatron_dataset_splits(blend[0], split_spoof, sizes_spoof)[i]
- )
-
- # Blend consists of multiple weights and prefixes
- else:
- (
- prefix_per_dataset,
- weight_per_dataset,
- sizes_per_dataset,
- ) = _get_prefixes_weights_and_sizes_for_blend(blend, sizes_spoof)
-
- megatron_datasets = []
- for j in range(len(prefix_per_dataset)):
- megatron_datasets.append(
- self._build_megatron_dataset_splits(
- prefix_per_dataset[j], split_spoof, sizes_per_dataset[j],
- )[i]
- )
-
- size_per_split = list(map(sum, zip(*sizes_per_dataset)))
-
- blended_datasets.append(
- self._build_generic_dataset(
- BlendedDataset,
- megatron_datasets,
- weight_per_dataset,
- size_per_split[i],
- self.config,
- )
- )
-
- return blended_datasets
-
- def _build_megatron_dataset_splits(
- self, path_prefix: str, split: List[float], sizes: List[int],
- ) -> List[Optional[MegatronDataset]]:
- """Build each MegatronDataset split from a single MMapIndexedDataset
-
- Args:
- path_prefix (str): The MMapIndexedDataset .bin and .idx file prefix
-
- split (List[float]): The dataset split ratios (must sum to 1.00)
-
- sizes (List[int]): The number of total samples to draw from each split
-
- Returns:
- List[Optional[MegatronDataset]]: The MegatronDatset (or None) per split
- """
- indexed_dataset = self._build_generic_dataset(
- MMapIndexedDataset, path_prefix, self.cls.is_multimodal()
- )
-
- if indexed_dataset is not None:
- if self.cls.is_split_by_sequence():
- split_idx_bounds = _get_split_indices(
- split, indexed_dataset.sequence_lengths.shape[0]
- )
- else:
- split_idx_bounds = _get_split_indices(
- split, indexed_dataset.document_indices.shape[0] - 1
- )
- split_indices = [
- numpy.arange(
- start=split_idx_bounds[i],
- stop=split_idx_bounds[i + 1],
- step=1,
- dtype=numpy.int32,
- )
- for i, _ in enumerate(Split)
- ]
- else:
- split_indices = [None for _ in Split]
-
- megatron_datasets = []
- for i, _split in enumerate(Split):
- if split[i] == 0.0:
- megatron_datasets.append(None)
- else:
- megatron_datasets.append(
- self._build_generic_dataset(
- self.cls, indexed_dataset, split_indices[i], sizes[i], _split, self.config
- )
- )
-
- return megatron_datasets
-
- def _build_generic_dataset(
- self, cls: Type[DistributedDataset], *args: Any,
- ) -> Optional[DistributedDataset]:
- """Build the DistributedDataset
-
- Return None if and only if the underlying MegatronDataset class is not built on the current
- rank and torch.distributed is initialized.
-
- Args:
- cls (Type[DistributedDataset]): The DistributedDataset class to be built
-
- args (Tuple[Any]): The positional arguments used to build the provided
- DistributedDataset class
-
- Raises:
- Exception: When the dataset constructor raises an OSError
-
- Returns:
- Optional[DistributedDataset]: The DistributedDataset instantion or None
- """
- if torch.distributed.is_initialized():
- rank = torch.distributed.get_rank()
-
- dataset = None
-
- # First, build on rank 0
- if rank == 0 and getattr(self.config, "is_built_on_rank")():
- try:
- dataset = cls(*args)
- except OSError as err:
- log = (
- f"Failed to write dataset materials to the data cache directory. "
- + f"Please supply a directory to which you have write access via "
- + f"the path_to_cache attribute in BlendedMegatronDatasetConfig and "
- + f"retry. Refer to the preserved traceback above for more information."
- )
- raise Exception(log) from err
-
- torch.distributed.barrier()
-
- # After, build on other ranks
- if rank != 0 and getattr(self.config, "is_built_on_rank")():
- dataset = cls(*args)
-
- return dataset
-
- return cls(*args)
-
-
-def _get_split_indices(split: List[float], num_elements: int) -> List[int]:
- """Determine the document index bounds per split
-
- Args:
- split (List[float]): The dataset split ratios (must sum to 1.00)
-
- num_elements (int): The number of elements, e.g. sequences or documents, available for
- the split
-
- Returns:
- List[int]: The indices for all three splits e.g. [0, 900, 990, 1000] for a 1000-document
- set and a [90.0, 9.0, 1.0] split
- """
- split_indices = [0]
- for split_pct in split:
- split_indices.append(split_indices[-1] + int(round(split_pct * float(num_elements))))
- split_indices[1:] = list(
- map(lambda _: _ - (split_indices[-1] - num_elements), split_indices[1:])
- )
-
- assert len(split_indices) == len(split) + 1
- assert split_indices[-1] == num_elements
-
- return split_indices
-
-
-def _get_prefixes_weights_and_sizes_for_blend(
- blend: List[str], target_num_samples_per_split: List[int]
-) -> Tuple[List[str], List[float], List[List[int]]]:
- """Determine the contribution of the MegatronDataset splits to the BlendedDataset splits
-
- Args:
- blend (List[str]): e.g. ["30", "path/to/dataset_1_prefix", "70",
- "path/to/dataset_2_prefix"]
-
- target_num_samples_per_split (List[int]): The number of samples to target for each
- BlendedDataset split
-
- Returns:
- Tuple[List[str], List[float], List[List[int]]]: The prefix strings e.g.
- ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], the normalized weights e.g.
- [0.3, 0.7], and the number of samples to request per MegatronDataset per split
- """
- weights, prefixes = zip(
- *[(float(blend[i]), blend[i + 1].strip()) for i in range(0, len(blend), 2)]
- )
-
- weights = normalize(weights)
-
- # Use 0.5% target margin to ensure we satiate the network
- sizes_per_dataset = [
- [
- int(math.ceil(target_num_samples * weight * 1.005))
- for target_num_samples in target_num_samples_per_split
- ]
- for weight in weights
- ]
-
- return prefixes, weights, sizes_per_dataset
diff --git a/megatron/core/datasets/blended_megatron_dataset_config.py b/megatron/core/datasets/blended_megatron_dataset_config.py
deleted file mode 100644
index b7e242a4be1ec7e62d0ec89cd00ea1b35d32042b..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/blended_megatron_dataset_config.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-import re
-from dataclasses import dataclass, field
-from typing import Callable, List, Optional
-
-import torch
-
-from megatron.core.datasets.utils import Split, log_single_rank, normalize
-from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class BlendedMegatronDatasetConfig:
- """Configuration object for megatron-core blended and megatron datasets
-
- Attributes:
- is_built_on_rank (Callable): A callable which returns True if the dataset should be built
- on the current rank. It should be Megatron Core parallelism aware i.e. global rank, group
- rank, and virtual rank may inform its return value.
-
- random_seed (int): The seed for all RNG during dataset creation.
-
- sequence_length (int): The sequence length.
-
- blend (Optional[List[str]]): The blend string, consisting of either a single dataset or a
- flattened sequential sequence of weight-dataset pairs. For exampe, ["dataset-path1"] and
- ["50", "dataset-path1", "50", "dataset-path2"] are both valid. Not to be used with
- 'blend_per_split'. Defaults to None.
-
- blend_per_split (blend_per_split: Optional[List[Optional[List[str]]]]): A set of blend
- strings, as defined above, one for each split distribution. Not to be used with 'blend'.
- Defauls to None.
-
- split (Optional[str]): The split string, a comma separated weighting for the dataset splits
- when drawing samples from a single distribution. Not to be used with 'blend_per_split'.
- Defaults to None.
-
- split_vector: (Optional[List[float]]): The split string, parsed and normalized post-
- initialization. Not to be passed to the constructor.
-
- path_to_cache (str): Where all re-useable dataset indices are to be cached.
- """
-
- is_built_on_rank: Callable
-
- random_seed: int
-
- sequence_length: int
-
- blend: Optional[List[str]] = None
-
- blend_per_split: Optional[List[Optional[List[str]]]] = None
-
- split: Optional[str] = None
-
- split_vector: Optional[List[float]] = field(init=False, default=None)
-
- path_to_cache: str = None
-
- def __post_init__(self):
- """Python dataclass method that is used to modify attributes after initialization. See
- https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
- """
- if torch.distributed.is_initialized():
- gb_rank = torch.distributed.get_rank()
- vp_rank = get_virtual_pipeline_model_parallel_rank()
- if gb_rank == 0 and (vp_rank == 0 or vp_rank is None):
- assert (
- self.is_built_on_rank()
- ), "is_built_on_rank must return True when global rank = 0 and vp rank = 0"
-
- if self.blend_per_split is not None and any(self.blend_per_split):
- assert self.blend is None, "blend and blend_per_split are incompatible"
- assert len(self.blend_per_split) == len(
- Split
- ), f"blend_per_split must contain {len(Split)} blends"
- if self.split is not None:
- self.split = None
- log_single_rank(logger, logging.WARNING, f"Let split = {self.split}")
- else:
- assert self.blend is not None, "one of either blend or blend_per_split must be provided"
- assert self.split is not None, "both blend and split must be provided"
- self.split_vector = _parse_and_normalize_split(self.split)
- log_single_rank(logger, logging.INFO, f"Let split_vector = {self.split_vector}")
-
-
-@dataclass
-class GPTDatasetConfig(BlendedMegatronDatasetConfig):
- """Configuration object for megatron-core blended and megatron GPT datasets
-
- Attributes:
- return_document_ids (bool): Whether to return the document ids when querying the dataset.
- """
-
- return_document_ids: bool = False
-
-
-def _parse_and_normalize_split(split: str) -> List[float]:
- """Parse the dataset split ratios from a string
-
- Args:
- split (str): The train valid test split string e.g. "99,1,0"
-
- Returns:
- List[float]: The trian valid test split ratios e.g. [99.0, 1.0, 0.0]
- """
- split = list(map(float, re.findall(r"[.0-9]+", split)))
- split = split + [0.0 for _ in range(len(Split) - len(split))]
-
- assert len(split) == len(Split)
- assert all(map(lambda _: _ >= 0.0, split))
-
- split = normalize(split)
-
- return split
diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py
deleted file mode 100644
index 1004e649a297853a4c799ea99ec0e6819f3c9fdf..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/gpt_dataset.py
+++ /dev/null
@@ -1,460 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-import os
-import time
-from typing import Dict, Tuple
-
-import numpy
-import torch
-
-from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig
-from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
-from megatron.core.datasets.megatron_dataset import MegatronDataset
-from megatron.core.datasets.utils import Split, log_single_rank
-
-logger = logging.getLogger(__name__)
-
-
-class GPTDataset(MegatronDataset):
- """The base GPT dataset
-
- Args:
- indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
- MegatronDataset
-
- indexed_indices (numpy.ndarray): The set of the documents indices to expose
-
- num_samples (int): The number of samples to draw from the indexed dataset
-
- index_split (Split): The indexed_indices Split
-
- config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters
- """
-
- def __init__(
- self,
- indexed_dataset: MMapIndexedDataset,
- indexed_indices: numpy.ndarray,
- num_samples: int,
- index_split: Split,
- config: GPTDatasetConfig,
- ) -> None:
- super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config)
-
- def _finalize(self) -> None:
- """Abstract method implementation
-
- Load or build/cache the document, sample, and shuffle indices
- """
- assert isinstance(self.config, GPTDatasetConfig)
-
- (
- self.document_index,
- self.sample_index,
- self.shuffle_index,
- ) = self._build_document_sample_shuffle_indices()
-
- def __len__(self) -> int:
- """Abstract method implementation
-
- Returns:
- int: The length of the dataset
- """
- return self.sample_index.shape[0] - 1
-
- def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
- """Abstract method implementation
-
- Args:
- idx (int): The index into the dataset
-
- Returns:
- Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a
- dictionary
- """
- text, document_ids = self._query_document_sample_shuffle_indices(idx)
- if getattr(self.config, "return_document_ids"):
- return {"text": text, "document_ids": document_ids}
- else:
- return {"text": text}
-
- @staticmethod
- def is_multimodal() -> bool:
- """Abstract method implementation
-
- Returns:
- bool: False
- """
- return False
-
- @staticmethod
- def is_split_by_sequence() -> bool:
- """Abstract method implementation
-
- Returns:
- bool: True
- """
- return True
-
- def _query_document_sample_shuffle_indices(
- self, idx: int
- ) -> Tuple[numpy.ndarray, numpy.ndarray]:
- """Get the text (token ids) and document ids for a given index
-
- Args:
- idx (int): The index into the dataset
-
- Returns:
- Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
- """
- # Do the shuffle mapping
- idx = self.shuffle_index[idx]
-
- # Get the beginning and end documents and offsets
- doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
- doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]
-
- document_ids = []
- sample_parts = []
-
- # Sample spans a single document
- if doc_index_beg == doc_index_end:
- # Add the document id
- document_ids.append(self.document_index[doc_index_beg])
-
- # Add the entire sample
- sample_parts.append(
- self.indexed_dataset.get(
- self.document_index[doc_index_beg],
- offset=doc_index_beg_offset,
- length=doc_index_end_offset - doc_index_beg_offset + 1,
- )
- )
-
- # Sample spans multiple documents
- else:
- for i in range(doc_index_beg, doc_index_end + 1):
- # Add the document id
- document_ids.append(self.document_index[i])
-
- # Add the sample part
- offset = 0 if i > doc_index_beg else doc_index_beg_offset
- length = None if i < doc_index_end else doc_index_end_offset + 1
- sample_parts.append(
- self.indexed_dataset.get(self.document_index[i], offset=offset, length=length)
- )
-
- return (
- numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64),
- numpy.array(document_ids, dtype=numpy.int64),
- )
-
- def _build_document_sample_shuffle_indices(
- self,
- ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
- """Build the document index, the sample index, and the shuffle index
-
- The document index:
- -- 1-D
- -- An ordered array of document ids
-
- The sample index:
- -- 2-D
- -- The document indices and offsets which mark the start of every sample
-
- The shuffle index:
- -- 1-D
- -- A random permutation of index range of the sample index
-
- Returns:
- Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the
- shuffle index
-
- TODO: Explain the 80% threshold
- """
- path_to_cache = getattr(self.config, "path_to_cache")
- if path_to_cache is None:
- path_to_cache = os.path.join(
- self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
- )
-
- get_path_to = lambda suffix: os.path.join(
- path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
- )
- path_to_description = get_path_to("description.txt")
- path_to_document_index = get_path_to("document_index.npy")
- path_to_sample_index = get_path_to("sample_index.npy")
- path_to_shuffle_index = get_path_to("shuffle_index.npy")
- cache_hit = all(
- map(
- os.path.isfile,
- [
- path_to_description,
- path_to_document_index,
- path_to_sample_index,
- path_to_shuffle_index,
- ],
- )
- )
-
- num_tokens_per_epoch = _get_num_tokens_per_epoch(self.indexed_dataset, self.indexed_indices)
-
- sequence_length = getattr(self.config, "sequence_length")
-
- num_epochs = _get_num_epochs(num_tokens_per_epoch, sequence_length, self.num_samples)
-
- if not cache_hit and torch.distributed.get_rank() == 0:
- log_single_rank(
- logger,
- logging.INFO,
- f"Build and save the {type(self).__name__} {self.index_split.name} indices",
- )
-
- if num_epochs == 1:
- separate_final_epoch = False
- else:
- # Get the number of samples for the last epoch
- num_samples_sans_final_epoch = (
- (num_epochs - 1) * num_tokens_per_epoch - 1
- ) // sequence_length
- num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
- num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length
-
- # num_samples_from_final_epoch should be non-negative
- assert num_samples_from_final_epoch >= 0
-
- # num_samples_from_final_epoch should not exceed max value
- assert num_samples_from_final_epoch <= num_samples_per_epoch + 1
-
- # Separate the final epoch if it falls below the threshold
- threshold = 0.80
- separate_final_epoch = num_samples_from_final_epoch < int(
- threshold * num_samples_per_epoch
- )
-
- log_single_rank(
- logger,
- logging.DEBUG,
- f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
- )
- log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
- log_single_rank(
- logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
- )
-
- log_single_rank(
- logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
- )
-
- numpy_random_state = numpy.random.RandomState(getattr(self.config, "random_seed"))
-
- os.makedirs(path_to_cache, exist_ok=True)
-
- # Write the description
- with open(path_to_description, "wt") as writer:
- writer.write(self.unique_description)
-
- # Build the document index
- log_single_rank(
- logger,
- logging.INFO,
- f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}",
- )
- t_beg = time.time()
- document_index = _build_document_index(
- self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch
- )
- numpy.save(path_to_document_index, document_index, allow_pickle=True)
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- # Build the sample index
- log_single_rank(
- logger,
- logging.INFO,
- f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}",
- )
- t_beg = time.time()
- from megatron.core.datasets import helpers
-
- assert document_index.dtype == numpy.int32
- assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32
- sample_index = helpers.build_sample_idx(
- self.indexed_dataset.sequence_lengths,
- document_index,
- sequence_length,
- num_epochs,
- num_tokens_per_epoch,
- )
- numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- # Build the shuffle index
- log_single_rank(
- logger,
- logging.INFO,
- f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}",
- )
- t_beg = time.time()
- if separate_final_epoch:
- shuffle_index = _build_shuffle_index(
- num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state
- )
- else:
- shuffle_index = _build_shuffle_index(
- sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state
- )
- numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(
- logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
- )
-
- log_single_rank(
- logger,
- logging.INFO,
- f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
- )
- t_beg = time.time()
- document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r')
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(
- logger,
- logging.INFO,
- f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
- )
- t_beg = time.time()
- sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(
- logger,
- logging.INFO,
- f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
- )
- t_beg = time.time()
- shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r')
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(
- logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
- )
- log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
-
- return document_index, sample_index, shuffle_index
-
-
-def _get_num_tokens_per_epoch(indexed_dataset: MMapIndexedDataset, indices: numpy.ndarray) -> int:
- """Calculate the number of tokens in a single epoch
-
- Args:
- indexed_dataset (MMapIndexedDataset): The underlying MMapIndexedDataset
-
- indices (numpy.ndarray): The subset of indices into the underlying MMapIndexedDataset
-
- Returns:
- int: The number of tokens in a single epoch
- """
- return numpy.sum(indexed_dataset.sequence_lengths[indices])
-
-
-def _get_num_epochs(num_tokens_per_epoch: int, seq_length: int, num_samples: int) -> int:
- """Calculate the number of epochs
-
- Args:
- num_tokens_per_epoch (int): The number of tokens in a single epoch
-
- seq_length (int): The sequence length in tokens
-
- num_samples (int): The total number of samples
-
- Returns:
- int: The number of epochs
- """
- num_epochs = 0
- num_tokens = 0
- while True:
- num_epochs += 1
- num_tokens += num_tokens_per_epoch
- # -1 is because we need to retrieve seq_length + 1 token each time
- # but the last token will overlap with the first token of the next
- # sample except for the last sample.
- if ((num_tokens - 1) // seq_length) >= num_samples:
- return num_epochs
-
-
-def _build_document_index(
- documents: numpy.ndarray,
- num_epochs: int,
- numpy_random_state: numpy.random.RandomState,
- separate_final_epoch: bool,
-) -> numpy.ndarray:
- """Build an array with length = num epochs * num documents
-
- Args:
- documents (numpy.ndarray): the subset of exposed document indices
-
- num_epochs (int): The number of epochs
-
- numpy_random_state (numpy.random.RandomState): The NumPy random state
-
- separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle
-
- Returns:
- numpy.ndarray: The document index
-
- TODO: Explain separate_final_epoch
- """
- if not separate_final_epoch or num_epochs == 1:
- document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1]
- document_index[:] = documents
- document_index = document_index.reshape(-1)
- document_index = document_index.astype(numpy.int32)
- numpy_random_state.shuffle(document_index)
- return document_index
-
- doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False)
- doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False)
- return numpy.concatenate((doc_idx_first, doc_idx_last))
-
-
-def _build_shuffle_index(
- num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState
-) -> numpy.ndarray:
- """Build the range [0, size) and shuffle
-
- Args:
- num_samples (int): The size of the first shuffle range [0, num_samples)
-
- total_size (int): The size of the entire index. If larger than 'num_samples', it defines
-
- the second shuffle range [num_samples, total_size)
-
- numpy_random_state (numpy.random.RandomState): The NumPy random state
-
- Returns:
- numpy.ndarray: The shuffle index
-
- TODO: Explain [0, num_samples) [num_samples, total_size) split
- """
- dtype_ = numpy.uint32
- if total_size >= (numpy.iinfo(numpy.uint32).max - 1):
- dtype_ = numpy.int64
-
- shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
- numpy_random_state.shuffle(shuffle_idx_first)
- if num_samples == total_size:
- return shuffle_idx_first
-
- shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
- numpy_random_state.shuffle(shuffle_idx_last)
-
- return numpy.concatenate((shuffle_idx_first, shuffle_idx_last))
diff --git a/megatron/core/datasets/helpers.cpp b/megatron/core/datasets/helpers.cpp
deleted file mode 100644
index 4e1b3dbc931c4cd1fe23dc8c1e2fba029d104086..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/helpers.cpp
+++ /dev/null
@@ -1,765 +0,0 @@
-/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
-
-/* Helper methods for fast index mapping builds */
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace py = pybind11;
-using namespace std;
-
-const int32_t LONG_SENTENCE_LEN = 512;
-
-void build_blending_indices(py::array_t &dataset_index,
- py::array_t &dataset_sample_index,
- const py::array_t &weights,
- const int32_t num_datasets,
- const int64_t size, const bool verbose)
-{
- /* Given multiple datasets and a weighting array, build samples
- such that it follows those wieghts.*/
-
- if (verbose)
- {
- std::cout << "> building indices for blended datasets ..." << std::endl;
- }
-
- // Get the pointer access without the checks.
- auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
- auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
- auto weights_ptr = weights.unchecked<1>();
-
- // Initialize buffer for number of samples used for each dataset.
- int64_t current_samples[num_datasets];
- for (int64_t i = 0; i < num_datasets; ++i)
- {
- current_samples[i] = 0;
- }
-
- // For each sample:
- for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx)
- {
-
- // Determine where the max error in sampling is happening.
- auto sample_idx_double = std::max(static_cast(sample_idx), 1.0);
- int64_t max_error_index = 0;
- double max_error = weights_ptr[0] * sample_idx_double -
- static_cast(current_samples[0]);
- for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx)
- {
- double error = weights_ptr[dataset_idx] * sample_idx_double -
- static_cast(current_samples[dataset_idx]);
- if (error > max_error)
- {
- max_error = error;
- max_error_index = dataset_idx;
- }
- }
-
- // Populate the indices.
- dataset_index_ptr[sample_idx] = static_cast(max_error_index);
- dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
-
- // Update the total samples.
- current_samples[max_error_index] += 1;
- }
-
- // print info
- if (verbose)
- {
- std::cout << " > sample ratios:" << std::endl;
- for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx)
- {
- auto ratio = static_cast(current_samples[dataset_idx]) /
- static_cast(size);
- std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
- }
- }
-}
-
-py::array build_sample_idx(const py::array_t &sizes_,
- const py::array_t &doc_idx_,
- const int32_t seq_length,
- const int32_t num_epochs,
- const int64_t tokens_per_epoch)
-{
- /* Sample index (sample_idx) is used for gpt2 like dataset for which
- the documents are flattened and the samples are built based on this
- 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
- where [..., 0] contains the index into `doc_idx` and [..., 1] is the
- starting offset in that document.*/
-
- // Consistency checks.
- assert(seq_length > 1);
- assert(num_epochs > 0);
- assert(tokens_per_epoch > 1);
-
- // Remove bound checks.
- auto sizes = sizes_.unchecked<1>();
- auto doc_idx = doc_idx_.unchecked<1>();
-
- // Mapping and it's length (1D).
- int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
- int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
-
- // Index into sample_idx.
- int64_t sample_index = 0;
- // Index into doc_idx.
- int64_t doc_idx_index = 0;
- // Begining offset for each document.
- int32_t doc_offset = 0;
- // Start with first document and no offset.
- sample_idx[2 * sample_index] = doc_idx_index;
- sample_idx[2 * sample_index + 1] = doc_offset;
- ++sample_index;
-
- while (sample_index <= num_samples)
- {
- // Start with a fresh sequence.
- int32_t remaining_seq_length = seq_length + 1;
- while (remaining_seq_length != 0)
- {
- // Get the document length.
- auto doc_id = doc_idx[doc_idx_index];
- auto doc_length = sizes[doc_id] - doc_offset;
- // And add it to the current sequence.
- remaining_seq_length -= doc_length;
- // If we have more than a full sequence, adjust offset and set
- // remaining length to zero so we return from the while loop.
- // Note that -1 here is for the same reason we have -1 in
- // `_num_epochs` calculations.
- if (remaining_seq_length <= 0)
- {
- doc_offset += (remaining_seq_length + doc_length - 1);
- remaining_seq_length = 0;
- }
- else
- {
- // Otherwise, start from the begining of the next document.
- ++doc_idx_index;
- doc_offset = 0;
- }
- }
- // Record the sequence.
- sample_idx[2 * sample_index] = doc_idx_index;
- sample_idx[2 * sample_index + 1] = doc_offset;
- ++sample_index;
- }
-
- // Method to deallocate memory.
- py::capsule free_when_done(sample_idx, [](void *mem_)
- {
- int32_t *mem = reinterpret_cast(mem_);
- delete[] mem; });
-
- // Return the numpy array.
- const auto byte_size = sizeof(int32_t);
- return py::array(std::vector{num_samples + 1, 2}, // shape
- {2 * byte_size, byte_size}, // C-style contiguous strides
- sample_idx, // the data pointer
- free_when_done); // numpy array references
-}
-
-inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
- const int32_t max_length,
- std::mt19937 &rand32_gen)
-{
- /* Training sample length. */
- if (short_seq_ratio == 0)
- {
- return max_length;
- }
- const auto random_number = rand32_gen();
- if ((random_number % short_seq_ratio) == 0)
- {
- return 2 + random_number % (max_length - 1);
- }
- return max_length;
-}
-
-template
-py::array build_mapping_impl(const py::array_t &docs_,
- const py::array_t &sizes_,
- const int32_t num_epochs,
- const uint64_t max_num_samples,
- const int32_t max_seq_length,
- const double short_seq_prob,
- const int32_t seed,
- const bool verbose,
- const int32_t min_num_sent)
-{
- /* Build a mapping of (start-index, end-index, sequence-length) where
- start and end index are the indices of the sentences in the sample
- and sequence-length is the target sequence length.
- */
-
- // Consistency checks.
- assert(num_epochs > 0);
- assert(max_seq_length > 1);
- assert(short_seq_prob >= 0.0);
- assert(short_seq_prob <= 1.0);
- assert(seed > 0);
-
- // Remove bound checks.
- auto docs = docs_.unchecked<1>();
- auto sizes = sizes_.unchecked<1>();
-
- // For efficiency, convert probability to ratio. Note: rand() generates int.
- int32_t short_seq_ratio = 0;
- if (short_seq_prob > 0)
- {
- short_seq_ratio = static_cast(round(1.0 / short_seq_prob));
- }
-
- if (verbose)
- {
- const auto sent_start_index = docs[0];
- const auto sent_end_index = docs[docs_.shape(0) - 1];
- const auto num_sentences = sent_end_index - sent_start_index;
- cout << " using:" << endl
- << std::flush;
- cout << " number of documents: " << docs_.shape(0) - 1 << endl
- << std::flush;
- cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
- << std::flush;
- cout << " total number of sentences: " << num_sentences << endl
- << std::flush;
- cout << " number of epochs: " << num_epochs << endl
- << std::flush;
- cout << " maximum number of samples: " << max_num_samples << endl
- << std::flush;
- cout << " maximum sequence length: " << max_seq_length << endl
- << std::flush;
- cout << " short sequence probability: " << short_seq_prob << endl
- << std::flush;
- cout << " short sequence ration (1/prob): " << short_seq_ratio << endl
- << std::flush;
- cout << " seed: " << seed << endl
- << std::flush;
- }
-
- // Mapping and it's length (1D).
- int64_t num_samples = -1;
- DocIdx *maps = NULL;
-
- // Perform two iterations, in the first iteration get the size
- // and allocate memory and in the second iteration populate the map.
- bool second = false;
- for (int32_t iteration = 0; iteration < 2; ++iteration)
- {
-
- // Set the seed so both iterations produce the same results.
- std::mt19937 rand32_gen(seed);
-
- // Set the flag on second iteration.
- second = (iteration == 1);
-
- // Counters:
- uint64_t empty_docs = 0;
- uint64_t one_sent_docs = 0;
- uint64_t long_sent_docs = 0;
-
- // Current map index.
- uint64_t map_index = 0;
-
- // For each epoch:
- for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
- {
- if (map_index >= max_num_samples)
- {
- if (verbose && (!second))
- {
- cout << " reached " << max_num_samples << " samples after "
- << epoch << " epochs ..." << endl
- << std::flush;
- }
- break;
- }
- // For each document:
- for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
- {
-
- // Document sentences are in [sent_index_first, sent_index_last)
- const auto sent_index_first = docs[doc];
- const auto sent_index_last = docs[doc + 1];
-
- // At the begining of the document previous index is the
- // start index.
- auto prev_start_index = sent_index_first;
-
- // Remaining documents.
- auto num_remain_sent = sent_index_last - sent_index_first;
-
- // Some bookkeeping
- if ((epoch == 0) && (!second))
- {
- if (num_remain_sent == 0)
- {
- ++empty_docs;
- }
- if (num_remain_sent == 1)
- {
- ++one_sent_docs;
- }
- }
-
- // Detect documents with long sentences.
- bool contains_long_sentence = false;
- if (num_remain_sent > 1)
- {
- for (auto sent_index = sent_index_first;
- sent_index < sent_index_last; ++sent_index)
- {
- if (sizes[sent_index] > LONG_SENTENCE_LEN)
- {
- if ((epoch == 0) && (!second))
- {
- ++long_sent_docs;
- }
- contains_long_sentence = true;
- break;
- }
- }
- }
-
- // If we have more than two sentences.
- if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
- {
-
- // Set values.
- auto seq_len = int32_t{0};
- auto num_sent = int32_t{0};
- auto target_seq_len = get_target_sample_len(short_seq_ratio,
- max_seq_length,
- rand32_gen);
-
- // Loop through sentences.
- for (auto sent_index = sent_index_first;
- sent_index < sent_index_last; ++sent_index)
- {
-
- // Add the size and number of sentences.
- seq_len += sizes[sent_index];
- ++num_sent;
- --num_remain_sent;
-
- // If we have reached the target length.
- // and if not only one sentence is left in the document.
- // and if we have at least two sentneces.
- // and if we have reached end of the document.
- if (((seq_len >= target_seq_len) &&
- (num_remain_sent > 1) &&
- (num_sent >= min_num_sent)) ||
- (num_remain_sent == 0))
- {
-
- // Check for overflow.
- if ((3 * map_index + 2) >
- std::numeric_limits::max())
- {
- cout << "number of samples exceeded maximum "
- << "allowed by type int64: "
- << std::numeric_limits::max()
- << endl;
- throw std::overflow_error("Number of samples");
- }
-
- // Populate the map.
- if (second)
- {
- const auto map_index_0 = 3 * map_index;
- maps[map_index_0] = static_cast(prev_start_index);
- maps[map_index_0 + 1] = static_cast(sent_index + 1);
- maps[map_index_0 + 2] = static_cast(target_seq_len);
- }
-
- // Update indices / counters.
- ++map_index;
- prev_start_index = sent_index + 1;
- target_seq_len = get_target_sample_len(short_seq_ratio,
- max_seq_length,
- rand32_gen);
- seq_len = 0;
- num_sent = 0;
- }
-
- } // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
-
- if (!second)
- {
- if (verbose)
- {
- cout << " number of empty documents: " << empty_docs << endl
- << std::flush;
- cout << " number of documents with one sentence: " << one_sent_docs << endl
- << std::flush;
- cout << " number of documents with long sentences: " << long_sent_docs << endl
- << std::flush;
- cout << " will create mapping for " << map_index << " samples" << endl
- << std::flush;
- }
- assert(maps == NULL);
- assert(num_samples < 0);
- maps = new DocIdx[3 * map_index];
- num_samples = static_cast(map_index);
- }
-
- } // for (int iteration=0; iteration < 2; ++iteration) {
-
- // Shuffle.
- // We need a 64 bit random number generator as we might have more
- // than 2 billion samples.
- std::mt19937_64 rand64_gen(seed + 1);
- for (auto i = (num_samples - 1); i > 0; --i)
- {
- const auto j = static_cast(rand64_gen() % (i + 1));
- const auto i0 = 3 * i;
- const auto j0 = 3 * j;
- // Swap values.
- swap(maps[i0], maps[j0]);
- swap(maps[i0 + 1], maps[j0 + 1]);
- swap(maps[i0 + 2], maps[j0 + 2]);
- }
-
- // Method to deallocate memory.
- py::capsule free_when_done(maps, [](void *mem_)
- {
- DocIdx *mem = reinterpret_cast(mem_);
- delete[] mem; });
-
- // Return the numpy array.
- const auto byte_size = sizeof(DocIdx);
- return py::array(std::vector{num_samples, 3}, // shape
- {3 * byte_size, byte_size}, // C-style contiguous strides
- maps, // the data pointer
- free_when_done); // numpy array references
-}
-
-py::array build_mapping(const py::array_t &docs_,
- const py::array_t &sizes_,
- const int num_epochs,
- const uint64_t max_num_samples,
- const int max_seq_length,
- const double short_seq_prob,
- const int seed,
- const bool verbose,
- const int32_t min_num_sent)
-{
-
- if (sizes_.size() > std::numeric_limits::max())
- {
- if (verbose)
- {
- cout << " using uint64 for data mapping..." << endl
- << std::flush;
- }
- return build_mapping_impl(docs_, sizes_, num_epochs,
- max_num_samples, max_seq_length,
- short_seq_prob, seed, verbose,
- min_num_sent);
- }
- else
- {
- if (verbose)
- {
- cout << " using uint32 for data mapping..." << endl
- << std::flush;
- }
- return build_mapping_impl(docs_, sizes_, num_epochs,
- max_num_samples, max_seq_length,
- short_seq_prob, seed, verbose,
- min_num_sent);
- }
-}
-
-template
-py::array build_blocks_mapping_impl(const py::array_t &docs_,
- const py::array_t &sizes_,
- const py::array_t &titles_sizes_,
- const int32_t num_epochs,
- const uint64_t max_num_samples,
- const int32_t max_seq_length,
- const int32_t seed,
- const bool verbose,
- const bool use_one_sent_blocks)
-{
- /* Build a mapping of (start-index, end-index, sequence-length) where
- start and end index are the indices of the sentences in the sample
- and sequence-length is the target sequence length.
- */
-
- // Consistency checks.
- assert(num_epochs > 0);
- assert(max_seq_length > 1);
- assert(seed > 0);
-
- // Remove bound checks.
- auto docs = docs_.unchecked<1>();
- auto sizes = sizes_.unchecked<1>();
- auto titles_sizes = titles_sizes_.unchecked<1>();
-
- if (verbose)
- {
- const auto sent_start_index = docs[0];
- const auto sent_end_index = docs[docs_.shape(0) - 1];
- const auto num_sentences = sent_end_index - sent_start_index;
- cout << " using:" << endl
- << std::flush;
- cout << " number of documents: " << docs_.shape(0) - 1 << endl
- << std::flush;
- cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
- << std::flush;
- cout << " total number of sentences: " << num_sentences << endl
- << std::flush;
- cout << " number of epochs: " << num_epochs << endl
- << std::flush;
- cout << " maximum number of samples: " << max_num_samples << endl
- << std::flush;
- cout << " maximum sequence length: " << max_seq_length << endl
- << std::flush;
- cout << " seed: " << seed << endl
- << std::flush;
- }
-
- // Mapping and its length (1D).
- int64_t num_samples = -1;
- DocIdx *maps = NULL;
-
- // Acceptable number of sentences per block.
- int min_num_sent = 2;
- if (use_one_sent_blocks)
- {
- min_num_sent = 1;
- }
-
- // Perform two iterations, in the first iteration get the size
- // and allocate memory and in the second iteration populate the map.
- bool second = false;
- for (int32_t iteration = 0; iteration < 2; ++iteration)
- {
-
- // Set the flag on second iteration.
- second = (iteration == 1);
-
- // Current map index.
- uint64_t map_index = 0;
-
- uint64_t empty_docs = 0;
- uint64_t one_sent_docs = 0;
- uint64_t long_sent_docs = 0;
- // For each epoch:
- for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
- {
- // assign every block a unique id
- int32_t block_id = 0;
-
- if (map_index >= max_num_samples)
- {
- if (verbose && (!second))
- {
- cout << " reached " << max_num_samples << " samples after "
- << epoch << " epochs ..." << endl
- << std::flush;
- }
- break;
- }
- // For each document:
- for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
- {
-
- // Document sentences are in [sent_index_first, sent_index_last)
- const auto sent_index_first = docs[doc];
- const auto sent_index_last = docs[doc + 1];
- const auto target_seq_len = max_seq_length - titles_sizes[doc];
-
- // At the begining of the document previous index is the
- // start index.
- auto prev_start_index = sent_index_first;
-
- // Remaining documents.
- auto num_remain_sent = sent_index_last - sent_index_first;
-
- // Some bookkeeping
- if ((epoch == 0) && (!second))
- {
- if (num_remain_sent == 0)
- {
- ++empty_docs;
- }
- if (num_remain_sent == 1)
- {
- ++one_sent_docs;
- }
- }
- // Detect documents with long sentences.
- bool contains_long_sentence = false;
- if (num_remain_sent >= min_num_sent)
- {
- for (auto sent_index = sent_index_first;
- sent_index < sent_index_last; ++sent_index)
- {
- if (sizes[sent_index] > LONG_SENTENCE_LEN)
- {
- if ((epoch == 0) && (!second))
- {
- ++long_sent_docs;
- }
- contains_long_sentence = true;
- break;
- }
- }
- }
- // If we have enough sentences and no long sentences.
- if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
- {
-
- // Set values.
- auto seq_len = int32_t{0};
- auto num_sent = int32_t{0};
-
- // Loop through sentences.
- for (auto sent_index = sent_index_first;
- sent_index < sent_index_last; ++sent_index)
- {
-
- // Add the size and number of sentences.
- seq_len += sizes[sent_index];
- ++num_sent;
- --num_remain_sent;
-
- // If we have reached the target length.
- // and there are an acceptable number of sentences left
- // and if we have at least the minimum number of sentences.
- // or if we have reached end of the document.
- if (((seq_len >= target_seq_len) &&
- (num_remain_sent >= min_num_sent) &&
- (num_sent >= min_num_sent)) ||
- (num_remain_sent == 0))
- {
-
- // Populate the map.
- if (second)
- {
- const auto map_index_0 = 4 * map_index;
- // Each sample has 4 items: the starting sentence index, ending sentence index,
- // the index of the document from which the block comes (used for fetching titles)
- // and the unique id of the block (used for creating block indexes)
-
- maps[map_index_0] = static_cast(prev_start_index);
- maps[map_index_0 + 1] = static_cast(sent_index + 1);
- maps[map_index_0 + 2] = static_cast(doc);
- maps[map_index_0 + 3] = static_cast(block_id);
- }
-
- // Update indices / counters.
- ++map_index;
- ++block_id;
- prev_start_index = sent_index + 1;
- seq_len = 0;
- num_sent = 0;
- }
- } // for (auto sent_index=sent_index_first; ...
- } // if (num_remain_sent > 1) {
- } // for (int doc=0; doc < num_docs; ++doc) {
- } // for (int epoch=0; epoch < num_epochs; ++epoch) {
-
- if (!second)
- {
- if (verbose)
- {
- cout << " number of empty documents: " << empty_docs << endl
- << std::flush;
- cout << " number of documents with one sentence: " << one_sent_docs << endl
- << std::flush;
- cout << " number of documents with long sentences: " << long_sent_docs << endl
- << std::flush;
- cout << " will create mapping for " << map_index << " samples" << endl
- << std::flush;
- }
- assert(maps == NULL);
- assert(num_samples < 0);
- maps = new DocIdx[4 * map_index];
- num_samples = static_cast(map_index);
- }
-
- } // for (int iteration=0; iteration < 2; ++iteration) {
-
- // Shuffle.
- // We need a 64 bit random number generator as we might have more
- // than 2 billion samples.
- std::mt19937_64 rand64_gen(seed + 1);
- for (auto i = (num_samples - 1); i > 0; --i)
- {
- const auto j = static_cast(rand64_gen() % (i + 1));
- const auto i0 = 4 * i;
- const auto j0 = 4 * j;
- // Swap values.
- swap(maps[i0], maps[j0]);
- swap(maps[i0 + 1], maps[j0 + 1]);
- swap(maps[i0 + 2], maps[j0 + 2]);
- swap(maps[i0 + 3], maps[j0 + 3]);
- }
-
- // Method to deallocate memory.
- py::capsule free_when_done(maps, [](void *mem_)
- {
- DocIdx *mem = reinterpret_cast(mem_);
- delete[] mem; });
-
- // Return the numpy array.
- const auto byte_size = sizeof(DocIdx);
- return py::array(std::vector{num_samples, 4}, // shape
- {4 * byte_size, byte_size}, // C-style contiguous strides
- maps, // the data pointer
- free_when_done); // numpy array references
-}
-
-py::array build_blocks_mapping(const py::array_t &docs_,
- const py::array_t &sizes_,
- const py::array_t &titles_sizes_,
- const int num_epochs,
- const uint64_t max_num_samples,
- const int max_seq_length,
- const int seed,
- const bool verbose,
- const bool use_one_sent_blocks)
-{
-
- if (sizes_.size() > std::numeric_limits::max())
- {
- if (verbose)
- {
- cout << " using uint64 for data mapping..." << endl
- << std::flush;
- }
- return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_,
- num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
- }
- else
- {
- if (verbose)
- {
- cout << " using uint32 for data mapping..." << endl
- << std::flush;
- }
- return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_,
- num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
- }
-}
-
-PYBIND11_MODULE(helpers, m)
-{
- m.def("build_mapping", &build_mapping);
- m.def("build_blocks_mapping", &build_blocks_mapping);
- m.def("build_sample_idx", &build_sample_idx);
- m.def("build_blending_indices", &build_blending_indices);
-}
diff --git a/megatron/core/datasets/indexed_dataset.py b/megatron/core/datasets/indexed_dataset.py
deleted file mode 100644
index cd62160ceab7a662d8fff1d597712f90ea9e7aba..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/indexed_dataset.py
+++ /dev/null
@@ -1,639 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Essentially re-written in entirety
-
-import logging
-import os
-import shutil
-import struct
-import time
-from enum import Enum
-from functools import lru_cache
-from itertools import accumulate
-from types import TracebackType
-from typing import List, Optional, Tuple, Type, Union
-
-import numpy
-import torch
-
-from megatron.core.datasets.utils import log_single_rank
-
-logger = logging.getLogger(__name__)
-
-_INDEX_HEADER = b"MMIDIDX\x00\x00"
-
-
-class DType(Enum):
- """The NumPy data type Enum for writing/reading the MMapIndexedDataset indices
- """
-
- uint8 = 1
- int8 = 2
- int16 = 3
- int32 = 4
- int64 = 5
- float64 = 6
- float32 = 7
- uint16 = 8
-
- @classmethod
- def code_from_dtype(cls, value: Type[numpy.number]) -> int:
- """Get the code from the dtype
-
- Args:
- value (Type[numpy.number]): The dtype
-
- Returns:
- int: The code
- """
- return cls[value.__name__].value
-
- @classmethod
- def dtype_from_code(cls, value: int) -> Type[numpy.number]:
- """Get the dtype from the code
-
- Args:
- value (int): The code
-
- Returns:
- Type[numpy.number]: The dtype
- """
- return getattr(numpy, cls(value).name)
-
- @staticmethod
- def size(key: Union[int, Type[numpy.number]]) -> int:
- """Get the size of the dtype/code in bytes
-
- Args:
- key (Union[int, Type[numpy.number]]): The dtype or code
-
- Raises:
- ValueError: If the key is neither dtype nor integer code
-
- Returns:
- int: The size of the dtype/code in in bytes
- """
- if isinstance(key, int):
- return DType.dtype_from_code(key)().itemsize
- elif numpy.number in key.__mro__:
- return key().itemsize
- else:
- raise ValueError
-
- @staticmethod
- def optimal_dtype(cardinality: Optional[int]) -> Type[numpy.number]:
- """Get the dtype to use for an index of a certain cardinality
-
- Args:
- cardinality (Optional[int]): The number of elements to be indexed
-
- Returns:
- Type[numpy.number]: The dtype to use for the index
- """
- if cardinality is not None and cardinality < 65500:
- return numpy.uint16
- else:
- return numpy.int32
-
-
-class _IndexWriter(object):
- """Object class to write the index (.idx) file
-
- Args:
- idx_path (str): The path to the index file
-
- dtype (Type[numpy.number]): The dtype of the index file
- """
-
- def __init__(self, idx_path: str, dtype: Type[numpy.number]) -> None:
- self.idx_path = idx_path
- self.dtype = dtype
-
- def __enter__(self) -> "_IndexWriter":
- """Enter the context introduced by the 'with' keyword
-
- Returns:
- _IndexWriter: The instance
- """
- self.idx_writer = open(self.idx_path, "wb")
- # fixed, vestigial practice
- self.idx_writer.write(_INDEX_HEADER)
- # fixed, vestigial practice
- self.idx_writer.write(struct.pack(" Optional[bool]:
- """Exit the context introduced by the 'with' keyword
-
- Args:
- exc_type (Optional[Type[BaseException]]): Exception type
-
- exc_val (Optional[BaseException]): Exception value
-
- exc_tb (Optional[TracebackType]): Exception traceback object
-
- Returns:
- Optional[bool]: Whether to silence the exception
- """
- self.idx_writer.close()
-
- def write(
- self,
- sequence_lengths: List[int],
- sequence_modes: Optional[List[int]],
- document_indices: List[int],
- ) -> None:
- """Write the index (.idx) file
-
- Args:
- sequence_lengths (List[int]): The length of each sequence
-
- sequence_modes (Optional[List[int]]): The mode of each sequences
-
- document_indices (List[int]): The seqyebce indices demarcating the end of each document
- """
- sequence_pointers = self._sequence_pointers(sequence_lengths)
-
- # the number of sequences in the dataset
- sequence_count = len(sequence_lengths)
- self.idx_writer.write(struct.pack(" List[int]:
- """Build the sequence pointers per the sequence lengths and dtype size
-
- Args:
- sequence_lengths (List[int]): The length of each sequence
-
- Returns:
- List[int]: The pointer to the beginning of each sequence
- """
- itemsize = DType.size(self.dtype)
- curr_ptr = 0
- list_ptr = []
- for length in sequence_lengths:
- list_ptr.append(curr_ptr)
- curr_ptr += length * itemsize
- return list_ptr
-
-
-class _IndexReader(object):
- """Object class to read the index (.idx) file
-
- Args:
- idx_path (str): The path to the index file
-
- multimodal (bool): Whether the dataset is multimodal
- """
-
- def __init__(self, idx_path: str, multimodal: bool) -> None:
-
- log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}")
-
- with open(idx_path, "rb") as stream:
- header = stream.read(9)
- assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}"
-
- version = struct.unpack(" time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers")
- t_beg = time.time()
- self.sequence_pointers = numpy.frombuffer(
- self.bin_buffer,
- dtype=numpy.int64,
- count=self.sequence_count,
- offset=offset + self.sequence_lengths.nbytes,
- )
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- log_single_rank(logger, logging.INFO, f"\tExtract the document indices")
- t_beg = time.time()
- self.document_indices = numpy.frombuffer(
- self.bin_buffer,
- dtype=numpy.int64,
- count=self.document_count,
- offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes,
- )
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- self.sequence_modes = None
- if multimodal:
- log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes")
- t_beg = time.time()
- self.sequence_modes = numpy.frombuffer(
- self.bin_buffer,
- dtype=numpy.int8,
- count=self.sequence_count,
- offset=offset
- + self.sequence_lengths.nbytes
- + self.sequence_pointers.nbytes
- + self.document_indices.nbytes,
- )
- t_end = time.time()
- log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
-
- assert self.sequence_lengths.shape[0] == len(self)
- assert self.sequence_lengths.shape[0] == self.sequence_count
- assert self.sequence_lengths.shape[0] == self.document_indices[-1]
-
- log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}")
- log_single_rank(
- logger,
- logging.INFO,
- f"> total number of documents: {self.document_indices.shape[0] - 1}",
- )
-
- def __del__(self) -> None:
- """Clean up the object
- """
- self.bin_buffer_mmap._mmap.close()
- del self.bin_buffer_mmap
-
- def __len__(self) -> int:
- """Return the length of the dataset
-
- Returns:
- int: The length of the dataset
- """
- return self.sequence_count
-
- @lru_cache(maxsize=8)
- def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]:
- """Return the pointer, length, and mode at the index
-
- Args:
- idx (int): The index into the dataset
-
- Returns:
- Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at
- the index
- """
- return (
- self.sequence_pointers[idx],
- self.sequence_lengths[idx],
- self.sequence_modes[idx] if self.sequence_modes is not None else None,
- )
-
-
-class MMapIndexedDataset(torch.utils.data.Dataset):
- """The low-level interface dataset class
-
- Args:
- path_prefix (str): The index (.idx) and data (.bin) prefix
-
- multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False.
- """
-
- def __init__(self, path_prefix: str, multimodal: bool = False) -> None:
- super().__init__()
- self.path_prefix = None
- self.multimodal = None
-
- self.index = None
- self.bin_buffer = None
- self.bin_buffer_mmap = None
-
- self.initialize(path_prefix, multimodal)
-
- def initialize(self, path_prefix: str, multimodal: bool) -> None:
- """Initialize the dataset
-
- This method is called by MMapIndexedDataset.__init__ during object creation and by
- MMapIndexedDataset.__setstate__ during un-puckling
-
- Args:
- path_prefix (str): The index (.idx) and data (.bin) prefix
-
- multimodal (bool): Whether the dataset is multimodal
- """
- self.path_prefix = path_prefix
- self.multimodal = multimodal
- self.index = _IndexReader(get_idx_path(self.path_prefix), self.multimodal)
- self.bin_buffer_mmap = numpy.memmap(get_bin_path(self.path_prefix), mode="r", order="C")
- self.bin_buffer = memoryview(self.bin_buffer_mmap)
-
- def __getstate__(self) -> Tuple[str, bool]:
- """Get the state during pickling
-
- Returns:
- Tuple[str, bool]: The state tuple
- """
- return self.path_prefix, self.multimodal
-
- def __setstate__(self, state: Tuple[str, bool]) -> None:
- """Set the state during un-pickling
-
- Args:
- state (Tuple[str, bool]): The state tuple
- """
- path_prefix, multimodal = state
- self.initialize(path_prefix, multimodal)
-
- def __del__(self) -> None:
- """Clean up the object
- """
- if self.bin_buffer_mmap is not None:
- self.bin_buffer_mmap._mmap.close()
- del self.bin_buffer_mmap
- del self.index
-
- def __len__(self) -> int:
- """Return the length of the dataset i.e. the number of sequences in the index
-
- Returns:
- int: The length of the dataset
- """
- return len(self.index)
-
- def __getitem__(
- self, idx: Union[int, numpy.integer, slice]
- ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
- """Return from the dataset
-
- Args:
- idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset
-
- Raises:
- ValueError: When the index slice is non-contiguous
-
- TypeError: When the index is of an unexpected type
-
- Returns:
- Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and
- modes at the index or index slice
- """
- if isinstance(idx, (int, numpy.integer)):
- sequence_pointer, sequence_length, sequence_mode = self.index[idx]
- sequence = numpy.frombuffer(
- self.bin_buffer,
- dtype=self.index.dtype,
- count=sequence_length,
- offset=sequence_pointer,
- )
- return (sequence, sequence_mode) if sequence_mode is not None else sequence
- elif isinstance(idx, slice):
- start, stop, step = idx.indices(len(self))
- if step != 1:
- raise ValueError("Slices into indexed_dataset must be contiguous")
- sequence_lengths = self.index.sequence_lengths[idx]
- sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None
- sequence_offsets = list(accumulate(sequence_lengths))
- sequences = numpy.split(
- numpy.frombuffer(
- self.bin_buffer,
- dtype=self.index.dtype,
- count=sum(sequence_lengths),
- offset=self.index.sequence_pointers[start],
- ),
- sequence_offsets[:-1],
- )
- return (sequences, sequence_modes) if sequence_modes is not None else sequences
- else:
- raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
-
- def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray:
- """Retrieve a single item from the dataset with the option to only
- return a portion of the item.
-
- get(idx) is the same as [idx] but get() does not support slicing.
- """
- sequence_pointer, sequence_length, sequence_mode = self.index[idx]
- if length is None:
- length = sequence_length - offset
- sequence_pointer += offset * DType.size(self.index.dtype)
- sequence = numpy.frombuffer(
- self.bin_buffer, dtype=self.index.dtype, count=length, offset=sequence_pointer
- )
- return (sequence, sequence_mode) if sequence_mode is not None else sequence
-
- @property
- def sequence_lengths(self) -> numpy.ndarray:
- """Get the sequence lengths
-
- Returns:
- numpy.ndarray: The sequence lengths
- """
- return self.index.sequence_lengths
-
- @property
- def document_indices(self) -> numpy.ndarray:
- """Get the document indices
-
- Returns:
- numpy.ndarray: The document indices
- """
- return self.index.document_indices
-
- def get_document_indices(self) -> numpy.ndarray:
- """Get the document indices
-
- This method is slated for deprecation.
-
- Returns:
- numpy.ndarray: The document indices
- """
- return self.index.document_indices
-
- def set_document_indices(self, document_indices: numpy.ndarray) -> None:
- """Set the document indices
-
- This method is slated for deprecation.
-
- Args:
- document_indices (numpy.ndarray): The document indices
- """
- self.index.document_indices = document_indices
-
- @property
- def sequence_modes(self) -> numpy.ndarray:
- """Get the sequence modes
-
- Returns:
- numpy.ndarray: The sequence modes
- """
- return self.index.sequence_modes
-
- @staticmethod
- def exists(path_prefix: str) -> bool:
- """Return whether the MMapIndexedDataset exists on disk at the prefix
-
- Args:
- path_prefix (str): The prefix to the index (.idx) and data (.bin) files
-
- Returns:
- bool: Whether the MMapIndexedDataset exists on disk at the prefix
- """
- return os.path.exists(get_idx_path(path_prefix)) and os.path.exists(
- get_bin_path(path_prefix)
- )
-
-
-class MMapIndexedDatasetBuilder(object):
- """Builder class for the MMapIndexedDataset class
-
- Args:
- bin_path (str): The path to the data (.bin) file
-
- dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32.
-
- multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False.
- """
-
- def __init__(
- self, bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False
- ) -> None:
- self.data_file = open(bin_path, "wb")
- self.dtype = dtype
- self.multimodal = multimodal
-
- self.sequence_lengths = []
- self.document_indices = [0]
- self.sequence_modes = [] if self.multimodal else None
-
- def add_item(self, tensor: torch.Tensor, mode: int = 0) -> None:
- """Add a single item to the dataset
-
- Args:
- tensor (torch.Tensor): The item to add to the data file
-
- mode (int, optional): The mode for the item. Defaults to 0.
- """
- np_array = numpy.array(tensor.numpy(), dtype=self.dtype)
- self.data_file.write(np_array.tobytes(order="C"))
- self.sequence_lengths.append(np_array.size)
- if self.multimodal:
- self.sequence_modes.append(mode)
-
- def add_document(
- self, tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None
- ) -> None:
- """Add an entire document to the dataset
-
- Args:
- tensor (torch.Tensor): The document to add
- lengths (List[int]): The lengths of each item in the document
- modes (Optional[List[int]], optional): The modes for each item in the document.
- Defaults to None.
- """
- np_array = numpy.array(tensor, dtype=self.dtype)
- self.data_file.write(np_array.tobytes(order="C"))
- self.sequence_lengths.extend(lengths)
- self.document_indices.append(len(self.sequence_lengths))
- if self.multimodal:
- self.sequence_modes.extend(modes if modes is not None else [0] * lengths)
-
- def end_document(self) -> None:
- """Finalize the document, for use with MMapIndexedDatasetBuilder.add_item
- """
- self.document_indices.append(len(self.sequence_lengths))
-
- def add_index(self, path_prefix: str) -> None:
- """Add an entire MMapIndexedDataset to the dataset
-
- Args:
- path_prefix (str): The index (.idx) and data (.bin) prefix
- """
- # Concatenate index
- index = _IndexReader(get_idx_path(path_prefix), multimodal=self.multimodal)
- assert index.dtype == self.dtype
-
- offset = len(self.sequence_lengths)
- self.sequence_lengths.extend(index.sequence_lengths)
- self.document_indices.extend((offset + index.document_indices)[1:])
-
- if self.multimodal:
- self.sequence_modes.extend(index.sequence_modes)
-
- # Concatenate data
- with open(get_bin_path(path_prefix), "rb") as f:
- shutil.copyfileobj(f, self.data_file)
-
- def finalize(self, idx_path: str) -> None:
- """Clean up and write the index (.idx) file
-
- Args:
- idx_path (str): The path to the index file
- """
- self.data_file.close()
- with _IndexWriter(idx_path, self.dtype) as writer:
- writer.write(self.sequence_lengths, self.sequence_modes, self.document_indices)
-
-
-def get_idx_path(path_prefix: str) -> str:
- """Get the path to the index file from the prefix
-
- Args:
- path_prefix (str): The prefix
-
- Returns:
- str: The path to the index file
- """
- return path_prefix + ".idx"
-
-
-def get_bin_path(path_prefix: str) -> str:
- """Get the path to the data file from the prefix
-
- Args:
- path_prefix (str): The prefix
-
- Returns:
- str: The path to the data file
- """
- return path_prefix + ".bin"
diff --git a/megatron/core/datasets/megatron_dataset.py b/megatron/core/datasets/megatron_dataset.py
deleted file mode 100644
index d75a6455099916128b67fefe1062f028c5071add..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/megatron_dataset.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import hashlib
-import json
-from abc import ABC, abstractmethod, abstractstaticmethod
-from collections import OrderedDict
-from typing import Dict, List
-
-import numpy
-import torch
-
-from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
-from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
-from megatron.core.datasets.utils import Split
-
-
-class MegatronDataset(ABC, torch.utils.data.Dataset):
- """The wrapper class from which dataset classes should inherit e.g. GPTDataset
-
- Args:
- indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
- MegatronDataset
-
- indexed_indices (numpy.ndarray): The set of the documents indices to expose
-
- num_samples (int): The number of samples to draw from the indexed dataset
-
- index_split (Split): The indexed_indices Split
-
- config (BlendedMegatronDatasetConfig): The container for all config sourced parameters
- """
-
- def __init__(
- self,
- indexed_dataset: MMapIndexedDataset,
- indexed_indices: numpy.ndarray,
- num_samples: int,
- index_split: Split,
- config: BlendedMegatronDatasetConfig,
- ) -> None:
- assert indexed_indices.size > 0
- assert num_samples > 0
- assert self.is_multimodal() == indexed_dataset.multimodal
- assert self.is_split_by_sequence() != self.is_split_by_document()
-
- self.indexed_dataset = indexed_dataset
- self.indexed_indices = indexed_indices
- self.num_samples = num_samples
- self.index_split = index_split
- self.config = config
-
- self.unique_identifiers = OrderedDict()
- self.unique_identifiers["class"] = type(self).__name__
- self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix
- self.unique_identifiers["num_samples"] = self.num_samples
- self.unique_identifiers["index_split"] = self.index_split.name
- for attr in self._key_config_attributes():
- self.unique_identifiers[attr] = getattr(self.config, attr)
-
- self.unique_description = json.dumps(self.unique_identifiers, indent=4)
- self.unique_description_hash = hashlib.md5(
- self.unique_description.encode("utf-8")
- ).hexdigest()
-
- self._finalize()
-
- @abstractmethod
- def _finalize(self) -> None:
- """Build the dataset and assert any subclass-specific conditions
- """
- pass
-
- @abstractmethod
- def __len__(self) -> int:
- """Return the length of the dataset
-
- Returns:
- int: See abstract implementation
- """
- pass
-
- @abstractmethod
- def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
- """Return from the dataset
-
- Args:
- idx (int): The index into the dataset
-
- Returns:
- Dict[str, numpy.ndarray]: See abstract implementation
- """
- pass
-
- @abstractstaticmethod
- def is_multimodal() -> bool:
- """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal
-
- Returns:
- bool: See abstract implementation
- """
- pass
-
- @abstractstaticmethod
- def is_split_by_sequence() -> bool:
- """Return whether the dataset is split by sequence
-
- For example, the GPT train/valid/test split is document agnostic
-
- Returns:
- bool: See abstract implementation
- """
- pass
-
- @classmethod
- def is_split_by_document(cls) -> bool:
- """Return whether the dataset is split by document
-
- For example, the BERT train/valid/test split is document aware
-
- Returns:
- bool: The negation of cls.is_split_by_sequence
- """
- return not cls.is_split_by_sequence()
-
- @staticmethod
- def _key_config_attributes() -> List[str]:
- """Return all config attributes which contribute to uniquely identifying the dataset.
-
- These attributes will be used to build a uniquely identifying string and MD5 hash which
- will be used to cache/load the dataset from run to run.
-
- Returns:
- List[str]: The key config attributes
- """
- return ["split", "random_seed", "sequence_length"]
diff --git a/megatron/core/datasets/readme.md b/megatron/core/datasets/readme.md
deleted file mode 100644
index 77d1e5862f54a9c224d1c4f655883e1b877616f5..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/readme.md
+++ /dev/null
@@ -1,193 +0,0 @@
-# Data Pipeline
-
-## Data pre-processing
-
-Data preprocessing is built around the following classes:
-
-1. `MMapIndexedDatasetBuilder`
-2. `MMapIndexedDataset`
-
-At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.
-
-#### MMapIndexedDatasetBuilder
-
-The `MMapIndexedDatasetBuilder` is capable of building and merging `MMapIndexedDataset` instances.
-
-#### MMapIndexedDataset
-
-The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `MMapIndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata.
-
-The index file stores dataset-level metadata first:
-- The index header, for backward compatibility
-- The index version, for backward compatibility
-- A numeric code corresponding to the data type used to write data to the data file
-- The number of sequences in the dataset
-- The number of documents in the dataset
-
-The index file stores document-level and sequence-level metadata second:
-- In order, the number of elements per sequence
-- In order, the byte offset (pointer) per sequence
-- In order, the consecutive sequence index range `[...)` per document
-- In order, the mode per sequence (in the multimodal case)
-
-## Data loading: construction
-
-Building the data loaders is a distributed-aware process built around the following classes:
-
-1. `BlendedMegatronDatasetConfig`
-2. `BlendedMegatronDatasetBuilder`
-3. `MMapIndexedDataset`
-3. `MegatronDataset`
-4. `BlendedDataset`
-
-See the class docstrings for more details.
-
-#### BlendedMegatronDatasetConfig (extendable)
-
-The `BlendedMegatronDatasetConfig` class parameterizes the `BlendedMegatronDatasetBuilder` and in turn the `MegatronDataset` and `BlendedDataset`.
-
-Different training/inference regimes will require different extensions e.g. the `GPTDatasetConfig`
-
-#### BlendedMegatronDatasetBuilder
-
-The `BlendedMegatronDatasetBuilder` class builds the highest-level data interfaces in Megatron Core.
-
-**NB:** All ranks should attempt to build the dataset via the `BlendedMegatronDatasetBuilder` or the program will hang. Which ranks follow through on their attempts can be controlled via the `BlendedMegatronDatasetConfig`.
-
-#### MMapIndexedDataset
-
-The `MMapIndexedDataset` class is the lowest-level data interface in Megatron Core.
-
-The `MMapIndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces.
-
-
-#### MegatronDataset (extendable)
-
-The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MMapIndexedDataset`.
-
-Different training/inference regimes will require different extensions e.g. the `GPTDataset`
-
-#### BlendedDataset
-
-The `BlendedDataset` class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MegatronDataset`.
-
-The `BlendedDataset` is only necessary when a blend multiple data distributions, i.e. multiple `MegatronDataset` instances, should contribute to a certain dataset split. The blend can be controlled via the `BlendedMegatronDatasetConfig`.
-
-## Data loading: implementation
-
-### GPTDataset
-
-The `GPTDataset` is parameterized by the following variables: the underlying `MMapIndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`.
-
-The `GPTDataset` creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.
-
-1. The document index _Do_idx_ is a 1-D array mapping from _i_ to document index of length `E * |indexed_indices|` where `E` corresponds to the minimum number of epochs such that `E * |indexed_indices| >= N`. The document index is shuffled according to `R`.
-
- ```
- Given:
-
- N = 15
- indexed_indices = [5, 6, 7, 8, 9]
- E = 3
-
- Then, for example:
-
- Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
- ```
-
-2. The sample index _Sa_idx_ is a 2-D array mapping from _j_ to pairs of (_i_, _Do_idx_[ _i_ ] offset) of shape `[N + 1, 2]`. The rows _j_ and _j_ + 1 serve as the left and right bounds for the _j_-th sample.
-
- ```
- Given:
-
- S = 1024
-
- Then, for example:
-
- Sa_idx[0] = (0, 0)
- Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S
- Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536
- Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536
- Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2]
- Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300
- ```
-
-3. The shuffle index _Sh_idx_ is a 1-D array mapping from _k_ to _j_ of length `N`. The shuffle index is shuffled according to `R`.
-
- ```
- Given
-
- N = 10
-
- Then, for example:
-
- Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
- ```
-
-To query the `GPTDataset` for the _k_-th sample we do the following
-
-- Use the shuffle index to get the index _j_ into the sample index.
-
- ```
- j = Sh_idx[k]
- ```
-- Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.
-
- ```
- i, offset = Sa_idx[j]
- i_next, offset_next = Sa_idx[j + 1]
- ```
-- Use the document index to retrieve `S` tokens from consecutive (in the document index) documents.
-
- ```
- sample = []
- sample += indexed_dataset[Do_idx[i]][offset:]
- if i != i_next:
- sample += indexed_dataset[Do_idx[i + 1:i_next]]
- sample += indexed_dataset[Do_idx[i_next]][:offset_next]
- ```
-
-To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `MegatronDataset.__init__` function.
-
-### BlendedDataset
-
-The `BlendedDataset` is parameterized by the following variables: the underlying `MegatronDataset` instances `D`, the weights `W` (one per dataset), and the size `S`. The `BlendedDataset` will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.
-
-The `BlendedDataset` creates two "blending" indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.
-
-1. The dataset index _Da_idx_ is a 1-D array mapping from _i_ to dataset index of length `S`.
-
- ```
- Given
-
- D = [d0, d1, d2]
- W = [1/2, 1/4, 1/4]
- S = 4
-
- Then, for example:
-
- Da_idx = [0, 1, 2, 0]
-
- ```
-
-2. The dataset sample index _Sa_idx_ is a 1-D mapping from _i_ to the sample index for dataset _Da_idx[i]_ of length `S`.
-
- ```
- Given
-
- Da_idx = [0, 1, 2, 0]
-
- Then, for example:
-
- Sa_idx = [0, 0, 0, 1]
- ```
-
-To query the `BlendedDataset` for the _k_-th sample we do the following
-
-- Use the dataset index to retrieve the corresponding dataset from `D` and the dataset sample index to retrieve the corresponding sample from that dataset.
-
- ```
- sample = D[Da_idx[k]][Sa_idx[k]]
- ```
-
-To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function.
diff --git a/megatron/core/datasets/utils.py b/megatron/core/datasets/utils.py
deleted file mode 100644
index 8a3279b5f44bf0f31f5850a836f2412771356d07..0000000000000000000000000000000000000000
--- a/megatron/core/datasets/utils.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-from enum import Enum
-from typing import List
-
-import numpy
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-class Split(Enum):
- train = 0
- valid = 1
- test = 2
-
-
-def compile_helpers():
- """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
- """
- import os
- import subprocess
-
- command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
- if subprocess.run(command).returncode != 0:
- import sys
-
- log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
- sys.exit(1)
-
-
-def log_single_rank(logger: logging.Logger, *args, rank=0, **kwargs):
- """If torch distributed is initialized, log only on rank
-
- Args:
- logger (logging.Logger): The logger to write the logs
-
- rank (int, optional): The rank to write on. Defaults to 0.
- """
- if torch.distributed.is_initialized():
- if torch.distributed.get_rank() == rank:
- logger.log(*args, **kwargs)
- else:
- logger.log(*args, **kwargs)
-
-
-def normalize(weights: List[float]) -> List[float]:
- """Do non-exponentiated normalization
-
- Args:
- weights (List[float]): The weights
-
- Returns:
- List[float]: The normalized weights
- """
- w = numpy.array(weights, dtype=numpy.float64)
- w_sum = numpy.sum(w)
- w = (w / w_sum).tolist()
- return w
diff --git a/megatron/core/dist_checkpointing/__init__.py b/megatron/core/dist_checkpointing/__init__.py
deleted file mode 100644
index df08d7eaba8e55e30a2ade39bb8b01ca580854b5..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-from .core import check_is_distributed_checkpoint
-from .mapping import LocalNonpersitentObject, ShardedTensor
-from .serialization import (
- load,
- load_common_state_dict,
- load_plain_tensors,
- load_tensors_metadata,
- save,
-)
diff --git a/megatron/core/dist_checkpointing/core.py b/megatron/core/dist_checkpointing/core.py
deleted file mode 100644
index f20a0c3a2de730a747b3bf5696b4bdf8b60550ce..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/core.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-import json
-from dataclasses import asdict, dataclass
-from pathlib import Path
-from typing import Optional
-
-CONFIG_FNAME = 'metadata.json'
-
-
-class CheckpointingException(Exception):
- pass
-
-
-@dataclass
-class CheckpointingConfig:
- """ Documents backends used in the checkpoint. """
-
- sharded_backend: str
- sharded_backend_version: int = 1
- common_backend: str = 'torch'
- common_backend_version: int = 1
-
-
-def check_is_distributed_checkpoint(checkpoint_dir):
- return maybe_load_config(checkpoint_dir) is not None
-
-
-def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
- config_path = Path(checkpoint_dir, CONFIG_FNAME)
- if not config_path.exists():
- return None
- with config_path.open() as f:
- config_dict = json.load(f)
- return CheckpointingConfig(**config_dict)
-
-
-def save_config(config: CheckpointingConfig, checkpoint_dir: str):
- config_path = Path(checkpoint_dir, CONFIG_FNAME)
- with config_path.open('w') as f:
- json.dump(asdict(config), f)
diff --git a/megatron/core/dist_checkpointing/dict_utils.py b/megatron/core/dist_checkpointing/dict_utils.py
deleted file mode 100644
index 493a61c91a89a0ee30a7869a8f3a84bef0e1eff4..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/dict_utils.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Utilities for operating with dicts and lists. """
-
-from collections import defaultdict
-from typing import Any, Callable, Iterable, Optional, Tuple, Union
-
-import torch
-
-
-def extract_matching_values(
- x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
-) -> Tuple[Union[dict, list], Union[dict, list]]:
- """ Return matching and nonmatching values. Keeps hierarchy.
-
- Arguments:
- x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
- predicate (object -> bool): determines matching values
- return_lists_as_dicts (bool): if True, matching lists will be turned
- into dicts, with keys indicating the indices of original elements.
- Useful for reconstructing the original hierarchy.
- """
-
- def _set_elem(target, k, v):
- if return_lists_as_dicts:
- target[k] = v
- else:
- target.append(v)
-
- if isinstance(x, dict):
- matching_vals = {}
- nonmatching_vals = {}
- for k, v in x.items():
- if isinstance(v, (list, dict)):
- match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
- if match:
- matching_vals[k] = match
- if nonmatch or not v:
- nonmatching_vals[k] = nonmatch
- elif predicate(v):
- matching_vals[k] = v
- else:
- nonmatching_vals[k] = v
- elif isinstance(x, list):
- matching_vals = {} if return_lists_as_dicts else []
- nonmatching_vals = {} if return_lists_as_dicts else []
- for ind, v in enumerate(x):
- if isinstance(v, (list, dict)) and v:
- match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
- if match:
- _set_elem(matching_vals, ind, match)
- if nonmatch or not v:
- _set_elem(nonmatching_vals, ind, nonmatch)
- else:
- target = matching_vals if predicate(v) else nonmatching_vals
- _set_elem(target, ind, v)
- else:
- raise ValueError(f'Unexpected top-level object type: {type(x)}')
- return matching_vals, nonmatching_vals
-
-
-def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
- mismatch = []
- if isinstance(x1, dict) and isinstance(x2, dict):
- only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
- only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
- for k in x2.keys() & x1.keys():
- _left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
- only_left.extend(_left)
- only_right.extend(_right)
- mismatch.extend(_mismatch)
- elif isinstance(x1, list) and isinstance(x2, list):
- only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
- only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
- for i, (v1, v2) in enumerate(zip(x1, x2)):
- _left, _right, _mismatch = diff(v1, v2, prefix + (i,))
- only_left.extend(_left)
- only_right.extend(_right)
- mismatch.extend(_mismatch)
- else:
- only_left = []
- only_right = []
- if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
- _is_mismatch = not torch.all(x1 == x2)
- else:
- try:
- _is_mismatch = bool(x1 != x2)
- except RuntimeError:
- _is_mismatch = True
-
- if _is_mismatch:
- mismatch.append((prefix, type(x1), type(x2)))
-
- return only_left, only_right, mismatch
-
-
-def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4):
- print_indent = lambda: print(' ' * indent * len(prefix), end='')
- for k, v in d.items():
- if isinstance(v, dict):
- print_indent()
- print(f'> {k}:')
- inspect_keys_types(v, prefix + (k,), indent)
- else:
- print_indent()
- if isinstance(v, torch.Tensor):
- print(f'> {k}: {type(v)} of shape {v.shape}')
- else:
- print(f'> {k}: {type(v)}')
-
-
-def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
- print_indent = lambda: print(' ' * indent * len(prefix), end='')
- if isinstance(x, dict):
- print()
- for k, v in x.items():
- print_indent()
- print(f'> {k}: ', end='')
- inspect_types(v, prefix + (k,), indent)
- elif isinstance(x, list):
- print()
- for i, v in enumerate(x):
- print_indent()
- print(f'- {i}: ', end='')
- inspect_types(v, prefix + (i,), indent)
- else:
- if isinstance(x, torch.Tensor):
- print(f'Tensor of shape {x.shape}')
- else:
- try:
- x_str = str(x)
- except:
- x_str = ''
- if len(x_str) > 30:
- x_str = x_str[:30] + '... (truncated)'
- print(f'[{type(x)}]: {x_str}')
-
-
-def nested_values(x: Union[dict, list]):
- x_iter = x.values() if isinstance(x, dict) else x
- for v in x_iter:
- if isinstance(v, (dict, list)):
- yield from nested_values(v)
- else:
- yield v
-
-
-def nested_items_iter(x: Union[dict, list]):
- x_iter = x.items() if isinstance(x, dict) else enumerate(x)
- for k, v in x_iter:
- if isinstance(v, (dict, list)):
- yield from nested_items_iter(v)
- else:
- yield x, k, v
-
-
-def dict_map(f: Callable, d: dict):
- for sub_d, k, v in nested_items_iter(d):
- sub_d[k] = f(v)
-
-
-def dict_map_with_key(f: Callable, d: dict):
- for sub_d, k, v in nested_items_iter(d):
- sub_d[k] = f(k, v)
-
-
-def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
- if isinstance(x, dict):
- for k, v in x.items():
- x[k] = dict_list_map_inplace(f, v)
- elif isinstance(x, list):
- x[:] = (dict_list_map_inplace(f, v) for v in x)
- else:
- return f(x)
- return x
-
-
-def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
- if isinstance(x, dict):
- return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
- elif isinstance(x, list):
- return [dict_list_map_outplace(f, v) for v in x]
- else:
- return f(x)
-
-
-def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()):
- if isinstance(x1, dict) and isinstance(x2, dict):
- for k, v2 in x2.items():
- if k not in x1:
- x1[k] = v2
- else:
- x1[k] = merge(x1[k], v2, key=key + (k,))
- elif isinstance(x1, list) and isinstance(x2, list):
- if len(x1) != len(x2):
- raise ValueError(
- f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})'
- )
- for i, v2 in enumerate(x2):
- x1[i] = merge(x1[i], v2, key=key + (i,))
- else:
- raise ValueError(
- f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})'
- )
- return x1
-
-
-def map_reduce(
- xs: Iterable,
- key_fn: Callable = lambda x: x,
- value_fn: Callable = lambda x: x,
- reduce_fn: Callable = lambda x: x,
-) -> dict:
- res = defaultdict(list)
- for x in xs:
- res[key_fn(x)].append(value_fn(x))
- for k in res:
- res[k] = reduce_fn(res[k])
- return dict(res)
diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py
deleted file mode 100644
index 2b4d5677d37c01f7196d887f8e7a767d7e0c581b..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/mapping.py
+++ /dev/null
@@ -1,308 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Core library classes. """
-import logging
-from dataclasses import dataclass, replace
-from itertools import chain
-from typing import Any, Callable, Dict, Optional, Tuple, Union
-
-import numpy as np
-import torch
-
-from .core import CheckpointingException
-from .dict_utils import dict_list_map_inplace, dict_list_map_outplace
-
-logger = logging.getLogger(__name__)
-
-# These type definitions are just hints to differentiate a plain model state
-# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
-# (ShardedStateDict).
-StateDict = Dict[str, Any]
-ShardedStateDict = Dict[str, Any]
-ReplicaId = Union[int, Tuple[int, ...]]
-
-
-@dataclass
-class ShardedTensor:
- """Represents a mapping between a local tensor and a global tensor.
-
- Global tensor is assumed to consist of many local tensors distributed
- between different processes.
-
- Attributes:
- key: unique identifier of a global tensor
- data: local tensor data. Can be None only for consistency validation
- dtype: tensor dtype
- local_shape: local tensor shape
- global_shape: global tensor shape
- global_offset: offset of a local tensor in a global tensor, specified
- in number of tensor elements
- axis_fragmentations: global tensor fragmentation of each axis
- replica_id: indicates given local tensor's replication wrt. local
- tensors in different processes
- prepend_axis_num: number of axes prepended to the local tensor
- to reflect global tensor shape.
- The behavior is similar to unsqueezing the local tensor.
- allow_shape_mismatch: if True, during loading, the global shape of a
- stored tensor does not have to match the expected global shape.
- Useful for representing tensors with flexible shape, e.g. padded.
- flattened_range: specifies a slice that should be applied to a flattened
- tensor with `local_shape` in order to get the tensor stored as `data`
- """
-
- key: str
- data: Optional[torch.Tensor]
- dtype: torch.dtype
- local_shape: Tuple[int, ...]
- global_shape: Tuple[int, ...]
- global_offset: Tuple[int, ...]
- axis_fragmentations: Optional[Tuple[int, ...]]
- replica_id: ReplicaId = 0
- prepend_axis_num: int = 0
- allow_shape_mismatch: bool = False
- flattened_range: Optional[slice] = None
-
- def global_slice(self) -> Tuple[Union[int, slice], ...]:
- assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
- return tuple(
- chain(
- (off for off in self.global_offset[: self.prepend_axis_num]),
- (
- slice(off, off + sh)
- for off, sh in zip(
- self.global_offset[self.prepend_axis_num :], self.local_shape
- )
- ),
- )
- )
-
- def global_coordinates(self) -> Tuple[np.ndarray, ...]:
- if self.flattened_range is None:
- raise CheckpointingException(
- f'`global_coordinates` is undefined for'
- f' {self.__class__.__name__} without `flattened_range`'
- )
-
- local_coords = self.local_coordinates()
- assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
- len(local_coords),
- self,
- )
- global_coords = tuple(
- c + off
- for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
- )
- return global_coords
-
- def local_coordinates(self) -> Tuple[np.ndarray, ...]:
- if self.flattened_range is None:
- raise CheckpointingException(
- f'`local_coordinates` is undefined for'
- f' {self.__class__.__name__} without `flattened_range`'
- )
-
- # TODO: np.unravel_index?
- mask = np.zeros(np.product(self.local_shape), dtype=bool)
- mask[self.flattened_range] = True
- return np.nonzero(mask.reshape(self.local_shape))
-
- def max_allowed_chunks(self) -> Tuple[int, ...]:
- chunks = []
- for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
- if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
- raise CheckpointingException(
- f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}'
- )
- axis_chunk_size = axis_sh // axis_fragm
- chunks.append(axis_chunk_size)
- return tuple(chunks)
-
- def without_data(self):
- return replace(self, data=None)
-
- @classmethod
- def from_rank_offsets(
- cls,
- key: str,
- data: torch.Tensor,
- *rank_offsets: Tuple[int, int, int],
- replica_id: ReplicaId = 0,
- prepend_axis_num: int = 0,
- allow_shape_mismatch: bool = False,
- ):
- """Allows to construct the ShardedTensor given offset specified in process ranks.
- Arguments:
- key: unique key
- data: local tensor data
- rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm)
- says that if global tensor is divided into `axis_fragm`
- fragment along `axis` axis, then local tensor data
- corresponds to the `axis_rank_offset` chunk.
- replica_id: see ShardedTensor
- prepend_axis_num: see ShardedTensor
- allow_shape_mismatch: see ShardedTensor
- """
- global_offset = [0] * (data.ndim + prepend_axis_num)
- global_shape = ([1] * prepend_axis_num) + list(data.shape)
- axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
- _seen_axis = set()
- for axis, axis_rank_offset, axis_fragm in rank_offsets:
- assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, (
- axis,
- axis_rank_offset,
- axis_fragm,
- )
- assert (
- axis_rank_offset < axis_fragm
- ), 'Rank offset must be lower than axis fragmentation'
- if axis in _seen_axis:
- raise CheckpointingException('Duplicated axis specified')
- _seen_axis.add(axis)
-
- local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
- global_shape[axis] = axis_fragm * local_axis_shape
- global_offset[axis] = axis_rank_offset * local_axis_shape
- axis_fragmentations[axis] = axis_fragm
-
- return cls(
- key,
- data,
- data.dtype,
- tuple(data.shape),
- tuple(global_shape),
- tuple(global_offset),
- tuple(axis_fragmentations),
- replica_id,
- prepend_axis_num,
- allow_shape_mismatch,
- )
-
- def __str__(self):
- return f'{self.__class__.__name__}(key=\'{self.key}\')'
-
-
-def is_main_replica(replica_id):
- if isinstance(replica_id, int):
- return replica_id == 0
- return all(r == 0 for r in replica_id)
-
-
-class LocalNonpersitentObject:
- """Object that should not be stored in a checkpoint, but restored locally.
-
- Wrapping any object inside the state dict with LocalNonpersitentObject
- will result in:
- - during saving, this object will *not* be stored in the checkpoint
- - during loading, a local version of this object will be placed in a state dict
- """
-
- def __init__(self, obj):
- self.obj = obj
-
- def unwrap(self):
- return self.obj
-
-
-@dataclass
-class ShardedObject:
- """Represents a mapping between a local object and a global object.
-
- Global object is assumed to consist of many local objects distributed
- between different processes.
-
- NOTE: Contrary to ShardedTensor, it's impossible to change global object
- sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
- with atomic arbitrary typed elements.
-
- Attributes:
- key: unique identifier of a global tensor
- data: local object data. Can be None only for consistency validation
- global_shape: global object shape
- global_offset: offset of a local object in a global object, specified
- in number of shards
- replica_id: indicates local object replication wrt. local
- objects in different processes
- """
-
- key: str
- data: object
- global_shape: Tuple[int, ...]
- global_offset: Tuple[int, ...]
- replica_id: ReplicaId = 0
-
- def without_data(self):
- return replace(self, data=None)
-
- @property
- def unique_key(self):
- return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}'
-
- def __str__(self):
- return f'{self.__class__.__name__}(key=\'{self.key}\')'
-
-
-@dataclass
-class ShardedTensorFactory:
- """ Allows to apply transformations to tensors before/after serialization.
-
- The essence of those transformations is that they can be applied to
- optimizer states the same way they are applied to the model params.
-
- Builder creates a sub-state-dict out of a tensor before saving, and merger
- merges the corresponding state dict after loading.
- """
-
- key: str
- data: torch.Tensor
- build_fn: Callable[[str, torch.Tensor], ShardedStateDict]
- merge_fn: Callable[[StateDict], torch.Tensor]
-
- def build(self):
- return self.build_fn(self.key, self.data)
-
-
-def apply_factories(sharded_state_dict: ShardedStateDict):
- def apply(x):
- if isinstance(x, ShardedTensorFactory):
- x = x.build()
- return x
-
- dict_list_map_inplace(apply, sharded_state_dict)
-
-
-def apply_factory_merges(x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()):
- if isinstance(x2, ShardedTensorFactory):
- return x2.merge_fn(x1)
-
- # There rest is almost the same as the `merge` function from `dict_utils`
- if isinstance(x1, dict) and isinstance(x2, dict):
- for k, v2 in x2.items():
- if k not in x1:
- raise ValueError(
- f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})'
- )
- else:
- x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
- elif isinstance(x1, list) and isinstance(x2, list):
- if len(x1) != len(x2):
- err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})'
- logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
- raise ValueError(err_msg)
- for i, v2 in enumerate(x2):
- x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
- elif isinstance(x1, list) and isinstance(x2, dict):
- for k, v2 in x2.items():
- if not isinstance(k, int):
- raise ValueError(
- f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}'
- )
- if k >= len(x1):
- raise ValueError(
- f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})'
- )
- x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
- else:
- raise ValueError(
- f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
- )
- return x1
diff --git a/megatron/core/dist_checkpointing/optimizer.py b/megatron/core/dist_checkpointing/optimizer.py
deleted file mode 100644
index d1c698787c4678009f09b5496fa4c5ddc17574d8..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/optimizer.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Optimizer related helpers. """
-
-import logging
-from copy import deepcopy
-from dataclasses import replace
-from itertools import chain
-from typing import Dict, Iterable, List, Tuple, Union
-
-logger = logging.getLogger(__name__)
-
-import torch
-
-from .dict_utils import nested_values
-from .mapping import (
- LocalNonpersitentObject,
- ShardedStateDict,
- ShardedTensor,
- ShardedTensorFactory,
- StateDict,
-)
-from .utils import extract_sharded_tensors, extract_sharded_tensors_and_factories
-
-
-def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
- param_mappings = {}
- for i, param in enumerate(optim_params_iter):
- if id(param) not in param_mappings:
- param_mappings[id(param)] = i
- return param_mappings
-
-
-def get_param_id_to_sharded_param_map(
- model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
-) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
- model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
- id_to_sharded_param_map = {}
- param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
- for ten in nested_values(model_sharded_state_dict):
- if id(ten.data) in param_to_id_map:
- id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
- else:
- logger.debug(f'{ten} is not tracked by the optimizer')
-
- if not id_to_sharded_param_map:
- logger.warning(
- "Sharded parameters mapping is empty. It means tensors in model state dict"
- " do not correspond to tensors in optimizer parameters map."
- " Make sure to call state_dict with `keep_vars=True`."
- )
- return id_to_sharded_param_map
-
-
-def make_sharded_optimizer_tensor(
- model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
-) -> Union[ShardedTensor, ShardedTensorFactory]:
- if isinstance(model_param, ShardedTensorFactory):
- return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
-
- assert (
- tuple(optim_param.shape) == model_param.local_shape
- ), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
- return replace(
- model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
- )
-
-
-def optim_state_to_sharding_state(
- optim_state_dict: StateDict,
- id_to_sharded_param_map: Dict[int, ShardedTensor],
- exclude_keys: Tuple[str] = (),
-):
- sharded_state = {}
- for param_id, param_state in optim_state_dict['state'].items():
- sharded_state[param_id] = {}
- for state_key, param in param_state.items():
- if state_key in exclude_keys:
- continue
- if param_id in id_to_sharded_param_map:
- sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
- id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
- )
- else:
- raise ValueError(f'Param id {param_id} does not match any model sharded param')
-
- optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
- for group in optim_state_dict['param_groups']:
- group['params'] = LocalNonpersitentObject(group['params'])
- optim_state_dict['state'] = sharded_state
diff --git a/megatron/core/dist_checkpointing/serialization.py b/megatron/core/dist_checkpointing/serialization.py
deleted file mode 100644
index 85baa16c21eb784369fd87adb8239cd0ac68d05f..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/serialization.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-import os
-from collections import Counter, defaultdict
-from itertools import chain
-from pathlib import Path
-from typing import Iterable, List, Tuple, Union
-
-import numpy as np
-import torch
-
-from .core import CheckpointingConfig, maybe_load_config, save_config
-from .dict_utils import (
- dict_list_map_inplace,
- diff,
- extract_matching_values,
- map_reduce,
- merge,
- nested_values,
-)
-from .mapping import (
- CheckpointingException,
- ShardedObject,
- ShardedStateDict,
- ShardedTensor,
- ShardedTensorFactory,
- StateDict,
- apply_factories,
- apply_factory_merges,
- is_main_replica,
-)
-from .strategies.base import (
- LoadCommonStrategy,
- LoadShardedStrategy,
- SaveCommonStrategy,
- SaveShardedStrategy,
- StrategyAction,
- get_default_strategy,
-)
-from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent
-
-COMMON_STATE_FNAME = 'common.pt'
-
-logger = logging.getLogger(__name__)
-
-
-def load(
- sharded_state_dict: ShardedStateDict,
- checkpoint_dir: str,
- sharded_strategy: Union[LoadShardedStrategy, None] = None,
- common_strategy: Union[LoadCommonStrategy, None] = None,
- validate_access_integrity: bool = True,
-) -> StateDict:
- """Loading entrypoint.
-
- Arguments:
- sharded_state_dict (ShardedStateDict): state dict of the existing model
- populated with ShardedTensors. Used as a mapping to determine which
- parts of global tensors stored in the checkpoint should be loaded.
- checkpoint_dir (str): directory with the checkpoint
- sharded_strategy (LoadShardedStrategy, optional): configures loading behavior for sharded tensors
- common_strategy (LoadCommonStrategy, optional): configures loading behavior for common data
- validate_access_integrity (bool default = True): checks if each tensor shard is accessed
- exactly once (as main replica) by some process
- """
- if common_strategy is not None:
- raise NotImplementedError('The only supported common strategy is torch')
-
- checkpoint_dir = Path(checkpoint_dir)
- common_state_dict = load_common_state_dict(checkpoint_dir)
- if not sharded_state_dict:
- return common_state_dict
-
- sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir)
- merge(common_state_dict, sharded_objects)
-
- saved_config = maybe_load_config(checkpoint_dir)
- if saved_config is None:
- raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
-
- sh_ten_factories, _ = extract_matching_values(
- sharded_state_dict,
- lambda x: isinstance(x, ShardedTensorFactory),
- return_lists_as_dicts=True,
- )
- apply_factories(sharded_state_dict)
- sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
- sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict)
- dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
- merge(common_state_dict, nonpersistent_state_dict)
-
- if validate_access_integrity:
- validate_sharding_integrity(nested_values(sharded_state_dict))
-
- if sharded_strategy is None:
- sharded_strategy = get_default_strategy(
- StrategyAction.LOAD_SHARDED,
- saved_config.sharded_backend,
- saved_config.sharded_backend_version,
- )
- else:
- # TODO: implement consistency checks here
- pass
- loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
-
- loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)
-
- merge(common_state_dict, loaded_state_dict)
- return common_state_dict
-
-
-# TODO: implement it as common torch strategy
-def load_common_state_dict(checkpoint_dir: Path):
- return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu')
-
-
-def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
- sharded_objects, sharded_state_dict = extract_matching_values(
- sharded_state_dict, lambda v: isinstance(v, ShardedObject)
- )
-
- def load_sharded_object(sh_obj: ShardedObject):
- sh_obj.data = None
- load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
- loaded_obj = torch.load(load_path)
- return loaded_obj
-
- return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict
-
-
-def load_tensors_metadata(
- checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
-) -> ShardedStateDict:
- """Load tensors metadata from the checkpoint.
-
- Returns a dictionary similar to a sharded state dict, but note that
- the dictionary keys are simply ShardedTensor keys (contrary to the
- actual sharded state dicts where keys correspond to state dict keys).
-
- Dict values are ShardedTensors without any sharding (so, the only useful
- information is tensors global shape and dtype).
-
- Concrete implementation depends on the loading strategy. If no strategy is
- given, a default for a given backend is used.
- """
- saved_config = maybe_load_config(checkpoint_dir)
- if saved_config is None:
- raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
-
- if sharded_strategy is None:
- sharded_strategy = get_default_strategy(
- StrategyAction.LOAD_SHARDED,
- saved_config.sharded_backend,
- saved_config.sharded_backend_version,
- )
- else:
- # TODO: implement consistency checks here
- pass
- return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
-
-
-def load_plain_tensors(checkpoint_dir: str):
- """Load checkpoint tensors without any sharding.
-
- NOTE: common state dict is NOT included."""
- sharded_state_dict = load_tensors_metadata(checkpoint_dir)
- # Don't validate integrity because shards will be overlapped
- # if world_size > 1 (all processes load whole tensors)
- return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
-
-
-def save(
- sharded_state_dict: ShardedStateDict,
- checkpoint_dir: str,
- sharded_strategy: Union[SaveShardedStrategy, None] = None,
- common_strategy: Union[SaveCommonStrategy, None] = None,
- validate_access_integrity: bool = True,
-):
- """Saving entrypoint.
-
- Extracts ShardedTensors from the given state dict. Rank 0 saves the
- "regular" part of the checkpoint to common torch file.
- The ShardedTensors are saved according to a strategy specified by the
- config.
-
- Arguments:
- sharded_state_dict (ShardedStateDict): state dict of the populated with
- ShardedTensors. Used as a mapping to determine how local tensors
- should be saved as global tensors in the checkpoint.
- checkpoint_dir (str): directory to save the checkpoint to
- sharded_strategy (SaveShardedStrategy, optional): configures sharded tensors saving behavior and backend
- common_strategy (SaveCommonStrategy, optional): configures common data saving behavior and backend
- validate_access_integrity (bool default = True): checks if each tensor shard is accessed
- exactly once (as main replica) by some process
- """
- checkpoint_dir = Path(checkpoint_dir)
-
- if torch.distributed.get_rank() == 0:
- if not checkpoint_dir.exists():
- raise CheckpointingException(
- f'Checkpoint destination directory does not exist: {checkpoint_dir}'
- )
-
- if next(checkpoint_dir.iterdir(), None) is not None:
- raise CheckpointingException(
- f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
- )
-
- if common_strategy is not None:
- raise NotImplementedError('The only supported common strategy is torch')
-
- if sharded_strategy is None:
- sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1)
-
- apply_factories(sharded_state_dict)
- sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
- sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict)
- sharded_tensors = list(nested_values(sharded_state_dict))
- if validate_access_integrity:
- validate_sharding_integrity(sharded_tensors)
-
- _save_common_dict(state_dict, checkpoint_dir, True)
-
- sharded_strategy.save(sharded_tensors, checkpoint_dir)
- save_config(
- CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
- )
-
-
-# TODO: implement it as common torch strategy
-def _save_common_dict(
- state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
-):
- common_state_dict = _extract_and_save_sharded_objects(
- state_dict, checkpoint_dir, validate_consistency
- )
- if torch.distributed.get_rank() == 0:
- torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
- if validate_consistency:
- # TODO: implement checking consistency with rank 0 common dict on other ranks
- pass
- # torch.distributed.barrier()
- # if not torch.distributed.get_rank() == 0:
- # rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
- # print(diff(common_state_dict, rank_0_state_dict))
-
-
-def _extract_and_save_sharded_objects(
- state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
-):
- sharded_objects, state_dict = extract_matching_values(
- state_dict, lambda v: isinstance(v, ShardedObject)
- )
- sharded_objects = list(nested_values(sharded_objects))
- if validate_consistency:
- validate_objects_sharding_integrity(sharded_objects)
- for sh_obj in sharded_objects:
- if is_main_replica(sh_obj.replica_id):
- save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
- os.makedirs(save_path.parent, exist_ok=True)
- torch.save(sh_obj.data, save_path)
- return state_dict
-
-
-def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
- sharding = [ten.without_data() for ten in sharded_tensors]
- all_sharding = [None] * torch.distributed.get_world_size()
- torch.distributed.all_gather_object(all_sharding, sharding)
- if torch.distributed.get_rank() != 0:
- return
-
- key_shardings = defaultdict(list)
- for rank, rank_shardings in enumerate(all_sharding):
- for sharding in rank_shardings:
- key_shardings[sharding.key].append((rank, sharding))
- for key, shardings in key_shardings.items():
- _validate_sharding_for_key(shardings)
-
-
-def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
- some_rank_shard = rank_sharding[0][1]
- global_shape = some_rank_shard.global_shape
- local_shape = some_rank_shard.local_shape
- dtype = some_rank_shard.dtype
- has_flattened_range = some_rank_shard.flattened_range is not None
- for rank, sharding in rank_sharding:
- assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
- assert sharding.global_shape == global_shape, (
- sharding.global_shape,
- global_shape,
- some_rank_shard,
- )
- assert sharding.local_shape == local_shape, (
- sharding.local_shape,
- local_shape,
- some_rank_shard,
- )
- assert (sharding.flattened_range is not None) == has_flattened_range, (
- (sharding.flattened_range is not None),
- has_flattened_range,
- some_rank_shard,
- )
-
- shard_access_cnt = _compute_shards_access(rank_sharding)
- if has_flattened_range:
- map_reduce(
- rank_sharding,
- lambda x: x[1].global_offset,
- lambda x: x[1],
- _validate_sharding_for_key_flattened,
- )
- else:
- if not torch.all(shard_access_cnt == 1):
- logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
- raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
-
-
-def _compute_shards_access(rank_sharding):
- def chunk_offset(sharding):
- assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num
- return tuple(
- chain(
- (off for off in sharding.global_offset[: sharding.prepend_axis_num]),
- (
- off // sh
- for off, sh in zip(
- sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape
- )
- ),
- )
- )
-
- shard_access_cnt = torch.zeros(
- rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
- )
- for rank, sharding in rank_sharding:
- if is_main_replica(sharding.replica_id):
- shard_access_cnt[chunk_offset(sharding)] += 1
- # TODO: consider validating different replicas too
- return shard_access_cnt
-
-
-def _validate_sharding_for_key_flattened(tensors_by_shard):
- all_slices = []
- local_shape = tensors_by_shard[0].local_shape
- for sharding in tensors_by_shard:
- assert sharding.local_shape == local_shape
- sharding: ShardedTensor
- if not is_main_replica(sharding.replica_id):
- # TODO: this checks only saving (and loading replica_id=0) consistency
- continue
-
- all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
-
- starts, stops = map(np.asarray, zip(*sorted(all_slices)))
- if (
- starts[0] != 0
- or stops[-1] != np.product(local_shape)
- or not np.all(starts[1:] == stops[:-1])
- ):
- logger.error(
- f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
- )
- raise CheckpointingException(
- f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
- )
-
-
-def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]):
- """ Ensure uniqueness of saved objects. """
- local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects]
- all_sh_objs = [None] * torch.distributed.get_world_size()
- torch.distributed.all_gather_object(all_sh_objs, local_sh_objs)
- if torch.distributed.get_rank() != 0:
- return
- unique_keys = [
- sh_obj.unique_key
- for sh_obj in chain.from_iterable(all_sh_objs)
- if is_main_replica(sh_obj.replica_id)
- ]
- if len(unique_keys) != len(set(unique_keys)):
- duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
- logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
- raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py
deleted file mode 100644
index 7177d973cfccdbf52dbc1fd8efc153f99e389eec..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/strategies/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Various loading and saving strategies """
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-try:
- import tensorstore
- import zarr
-
- from .tensorstore import _import_trigger
- from .zarr import _import_trigger
-except ImportError:
- logger.warning('Zarr-based strategies will not be registered because of missing packages')
diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py
deleted file mode 100644
index 3989ea74a204349fd8623736c5a2ecaf05b2c06b..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/strategies/base.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from enum import Enum
-from pathlib import Path
-from typing import Dict, List, Optional
-
-from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict
-
-
-class StrategyAction(Enum):
- LOAD_COMMON = 'load_common'
- LOAD_SHARDED = 'load_sharded'
- SAVE_COMMON = 'save_common'
- SAVE_SHARDED = 'save_sharded'
-
-
-default_strategies = defaultdict(dict)
-
-
-def get_default_strategy(action: StrategyAction, backend: str, version: int):
- try:
- return default_strategies[action.value][(backend, version)]
- except KeyError as e:
- hint = ''
- if backend == 'zarr':
- try:
- import tensorstore
- import zarr
- except ImportError:
- hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
- raise CheckpointingException(
- f'Cannot find a default strategy for: {(action.value, backend, version)}.{hint}'
- ) from e
-
-
-class LoadStrategyBase(ABC):
- @abstractmethod
- def check_backend_compatibility(self, loaded_version):
- raise NotImplementedError
-
- @abstractmethod
- def check_version_compatibility(self, loaded_version):
- raise NotImplementedError
-
-
-class SaveStrategyBase(ABC):
- def __init__(self, backend: str, version: int):
- self.backend = backend
- self.version = version
-
-
-class LoadCommonStrategy(LoadStrategyBase):
- @abstractmethod
- def load(self, checkpoint_dir: Path):
- raise NotImplementedError
-
-
-class LoadShardedStrategy(LoadStrategyBase):
- @abstractmethod
- def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
- raise NotImplementedError
-
- @abstractmethod
- def load_tensors_metadata(self, checkpoint_dir: Path):
- """Load tensors metadata from the checkpoint.
-
- Returns a dictionary similar to a sharded state dict, but note that
- the dictionary keys are simply ShardedTensor keys (contrary to the
- actual sharded state dicts where keys correspond to state dict keys).
-
- Dict values are ShardedTensors without any sharding (so, the only useful
- information is tensors global shape and dtype).
- """
- raise NotImplementedError(
- f'{self.__class__.__name__} doesnt allow loading only sharded metadata'
- )
-
-
-class SaveCommonStrategy(SaveStrategyBase):
- @abstractmethod
- def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
- raise NotImplementedError
-
-
-class SaveShardedStrategy(SaveStrategyBase):
- @abstractmethod
- def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
- raise NotImplementedError
diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py
deleted file mode 100644
index 4a619353a19075f41b68a6cdae519313a7e305b8..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/strategies/tensorstore.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Strategies using TensorStore to load and save Zarr arrays. """
-
-from functools import partial
-from itertools import starmap
-from logging import getLogger
-from pathlib import Path
-
-import tensorstore as ts
-import torch
-
-from ..core import CheckpointingException
-from ..dict_utils import dict_list_map_inplace
-from ..mapping import ShardedStateDict, ShardedTensor
-from .base import LoadShardedStrategy, StrategyAction, default_strategies
-from .zarr import (
- load_zarr_based_sharded_metadata,
- numpy_to_torch_dtype_dict,
- postprocess_numpy_array,
-)
-
-_import_trigger = None
-
-logger = getLogger(__name__)
-
-
-class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
- def __init__(self, load_directly_on_device: bool = False):
- super().__init__()
- self.load_directly_on_device = load_directly_on_device
-
- def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
- if torch.distributed.get_rank() == 0:
- print(f'Loading distributed checkpoint with {self.__class__.__name__}')
- if self.load_directly_on_device:
- print(f'Loading distributed checkpoint directly on the GPU')
- load_fn = partial(
- _load_from_array,
- checkpoint_dir=checkpoint_dir,
- load_directly_on_device=self.load_directly_on_device,
- )
- dict_list_map_inplace(load_fn, sharded_state_dict)
- return sharded_state_dict
-
- def load_tensors_metadata(self, checkpoint_dir: Path):
- def get_ts_shape_dtype(path):
- arr = open_ts_array(path)
- return arr.shape, arr.dtype.numpy_dtype
-
- return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
-
- def check_backend_compatibility(self, loaded_version):
- pass # TODO
-
- def check_version_compatibility(self, loaded_version):
- pass # TODO
-
-
-def merge_global_slice_with_shape(global_slice, actual_shape, key):
- def _merge_slice(dim_slice, dim_size):
- if isinstance(dim_slice, slice):
- assert (
- dim_slice.start < dim_size
- ), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})'
- if dim_slice.stop > dim_size:
- dim_slice = slice(dim_slice.start, dim_size, dim_slice.step)
- return dim_slice
-
- assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key)
- return tuple(starmap(_merge_slice, zip(global_slice, actual_shape)))
-
-
-def _load_from_array(
- sharded_tensor: ShardedTensor,
- checkpoint_dir: Path,
- load_directly_on_device: bool = False,
- apply_flattened_range: bool = True,
-):
- x = _load_regular_chunk(sharded_tensor, checkpoint_dir)
- ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range)
- if load_directly_on_device:
- sharded_tensor.data.data.copy_(ten)
- return sharded_tensor.data
- else:
- return ten
-
-
-def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
- assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
- arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
- if sharded_tensor.global_shape == arr.shape:
- x = (
- arr[sharded_tensor.global_slice()].read().result()
- ) # flattened tensors loading is delayed
- elif sharded_tensor.allow_shape_mismatch:
- global_slice = merge_global_slice_with_shape(
- sharded_tensor.global_slice(), arr.shape, sharded_tensor.key
- )
- x = arr[global_slice].read().result() # flattened tensors loading is delayed
- else:
- _msg = (
- f'Global shape mismatch for loaded ({arr.shape})'
- f' and expected ({sharded_tensor.global_shape}) tensor'
- f' for key {sharded_tensor.key}'
- )
- raise CheckpointingException(_msg)
- return x
-
-
-def open_ts_array(arr_path: Path):
- """Opens a Zarr file array with Tensorstore with basic setting.
-
- Arguments:
- arr_path (Path): path to a Zarr (Tensorstore) array
- """
- spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
- spec['kvstore'] = {
- 'driver': 'file',
- 'path': str(arr_path),
- }
- try:
- arr = ts.open(ts.Spec(spec), open=True).result()
- except Exception as e:
- raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
- return arr
-
-
-default_strategies[StrategyAction.LOAD_SHARDED.value][
- ('zarr', 1)
-] = TensorStoreLoadShardedStrategy()
diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py
deleted file mode 100644
index a9844ff6e54a629fbe689771f172239a41687f5e..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/strategies/two_stage.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" 2-stage checkpoint loading. """
-import os
-import time
-from collections import defaultdict
-from dataclasses import dataclass
-from functools import partial, wraps
-from itertools import chain
-from logging import DEBUG, INFO, StreamHandler, getLogger
-from operator import attrgetter, itemgetter
-from pathlib import Path
-from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
-
-import torch
-
-from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
-from ..mapping import ShardedStateDict, ShardedTensor, StateDict
-from .base import LoadShardedStrategy
-from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
-from .zarr import flatten_range, load_zarr_based_sharded_metadata
-
-_import_trigger = None
-
-
-timers = defaultdict(list)
-
-logger = getLogger(__name__)
-
-
-def timed(verbose=True):
- def timed_dec(fn):
- name = fn.__name__
-
- @wraps(fn)
- def wrapped(*args, **kwargs):
- if verbose:
- logger.debug(f'{name} init')
- start = time.time()
- ret = fn(*args, **kwargs)
- took = time.time() - start
- if verbose:
- logger.debug(f'{name} took {took}s')
- timers[name].append(took)
- return ret
-
- return wrapped
-
- return timed_dec
-
-
-@dataclass
-class _ShardedTensorMetadata:
- global_rank: int
- sharded_tensor_no_data: ShardedTensor
- dist_group_rank: Tuple[int] # id of distributed group
- dist_group_ranks: Tuple[int] # id of distributed group
- data_size: Optional[int] = None # bytes
-
-
-def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
- return (
- sharded_tensor.key,
- sharded_tensor.global_offset,
- )
-
-
-class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
- """ Loads one checkpoint replica from storage and broadcasts to other nodes.
-
- This strategy loads checkpoint from storage on minimal set of nodes
- and distributes the checkpoint to other nodes with torch.distributed.
- Loading is performed with tensorstore.
-
- Steps:
- 0. (optional) create Gloo distributed groups
- 1. Exchange ShardedTensors metadata between all nodes
- 2. Align needed tensors within DP groups
- 3. For each globally unique tensor:
- a) on one of the ranks load it from storage to CPU and move to CUDA
- b) allocate CUDA tensor on other ranks
- c) broadcast within DP group
- d) copy tensor content to the model param location
- e) free tensor buffers from a) and b)
-
- Notes:
- 1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
- 2. There is a lot of overlap potential between all three steps done for each tensor:
- a) loading from storage to numpy
- b) moving CPU tensors to CUDA
- c) broadcast
-
- """
-
- def __init__(self, data_parallel_group, cpu_transfer=True):
- super().__init__()
-
- self.cpu_transfer = cpu_transfer
- self.data_parallel_group_orig = data_parallel_group
- self.data_parallel_group = None if cpu_transfer else data_parallel_group
- self.dp_group_ranks = tuple(
- sorted(torch.distributed.get_process_group_ranks(data_parallel_group))
- )
- self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig)
- self.global_rank = torch.distributed.get_rank()
-
- def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
- self.maybe_init_gloo_group()
- all_tensors_sorted = self._build_load_plan(sharded_state_dict)
- self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
- self.summarize_load_times()
- return sharded_state_dict
-
- def summarize_load_times(self):
- torch.distributed.barrier()
- logger.info('Checkpoint loading finished. Summary:')
- for key, times in sorted(timers.items()):
- times_sum = sum(times)
- max_times = torch.tensor([times_sum], device='cuda')
- avg_times = torch.tensor([times_sum], device='cuda')
- torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX)
- torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM)
- avg_times /= torch.distributed.get_world_size()
- if torch.distributed.get_rank() == 0:
- logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}')
-
- @timed(verbose=False)
- def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
- logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
- ret = _load_from_array(
- ten_meta.sharded_tensor_no_data,
- checkpoint_dir,
- load_directly_on_device=False,
- apply_flattened_range=False,
- )
- logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE')
- return ret
-
- @timed()
- def maybe_init_gloo_group(self):
- if not self.cpu_transfer:
- return
- all_groups = [None] * torch.distributed.get_world_size()
- torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
- all_groups = set(tuple(sorted(gr)) for gr in all_groups)
- for group_ranks in sorted(all_groups):
- gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
- if self.global_rank in group_ranks:
- self.data_parallel_group = gloo_pg
- assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group)
-
- def check_backend_compatibility(self, loaded_version):
- pass # TODO
-
- def check_version_compatibility(self, loaded_version):
- pass # TODO
-
- @timed()
- def _build_load_plan(
- self, sharded_state_dict: ShardedStateDict
- ) -> List[_ShardedTensorMetadata]:
- local_meta = [
- _ShardedTensorMetadata(
- self.global_rank,
- sharded_ten.without_data(),
- self.dp_group_rank,
- self.dp_group_ranks,
- )
- for sharded_ten in nested_values(sharded_state_dict)
- ]
- all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group)
- torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group)
- all_meta = list(chain.from_iterable(all_meta))
- all_tensors_sorted = self.deduplicate_chunks(all_meta)
- return all_tensors_sorted
-
- @timed()
- def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
- """ Group tensors by chunk and then pick the tensor with the lowest rank.
-
- NOTE: with proper loading overlap, loading from randomized ranks
- (instead of the smallest one) could be beneficial here.
- """
- ten_metas = map_reduce(
- ten_metas,
- key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data),
- reduce_fn=partial(min, key=attrgetter('dist_group_rank')),
- )
- all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items())))
- return all_metas_sorted
-
- @timed()
- def _exchange_loaded_tensors(
- self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir
- ):
- logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}')
- for ten_meta in ten_metas:
-
- src_rank = torch.distributed.get_global_rank(
- self.data_parallel_group, ten_meta.dist_group_rank
- )
-
- if self.dp_group_rank == ten_meta.dist_group_rank:
- exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta)
- if not self.cpu_transfer:
- exchange_tensor = exchange_tensor.cuda()
- else:
- # TODO: for non-flattened ranges we could reuse the buffer from the start here
- exchange_tensor = torch.empty(
- ten_meta.sharded_tensor_no_data.local_shape,
- device='cpu' if self.cpu_transfer else 'cuda',
- dtype=ten_meta.sharded_tensor_no_data.dtype,
- )
-
- logger.debug(
- f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
- )
- torch.distributed.broadcast(
- exchange_tensor, group=self.data_parallel_group, src=src_rank
- )
- self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict)
- logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done')
-
- # free buffer memory
- exchange_tensor = None
-
- @timed(verbose=False)
- def _distribute_data_to_state_dict(
- self,
- ten_meta: _ShardedTensorMetadata,
- loaded_ten: torch.Tensor,
- sharded_state_dict: ShardedStateDict,
- ):
- tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data)
-
- def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]):
- if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key:
- # already filled-in or key not matching
- return t
- sharded_tensor: ShardedTensor = t
- x = loaded_ten
- if sharded_tensor.flattened_range is not None:
- x = flatten_range(sharded_tensor, x)
-
- # Reuse existing buffer
- sharded_tensor.data.data.copy_(x)
- return sharded_tensor.data
-
- dict_list_map_inplace(_fill_in_data, sharded_state_dict)
-
- def load_tensors_metadata(self, checkpoint_dir: Path):
- def get_ts_shape_dtype(path):
- arr = open_ts_array(path)
- return arr.shape, arr.dtype.numpy_dtype
-
- return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
diff --git a/megatron/core/dist_checkpointing/strategies/zarr.py b/megatron/core/dist_checkpointing/strategies/zarr.py
deleted file mode 100644
index 0ce0cf0e27c8ab2441c6432840bb1c8f368632c3..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/strategies/zarr.py
+++ /dev/null
@@ -1,285 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-""" Strategies using Zarr as an underlying format. """
-import os
-from functools import partial
-from logging import getLogger
-from pathlib import Path
-from typing import Callable, List, Optional, Tuple
-
-import numpy as np
-import torch
-import zarr
-
-from ..core import CheckpointingException
-from ..dict_utils import dict_list_map_inplace
-from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
-from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
-
-numpy_to_torch_dtype_dict = {
- np.dtype('bool'): torch.bool,
- np.dtype('uint8'): torch.uint8,
- np.dtype('int8'): torch.int8,
- np.dtype('int16'): torch.int16,
- np.dtype('int32'): torch.int32,
- np.dtype('int64'): torch.int64,
- np.dtype('float16'): torch.float16,
- np.dtype('float32'): torch.float32,
- np.dtype('float64'): torch.float64,
- np.dtype('complex64'): torch.complex64,
- np.dtype('complex128'): torch.complex128,
-}
-
-torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
-
-
-try:
- import tensorstore
-
- HAS_BFLOAT16 = True
- numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
- torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16')
-except ImportError:
- HAS_BFLOAT16 = False
-
-_import_trigger = None
-
-logger = getLogger(__name__)
-
-
-class ZarrSaveShardedStrategy(SaveShardedStrategy):
- def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
- arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
- for ten, arr in zip(sharded_tensors, arrays):
- _save_to_existing_array(ten, arr)
- torch.distributed.barrier()
-
-
-def _create_or_open_zarr_arrays(
- sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
-) -> List[Optional[zarr.Array]]:
- """ Returns list of zarr arrays corresponding to given tensors.
-
- For a sharded tensors that:
- a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
- b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process)
- c) otherwise, sets the corresponding array to None since it won't be used
-
- Args:
- sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint
- checkpoint_dir (Path): checkpoint in which the arrays will be created
- """
- arrays = []
- for ten in sharded_tensors:
- arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None
- arrays.append(arr)
-
- torch.distributed.barrier()
- # Open arrays created above by other processes
- for arr_idx, ten in enumerate(sharded_tensors):
- if arrays[arr_idx] is not None:
- # array created by this process
- assert _should_create_array(ten), ten
- continue
- if not is_main_replica(ten.replica_id):
- # this array won't be needed for saving and can stay None
- continue
- open_kwargs = {}
- if ten.flattened_range is not None:
- open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
- str(checkpoint_dir / f'{ten.key}.sync')
- )
- arrays[arr_idx] = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs)
- return arrays
-
-
-def _should_create_array(ten: ShardedTensor):
- return (
- is_main_replica(ten.replica_id)
- and set(ten.global_offset) == {0}
- and (ten.flattened_range is None or ten.flattened_range.start == 0)
- )
-
-
-def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]):
- if not is_main_replica(sharded_tensor.replica_id):
- return
- assert arr is not None
- x = sharded_tensor.data
- x = x.detach().cpu()
- torch.cuda.synchronize()
- if x.dtype == torch.bfloat16:
- x = x.float()
- x = x.numpy()
- x = x.astype('bfloat16')
- else:
- x = x.numpy()
-
- if sharded_tensor.flattened_range is None:
- arr[sharded_tensor.global_slice()] = x
- else:
- arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x)
-
-
-def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
- np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype]
- try:
- arr = zarr.create(
- sharded_tensor.global_shape,
- dtype=np_dtype,
- store=checkpoint_dir / sharded_tensor.key,
- chunks=sharded_tensor.max_allowed_chunks(),
- compressor=None,
- fill_value=None,
- write_empty_chunks=True,
- )
- except zarr.errors.ContainsArrayError as e:
- raise CheckpointingException(
- f'Array {checkpoint_dir / sharded_tensor.key} already exists'
- ) from e
-
- if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'):
- arr._dtype = np_dtype
- zarray = arr.store['.zarray']
- arr.store['.zarray'] = zarray.replace(b' exp_sh:
- assert (
- False
- ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
- else:
- pad_args.extend((0, exp_sh - x_sh))
- # TODO: behavior control with envvar is for testing purposes only, remove it
- if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)):
- return torch.nn.functional.pad(x, pad_args)
-
- # unsqueeze and squeeze to get shapes supported by cudnn
- print(f'Replicating last row for {expected_sharded_ten.key}')
- if x.dtype == torch.bfloat16:
- return (
- torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate')
- .squeeze(0)
- .bfloat16()
- )
- return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
-
-
-def load_zarr_based_sharded_metadata(
- checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]]
-) -> ShardedStateDict:
- """Load metadata of Zarr arrays.
-
- Arguments:
- checkpoint_dir (str): checkpoint root directory
- get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
- an array shape and dtype for a given Zarr array path
- """
- sharded_state_dict = {}
- for subdir in checkpoint_dir.iterdir():
- if not subdir.is_dir() or not (subdir / '.zarray').exists():
- continue
- key = subdir.name
- arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
-
- sharded_state_dict[key] = ShardedTensor(
- key,
- None,
- numpy_to_torch_dtype_dict[arr_dtype],
- arr_shape,
- arr_shape,
- tuple(0 for _ in arr_shape),
- tuple(1 for _ in arr_shape),
- )
- return sharded_state_dict
-
-
-# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
-default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
- 'zarr', 1
-)
diff --git a/megatron/core/dist_checkpointing/utils.py b/megatron/core/dist_checkpointing/utils.py
deleted file mode 100644
index f7976f007408197338b9f9a96eec85db4d63d087..0000000000000000000000000000000000000000
--- a/megatron/core/dist_checkpointing/utils.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
-
-from typing import Tuple
-
-from .dict_utils import dict_list_map_inplace, extract_matching_values
-from .mapping import (
- LocalNonpersitentObject,
- ShardedStateDict,
- ShardedTensor,
- ShardedTensorFactory,
- StateDict,
-)
-
-
-def extract_sharded_tensors(
- sharded_state_dict: ShardedStateDict,
-) -> Tuple[ShardedStateDict, StateDict]:
- return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
-
-
-def extract_sharded_tensors_and_factories(
- sharded_state_dict: ShardedStateDict,
-) -> Tuple[ShardedStateDict, StateDict]:
- return extract_matching_values(
- sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
- )
-
-
-def extract_sharded_tensors_or_nonpersistent(
- sharded_state_dict: ShardedStateDict,
-) -> Tuple[ShardedStateDict, StateDict]:
- return extract_matching_values(
- sharded_state_dict,
- lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
- )
-
-
-def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
- def add_prefix(t):
- if isinstance(t, ShardedTensor):
- t.key = f'{prefix}.{t.key}'
- return t
-
- dict_list_map_inplace(add_prefix, sharded_state_dict)
diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py
deleted file mode 100644
index 34c7209a27fde7c5202f275663d951276caff85d..0000000000000000000000000000000000000000
--- a/megatron/core/distributed/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .distributed_data_parallel import DistributedDataParallel
-from .finalize_model_grads import finalize_model_grads
diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py
deleted file mode 100644
index 63f6e3d65ec2bb3a7d771f3dd6fea61216112d67..0000000000000000000000000000000000000000
--- a/megatron/core/distributed/distributed_data_parallel.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from contextlib import contextmanager
-from typing import Dict
-
-import torch
-
-from .. import parallel_state
-from ..transformer.module import MegatronModule
-from ..transformer.transformer_config import TransformerConfig
-from .grad_buffer import GradBuffer
-
-
-class DistributedDataParallel(MegatronModule):
- """
- DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
- communication with backprop computation by breaking up full model's gradients into smaller
- buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
- also provides the option to do the gradient accumulation in a type other than the param type
- (e.g., fp32 for a bf16 model).
-
- Arguments:
- config: Transformer config object.
- module: Underlying model.
- data_parallel_group: Data-parallel process group.
- accumulate_allreduce_grads_in_fp32: If true, do the gradient accumulation and
- communication in fp32.
- overlap_grad_reduce: If true, overlap communication with backprop computation by
- breaking up grads into buckets. If false, single synchronous communication call
- is used instead.
- use_distributed_optimizer: If true, issue reduce-scatter communication calls as part
- of distributed optimizer. If false, issue all-reduce communication calls.
- disable_bucketing: If true, force assign all parameters to a single bucket. If false,
- use standard bucketing policy: assign parameters to smaller buckets and all-reduce
- per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
-
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- module: torch.nn.Module,
- data_parallel_group: torch.distributed.ProcessGroup,
- accumulate_allreduce_grads_in_fp32: bool,
- overlap_grad_reduce: bool,
- use_distributed_optimizer: bool,
- disable_bucketing: bool = False,
- bucket_size: int = 40000000,
- ):
- super().__init__(config=config)
- self.module = module
-
- # Set bucket_size to infinity if overlap_grad_reduce is False.
- self.overlap_grad_reduce = overlap_grad_reduce
- self.use_distributed_optimizer = use_distributed_optimizer
-
- # Turn off bucketing if overlap_grad_reduce is False, if we are on a pipeline stage
- # that is not the first (since data-parallel communication on these stages is not on
- # the critical path), or if disable_bucketing is True (e.g., we might not want to
- # break up model parameters into buckets for model chunks after the first
- # in the interleaved schedule).
- if not self.overlap_grad_reduce:
- bucket_size = None
- if parallel_state.get_pipeline_model_parallel_rank() > 0:
- bucket_size = None
- if disable_bucketing:
- bucket_size = None
- self.bucket_size = bucket_size
-
- self.module = module
- self.grad_buffers = {}
- self.expert_grads = []
- self.grad_buffer_param_index_map = {}
- self.param_to_grad_buffer = {}
-
- # Group parameters by their gradient type.
- grad_dtype_to_params = {}
- param_to_name = {}
- for name, param in self.module.named_parameters():
- if param.requires_grad and getattr(param, 'allreduce', True):
- param.grad_added_to_main_grad = False
- param_to_name[param] = name
- dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype
-
- params = grad_dtype_to_params.get(dtype, [])
- params.append(param)
- grad_dtype_to_params[dtype] = params
-
- # Allocate the grad buffers and map the grads.
- # The grad buffer under the hood creates buckets as appropriate based on bucket_size.
- self.data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group)
- for dtype, params in grad_dtype_to_params.items():
- self.grad_buffers[dtype] = GradBuffer(
- dtype,
- params,
- data_parallel_group,
- bucket_size,
- param_to_name,
- self.overlap_grad_reduce,
- self.use_distributed_optimizer,
- )
- self.grad_buffer_param_index_map[dtype] = self.grad_buffers[dtype].param_index_map
- for param in params:
- self.param_to_grad_buffer[param] = self.grad_buffers[dtype]
-
- # Allocate separate buffer for MoE params' grads.
- for param in self.module.parameters():
- if param.requires_grad and not getattr(param, 'allreduce', True):
- param.grad_added_to_main_grad = False
- dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype
- param.main_grad = torch.zeros(
- param.data.shape,
- dtype=dtype,
- device=torch.cuda.current_device(),
- requires_grad=False,
- )
- self.expert_grads.append(param.main_grad)
-
- # Register backward hook.
- # Accumulation function for the gradients need to be stored so they
- # don't go out of scope.
- self.grad_accs = []
- for param in self.module.parameters():
- if param.requires_grad:
- # Expand so we get access to grad_fn.
- param_tmp = param.expand_as(param)
- # Get the gradient accumulator function.
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
- grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer))
- self.grad_accs.append(grad_acc)
-
- def forward(self, *inputs, **kwargs):
- """
- Calls the wrapped module's forward() method.
- """
- return self.module(*inputs, **kwargs)
-
- def _make_param_hook(
- self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer]
- ):
- """
- Creates the all-reduce / reduce-scatter hook for backprop.
- """
-
- def param_hook(*unused):
- if param.requires_grad:
- if self.overlap_grad_reduce:
- assert (
- param.grad is not None
- ), 'param.grad being None is not safe when overlap_grad_reduce is True'
- if param.grad is not None and not param.grad_added_to_main_grad:
- param.main_grad.add_(param.grad.data)
- param.grad = None
- if self.overlap_grad_reduce:
- param_to_grad_buffer[param].register_grad_ready(param)
-
- return param_hook
-
- @contextmanager
- def no_sync(self):
- """
- Context manager that turns off gradient synchronization.
- """
- for grad_buffer in self.grad_buffers.values():
- grad_buffer.is_last_microbatch = False
- try:
- yield
- finally:
- for grad_buffer in self.grad_buffers.values():
- grad_buffer.is_last_microbatch = True
-
- def start_grad_sync(self, *unused):
- """
- Initiates grad sync (all-reduce or reduce-scatter) communication operations
- for all model gradients.
-
- When overlap_grad_reduce is set to True, dispatches asynchronous communication
- calls. When overlap_grad_reduce is set to False, calls synchronous
- communication ops.
- """
- for grad_buffer in self.grad_buffers.values():
- grad_buffer.start_grad_sync()
-
- def finish_grad_sync(self):
- """
- Finishes grad sync (all-reduce or reduce-scatter) communication operations
- for all model gradients.
-
- When overlap_grad_reduce is set to True, waits for asynchronous communication
- calls to complete. When overlap_grad_reduce is set to False, calls synchronous
- communication ops.
- """
- for grad_buffer in self.grad_buffers.values():
- grad_buffer.finish_grad_sync()
-
- for expert_grad in self.expert_grads:
- expert_grad /= self.data_parallel_world_size
-
- def zero_grad_buffer(self, zero_buffer):
- """
- Zeros out all grad buffers. Needs to be called at the beginning of each
- training iteration.
-
- When zero_buffer is set to True, the underlying grad buffer is zeroed out.
- """
- for param in self.module.parameters():
- if param.requires_grad:
- param.grad_added_to_main_grad = False
- for grad_buffer in self.grad_buffers.values():
- grad_buffer.reset(zero_buffer)
- for expert_grad in self.expert_grads:
- expert_grad.zero_()
-
- def broadcast_params(self):
- """
- Syncs parameters across all DP ranks.
- """
- for param in self.module.parameters():
- torch.distributed.broadcast(
- param.data,
- src=parallel_state.get_data_parallel_src_rank(with_context_parallel=True),
- group=parallel_state.get_data_parallel_group(with_context_parallel=True),
- )
-
- def state_dict(self, prefix='', keep_vars=False):
- """
- Returns a dictionary containing references to the whole state of the
- wrapped module.
-
- Both parameters and persistent buffers (e.g. running averages) are included.
- Keys are corresponding parameter and buffer names. Parameters and buffers
- set to None are not included.
- """
- return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """
- Returns wrapped module's state_dict for checkpoint saving.
- """
- return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
-
- def load_state_dict(self, state_dict, strict=True):
- """
- Copies parameters and buffers from state_dict into the wrapped module and its
- descendants. If strict is True, then the keys of state_dict must exactly match
- the keys returned by this module’s state_dict() function.
- """
- self.module.load_state_dict(state_dict, strict=strict)
diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py
deleted file mode 100644
index 916e4f3ecbffafca7f97d2b33193bb289e12228d..0000000000000000000000000000000000000000
--- a/megatron/core/distributed/finalize_model_grads.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from typing import List
-
-import torch
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-
-from .. import parallel_state
-from ..transformer.transformer_config import TransformerConfig
-from ..utils import get_attr_wrapped_model, get_model_config
-
-
-def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
- """
- All-reduce word embedding grads.
-
- Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
- sync. This should only run for models that support pipelined model parallelism (BERT and GPT).
- """
-
- if (
- parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
- and parallel_state.get_pipeline_model_parallel_world_size() > 1
- ):
- if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
- model_module = model[0]
- elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- model_module = model[-1]
- else: # We do not support the interleaved schedule for T5 yet.
- model_module = model[0]
-
- # Look for module with 'pre_process' attribute to get around the fact that DDP and
- # other wrapper classes inherit from non-core MegatronModule that has
- # 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight'
- # attributes already, causing get_attr_wrapped_model() to not unwrap anything here.
- # TODO: Clean this up once the wrapper classes inherit from core MegatronModule.
- model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
- if model_module.share_embeddings_and_output_weights:
- weight = model_module.shared_embedding_or_output_weight()
- grad = weight.main_grad
- torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
-
-
-def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
- """
- All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to
- ensure that position embeddings parameters stay in sync. This should only run for T5 models
- with pipeline parallelism.
- """
- if (
- parallel_state.is_rank_in_position_embedding_group()
- and parallel_state.get_pipeline_model_parallel_world_size() > 1
- and config.pipeline_model_parallel_split_rank is not None
- ):
- model_module = model[0]
- grad = get_attr_wrapped_model(
- model_module, 'language_model.embedding.position_embeddings.weight.main_grad'
- )
- torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
-
-
-def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
- """
- All-reduce both word and position embeddings.
- """
- _allreduce_word_embedding_grads(model, config)
- _allreduce_position_embedding_grads(model, config)
-
-
-def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
- """
- All-reduce layernorm grads (for sequence parallelism).
- """
-
- # All-reduce layernorm parameters across model parallel nodes
- # when sequence parallelism is used
- if parallel_state.get_tensor_model_parallel_world_size() > 1 and config.sequence_parallel:
- grads = []
- for model_chunk in model:
- for param in get_attr_wrapped_model(model_chunk, 'parameters')():
- if getattr(param, 'sequence_parallel', False):
- grad = param.main_grad
- grads.append(grad.data)
- coalesced = _flatten_dense_tensors(grads)
- torch.distributed.all_reduce(
- coalesced, group=parallel_state.get_tensor_model_parallel_group()
- )
- for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
- buf.copy_(synced)
-
-
-def _allreduce_expert_grads(model: List[torch.nn.Module], config: TransformerConfig):
- """
- All-reduce expert grads (for expert parallelism).
- """
-
- # All-reduce switchmlp parameters across data modulo expert parallel nodes
- if (
- config.expert_model_parallel_size > 1
- and config.expert_model_parallel_size < parallel_state.get_data_parallel_world_size()
- ):
- grads = []
- for model_chunk in model:
- for param in get_attr_wrapped_model(model_chunk, 'parameters')():
- if not getattr(param, 'allreduce', True):
- grad = param.main_grad
- grads.append(grad.data)
- coalesced = _flatten_dense_tensors(grads)
- torch.distributed.all_reduce(
- coalesced, group=parallel_state.get_data_modulo_expert_parallel_group()
- )
- for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
- buf.copy_(synced)
-
-
-def finalize_model_grads(model: List[torch.nn.Module]):
- """
- All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
- embedding grads across first and last pipeline stages (if not tied), and expert grads
- for expert parallelism.
- """
-
- config = get_model_config(model[0])
-
- # All-reduce / reduce-scatter across DP replicas.
- if config.timers is not None:
- config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
- for model_chunk in model:
- model_chunk.finish_grad_sync()
- if config.timers is not None:
- config.timers('all-grads-sync').stop()
-
- # All-reduce layer-norm grads (for sequence parallelism).
- if config.timers is not None:
- config.timers('layernorm-grads-all-reduce', log_level=1).start(
- barrier=config.barrier_with_L1_time
- )
- _allreduce_layernorm_grads(model, config)
- if config.timers is not None:
- config.timers('layernorm-grads-all-reduce').stop()
-
- # All-reduce embedding grads (for pipeline parallelism).
- if config.timers is not None:
- config.timers('embedding-grads-all-reduce', log_level=1).start(
- barrier=config.barrier_with_L1_time
- )
- _allreduce_embedding_grads(model, config)
- if config.timers is not None:
- config.timers('embedding-grads-all-reduce').stop()
-
- # All-reduce expert grads (for expert parallelism).
- if config.timers is not None:
- config.timers('expert-grads-all-reduce', log_level=1).start(
- barrier=config.barrier_with_L1_time
- )
- _allreduce_expert_grads(model, config)
- if config.timers is not None:
- config.timers('expert-grads-all-reduce').stop()
diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py
deleted file mode 100644
index 8bc88a8e710db31840c80444ae726f0b6bd6c1be..0000000000000000000000000000000000000000
--- a/megatron/core/distributed/grad_buffer.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import math
-from logging import getLogger
-from typing import Dict, List
-
-import torch
-
-from .. import parallel_state
-
-logger = getLogger(__name__)
-
-
-def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
- """
- Shard buffer into data_parallel_world_size chunks of equal size.
- """
- assert buffer.numel() % data_parallel_world_size == 0
- shard_size = buffer.numel() // data_parallel_world_size
- sharded_buffer = [
- buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
- ]
- return sharded_buffer
-
-
-class Bucket:
- """
- Bucket to keep track of a subset of the model's gradients. Provides functionality to register
- when params in the bucket have grads ready to be synced; an asynchronous communication call
- is automatically launched when _all_ params in the bucket have grads ready.
-
- Arguments:
- params: List of parameters whose gradients are collated in this bucket.
- data: View in larger GradBuffer that this bucket is responsible for.
- offset: Offset of this bucket's view in the larger GradBuffer.
- data_parallel_group: Data-parallel process group.
- data_parallel_world_size: World size using the data-parallel group group.
- overlap_grad_reduce: If true, overlap communication with backprop computation by
- breaking up grads into buckets. If false, single synchronous communication call
- is used instead.
- use_distributed_optimizer: If true, issue reduce-scatter communication calls as part
- of distributed optimizer. If false, issue all-reduce communication calls.
- """
-
- def __init__(
- self,
- params: List[torch.nn.Parameter],
- data: torch.Tensor,
- offset: int,
- data_parallel_group: torch.distributed.ProcessGroup,
- data_parallel_world_size: int,
- overlap_grad_reduce: bool,
- use_distributed_optimizer: bool,
- ):
- # State for bookkeeping: params is the set of parameters this bucket is
- # responsible for, params_with_grad is the set of parameters with grads
- # available. When overlap_grad_reduce is True, communication (all-reduce
- # or reduce-scatter) is issued when params_with_grad equals params.
- self.params_list = params
- self.params = set(params)
- self.params_with_grad = set()
- self.data = data
- # The distributed optimizer needs to keep track of this bucket's offset
- # within the full grad_buffer.
- self.offset = offset
- self.data_parallel_group = data_parallel_group
- self.data_parallel_world_size = data_parallel_world_size
- self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
- self.overlap_grad_reduce = overlap_grad_reduce
- self.use_distributed_optimizer = use_distributed_optimizer
-
- self.reset()
-
- def reset(self):
- """
- Reset metadata in bucket in preparation for the next iteration of training.
- """
- self.params_with_grad = set()
- self.communication_handle = None
- self.communication_issued = False
-
- def start_grad_sync(self):
- """
- Initiates grad sync (all-reduce or reduce-scatter) communication operation
- for this bucket.
-
- When overlap_grad_reduce is set to True, dispatches an asynchronous
- communication call. When overlap_grad_reduce is set to False, makes
- synchronous call.
- """
- assert (
- self.communication_handle is None and not self.communication_issued
- ), 'Should not have multiple communication calls in flight at once'
-
- self.data /= self.data_parallel_world_size
- # Use async_op only when overlap_grad_reduce is True.
- if self.use_distributed_optimizer:
- local_data_view = shard_buffer(self.data, self.data_parallel_world_size)[
- self.data_parallel_rank
- ]
- self.communication_handle = torch.distributed._reduce_scatter_base(
- local_data_view,
- self.data,
- group=self.data_parallel_group,
- async_op=self.overlap_grad_reduce,
- )
- else:
- self.communication_handle = torch.distributed.all_reduce(
- self.data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce
- )
- self.communication_issued = True
-
- def finish_grad_sync(self):
- """
- Finishes grad sync (all-reduce or reduce-scatter) communication operation
- for this bucket.
-
- When overlap_grad_reduce is set to True, waits for asynchronous communication
- call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
- """
- # If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
- if not self.overlap_grad_reduce:
- self.start_grad_sync()
- return
- assert self.communication_handle is not None and self.communication_issued, (
- f'Communication call has not been issued for this bucket '
- f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
- )
- self.communication_handle.wait()
-
- def register_grad_ready(self, param: torch.nn.Parameter):
- """
- Registers grads for the passed-in param to be "ready" for grad sync.
-
- When the number of microbatches is greater than 1, we only want to register
- grads as ready when processing the last microbatch and overlap_grad_reduce is True.
- """
- assert param in self.params, 'Param is not in the bucket'
- assert param not in self.params_with_grad, 'Cannot set grad twice'
- assert (
- self.overlap_grad_reduce
- ), 'register_grad_ready() should be called only when overlapping grad reduce'
- self.params_with_grad.add(param)
- # If all params in bucket have grads available, issue communication call.
- if len(self.params_with_grad) == len(self.params):
- self.start_grad_sync()
-
-
-class GradBuffer:
- """
- Groups gradients into a contiguous buffer, and then breaks the buffer into buckets with
- roughly `bucket_size` parameters each.
-
- Arguments:
- dtype: Type of underlying tensor.
- params: List of parameters whose gradients are collated in the underlying tensor.
- data_parallel_group: Data-parallel process group.
- bucket_size: The rough size of each bucket in terms of number of parameters.
- param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
- overlap_grad_reduce: If true, overlap communication with backprop computation by
- breaking up grads into buckets. If false, single synchronous communication call
- is used instead.
- use_distributed_optimizer: If true, issue reduce-scatter communication calls as part
- of distributed optimizer. If false, issue all-reduce communication calls.
- """
-
- def __init__(
- self,
- dtype: torch.dtype,
- params: List[torch.nn.Parameter],
- data_parallel_group: torch.distributed.ProcessGroup,
- bucket_size: int,
- param_to_name: Dict[torch.nn.Parameter, str],
- overlap_grad_reduce: bool,
- use_distributed_optimizer: bool,
- ):
-
- # Check that params are unique.
- unique_params = set()
- for param in params:
- assert param not in unique_params
- unique_params.add(param)
- del unique_params
-
- # Store attributes that will be needed later.
- self.dtype = dtype
- self.data_parallel_group = data_parallel_group
- self.data_parallel_world_size = torch.distributed.get_world_size(
- group=self.data_parallel_group
- )
- self.overlap_grad_reduce = overlap_grad_reduce
- self.use_distributed_optimizer = use_distributed_optimizer
- self.is_last_microbatch = True
-
- # Data structures to store underlying buckets and relevant indexing data.
- self.buckets = []
- self.param_to_bucket = {} # Param -> bucket mapping.
- self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
-
- def _pad_if_needed(data_index: int):
- """Pads data indices if using distributed optimizer (to ensure uniform sharding)."""
- if use_distributed_optimizer:
- return (
- int(math.ceil(data_index / self.data_parallel_world_size))
- * self.data_parallel_world_size
- )
- return data_index
-
- # First, figure out how many elements should be in the underlying buffer storage.
- # Note that if we need to split the buffer into smaller buckets, each of these
- # might need to be padded as well (if using the distributed optimizer).
- data_start_index = 0
- bucket_data_start_index = data_start_index
- bucket_params = set()
- self.bucket_indices = []
- bucket_id = 0
- for param in params[::-1]:
- # Iterate through parameters in reverse order to roughly follow backprop order,
- # and skip parameters that don't require gradients.
- if not param.requires_grad:
- continue
- this_numel = param.data.nelement()
- data_end_index = data_start_index + this_numel
- self.param_index_map[param] = (
- data_start_index,
- data_end_index,
- bucket_id,
- )
- bucket_params.add(param)
-
- # If we have enough elements already, form a new bucket.
- # If bucket_size is None, accumulate everything into a single bucket.
-
- # TODO: Remove len(bucket_params) > 1 when the final head that transforms token
- # representations from hidden space to vocabulary space is in a PyTorch module
- # whose forward method is called. If it is not and a bucket contains only this
- # one parameter, we get incorrect behavior (i.e., higher losses) since we do not
- # call the wait function on the bucket's all_gather_handle (we use forward pre-
- # hooks on PyTorch modules to do this when --overlap-param-gather is used).
- # As a temporary workaround, we make sure that no bucket has only one parameter.
- if bucket_size is not None:
- if (data_end_index - bucket_data_start_index) >= bucket_size and len(
- bucket_params
- ) > 1:
- data_end_index = _pad_if_needed(data_end_index)
- self.bucket_indices.append((bucket_data_start_index, data_end_index))
- bucket_data_start_index = data_end_index
- bucket_params = set()
- bucket_id += 1
- data_start_index = data_end_index
-
- # Add remaining params to a new bucket.
- if len(bucket_params) > 0:
- data_end_index = _pad_if_needed(data_end_index)
- self.bucket_indices.append((bucket_data_start_index, data_end_index))
-
- # Next, create underlying storage for buffer (with numel elements that includes
- # padding as necessary).
- self.numel = data_end_index
- if use_distributed_optimizer:
- assert self.numel % self.data_parallel_world_size == 0
- self.data = torch.zeros(
- self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False,
- )
-
- # Finally, map main_grad fields for each parameter with a .grad field.
- bucket_params = set()
- bucket_data_start_index = 0
- cur_bucket_id = 0
- for param in params[::-1]:
- if not param.requires_grad:
- continue
- data_start_index, data_end_index, bucket_id = self.param_index_map[param]
- param.main_grad = self._get(param.data.shape, data_start_index)
- if bucket_id != cur_bucket_id:
- bucket_data_end_index = _pad_if_needed(data_start_index)
- self._set_bucket(
- bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id
- )
- bucket_data_start_index = bucket_data_end_index
- bucket_params = set()
- assert cur_bucket_id + 1 == len(self.buckets)
- assert bucket_id == cur_bucket_id + 1
- cur_bucket_id = bucket_id
- bucket_params.add(param)
-
- # Add remaining params to a new bucket.
- if len(bucket_params) > 0:
- bucket_data_end_index = _pad_if_needed(data_end_index)
- self._set_bucket(
- bucket_params, bucket_data_start_index, bucket_data_end_index, cur_bucket_id
- )
-
- if not overlap_grad_reduce:
- assert len(bucket_params) == len(
- params
- ), 'All params should be in one bucket when overlap_grad_reduce is False'
-
- # Log buckets for all PP stages.
- if (
- parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
- and parallel_state.get_tensor_model_parallel_rank() == 0
- ):
- logger.info(
- f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
- )
- for index, bucket in enumerate(self.buckets):
- numel = 0
- for param in bucket.params:
- numel += param.data.nelement()
- logger.info(f'Params for bucket {index+1} ({numel} elements):')
- for param in bucket.params:
- logger.info(f' {param_to_name[param]}')
-
- def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor:
- """
- Return a tensor with the input `shape` as a view into the 1-D data starting at
- `start_index`.
- """
- end_index = start_index + shape.numel()
- assert end_index <= self.numel, 'Requested tensor is out of buffer range'
- buffer_tensor = self.data[start_index:end_index]
- buffer_tensor = buffer_tensor.view(shape)
- return buffer_tensor
-
- def _set_bucket(
- self,
- bucket_params: List[torch.nn.Parameter],
- start_index: int,
- end_index: int,
- bucket_id: int,
- ):
- """
- Helper function to create new bucket, add it to list of buckets, and
- also update param->bucket mapping.
- """
-
- # Assert that indices are correctly padded (if needed), and that bucket
- # position is same as originally computed.
- if self.use_distributed_optimizer:
- assert start_index % self.data_parallel_world_size == 0
- assert end_index % self.data_parallel_world_size == 0
- assert (start_index, end_index) == self.bucket_indices[bucket_id]
-
- # Get appropriate view into global GradBuffer.
- bucket_data = self._get(torch.Size([end_index - start_index]), start_index)
- bucket = Bucket(
- params=bucket_params,
- data=bucket_data,
- offset=start_index,
- data_parallel_group=self.data_parallel_group,
- data_parallel_world_size=self.data_parallel_world_size,
- overlap_grad_reduce=self.overlap_grad_reduce,
- use_distributed_optimizer=self.use_distributed_optimizer,
- )
- self.buckets.append(bucket)
- for bucket_param in bucket_params:
- assert bucket_param not in self.param_to_bucket
- self.param_to_bucket[bucket_param] = bucket
-
- def reset(self, zero_buffer):
- """
- Zero out the underlying buffer and reset all buckets in preparation for the next
- iteration of training.
-
- When zero_buffer is set to True, the underlying buffer is zeroed out.
- """
- if zero_buffer:
- self.data.zero_()
- for bucket in self.buckets:
- bucket.reset()
- self.is_last_microbatch = True
-
- def start_grad_sync(self):
- """
- Initiates grad sync (all-reduce or reduce-scatter) communication operations
- for all buckets in the grad buffer.
-
- When overlap_grad_reduce is set to True, dispatches asynchronous communication
- calls. When overlap_grad_reduce is set to False, calls synchronous
- communication ops.
- """
- for bucket in self.buckets:
- bucket.start_grad_sync()
-
- def finish_grad_sync(self):
- """
- Finishes grad sync (all-reduce or reduce-scatter) communication operations
- for all buckets in the grad buffer.
-
- When overlap_grad_reduce is set to True, waits for asynchronous communication
- calls to complete. When overlap_grad_reduce is set to False, calls synchronous
- communication ops.
- """
- for bucket in self.buckets:
- bucket.finish_grad_sync()
-
- def register_grad_ready(self, param: torch.nn.Parameter):
- """
- Registers grads for the passed-in param to be "ready" for grad sync.
-
- When the number of microbatches is greater than 1, we only want to register
- grads as ready when processing the last microbatch and overlap_grad_reduce is True.
- """
- assert (
- self.overlap_grad_reduce
- ), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
- if self.is_last_microbatch:
- bucket = self.param_to_bucket[param]
- bucket.register_grad_ready(param)
diff --git a/megatron/core/enums.py b/megatron/core/enums.py
deleted file mode 100644
index 46e7d3b766af061cd36363f8486f75f7ad80b08f..0000000000000000000000000000000000000000
--- a/megatron/core/enums.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import enum
-
-
-class ModelType(enum.Enum):
- encoder_or_decoder = 1
- encoder_and_decoder = 2
- retro_encoder = 3
- retro_decoder = 4
diff --git a/megatron/core/fusions/__init__.py b/megatron/core/fusions/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/fusions/fused_bias_dropout.py b/megatron/core/fusions/fused_bias_dropout.py
deleted file mode 100644
index 14c1fe0d718223ba78830cf3099ac02907e65fc2..0000000000000000000000000000000000000000
--- a/megatron/core/fusions/fused_bias_dropout.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-
-
-def _bias_dropout_add_func(x_with_bias, residual, prob, training):
- # type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
- # NOTE: Previously, the argument `bias` used to be passed as
- # `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
- # transformer layer but broadcasting should automatically take care of that.
- # Also, looking at broadcasting semantics, `expand_as` and broadcasting
- # seem to be identical performance-wise (both just change the view).
-
- x, bias = x_with_bias # unpack
-
- # If we want to train mixed precision, then the output of this function
- # should be half precision. However, in AMP O1, the input (residual) is
- # in fp32, and it will up-cast the result to fp32, causing pipeline parallel
- # GPU communication to hang. Therefore, we need to cast residual to the same
- # dtype as x.
- residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
-
- # The Dropout operation, Residual Addition and the tensor returning can be
- # done generically outside the if statement, but that stops fusing of Bias
- # Addition-Dropout-Residual Addition operation. So doing it together inside
- # the conditional branch to improve performance
- if bias is not None:
- x = x + bias
- out = torch.nn.functional.dropout(x, p=prob, training=training)
- out = residual + out
- return out
- else:
- out = torch.nn.functional.dropout(x, p=prob, training=training)
- out = residual + out
- return out
-
-
-def bias_dropout_add_unfused(training):
- def _bias_dropout_add(x_with_bias, residual, prob):
- return _bias_dropout_add_func(x_with_bias, residual, prob, training)
-
- return _bias_dropout_add
-
-
-@torch.jit.script
-def bias_dropout_add_fused_train(
- x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
-) -> torch.Tensor:
- return _bias_dropout_add_func(x_with_bias, residual, prob, True)
-
-
-@torch.jit.script
-def bias_dropout_add_fused_inference(
- x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
-) -> torch.Tensor:
- return _bias_dropout_add_func(x_with_bias, residual, prob, False)
-
-
-def get_bias_dropout_add(training, fused):
- if fused:
- # jit scripting for a nn.module (with dropout) is not
- # triggering the fusion kernel. For now, we use two
- # different nn.functional routines to account for varying
- # dropout semantics during training and inference phases.
- if training:
- return bias_dropout_add_fused_train
- else:
- return bias_dropout_add_fused_inference
- else:
- return bias_dropout_add_unfused(training)
diff --git a/megatron/core/fusions/fused_bias_gelu.py b/megatron/core/fusions/fused_bias_gelu.py
deleted file mode 100644
index 9c791c180765b99c49e78dedf63444b57fed5ec1..0000000000000000000000000000000000000000
--- a/megatron/core/fusions/fused_bias_gelu.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-###### BIAS GELU FUSION/ NO AUTOGRAD ################
-# 1/sqrt(2*pi)-> 0.3989423
-# 1/sqrt(2) -> 0.70710678
-# sqrt(2/pi) -> 0.79788456
-# this function is tanh approximation of gelu
-# actual gelu is:
-# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
-
-
-@torch.jit.script
-def bias_gelu(bias, y):
- x = bias + y
- return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
-
-
-# gradient of tanh approximation of gelu
-# gradient of actual gelu is:
-# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
-@torch.jit.script
-def bias_gelu_back(g, bias, y):
- x = bias + y
- tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
- # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
- ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
- 1 + tanh_out
- )
- return ff * g
-
-
-class GeLUFunction(torch.autograd.Function):
- @staticmethod
- # bias is an optional argument
- def forward(ctx, input, bias):
- ctx.save_for_backward(input, bias)
- return bias_gelu(bias, input)
-
- @staticmethod
- def backward(ctx, grad_output):
- input, bias = ctx.saved_tensors
- tmp = bias_gelu_back(grad_output, bias, input)
- return tmp, tmp
-
-
-bias_gelu_impl = GeLUFunction.apply
diff --git a/megatron/core/fusions/fused_layer_norm.py b/megatron/core/fusions/fused_layer_norm.py
deleted file mode 100644
index c12ec173d0aabf9548f00bbdd4d1cbadb4d1a9e3..0000000000000000000000000000000000000000
--- a/megatron/core/fusions/fused_layer_norm.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import importlib
-import numbers
-
-import torch
-from torch import Tensor
-from torch.nn import init
-from torch.nn.parameter import Parameter
-
-from megatron.core.transformer import TransformerConfig
-from megatron.core.utils import make_viewless_tensor
-
-try:
- from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
-
- HAVE_PERSIST_LAYER_NORM = True
-except:
- HAVE_PERSIST_LAYER_NORM = False
-
-try:
- from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
-
- HAVE_FUSED_LAYER_NORM = True
-except:
- HAVE_FUSED_LAYER_NORM = False
-
-
-class FusedLayerNorm(torch.nn.Module):
-
- """Layer Norm, fused into a single CUDA kernel.
-
- Arguments:
- hidden_size (int): Transformer hidden dimension.
-
- eps (float): Epsilon added to denominator, for numerical stability.
-
- persist_layer_norm (bool): Use persistent fused layer norm kernel.
- This kernel supports only a set of hidden sizes. Please
- check persist_ln_hidden_sizes if your hidden size is supported.
-
- sequence parallel (bool): Apply sequence parallelism optimization.
-
- zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
- centered around zero. This improves numerical stability.
-
- config (TransformerConfig): Transformer config. Include to match custom
- layer norm interfaces.
-
- normalization (str): Normalization type, used for Transformer Engine.
- Must equal 'LayerNorm' here.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- hidden_size: int,
- eps: float = 1e-5,
- persist_layer_norm: bool = True,
- sequence_parallel: bool = False,
- zero_centered_gamma: bool = False,
- normalization: str = "LayerNorm", # included to match TE interface
- ):
- super().__init__()
-
- self.zero_centered_gamma = config.layernorm_zero_centered_gamma
- assert (
- config.normalization == "LayerNorm"
- ), f'({config.normalization}) is not supported in FusedLayerNorm'
-
- # List of hiddens sizes supported in the persistent layer norm kernel
- # If the hidden size is not supported, fall back to the non-persistent
- # kernel.
- persist_ln_hidden_sizes = [
- 1024,
- 1536,
- 2048,
- 2304,
- 3072,
- 3840,
- 4096,
- 5120,
- 6144,
- 8192,
- 10240,
- 12288,
- 12800,
- 15360,
- 16384,
- 18432,
- 20480,
- 24576,
- 25600,
- 30720,
- 32768,
- 40960,
- 49152,
- 65536,
- ]
- persist_layer_norm = config.persist_layer_norm
- if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
- persist_layer_norm = False
-
- if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
- # TODO: Add pytorch only layer norm
- raise ValueError(f'Apex must currently be installed to use megatron core.')
-
- if isinstance(hidden_size, numbers.Integral):
- hidden_size = (hidden_size,)
- self.hidden_size = torch.Size(hidden_size)
- self.eps = eps
- self.weight = Parameter(torch.Tensor(*hidden_size))
- self.bias = Parameter(torch.Tensor(*hidden_size))
- self.reset_parameters()
- self.persist_layer_norm = persist_layer_norm
- self.sequence_parallel = config.sequence_parallel
-
- # set sequence parallelism flag on weight and bias parameters
- setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
- setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
-
- def reset_parameters(self):
-
- if self.zero_centered_gamma:
- init.zeros_(self.weight)
- init.zeros_(self.bias)
- else:
- init.ones_(self.weight)
- init.zeros_(self.bias)
-
- def forward(self, input: Tensor) -> Tensor:
-
- weight = self.weight + 1 if self.zero_centered_gamma else self.weight
-
- if self.persist_layer_norm:
- output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
-
- # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
- # a populated '_base' field). This will result in schedule.py's
- # deallocate_output_tensor() throwing an error, so a viewless tensor is
- # created to prevent this.
- output = make_viewless_tensor(
- inp=output, requires_grad=input.requires_grad, keep_graph=True
- )
-
- else:
- output = FusedLayerNormAffineFunction.apply(
- input, weight, self.bias, self.hidden_size, self.eps
- )
-
- return output
diff --git a/megatron/core/fusions/fused_softmax.py b/megatron/core/fusions/fused_softmax.py
deleted file mode 100644
index 56eb2e80111a1ff2eb0af95ea8cfb4c70a2ab9f0..0000000000000000000000000000000000000000
--- a/megatron/core/fusions/fused_softmax.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-
-import torch
-import torch.nn as nn
-
-from megatron.core.transformer.enums import AttnMaskType
-
-
-class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following three operations in sequence
- 1. Scale the tensor.
- 2. Apply upper triangular mask (typically used in gpt models).
- 3. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, scale):
- import scaled_upper_triang_masked_softmax_cuda
-
- scale_t = torch.tensor([scale])
- softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
-
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_upper_triang_masked_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
- input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
- output_grads, softmax_results, scale_t[0]
- )
-
- return input_grads, None
-
-
-class ScaledMaskedSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following three operations in sequence
- 1. Scale the tensor.
- 2. Apply the mask.
- 3. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, mask, scale):
- import scaled_masked_softmax_cuda
-
- scale_t = torch.tensor([scale])
-
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_masked_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
-
- input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
- return input_grads, None, None
-
-
-class ScaledSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following two operations in sequence
- 1. Scale the tensor.
- 2. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, scale):
- import scaled_softmax_cuda
-
- scale_t = torch.tensor([scale])
-
- softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
-
- input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
- return input_grads, None, None
-
-
-class FusedScaleMaskSoftmax(nn.Module):
- """
- fused operation: scaling + mask + softmax
-
- Arguments:
- input_in_fp16: flag to indicate if input in fp16 data format.
- input_in_bf16: flag to indicate if input in bf16 data format.
- attn_mask_type: attention mask type (pad or causal)
- scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
- mask_func: mask function to be applied.
- softmax_in_fp32: if true, softmax in performed at fp32 precision.
- scale: scaling factor used in input tensor scaling.
- """
-
- def __init__(
- self,
- input_in_fp16,
- input_in_bf16,
- attn_mask_type,
- scaled_masked_softmax_fusion,
- mask_func,
- softmax_in_fp32,
- scale,
- ):
- super(FusedScaleMaskSoftmax, self).__init__()
- self.input_in_fp16 = input_in_fp16
- self.input_in_bf16 = input_in_bf16
- assert not (
- self.input_in_fp16 and self.input_in_bf16
- ), "both fp16 and bf16 flags cannot be active at the same time."
- self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
- self.attn_mask_type = attn_mask_type
- self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
- self.mask_func = mask_func
- self.softmax_in_fp32 = softmax_in_fp32
- self.scale = scale
-
- assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
-
- def forward(self, input, mask):
- # [b, np, sq, sk]
- assert input.dim() == 4
-
- if self.is_kernel_available(mask, *input.size()):
- return self.forward_fused_softmax(input, mask)
- else:
- return self.forward_torch_softmax(input, mask)
-
- def is_kernel_available(self, mask, b, np, sq, sk):
- attn_batches = b * np
-
- if (
- self.scaled_masked_softmax_fusion # user want to fuse
- and self.input_in_float16 # input must be fp16
- and 16 < sk <= 4096 # sk must be 16 ~ 2048
- and sq % 4 == 0 # sq must be divisor of 4
- and sk % 4 == 0 # sk must be divisor of 4
- and attn_batches % 4 == 0 # np * b must be divisor of 4
- ):
- if 0 <= sk <= 4096:
- batch_per_block = self.get_batch_per_block(sq, sk, b, np)
-
- if self.attn_mask_type == AttnMaskType.causal:
- if attn_batches % batch_per_block == 0:
- return True
- else:
- if sq % batch_per_block == 0:
- return True
- return False
-
- def forward_fused_softmax(self, input, mask):
- b, np, sq, sk = input.size()
- scale = self.scale if self.scale is not None else 1.0
-
- if self.attn_mask_type == AttnMaskType.causal:
- assert sq == sk, "causal mask is only for self attention"
-
- # input is 3D tensor (attn_batches, sq, sk)
- input = input.view(-1, sq, sk)
- probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
- return probs.view(b, np, sq, sk)
- else:
- # input is 4D tensor (b, np, sq, sk)
- if mask is not None:
- return ScaledMaskedSoftmax.apply(input, mask, scale)
- else:
- return ScaledSoftmax.apply(input, scale)
-
- def forward_torch_softmax(self, input, mask):
- if self.input_in_float16 and self.softmax_in_fp32:
- input = input.float()
-
- if self.scale is not None:
- input = input * self.scale
- mask_output = self.mask_func(input, mask) if mask is not None else input
- probs = torch.nn.Softmax(dim=-1)(mask_output)
-
- if self.input_in_float16 and self.softmax_in_fp32:
- if self.input_in_fp16:
- probs = probs.half()
- else:
- probs = probs.bfloat16()
-
- return probs
-
- @staticmethod
- def get_batch_per_block(sq, sk, b, np):
- import scaled_masked_softmax_cuda
-
- return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py
deleted file mode 100644
index 287902460fab6d411781fb15c86f0a333b7cf245..0000000000000000000000000000000000000000
--- a/megatron/core/inference_params.py
+++ /dev/null
@@ -1,27 +0,0 @@
-class InferenceParams:
- """Inference parameters that are passed to the main model in order
- to efficienly calculate and store the context during inference."""
-
- def __init__(self, max_batch_size, max_sequence_length):
- self.max_sequence_length = max_sequence_length
- self.max_batch_size = max_batch_size
- self.sequence_len_offset = 0
- self.batch_size_offset = 0
- self.key_value_memory_dict = {}
-
- def swap_key_value_dict(self, batch_idx):
- "swap between batches"
- if len(self.key_value_memory_dict) == 0:
- raise ValueError("should not swap when dict in empty")
-
- for layer_number in self.key_value_memory_dict.keys():
- inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
- assert (
- len(batch_idx) == inference_key_memory.shape[1]
- ) # make sure batch size is the same
- new_inference_key_memory = inference_key_memory[:, batch_idx]
- new_inference_value_memory = inference_value_memory[:, batch_idx]
- self.key_value_memory_dict[layer_number] = (
- new_inference_key_memory,
- new_inference_value_memory,
- )
diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py
deleted file mode 100644
index 22d34da92129ce2193e5c297bc9ec8a6f3e555ff..0000000000000000000000000000000000000000
--- a/megatron/core/model_parallel_config.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from dataclasses import dataclass
-from typing import Callable, Optional
-
-import torch
-
-
-@dataclass
-class ModelParallelConfig:
- """Base configuration for Megatron Core
-
- Model Parallelism
- -----------------
-
- tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
-
- context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. Defaults to 1.
-
- pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
- ranks. Defaults to 1.
-
- virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
- reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
- The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
- Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
- more details. Defaults to None.
-
- sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
- parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
- Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
-
- expert_model_parallel_size (int): Distributes Moe Experts across sub data parallel dimension. Defaults to False.
-
- Initialization
- --------------
-
- perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you
- know you are going to load values from a checkpoint.
-
- use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU.
- Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
-
- Training
- --------
-
- fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
-
- bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
-
- params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32
-
- timers (optional, default=None): TODO
-
- Optimizations
- -------------
-
- gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
- extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
- --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
- ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
- Defaults to False.
-
- async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of
- tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False.
-
- tp_comm_overlap (bool, default=False): If true, allows overlapping of Linear layer execution with tensor parallel
- communication collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever possible
- during the forward and the backward pass. Defaults to False.
-
- tp_comm_split_ag (bool, default=True): If true, allows All-Gather overlap with Fprop GEMM. Don't care if tp_comm_overlap
- is False.
-
- tp_comm_split_rs (bool, default=True): If true, allows Reduce-Scatter overlap with Fprop GEMM. Don't care if
- tp_comm_overlap is False.
-
- tp_comm_bulk_dgrad (bool, default=True): If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't
- care if tp_comm_overlap is False.
-
- tp_comm_bulk_wgrad (bool, default=True): If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't
- care if tp_comm_overlap is False.
-
- Parallelism
- -----------
-
- finalize_model_grads_func (optional): Function that finalizes gradients on all workers. Could include ensuring that
- grads are all-reduced across data parallelism, pipeline parallelism, and sequence parallelism dimensions.
-
- Pipeline Parallelism
- --------------------
-
- pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
-
- grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the
- scaled loss. If None, no function is called on the loss.
-
- enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
-
- autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
-
- variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this
- communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
- should only be set if the sequence length varies by microbatch within a global batch.
-
- num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches
- where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
- of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
- None, the checkpoint and recompute will be left up to the forward_step function.
-
- overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline
- parallelism will overlap with computation. Must be False if batch_p2p_comm is true.
-
- batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
- if overlap_p2p_comm is True.
-
- batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
- around a bug in older version of PyTorch.
-
- use_ring_exchange_p2p (bool, default=False): Use custom ring_exchange kernel instead of
- torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
-
- deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent
- to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
-
- no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
- communication. If the model is an instance of core.distributed.DistributedDataParallel, the default is to use
- core.distributed.DistributedDataParallel.no_sync.
-
- grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
- gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
- to be synchronized.
-
- param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
- optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
- synchronized.
-
- pipeline_model_parallel_split_rank (int, default=None): If int, rank where encoder and decoder should be split in
- cases where the model has both an encoder and decoder (e.g., T5). Ignored if None.
-
- barrier_with_L1_time (bool, default=True): If true, use barrier with level 1 time measurements. It is up to the user
- to make sure calling barrier with their timers will not result in hangs. This can happen if for example the user
- adds a level 1 timer that is not called by all ranks.
-
- """
-
- # Model parallelism
- tensor_model_parallel_size: int = 1
- context_parallel_size: int = 1
- pipeline_model_parallel_size: int = 1
- virtual_pipeline_model_parallel_size: Optional[int] = None
- sequence_parallel: bool = False
- expert_model_parallel_size: int = 1
-
- # Initialization
- perform_initialization: bool = True
- use_cpu_initialization: bool = False
-
- # Training
- fp16: bool = False
- bf16: bool = False
- params_dtype: torch.dtype = torch.float32
- timers: Callable = None
-
- # Optimizations
- gradient_accumulation_fusion: bool = False
- async_tensor_model_parallel_allreduce: bool = False
- tp_comm_overlap: bool = False
-
- # Debug Options
- tp_comm_split_ag: bool = True
- tp_comm_split_rs: bool = True
- tp_comm_bulk_wgrad: bool = True
- tp_comm_bulk_dgrad: bool = True
-
- # Parallelism
- finalize_model_grads_func: Callable = None
-
- # Pipeline Parallel
- pipeline_dtype: torch.dtype = None
- grad_scale_func: Callable = None
- enable_autocast: bool = False
- autocast_dtype: torch.dtype = None
- variable_seq_lengths: bool = False
- num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
- overlap_p2p_comm: bool = False
- batch_p2p_comm: bool = True
- batch_p2p_sync: bool = True
- use_ring_exchange_p2p: bool = False
- deallocate_pipeline_outputs: bool = False
- no_sync_func: Callable = None
- grad_sync_func: Callable = None
- param_sync_func: Callable = None
- pipeline_model_parallel_split_rank: Optional[int] = None
-
- # Timing
- barrier_with_L1_time: bool = True
-
- def __post_init__(self):
- """ Python dataclass method that is used to modify attributes after initialization.
- See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
- """
- if self.sequence_parallel:
- if self.tensor_model_parallel_size <= 1:
- raise ValueError("Can not use sequence paralllelism without tensor parallelism")
- if self.async_tensor_model_parallel_allreduce:
- # sequence_parallelism already does this async
- self.async_tensor_model_parallel_allreduce = False
-
- if self.pipeline_model_parallel_size > 1:
- if self.pipeline_dtype is None:
- raise ValueError(
- "When using pipeline parallelism, pipeline_dtype must be specified"
- )
-
- if self.autocast_dtype is None:
- self.autocast_dtype = self.params_dtype
-
- if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
- if self.sequence_parallel is False:
- raise ValueError(
- "When using expert parallelism and tensor parallelism, sequence parallelism must be used"
- )
diff --git a/megatron/core/models/T5/__init__.py b/megatron/core/models/T5/__init__.py
deleted file mode 100644
index f65859a6dafcdfeb650f6b4a0da4fdecfe7f4dcf..0000000000000000000000000000000000000000
--- a/megatron/core/models/T5/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .t5_model import T5Model
diff --git a/megatron/core/models/T5/t5_model.py b/megatron/core/models/T5/t5_model.py
deleted file mode 100644
index f2ce4809f365ef523274147068cfb0d2815e3051..0000000000000000000000000000000000000000
--- a/megatron/core/models/T5/t5_model.py
+++ /dev/null
@@ -1,466 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-from typing import List, Literal, Optional
-
-import torch
-from torch import Tensor
-
-from megatron.core import InferenceParams, parallel_state, tensor_parallel
-from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
-from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
-from megatron.core.models.common.language_module.language_module import LanguageModule
-from megatron.core.transformer.enums import AttnMaskType, ModelType
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.transformer_block import TransformerBlock
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
-
-
-class T5LMHead(MegatronModule):
- """Masked LM head for T5
-
- Args:
- config (TransformerConfig): transformer config
- parallel_output (bool): wether output logits being distributed or not.
- vocab_size (int): vocabulary size
- pre_process (bool): Include embedding layer
- share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
- shared.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- parallel_output: bool,
- vocab_size: int,
- pre_process: bool = True,
- share_embeddings_and_output_weights: bool = False,
- ):
- super(T5LMHead, self).__init__(config=config)
-
- self.parallel_output = parallel_output
-
- self.output_layer = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- vocab_size,
- config=config,
- init_method=config.init_method,
- bias=share_embeddings_and_output_weights,
- skip_bias_add=not share_embeddings_and_output_weights,
- gather_output=not self.parallel_output,
- skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
- )
-
- def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- hidden_states (Tensor): output hidden states from decoder
- word_embeddings_weight (Tensor): word embedding weight
-
- Returns:
- Tensor: logits tensor
- """
-
- logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight)
- return logits
-
-
-class T5Model(LanguageModule):
- """T5 Language model.
-
- Args:
- config (TransformerConfig): transformer config
-
- transformer_encoder_layer_spec (ModuleSpec): transformer layer customization specs for encoder
-
- transformer_decoder_layer_spec (ModuleSpec): transformer layer customization specs for decoder
-
- vocab_size (int): vocabulary size
-
- max_sequence_length (int): maximum size of sequence. This is used for positional embedding
-
- pre_process (bool): Include embedding layer (used with pipeline parallelism)
- post_process (bool): Include an output layer (used with pipeline parallelism)
-
- fp16_lm_cross_entropy (bool, optional): Defaults to False
-
- parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
-
- share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are
- shared. Defaults to False.
-
- position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
- Defaults is 'learned_absolute'.
-
- rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
- Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
-
- seq_len_interpolation_factor (float): scale of linearly interpolating RoPE for longer sequences.
- The value must be a float larger than 1.0. Defaults to None.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- transformer_encoder_layer_spec: ModuleSpec,
- transformer_decoder_layer_spec: ModuleSpec,
- vocab_size: int,
- max_sequence_length: int,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
- rotary_percent: float = 1.0,
- seq_len_interpolation_factor: Optional[float] = None,
- ):
-
- super(T5Model, self).__init__(config=config)
-
- self.config: TransformerConfig = config
- self.transformer_encoder_layer_spec: ModuleSpec = transformer_encoder_layer_spec
- self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec
- self.vocab_size = vocab_size
- self.max_sequence_length = max_sequence_length
- self.pre_process = pre_process
- self.post_process = post_process
- self.add_encoder = True
- self.add_decoder = True
- self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
- self.parallel_output = parallel_output
- self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
- self.position_embedding_type = position_embedding_type
-
- # megatron core pipelining currently depends on model type
- self.model_type = ModelType.encoder_and_decoder
-
- # Embeddings.
- if self.pre_process:
- self.embedding = LanguageModelEmbedding(
- config=self.config,
- vocab_size=self.vocab_size,
- max_sequence_length=self.max_sequence_length,
- position_embedding_type=self.position_embedding_type,
- )
-
- # Rotary Position Embeddings
- if self.position_embedding_type == 'rope':
- self.rotary_pos_emb = RotaryEmbedding(
- self.config.kv_channels, rotary_percent, seq_len_interpolation_factor
- )
-
- # Transformer encoder
- encoder_spec, decoder_spec = (
- self.transformer_encoder_layer_spec,
- self.transformer_decoder_layer_spec,
- )
- self.encoder = TransformerBlock(
- config=self.config,
- spec=encoder_spec,
- pre_process=self.pre_process,
- post_process=self.post_process,
- )
- # Transformer decoder
- self.decoder = TransformerBlock(
- config=self.config,
- spec=decoder_spec,
- pre_process=self.pre_process,
- post_process=self.post_process,
- )
-
- # Output
- if post_process:
- self.lm_head = T5LMHead(
- config,
- parallel_output,
- self.vocab_size,
- self.pre_process,
- self.share_embeddings_and_output_weights,
- )
- self.output_layer = self.lm_head.output_layer
-
- if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
- self.initialize_last_stage_with_word_embeddings()
-
- def forward(
- self,
- encoder_input_ids: Tensor,
- decoder_input_ids: Tensor,
- encoder_attn_mask: Tensor,
- decoder_attn_mask: Tensor,
- encoder_decoder_attn_mask: Tensor,
- lm_labels: Tensor = None,
- inference_params: InferenceParams = None,
- ) -> Tensor:
- """Forward pass.
-
- Args:
- encoder_input_ids (Tensor): input ids for encoder
- decoder_input_ids (Tensor): input ids for decoder
- encoder_attn_mask (Tensor): self-attention mask for encoder
- decoder_attn_mask (Tensor): self-attention mask for decoder
- encoder_decoder_attn_mask (Tensor): cross-attention mask between encoder and decoder
- lm_labels (Tensor): labels for decoder output
- inference_params (InferenceParams): relevant arguments for inferencing
-
- Returns:
- Tensor: loss tensor
- """
-
- (
- encoder_attn_mask,
- decoder_attn_mask,
- encoder_decoder_attn_mask,
- ) = t5_extended_attention_mask(
- [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]
- )
- encoder_position_ids = t5_position_ids(encoder_input_ids)
- decoder_position_ids = t5_position_ids(decoder_input_ids)
-
- ## Encoder forward
- # Encoder embedding.
- if self.pre_process:
- encoder_input = self.embedding(
- input_ids=encoder_input_ids, position_ids=encoder_position_ids
- )
- else:
- # intermediate stage of pipeline
- encoder_input = None
-
- # Rotary positional embeddings
- rotary_pos_emb = None
- if self.position_embedding_type == 'rope':
- rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
- inference_params, self.encoder, encoder_input, self.config
- )
- rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
-
- # Run encoder.
- encoder_hidden_states = self.encoder(
- hidden_states=encoder_input,
- attention_mask=encoder_attn_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb,
- )
-
- ## Decoder forward
- # Decoder embedding.
- if self.pre_process:
- decoder_input = self.embedding(
- input_ids=decoder_input_ids, position_ids=decoder_position_ids
- )
- else:
- # intermediate stage of pipeline
- decoder_input = None ### should it take encoder_hidden_states
-
- # Rotary positional embeddings
- rotary_pos_emb = None
- if self.position_embedding_type == 'rope':
- rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
- inference_params, self.decoder, decoder_input, self.config
- )
- rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
-
- # Run decoder.
- decoder_hidden_states = self.decoder(
- hidden_states=decoder_input,
- attention_mask=decoder_attn_mask,
- context=encoder_hidden_states,
- context_mask=encoder_decoder_attn_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb,
- )
-
- # Return if not post_process
- if not self.post_process:
- return decoder_hidden_states
-
- # logits and loss
- output_weight = None
- if self.share_embeddings_and_output_weights:
- output_weight = self.shared_embedding_or_output_weight()
- logits = self.lm_head(decoder_hidden_states, word_embeddings_weight=output_weight)
-
- if lm_labels is None:
- # [s b h] => [b s h]
- return logits.transpose(0, 1).contiguous()
-
- loss = self.compute_language_model_loss(lm_labels, logits)
-
- return loss
-
- def set_input_tensor(self, input_tensor):
- """ See megatron.model.transformer.set_input_tensor()"""
-
- # This is usually handled in schedules.py but some inference code still
- # gives us non-lists or None
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
-
- if self.add_encoder and self.add_decoder:
- assert (
- len(input_tensor) == 1
- ), 'input_tensor should only be length 1 for stage with both encoder and decoder'
- self.encoder.set_input_tensor(input_tensor[0])
- elif self.add_encoder:
- assert (
- len(input_tensor) == 1
- ), 'input_tensor should only be length 1 for stage with only encoder'
- self.encoder.set_input_tensor(input_tensor[0])
- elif self.add_decoder:
- if len(input_tensor) == 2:
- self.decoder.set_input_tensor(input_tensor[0])
- self.encoder_hidden_state = input_tensor[1]
- elif len(input_tensor) == 1:
- self.decoder.set_input_tensor(None)
- self.encoder_hidden_state = input_tensor[0]
- else:
- raise Exception('input_tensor must have either length 1 or 2')
- else:
- raise Exception('Stage must have at least either encoder or decoder')
-
- def shared_embedding_or_output_weight(self) -> Tensor:
- """Function to share the input embeddings and output logit weights."""
-
- if self.pre_process:
- return self.embedding.word_embeddings.weight
- elif self.post_process:
- return self.lm_head.output_layer.weight
- return None
-
- def sharded_state_dict(self, prefix: str = ''):
- sharded_state_dict = {}
-
- if self.pre_process:
- embedding_prefix = f'{prefix}embedding.'
- embedding_sharded_state_dict = self.embedding.sharded_state_dict(
- prefix=embedding_prefix
- )
- sharded_state_dict.update(embedding_sharded_state_dict)
-
- encoder_prefix = f'{prefix}encoder.'
- encoder_sharded_state_dict = self.encoder.sharded_state_dict(prefix=encoder_prefix)
- sharded_state_dict.update(encoder_sharded_state_dict)
-
- decoder_prefix = f'{prefix}decoder.'
- decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix)
- sharded_state_dict.update(decoder_sharded_state_dict)
-
- if self.post_process:
- output_layer_prefix = f'{prefix}output_layer.'
- output_layer_weight_key = f'{output_layer_prefix}weight'
- output_layer_bias_key = f'{output_layer_prefix}bias'
- if self.share_embeddings_and_output_weights:
- if not self.pre_process:
- # when sharing embeddings with last stage, we need to use the weights from the first stage
- # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
- tensor = self.shared_embedding_or_output_weight()
- first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
- dp_rank = parallel_state.get_data_parallel_rank()
- dp_size = parallel_state.get_data_parallel_world_size()
- last_stage_word_emb_replica_id = (
- dp_rank + dp_size
- ) # copy of first stage embedding
-
- sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=tensor,
- key=first_stage_word_emb_key,
- replica_id=last_stage_word_emb_replica_id,
- allow_shape_mismatch=True,
- )
-
- sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
- # output_layer.weight is shared, but we still need to process output_layer.bias
- sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=self.lm_head.output_layer.bias,
- key=output_layer_bias_key,
- allow_shape_mismatch=True,
- )
- sharded_state_dict[output_layer_bias_key] = sharded_output_layer_tensor
- else:
- output_layer_state_dict = self.output_layer.state_dict(
- prefix=output_layer_prefix, keep_vars=True
- )
- output_layer_tensor = output_layer_state_dict[output_layer_weight_key]
- # independent output layer
- sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=output_layer_tensor,
- key=output_layer_weight_key,
- replica_id=parallel_state.get_data_parallel_rank(),
- allow_shape_mismatch=True,
- )
-
- sharded_state_dict[output_layer_weight_key] = sharded_output_layer_tensor
-
- return sharded_state_dict
-
- def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_["embedding"] = self.embedding.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars
- )
- state_dict_["encoder"] = self.encoder.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars
- )
- state_dict_["decoder"] = self.decoder.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars
- )
-
- if self.post_process and self.add_decoder:
- state_dict_["lm_head"] = self.lm_head.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars
- )
- # Save word_embeddings.
- if self.post_process and not self.pre_process and self.add_decoder:
- state_dict_["word_embeddings_for_head"] = self.embedding.state_dict(
- prefix=prefix, keep_vars=keep_vars
- )
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- self.embedding.load_state_dict(state_dict["embedding"], strict=strict)
-
- self.encoder.load_state_dict(state_dict["encoder"], strict=strict)
-
- self.decoder.load_state_dict(state_dict["decoder"], strict=strict)
-
- if self.post_process and self.add_decoder:
- self.lm_head.load_state_dict(state_dict["lm_head"], strict=strict)
-
- # Load word embeddings
- if self.post_process and not self.pre_process and self.add_decoder:
- self.word_embeddings.load_state_dict(
- state_dict["word_embeddings_for_head"], strict=strict
- )
-
-
-def t5_extended_attention_mask(attention_mask_list: List[Tensor]) -> List[Tensor]:
- def attn_mask_postprocess(attn_mask):
- # [b, 1, s, s]
- extended_attention_mask = attn_mask.unsqueeze(1)
- return extended_attention_mask
-
- return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
-
-
-def t5_position_ids(token_ids: Tensor) -> Tensor:
- """Calculate position ids from token ids
- Args:
- token_ids (Tensor): input tokens
-
- Returns:
- Tensor: position ids
- """
- seq_length = token_ids.size(1)
- position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
-
- return position_ids
diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py
deleted file mode 100644
index 60f33dbd9810d76ce014669b0f247eb8516017f9..0000000000000000000000000000000000000000
--- a/megatron/core/models/T5/t5_spec.py
+++ /dev/null
@@ -1,212 +0,0 @@
-from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
-from megatron.core.transformer.attention import (
- CrossAttention,
- CrossAttentionSubmodules,
- SelfAttention,
- SelfAttentionSubmodules,
-)
-from megatron.core.transformer.custom_layers.transformer_engine import (
- TEColumnParallelLinear,
- TEDotProductAttention,
- TELayerNormColumnParallelLinear,
- TENorm,
- TERowParallelLinear,
-)
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.mlp import MLP, MLPSubmodules
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.transformer_block import (
- TransformerBlockSubmodules,
- get_num_layers_to_build,
-)
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
-
-
-def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
- """T5 encoder TE spec (uses Transformer Engine components)."""
-
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.padding},
- submodules=SelfAttentionSubmodules(
- linear_qkv=TELayerNormColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
- """T5 decoder TE spec (uses Transformer Engine components)."""
-
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=TELayerNormColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_cross_attn_layernorm=TENorm,
- cross_attention=ModuleSpec(
- module=CrossAttention,
- submodules=CrossAttentionSubmodules(
- linear_q=TEColumnParallelLinear,
- linear_kv=TEColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- cross_attn_bda=get_bias_dropout_add,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-def encoder_model_with_local_spec() -> ModuleSpec:
- """T5 encoder local spec (uses Megatron-Core components)."""
-
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- input_layernorm=FusedLayerNorm,
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.padding},
- submodules=SelfAttentionSubmodules(
- linear_qkv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-def decoder_model_with_local_spec() -> ModuleSpec:
- """T5 decoder local spec (uses Megatron-Core components)."""
-
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- input_layernorm=FusedLayerNorm,
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_cross_attn_layernorm=FusedLayerNorm,
- cross_attention=ModuleSpec(
- module=CrossAttention,
- submodules=CrossAttentionSubmodules(
- linear_q=ColumnParallelLinear,
- linear_kv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- cross_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-def get_t5_encoder_with_transformer_engine_block_spec(
- num_layers: int,
-) -> TransformerBlockSubmodules:
- """T5 encoder block spec for Transformer Engine
-
- Args:
- config (TransformerConfig): config, containing number of layers for encoder
- """
-
- layer_spec = encoder_model_with_transformer_engine_default_spec()
- block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
- return block_spec
-
-
-def get_t5_decoder_with_transformer_engine_block_spec(
- num_layers: int,
-) -> TransformerBlockSubmodules:
- """T5 decoder block spec for Transformer Engine
-
- Args:
- config (TransformerConfig): config, containing number of layers for decoder
- """
-
- layer_spec = decoder_model_with_transformer_engine_default_spec()
- block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
- return block_spec
-
-
-def get_t5_encoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
- """T5 encoder block spec for local (uses Megatron-Core components)
-
- Args:
- num_layers (int): number of encoder layers
- """
-
- layer_spec = encoder_model_with_local_spec()
- block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
- return block_spec
-
-
-def get_t5_decoder_with_local_block_spec(num_layers: int) -> TransformerBlockSubmodules:
- """T5 decoder block spec for local (uses Megatron-Core components)
-
- Args:
- num_layers (int): number of decoder layers
- """
-
- layer_spec = decoder_model_with_local_spec()
- block_spec = TransformerBlockSubmodules([layer_spec] * num_layers)
- return block_spec
diff --git a/megatron/core/models/__init__.py b/megatron/core/models/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/models/bert/__init__.py b/megatron/core/models/bert/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py
deleted file mode 100644
index 9c36711fdd38cc1dced5764a4dca9f4f6def309d..0000000000000000000000000000000000000000
--- a/megatron/core/models/bert/bert_layer_specs.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
-from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
-from megatron.core.transformer.custom_layers.transformer_engine import (
- TEDotProductAttention,
- TELayerNormColumnParallelLinear,
- TERowParallelLinear,
-)
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.mlp import MLP, MLPSubmodules
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
-
-# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
-bert_layer_with_transformer_engine_spec = ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.padding},
- submodules=SelfAttentionSubmodules(
- linear_qkv=TELayerNormColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
-)
-
-# Use this spec for an implementation using only modules in megatron core
-bert_layer_local_spec = ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- input_layernorm=FusedLayerNorm,
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.padding},
- submodules=SelfAttentionSubmodules(
- linear_qkv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
-)
diff --git a/megatron/core/models/bert/bert_lm_head.py b/megatron/core/models/bert/bert_lm_head.py
deleted file mode 100644
index ea6f8f122604a30eb82f77bcca46dcc8e7e3d858..0000000000000000000000000000000000000000
--- a/megatron/core/models/bert/bert_lm_head.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import torch
-from torch import Tensor
-
-from megatron.core import tensor_parallel
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import erf_gelu, get_linear_layer, openai_gelu
-from megatron.model import LayerNorm
-
-
-class BertLMHead(MegatronModule):
- """Masked LM head for Bert
-
- Args:
- hidden_size: hidden size
- config (TransformerConfig): TransformerConfig object
- parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
- vocab_size(int): The vocabulary size
- share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False
- pre_process (bool): Include embedding layer (used with pipeline parallelism)
- """
-
- def __init__(
- self,
- hidden_size: int,
- config: TransformerConfig,
- parallel_output: bool,
- vocab_size: int,
- pre_process: bool,
- share_embeddings_and_output_weights: bool = False,
- ):
- super().__init__(config=config)
-
- self.vocab_size = vocab_size
- self.parallel_output = parallel_output
-
- # TODO: Shoudl switch this to TE ?
- self.dense = get_linear_layer(
- hidden_size, hidden_size, config.init_method, config.perform_initialization
- )
-
- setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
- setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
-
- self.layernorm = LayerNorm(
- hidden_size, eps=config.layernorm_epsilon, sequence_parallel=config.sequence_parallel
- )
-
- self.gelu = torch.nn.functional.gelu
- # TODO Use activation_func in config to determine what to use
- # if config.openai_gelu: # Dont have these configs in transfomer config yet
- # self.gelu = openai_gelu
- # elif config.onnx_safe: # Dont have these configs in transfomer config yet
- # self.gelu = erf_gelu
-
- self.output_layer = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- self.vocab_size,
- config=config,
- init_method=config.init_method,
- bias=True,
- skip_bias_add=False,
- gather_output=not self.parallel_output,
- skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
- )
-
- def forward(self, hidden_states: Tensor, word_embeddings_weight: Tensor) -> Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.gelu(hidden_states)
- hidden_states = self.layernorm(hidden_states)
- logits, _ = self.output_layer(hidden_states, weight=word_embeddings_weight)
- return logits
diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py
deleted file mode 100644
index 165c1b39028b08d93f447412be01c241a044bb89..0000000000000000000000000000000000000000
--- a/megatron/core/models/bert/bert_model.py
+++ /dev/null
@@ -1,234 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-from typing import Literal, Optional
-
-import torch
-from torch import Tensor
-
-from megatron.core.models.bert.bert_lm_head import BertLMHead
-from megatron.core.models.bert.pooler import Pooler
-from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
-from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
-from megatron.core.models.common.language_module.language_module import LanguageModule
-from megatron.core.transformer.enums import AttnMaskType, ModelType
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.transformer_block import TransformerBlock
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import get_linear_layer
-from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
-
-
-class BertModel(LanguageModule):
- """Transformer language model.
-
- Args:
- config (TransformerConfig): transformer config
- num_tokentypes (int) : Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
- transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
- vocab_size (int): vocabulary size
- max_sequence_length (int): maximum size of sequence. This is used for positional embedding
- pre_process (bool): Include embedding layer (used with pipeline parallelism)
- post_process (bool): Include an output layer (used with pipeline parallelism)
- parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
- share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
- position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
- Defaults is 'learned_absolute'.
- rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
- Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- num_tokentypes: int,
- transformer_layer_spec: ModuleSpec,
- vocab_size: int,
- max_sequence_length: int,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
- rotary_percent: float = 1.0,
- seq_len_interpolation_factor: Optional[float] = None,
- add_binary_head=True,
- return_embeddings=False,
- ):
- super(BertModel, self).__init__(config=config)
-
- if return_embeddings:
- assert self.post_process and self.add_binary_head
-
- self.config: TransformerConfig = config
- self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
- self.vocab_size = vocab_size
- self.max_sequence_length = max_sequence_length
- self.pre_process = pre_process
- self.post_process = post_process
- self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
- self.parallel_output = parallel_output
- self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
- self.position_embedding_type = position_embedding_type
- self.add_binary_head = add_binary_head
- self.return_embeddings = return_embeddings
-
- # megatron core pipelining currently depends on model type
- self.model_type = ModelType.encoder_or_decoder
-
- # Embeddings.
- if self.pre_process:
- self.embedding = LanguageModelEmbedding(
- config=self.config,
- vocab_size=self.vocab_size,
- max_sequence_length=self.max_sequence_length,
- position_embedding_type=position_embedding_type,
- num_tokentypes=num_tokentypes,
- )
-
- if self.position_embedding_type == 'rope':
- self.rotary_pos_emb = RotaryEmbedding(
- self.config.kv_channels, rotary_percent, seq_len_interpolation_factor
- )
-
- # Transformer.
- self.encoder = TransformerBlock(
- config=self.config,
- spec=self.transformer_layer_spec,
- pre_process=self.pre_process,
- post_process=self.post_process,
- )
-
- # Output
- if post_process:
- # TODO: Make sure you are passing in the mpu_vocab_size properly
- self.lm_head = BertLMHead(
- config.hidden_size,
- config,
- parallel_output,
- self.vocab_size,
- self.pre_process,
- self.share_embeddings_and_output_weights,
- )
-
- self.output_layer = self.lm_head.output_layer
-
- self.binary_head = None
- if self.add_binary_head:
- # TODO: Shoudl switch this to TE ?
- self.binary_head = get_linear_layer(
- config.hidden_size, 2, config.init_method, config.perform_initialization
- )
-
- self.pooler = Pooler(
- config.hidden_size, config.init_method, config, config.sequence_parallel
- )
-
- if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
- self.initialize_last_stage_with_word_embeddings()
-
- def set_input_tensor(self, input_tensor: Tensor) -> None:
- """Sets input tensor to the model.
-
- See megatron.model.transformer.set_input_tensor()
-
- Args:
- input_tensor (Tensor): Sets the input tensor for the model.
- """
- # This is usually handled in schedules.py but some inference code still
- # gives us non-lists or None
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
-
- assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
- self.encoder.set_input_tensor(input_tensor[0])
-
- def forward(
- self,
- input_ids: Tensor,
- attention_mask: Tensor,
- tokentype_ids: Tensor = None,
- lm_labels: Tensor = None,
- inference_params=None,
- ):
- """Forward function of BERT model
-
- Forward function of the BERT Model This function passes the input tensors
- through the embedding layer, and then the encoder and finally into the post
- processing layer (optional).
-
- It either returns the Loss values if labels are given or the final hidden units
- """
- extended_attention_mask = bert_extended_attention_mask(attention_mask)
-
- position_ids = bert_position_ids(input_ids)
-
- # Encoder embedding.
- if self.pre_process:
- encoder_input = self.embedding(
- input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids
- )
- else:
- # intermediate stage of pipeline
- # decoder will get hidden_states from encoder.input_tensor
- encoder_input = None
-
- # Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?)
- rotary_pos_emb = None
- if self.position_embedding_type == 'rope':
- rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
- inference_params, self.encoder, encoder_input, self.config
- )
- rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
-
- # Run decoder.
- hidden_states = self.encoder(
- hidden_states=encoder_input,
- attention_mask=extended_attention_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb,
- )
- if not self.post_process:
- return hidden_states
-
- if self.add_binary_head:
- pooled_output = self.pooler(hidden_states, 0)
-
- if self.return_embeddings:
- embeddings = torch.transpose(hidden_states, 0, 1)
- masks = torch.sum(attention_mask, dim=1)
- # Collect masked embeddings.
- output = torch.zeros(
- size=(embeddings.shape[0], embeddings.shape[2]),
- dtype=torch.float32,
- device=torch.cuda.current_device(),
- )
- for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
- output[i, :] = torch.mean(embedding[1 : mask - 1], dim=0)
- return output
-
- # logits and loss
- output_weight = None
- if self.share_embeddings_and_output_weights:
- output_weight = self.shared_embedding_or_output_weight()
-
- logits = self.lm_head(hidden_states=hidden_states, word_embeddings_weight=output_weight)
-
- binary_logits = None
- if self.binary_head is not None:
- binary_logits = self.binary_head(pooled_output)
-
- if lm_labels is None:
- # [s b h] => [b s h]
- return logits.transpose(0, 1).contiguous(), binary_logits
-
- loss = self.compute_language_model_loss(lm_labels, logits)
-
- return loss, binary_logits
-
- # TODO: add distributed checkpointing
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- pass
-
- # TODO: add distributed checkpointing
- def load_state_dict(self, state_dict, strict=True):
- pass
diff --git a/megatron/core/models/bert/pooler.py b/megatron/core/models/bert/pooler.py
deleted file mode 100644
index c144d8c9c4daba6f1e43d1fe5bc78c9365e1c90e..0000000000000000000000000000000000000000
--- a/megatron/core/models/bert/pooler.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-from torch import Tensor
-
-from megatron.core import tensor_parallel
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import get_linear_layer
-
-
-class Pooler(MegatronModule):
- """Pooler layer.
-
- Pool hidden states of a specific token (for example start of the
- sequence) and add a linear transformation followed by a tanh.
-
- Args:
- hidden_size (int): The hidden size_
- init_method (callable): weight initialization method for the linear layer. bias is set to zero.
- config (TransformerConfig): The transformer configuration
- sequence_parallel (bool): Using squence parallel ? Defaults to False
- """
-
- def __init__(
- self,
- hidden_size: int,
- init_method: callable,
- config: TransformerConfig,
- sequence_parallel: bool = False,
- ):
- super(Pooler, self).__init__(config)
- # TODO: Shoudl switch this to TE ?
- self.dense = get_linear_layer(
- hidden_size, hidden_size, init_method, config.perform_initialization
- )
- self.sequence_parallel = sequence_parallel
-
- def forward(self, hidden_states: Tensor, sequence_index=0):
- # hidden_states: [s, b, h]
- # sequence_index: index of the token to pool.
-
- # gather data along sequence dimensions
- # same pooler is run on all tensor parallel nodes
- if self.sequence_parallel:
- hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
- hidden_states, tensor_parallel_output_grad=False
- )
-
- pooled = hidden_states[sequence_index, :, :]
- pooled = self.dense(pooled)
- pooled = torch.tanh(pooled)
- return pooled
diff --git a/megatron/core/models/common/__init__.py b/megatron/core/models/common/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/models/common/embeddings/__init__.py b/megatron/core/models/common/embeddings/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/models/common/embeddings/language_model_embedding.py b/megatron/core/models/common/embeddings/language_model_embedding.py
deleted file mode 100644
index 40d679d7b13dd44797957909c4c4fd2272e1f932..0000000000000000000000000000000000000000
--- a/megatron/core/models/common/embeddings/language_model_embedding.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from typing import Literal, Optional
-
-import torch
-from torch import Tensor
-
-from megatron.core import tensor_parallel
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import (
- make_sharded_tensor_for_checkpoint,
- make_tp_sharded_tensor_for_checkpoint,
-)
-
-
-class LanguageModelEmbedding(MegatronModule):
- """Language model embeddings.
-
- Arguments:
- config (TransformerConfig): config object with all necessary configs for TransformerBlock
- vocab_size (int): vocabulary size
- max_sequence_length (int): maximum size of sequence. This
- is used for positional embedding
- add_position_embedding (bool): Add a position embedding.
- embedding_dropout_prob (float): dropout probability for embeddings
- num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head . Defaults to 0.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- vocab_size: int,
- max_sequence_length: int,
- position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
- num_tokentypes: int = 0,
- ):
- super().__init__(config=config)
-
- self.config: TransformerConfig = config
- self.vocab_size: int = vocab_size
- self.max_sequence_length: int = max_sequence_length
- self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
- self.num_tokentypes = num_tokentypes
-
- # Word embeddings (parallel).
- self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
- num_embeddings=self.vocab_size,
- embedding_dim=self.config.hidden_size,
- init_method=self.config.init_method,
- config=self.config,
- )
-
- # Position embedding (serial).
- if self.add_position_embedding:
- self.position_embeddings = torch.nn.Embedding(
- self.max_sequence_length, self.config.hidden_size
- )
-
- # Initialize the position embeddings.
- if self.config.perform_initialization:
- self.config.init_method(self.position_embeddings.weight)
-
- if self.num_tokentypes > 0:
- self.tokentype_embeddings = torch.nn.Embedding(
- self.num_tokentypes, self.config.hidden_size
- )
- # Initialize the token-type embeddings.
- if self.config.perform_initialization:
- self.config.init_method(self.tokentype_embeddings.weight)
- else:
- self.tokentype_embeddings = None
-
- # Embeddings dropout
- self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
-
- def zero_parameters(self):
- """Zero out all parameters in embedding."""
- self.word_embeddings.weight.data.fill_(0)
- self.word_embeddings.weight.shared = True
- self.position_embeddings.weight.data.fill_(0)
- self.position_embeddings.weight.shared = True
- if self.num_tokentypes > 0:
- self.tokentype_embeddings.weight.data.fill_(0)
- self.tokentype_embeddings.weight.shared = True
-
- def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
- """Forward pass of the embedding module
- Args:
- input_ids (Tensor): The input tokens
- position_ids (Tensor): The position id's used to calculate position embeddings
- tokentype_ids (int): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
-
- Returns:
- Tensor: The output embeddings
- """
- word_embeddings = self.word_embeddings(input_ids)
- if self.add_position_embedding:
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = word_embeddings + position_embeddings
- else:
- embeddings = word_embeddings
-
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
- embeddings = embeddings.transpose(0, 1).contiguous()
-
- if tokentype_ids is not None:
- assert self.tokentype_embeddings is not None
- # [b s h] -> [s b h] (So that it can be added with embeddings)
- tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
- embeddings = embeddings + tokentype_embedding
- else:
- assert self.tokentype_embeddings is None
-
- # If the input flag for fp32 residual connection is set, convert for float.
- if self.config.fp32_residual_connection:
- embeddings = embeddings.float()
-
- # Dropout.
- if self.config.sequence_parallel:
- embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
- # `scatter_to_sequence_parallel_region` returns a view, which prevents
- # the original tensor from being garbage collected. Clone to facilitate GC.
- # Has a small runtime cost (~0.5%).
- if self.config.clone_scatter_output_in_embedding:
- embeddings = embeddings.clone()
- with tensor_parallel.get_cuda_rng_tracker().fork():
- embeddings = self.embedding_dropout(embeddings)
- else:
- embeddings = self.embedding_dropout(embeddings)
-
- return embeddings
-
- def sharded_state_dict(self, prefix=''):
-
- sharded_state_dict = {}
-
- word_embeddings_prefix = f'{prefix}word_embeddings.'
- word_embeddings_state_dict = self.word_embeddings.state_dict(
- prefix=word_embeddings_prefix, keep_vars=True
- )
-
- sharded_word_embeddings_key = f'{word_embeddings_prefix}weight'
- sharded_word_embeddings_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=word_embeddings_state_dict[sharded_word_embeddings_key],
- key=sharded_word_embeddings_key,
- allow_shape_mismatch=True,
- )
- sharded_state_dict[sharded_word_embeddings_key] = sharded_word_embeddings_tensor
-
- if self.add_position_embedding:
- position_embeddings_prefix = f'{prefix}position_embeddings.'
- position_embeddings_state_dict = self.position_embeddings.state_dict(
- prefix=position_embeddings_prefix, keep_vars=True
- )
- sharded_position_embeddings_key = f'{position_embeddings_prefix}weight'
- sharded_position_embeddings_tensor = make_sharded_tensor_for_checkpoint(
- tensor=position_embeddings_state_dict[sharded_position_embeddings_key],
- key=sharded_position_embeddings_key,
- )
- sharded_state_dict[sharded_position_embeddings_key] = sharded_position_embeddings_tensor
-
- return sharded_state_dict
diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py
deleted file mode 100644
index ee2260e3ae001c756d3011bb9fe03ece338716be..0000000000000000000000000000000000000000
--- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from megatron.core.transformer.transformer_config import TransformerConfig
- from megatron.core.transformer.transformer_block import TransformerBlock
-
-import torch
-from torch import Tensor, nn
-
-from megatron.core import parallel_state
-
-__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
-
-
-def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
- cp_size = parallel_state.get_context_parallel_world_size()
- cp_rank = parallel_state.get_context_parallel_rank()
- cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=pos_emb.device)
- pos_emb = pos_emb.view(
- *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
- )
- pos_emb = pos_emb.index_select(seq_dim, cp_idx)
- pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
- return pos_emb
-
-
-class RotaryEmbedding(nn.Module):
- """Rotary Embedding for language model.
-
- Args:
- kv_channels (int): Projection weights dimension in multi-head attention. Obtained from transformer config
- rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
- seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None
- rotary_base (int, optional): Base period for rotary position embeddings. Defaults to 10000.
- """
-
- def __init__(
- self,
- kv_channels: int,
- rotary_percent: float,
- seq_len_interpolation_factor: float = None,
- rotary_base: int = 10000,
- ) -> None:
- super().__init__()
-
- dim = kv_channels
- if rotary_percent < 1.0:
- dim = int(dim * rotary_percent)
-
- self.seq_len_interpolation_factor = seq_len_interpolation_factor
- self.inv_freq = 1.0 / (
- rotary_base
- ** (
- torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
- / dim
- )
- )
-
- def forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
- """Forward pass of RoPE embedding.
-
- Args:
- max_seq_len (int): Maximum size of sequence
- offset (int, optional): _description_. Defaults to 0.
-
- Returns:
- Tensor: Embeddings after applying RoPE.
- """
- seq = (
- torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
- + offset
- )
-
- if self.seq_len_interpolation_factor is not None:
- seq *= 1 / self.seq_len_interpolation_factor
-
- freqs = torch.outer(seq, self.inv_freq)
- # first part even vector components, second part odd vector components,
- # 2 * dim in dimension size
- emb = torch.cat((freqs, freqs), dim=-1)
- # emb [seq_length, .., dim]
- emb = emb[:, None, None, :]
- if parallel_state.get_context_parallel_world_size() > 1:
- # slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank
- emb = get_pos_emb_on_this_cp_rank(emb, 0)
- return emb
-
- def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
- state_dict.pop(f'{prefix}inv_freq', None)
- return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
-
- def get_rotary_seq_len(
- self,
- inference_params,
- transformer: TransformerBlock,
- transformer_input: Tensor,
- transformer_config: TransformerConfig,
- ) -> float:
- """Function to get the rotary sequence length.
-
- Args:
- inference_params : Used during Inference time
- transformer (TransformerBlock): The transformer block (decoder/encoder) used by the model
- transformer_input (Tensor): _description_
- transformer_config (TransformerConfig): Transformer config used by the model
-
- Returns:
- float: The rotary sequence length
- """
- if inference_params is not None:
- rotary_seq_len = inference_params.max_sequence_length
- else:
- if transformer.input_tensor is not None:
- rotary_seq_len = transformer.input_tensor.size(0)
- else:
- rotary_seq_len = transformer_input.size(0)
-
- if transformer_config.sequence_parallel:
- rotary_seq_len *= transformer_config.tensor_model_parallel_size
-
- rotary_seq_len *= transformer_config.context_parallel_size
-
- return rotary_seq_len
-
-
-def _rotate_half(x: Tensor) -> Tensor:
- """Change sign so the last dimension becomes [-odd, +even]
-
- Args:
- x (Tensor): Input tensor
-
- Returns:
- Tensor: Tensor rotated half
- """
-
- x1, x2 = torch.chunk(x, 2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
- """Apply rotary positional embedding to input tensor T.
-
- check https://kexue.fm/archives/8265 for detailed formulas
-
- Args:
- t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
- freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
-
- Returns:
- Tensor: The input tensor after applying RoPE
- """
- rot_dim = freqs.shape[-1]
-
- # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
- t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
-
- # first part is cosine component
- # second part is sine component, need to change signs with _rotate_half method
- cos_ = torch.cos(freqs).to(t.dtype)
- sin_ = torch.sin(freqs).to(t.dtype)
-
- t = (t * cos_) + (_rotate_half(t) * sin_)
- return torch.cat((t, t_pass), dim=-1)
diff --git a/megatron/core/models/common/language_module/__init__.py b/megatron/core/models/common/language_module/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py
deleted file mode 100644
index 97fbbf0f66229eb2ef4ecea52940bfef86342ad8..0000000000000000000000000000000000000000
--- a/megatron/core/models/common/language_module/language_module.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import logging
-
-import torch
-from torch import Tensor
-
-from megatron.core import parallel_state, tensor_parallel
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-
-
-class LanguageModule(MegatronModule):
- """Base language module that has common helper functions used across GPT, BERT etc.
-
- Args:
- config (TransformerConfig): Input transformer config for the model
- """
-
- def __init__(self, config: TransformerConfig) -> None:
- super().__init__(config=config)
-
- def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
- """Computes the language model loss (Cross entropy across vocabulary)
-
- Args:
- labels (Tensor): The labels of dimension [batch size, seq length]
- logits (Tensor): The final logits returned by the output layer of the transformer model
-
- Returns:
- Tensor: Loss tensor of dimensions [batch size, sequence_length]
- """
- # [b s] => [s b]
- labels = labels.transpose(0, 1).contiguous()
- loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels)
-
- # [s b] => [b, s]
- loss = loss.transpose(0, 1).contiguous()
- return loss
-
- def initialize_last_stage_with_word_embeddings(self) -> None:
- """Intializes the word embeddings in the final stage.
-
- This function just initalizes word embeddings in the final stage, when we are
- using pipeline parallelism and sharind word embeddings. Nothing to do if we
- arn't sharing weights or aren't using Pipeline parallelism
- """
- if not self.share_embeddings_and_output_weights or (self.pre_process and self.post_process):
- return
-
- if self.post_process and not self.pre_process:
- assert not parallel_state.is_pipeline_first_stage()
- # set word_embeddings weights to 0 here, then copy first
- # stage's weights using all_reduce below.
- self.output_layer.weight.data.fill_(0)
- self.output_layer.weight.shared = True
-
- # Parameters are shared between the word embeddings layers, and the
- # heads at the end of the model. In a pipelined setup with more than
- # one stage, the initial embedding layer and the head are on different
- # workers, so we do the following:
- # 1. Create a second copy of word_embeddings on the last stage, with
- # initial parameters of 0.0.
- # 2. Do an all-reduce between the first and last stage to ensure that
- # the two copies of word_embeddings start off with the same
- # parameter values.
- # 3. In the training loop, before an all-reduce between the grads of
- # the two word_embeddings layers to ensure that every applied weight
- # update is the same on both stages.
-
- # Ensure that first and last stages have the same initial parameter
- # values.
- if torch.distributed.is_initialized():
- if parallel_state.is_rank_in_embedding_group():
- weight = self.shared_embedding_or_output_weight()
- torch.distributed.all_reduce(
- weight.data, group=parallel_state.get_embedding_group()
- )
-
- elif not getattr(LanguageModule, "embedding_warning_printed", False):
- logging.getLogger(__name__).warning(
- "Distributed processes aren't initialized, so the output layer "
- "is not initialized with weights from the word embeddings. "
- "If you are just manipulating a model this is fine, but "
- "this needs to be handled manually. If you are training "
- "something is definitely wrong."
- )
- LanguageModule.embedding_warning_printed = True
-
- def shared_embedding_or_output_weight(self) -> Tensor:
- """Gets the emedding weight or output logit weights when share embedding and output weights set to True.
-
- Returns:
- Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
- """
- if self.pre_process:
- return self.embedding.word_embeddings.weight
- elif self.post_process:
- return self.output_layer.weight
- return None
diff --git a/megatron/core/models/gpt/__init__.py b/megatron/core/models/gpt/__init__.py
deleted file mode 100644
index 2d5eb8674f1d19673664160d5eddf3432a6a5399..0000000000000000000000000000000000000000
--- a/megatron/core/models/gpt/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .gpt_model import GPTModel
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
deleted file mode 100644
index aace1590d82456dc3b9d32d783bd01aad6514dd2..0000000000000000000000000000000000000000
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
-from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
-from megatron.core.transformer.custom_layers.transformer_engine import (
- TEDotProductAttention,
- TELayerNormColumnParallelLinear,
- TERowParallelLinear,
-)
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.mlp import MLP, MLPSubmodules
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.switch_mlp import SwitchMLP
-from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
-
-
-# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
-def get_gpt_layer_with_transformer_engine_spec() -> ModuleSpec:
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=TELayerNormColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-# Use this spec for an implementation using only modules in megatron core
-def get_gpt_layer_local_spec() -> ModuleSpec:
- return ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- input_layernorm=FusedLayerNorm,
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
- )
-
-
-# Use this spec to use lower level Transformer Engine modules and SwitchMLP based MoE
-gpt_layer_with_transformer_engine_spec_moe = ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=TELayerNormColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=SwitchMLP, # MOE
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
-)
-
-# Use this spec for an implementation using only modules in megatron core for MoE models
-gpt_layer_local_spec_moe = ModuleSpec(
- module=TransformerLayer,
- submodules=TransformerLayerSubmodules(
- input_layernorm=FusedLayerNorm,
- self_attention=ModuleSpec(
- module=SelfAttention,
- params={"attn_mask_type": AttnMaskType.causal},
- submodules=SelfAttentionSubmodules(
- linear_qkv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- ),
- self_attn_bda=get_bias_dropout_add,
- pre_mlp_layernorm=FusedLayerNorm,
- mlp=ModuleSpec(
- module=SwitchMLP, # MOE
- submodules=MLPSubmodules(
- linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,
- ),
- ),
- mlp_bda=get_bias_dropout_add,
- ),
-)
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
deleted file mode 100644
index 2cf26bacacd21ba26bd8df0eefb2d52a0f53ea95..0000000000000000000000000000000000000000
--- a/megatron/core/models/gpt/gpt_model.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import logging
-from typing import Literal, Optional, Union
-
-import torch
-from torch import Tensor
-
-from megatron.core import InferenceParams, parallel_state, tensor_parallel
-from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
-from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
-from megatron.core.models.common.language_module.language_module import LanguageModule
-from megatron.core.transformer.enums import AttnMaskType, ModelType
-from megatron.core.transformer.spec_utils import ModuleSpec
-from megatron.core.transformer.transformer_block import TransformerBlock
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
-
-
-class GPTModel(LanguageModule):
- """GPT Transformer language model.
-
- Args:
- config (TransformerConfig): Transformer config
- transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
- vocab_size (int): Vocabulary size
- max_sequence_length (int): maximum size of sequence. This is used for positional embedding
- pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True.
- post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True.
- fp16_lm_cross_entropy (bool, optional): Defaults to False.
- parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True.
- share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False.
- position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'.
- rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
- rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000.
- seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- transformer_layer_spec: ModuleSpec,
- vocab_size: int,
- max_sequence_length: int,
- pre_process: bool = True,
- post_process: bool = True,
- fp16_lm_cross_entropy: bool = False,
- parallel_output: bool = True,
- share_embeddings_and_output_weights: bool = False,
- position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
- rotary_percent: float = 1.0,
- rotary_base: int = 10000,
- seq_len_interpolation_factor: Optional[float] = None,
- ) -> None:
- super().__init__(config=config)
-
- self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
- self.vocab_size = vocab_size
- self.max_sequence_length = max_sequence_length
- self.pre_process = pre_process
- self.post_process = post_process
- self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
- self.parallel_output = parallel_output
- self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
- self.position_embedding_type = position_embedding_type
-
- # megatron core pipelining currently depends on model type
- # TODO: remove this dependency ?
- self.model_type = ModelType.encoder_or_decoder
-
- if self.pre_process:
- self.embedding = LanguageModelEmbedding(
- config=self.config,
- vocab_size=self.vocab_size,
- max_sequence_length=self.max_sequence_length,
- position_embedding_type=position_embedding_type,
- )
-
- if self.position_embedding_type == 'rope':
- self.rotary_pos_emb = RotaryEmbedding(
- kv_channels=self.config.kv_channels,
- rotary_percent=rotary_percent,
- seq_len_interpolation_factor=seq_len_interpolation_factor,
- rotary_base=rotary_base,
- )
-
- # Transformer.
- self.decoder = TransformerBlock(
- config=self.config,
- spec=transformer_layer_spec,
- pre_process=self.pre_process,
- post_process=self.post_process,
- )
-
- # Output
- if post_process:
- self.output_layer = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- self.vocab_size,
- config=config,
- init_method=config.init_method,
- bias=False,
- skip_bias_add=False,
- gather_output=not self.parallel_output,
- skip_weight_param_allocation=self.pre_process
- and self.share_embeddings_and_output_weights,
- )
-
- if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
- self.initialize_last_stage_with_word_embeddings()
-
- def set_input_tensor(self, input_tensor: Tensor) -> None:
- """Sets input tensor to the model.
-
- See megatron.model.transformer.set_input_tensor()
-
- Args:
- input_tensor (Tensor): Sets the input tensor for the model.
- """
- # This is usually handled in schedules.py but some inference code still
- # gives us non-lists or None
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
-
- assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
- self.decoder.set_input_tensor(input_tensor[0])
-
- def forward(
- self,
- input_ids: Tensor,
- position_ids: Tensor,
- attention_mask: Tensor,
- decoder_input: Tensor = None,
- labels: Tensor = None,
- inference_params: InferenceParams = None,
- extra_block_kwargs: dict = None,
- ) -> Tensor:
- """Forward function of the GPT Model This function passes the input tensors
- through the embedding layer, and then the decoeder and finally into the post
- processing layer (optional).
-
- It either returns the Loss values if labels are given or the final hidden units
- """
- # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
- # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
-
- # Decoder embedding.
- if decoder_input is not None:
- pass
- elif self.pre_process:
- decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
- else:
- # intermediate stage of pipeline
- # decoder will get hidden_states from encoder.input_tensor
- decoder_input = None
-
- # Rotary positional embeddings (embedding is None for PP intermediate devices)
- rotary_pos_emb = None
- if self.position_embedding_type == 'rope':
- rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
- inference_params, self.decoder, decoder_input, self.config
- )
- rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
-
- # Run decoder.
- hidden_states = self.decoder(
- hidden_states=decoder_input,
- attention_mask=attention_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb,
- **(extra_block_kwargs or {}),
- )
-
- if not self.post_process:
- return hidden_states
-
- # logits and loss
- output_weight = None
- if self.share_embeddings_and_output_weights:
- output_weight = self.shared_embedding_or_output_weight()
- logits, _ = self.output_layer(hidden_states, weight=output_weight)
-
- if labels is None:
- # [s b h] => [b s h]
- return logits.transpose(0, 1).contiguous()
-
- loss = self.compute_language_model_loss(labels, logits)
-
- return loss
-
- def sharded_state_dict(self, prefix: str = '') -> dict:
- sharded_state_dict = {}
-
- if self.pre_process:
- embedding_prefix = f'{prefix}embedding.'
- embedding_sharded_state_dict = self.embedding.sharded_state_dict(
- prefix=embedding_prefix
- )
- sharded_state_dict.update(embedding_sharded_state_dict)
-
- decoder_prefix = f'{prefix}decoder.'
- decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix)
- sharded_state_dict.update(decoder_sharded_state_dict)
-
- if self.post_process:
- output_layer_prefix = f'{prefix}output_layer.'
- output_layer_key = f'{output_layer_prefix}weight'
- if self.share_embeddings_and_output_weights:
- if not self.pre_process:
- # when sharing embeddings with last stage, we need to use the weights from the first stage
- # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight
- tensor = self.shared_embedding_or_output_weight()
- first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
- last_stage_word_emb_replica_id = (
- 1, # copy of first stage embedding
- 0,
- parallel_state.get_data_parallel_rank(),
- )
-
- sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=tensor,
- key=first_stage_word_emb_key,
- replica_id=last_stage_word_emb_replica_id,
- allow_shape_mismatch=True,
- )
-
- sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
-
- else:
- output_layer_state_dict = self.output_layer.state_dict(
- prefix=output_layer_prefix, keep_vars=True
- )
- output_layer_tensor = output_layer_state_dict[output_layer_key]
- # independent output layer
- sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint(
- tensor=output_layer_tensor, key=output_layer_key, allow_shape_mismatch=True,
- )
-
- sharded_state_dict[output_layer_key] = sharded_output_layer_tensor
-
- return sharded_state_dict
diff --git a/megatron/core/models/retro/__init__.py b/megatron/core/models/retro/__init__.py
deleted file mode 100644
index c101fcb1e4cf51be9b2e2268597ed1b1f11a9319..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from .config import RetroConfig
-from .decoder_spec import get_retro_decoder_block_spec
-from .model import RetroModel
diff --git a/megatron/core/models/retro/base_attention.py b/megatron/core/models/retro/base_attention.py
deleted file mode 100644
index 4bafd48daf321e7db6e907ac520ebc92716c93a6..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/base_attention.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from megatron.core.models.retro.config import RetroConfig
-from megatron.core.transformer.attention import CrossAttention, CrossAttentionSubmodules
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.module import MegatronModule
-
-
-class BaseRetroCrossAttention(MegatronModule):
-
- """Base class for Retro cross attention, for both encoder & decoder layers.
-
- This class collects the retro arguments below (i.e., num neighbors, chunk
- length, and retrieve length) for use in Retro's custom cross attention
- operators.
-
- Arguments:
- config (RetroConfig): Retro config.
-
- submodules (CrossAttentionSubmodules): Cross attention submodules.
-
- layer_number (int): Layer number within transformer block.
-
- attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
- """
-
- def __init__(
- self,
- config: RetroConfig,
- submodules: CrossAttentionSubmodules,
- layer_number: int = 1,
- attn_mask_type: AttnMaskType = AttnMaskType.padding,
- ):
- super().__init__(config=config)
-
- self.attn = CrossAttention(
- config=config,
- submodules=submodules,
- layer_number=layer_number,
- attn_mask_type=attn_mask_type,
- )
-
- self.retro_num_neighbors = config.retro_num_neighbors
- self.retro_chunk_length = config.retro_preprocess.retro_gpt_chunk_length
- self.retro_retrieved_length = config.retro_preprocess.retro_gpt_retrieved_length
diff --git a/megatron/core/models/retro/config.py b/megatron/core/models/retro/config.py
deleted file mode 100644
index 2ffeb94bb386c3d394f17b3fd5e8bbf74d495474..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/config.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import types
-from dataclasses import dataclass
-
-from megatron.core.transformer import TransformerConfig
-
-
-@dataclass
-class RetroConfig(TransformerConfig):
-
- """Configuration object for Retro models.
-
- Attributes:
-
- retro_preprocess (SimpleNamespace): Retro preprocess arguments.
- retro_workdir (str): Retro working directory, which contains the
- preprocessed data for for pretraining. This directory is built during
- preprocessing (see tools/retro/README.md), and contains subdirectories
- for the chunk database and pretraining neighbors.
- retro_encoder_layers (int): Number of layers to use for the retrieval
- encoder.
- retro_encoder_hidden_dropout (float): Hidden dropout for retrieval
- encoder.
- retro_encoder_attention_dropout (float): Attention dropout for retrieval
- encoder.
- retro_num_neighbors (int): Number of neighbors to retrieve during
- pretraining.
- retro_num_retrieved_chunks (int): Number of chunks to retrieve from the
- retrieval database.
- retro_verify_neighbor_count (bool): Verify that len(GPT dataset) ==
- len(saved neighbors).
- """
-
- # Retro.
- retro_preprocess: types.SimpleNamespace = None
- retro_workdir: str = None
- retro_encoder_num_layers: int = 2
- retro_encoder_hidden_dropout: float = 0.1
- retro_encoder_attention_dropout: float = 0.1
- retro_num_neighbors: int = 2
- retro_num_retrieved_chunks: int = 2
- retro_verify_neighbor_count: bool = True
diff --git a/megatron/core/models/retro/decoder_attention.py b/megatron/core/models/retro/decoder_attention.py
deleted file mode 100644
index f934c6c717f370c682329484b47c39c4f0b71577..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/decoder_attention.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Retro's cross attention modules for the decoder block."""
-
-from functools import partial
-from typing import Callable
-
-import numpy as np
-import torch
-from torch import Tensor
-
-from megatron.core import InferenceParams
-from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
-from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
-from megatron.core.models.retro.config import RetroConfig
-from megatron.core.transformer import ModuleSpec
-from megatron.core.transformer.attention import CrossAttentionSubmodules
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_block import TransformerBlock
-
-
-class RetroDecoderCrossAttention(BaseRetroCrossAttention):
-
- """Retro decoder's chunked cross attention operator.
-
- See this paper for more details: https://arxiv.org/abs/2112.04426.
- Neighboring chunks retrieved from the chunk database are used here for
- chunked-cross attention.
-
- Arguments:
- config (RetroConfig): Retro config.
-
- submodules (CrossAttentionSubmodules): Cross attention submodules.
-
- layer_number (int): Layer number within transformer block.
-
- attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
-
- encoder_block_spec (ModuleSpec): The first Retro decoder
- layer is provided with a transformer block spec to construct the
- neighbor encoder.
- """
-
- def __init__(
- self,
- config: RetroConfig,
- submodules: CrossAttentionSubmodules,
- layer_number: int = 1,
- attn_mask_type: AttnMaskType = AttnMaskType.padding,
- encoder_block_spec: ModuleSpec = None,
- ):
- """
- ** Note about 'encoder_block_spec' **
-
- Retro is an encoder-decoder model that uses its encoder for encoding
- neighboring chunks that are retrieved from a chunk database. These
- encoded neighbors are then used in the decoder stack for performing
- chunked-cross attention (see paper link above).
-
- In contrast to the T5 model, the encoder and decoder are computationally
- intertwined, since the input to the encoder is the output of the self-
- attention of the first decoder layer. As such, the encoder block itself
- is instantiated within the first Retro decoder layer, in order to receive
- the self-attention's output. (Note, that only the first decoder layer
- instantiates an encoder block, and the remaining decoder layers use the
- encoder output from the first decoder layer.)
- """
-
- super().__init__(
- config=config,
- submodules=submodules,
- layer_number=layer_number,
- attn_mask_type=attn_mask_type,
- )
-
- if encoder_block_spec:
- self.encoder = TransformerBlock(
- config=config, spec=encoder_block_spec, pre_process=True, post_process=False,
- )
- # self._encoder_key = 'encoder' # ... necessary?
- else:
- self.encoder = None
-
- def forward(
- self,
- hidden_states: Tensor,
- attention_mask: Tensor,
- key_value_states: Tensor = None,
- inference_params: InferenceParams = None,
- # rotary_pos_emb: Tensor = None, # ... unsupported for retro.
- ) -> Tensor:
- """Cross attention for Retro decoder.
-
- Notation:
- ns : Sequence length.
- bs : Batch size.
- d : Hidden size.
- l : Number of chunks per sample (i.e., seq_length/chunk_length).
- m : Number of tokens per chunk.
- k : Number of neighbors.
- r : Number of retrieved tokens (neighbors + continuation).
-
- Arguments:
- hidden_states (Tensor): Transformer layer hidden states.
-
- attention_mask (Tensor): Attention mask.
-
- key_value_states (Tensor): Neighbor embeddings if first decoder
- layer, else encoder output.
-
- inference_params (InferenceParams): Inference params.
- """
-
- # hidden_states: [ ns, bs, d ]
- # key_value_states: [ r, k*bs*l, d ]
-
- ns, bs, d = hidden_states.shape
- l = int(np.ceil(ns / self.retro_chunk_length))
-
- # Retrieve neighbors.
- if self.encoder:
-
- # Sequence length remainder.
- first_ns = ns % self.retro_chunk_length
-
- # Case 1: Sequence length not divisible by chunk length.
- if first_ns > 0:
-
- # Split sequence into first partial chunk & remaining chunks.
- first_chunk, rest_chunk = hidden_states[:first_ns], hidden_states[first_ns:]
-
- # Pad partial chunk with zeros.
- first_chunk = torch.nn.functional.pad(
- first_chunk, (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), 'constant', 0,
- )
-
- # Concatenate padded chunk with remaining chunks.
- chunked_output = torch.cat((first_chunk, rest_chunk), dim=0) # [ l*m, bs, d ]
-
- # Case 2: Sequence length is divisible by chunk length.
- else:
- chunked_output = hidden_states # [ l*m, bs, d ]
-
- # Chunk & permute hidden states.
- # - hidden_states: [ l*m, bs, d ]
- # - chunked_output: [ m, bs*l, d ]
- chunked_output = (
- chunked_output.reshape(l, self.retro_chunk_length, bs, d)
- .permute(1, 2, 0, 3)
- .reshape(self.retro_chunk_length, bs * l, d)
- .contiguous()
- )
-
- # Encode neighbors. (Note: 'key_value_states' re-assigned here.)
- key_value_states = self.encoder(
- hidden_states=key_value_states,
- attention_mask=attention_mask,
- context=chunked_output,
- context_mask=None,
- inference_params=inference_params,
- ) # [ r, k*bs*l, d ]
- key_value_states = key_value_states.reshape(
- self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d
- ) # [ r*k, bs*l, d ]
-
- # Attend starting at last token of first chunk.
- pad = (ns - 1) % self.retro_chunk_length
- attending_chunks = hidden_states[pad:]
-
- # Pad attending tokens to sequence length.
- padded_chunks = torch.nn.functional.pad(
- attending_chunks, (0, 0, 0, 0, 0, self.retro_chunk_length - 1), 'constant', 0,
- )
-
- # Permute attending chunks.
- # - padded_chunks: [ l*m, bs, d ]
- # - padded_chunked_output: [ m, bs*l, d ] (matches 'chunked_output' above)
- padded_chunked_output = padded_chunks.reshape(l, self.retro_chunk_length, bs, d).permute(
- 1, 2, 0, 3
- )
- padded_chunked_output = padded_chunked_output.reshape(
- self.retro_chunk_length, bs * l, d
- ).contiguous()
-
- # Attend to encoded neighbors.
- attention_output, attention_bias = self.attn(
- padded_chunked_output, None, key_value_states=key_value_states,
- )
-
- # Return dimensions for bias-dropout step.
- return {
- "ns": ns,
- "bs": bs,
- "d": d,
- "l": l,
- "pad": pad,
- "attention_output": attention_output, # [ m, bs*l, d ]
- "attention_bias": attention_bias, # [ d ]
- "context": key_value_states, # [ r*k, bs*l, d ]
- }
-
-
-class RetroDecoderBiasDropoutAdd(MegatronModule):
-
- """Retro decoder's bias-dropout-add operator.
-
- This operator takes care of reshaping and permuting the output from the
- chunk dimension to the sequence dimension.
-
- Arguments:
- config (RetroConfig): Retro config.
- """
-
- def __init__(
- self, config: RetroConfig,
- ):
- super().__init__(config=config)
- self.retro_chunk_length = config.retro_preprocess.retro_gpt_chunk_length
-
- @classmethod
- def _forward(
- cls,
- x_with_bias: dict,
- residual: Tensor,
- prob: float,
- retro_chunk_length: int,
- bias_dropout_add: Callable,
- ) -> Tensor:
- """Per-chunk bias-dropout-add.
-
- Arguments:
- x_with_bias (dict): Attention output and bias, along with other Retro
- relevant parameters.
-
- residual (Tensor): Transformer layer residual.
-
- prob (float): Dropout probability.
-
- retro_chunk_length (int): Retro chunk length (e.g., 64).
-
- bias_dropout_add (Callable): Bias-dropout-add function.
- """
-
- # Extract input dict.
- ns = x_with_bias["ns"]
- bs = x_with_bias["bs"]
- d = x_with_bias["d"]
- l = x_with_bias["l"]
- pad = x_with_bias["pad"]
- attention_output = x_with_bias["attention_output"] # [ m, bs*l, d ]
- attention_bias = x_with_bias["attention_bias"] # [ d ]
-
- # Re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
-
- # Bias-dropout-add.
- x = bias_dropout_add(
- (
- attention_output,
- None if attention_bias is None else attention_bias.expand_as(attention_output),
- ),
- torch.zeros_like(attention_output),
- prob,
- )
-
- # Permute chunks back to sequence dimension.
- # 1. [ m, bs*l, d ]
- # 2. [ m, bs, l, d ]
- # 3. [ l, m, bs, d ]
- # 4. [ m*l, bs, d ] == [ ns, bs, d ]
- x = (
- x.reshape(retro_chunk_length, bs, l, d)
- .permute(2, 0, 1, 3)
- .reshape(retro_chunk_length * l, bs, d)
- )
-
- # Prepend zeros for non-attending tokens.
- x = torch.nn.functional.pad(x, (0, 0, 0, 0, pad, 0), 'constant', 0,)[
- :ns
- ] # [ ns, bs, d ]
-
- # Add residual. [ ns, bs, d ]
- x = x + residual
-
- # Output. [ ns, bs, d ]
- return x
-
- def forward(self, training: bool, fused: bool) -> Tensor:
- """Retro decoder bias-dropout-add.
-
- Arguments:
- training (bool): If training, then apply dropout.
-
- fused (bool): Fuse bias-dropout-add.
- """
- return partial(
- self._forward,
- retro_chunk_length=self.retro_chunk_length,
- bias_dropout_add=get_bias_dropout_add(training, fused),
- )
diff --git a/megatron/core/models/retro/decoder_spec.py b/megatron/core/models/retro/decoder_spec.py
deleted file mode 100644
index d23e4981e004c3de26f5f708e90597177211e1db..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/decoder_spec.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from megatron.core import parallel_state
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.models.gpt.gpt_layer_specs import (
- get_gpt_layer_local_spec,
- get_gpt_layer_with_transformer_engine_spec,
-)
-from megatron.core.models.retro.config import RetroConfig
-from megatron.core.models.retro.decoder_attention import (
- RetroDecoderBiasDropoutAdd,
- RetroDecoderCrossAttention,
-)
-from megatron.core.models.retro.encoder_spec import get_retro_encoder_block_spec
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
-from megatron.core.transformer import ModuleSpec
-from megatron.core.transformer.attention import CrossAttentionSubmodules
-from megatron.core.transformer.custom_layers.transformer_engine import (
- TEColumnParallelLinear,
- TEDotProductAttention,
- TENorm,
- TERowParallelLinear,
-)
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.transformer_block import (
- TransformerBlockSubmodules,
- get_num_layers_to_build,
-)
-
-
-def get_retro_decoder_layer_te_spec(encoder_block_spec: ModuleSpec = None) -> ModuleSpec:
- """Retro decoder TE spec (uses Transformer Engine components).
-
- A Retro decoder layer uses custom attention and bias-dropout-add operators
- to perform chunked-cross attention. Additionally, the first Retro decoder
- layer instantiates an entire encoder transformer block. As such, the decoder
- cross attention module takes an optional encoder block spec, which is only
- provided for the first Retro decoder layer.
-
- Arguments:
- encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided
- for the first Retro decoder layer.
- """
- spec = get_gpt_layer_with_transformer_engine_spec()
- spec.submodules.pre_cross_attn_layernorm = TENorm
- spec.submodules.cross_attention = ModuleSpec(
- module=RetroDecoderCrossAttention,
- params={"encoder_block_spec": encoder_block_spec,},
- submodules=CrossAttentionSubmodules(
- linear_q=TEColumnParallelLinear,
- linear_kv=TEColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- )
- spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
- return spec
-
-
-def get_retro_decoder_layer_local_spec(encoder_block_spec: ModuleSpec = None) -> ModuleSpec:
- """Retro decoder local spec (uses Megatron-Core components).
-
- A Retro decoder layer uses custom attention and bias-dropout-add operators
- to perform chunked-cross attention. Additionally, the first Retro decoder
- layer instantiates an entire encoder transformer block. As such, the decoder
- cross attention module takes an optional encoder block spec, which is only
- provided for the first Retro decoder layer.
-
- Arguments:
- encoder_block_spec (ModuleSpec): Retro encoder block spec, to be provided
- for the first Retro decoder layer.
- """
- spec = get_gpt_layer_local_spec()
- spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm
- spec.submodules.cross_attention = ModuleSpec(
- module=RetroDecoderCrossAttention,
- params={"encoder_block_spec": encoder_block_spec,},
- submodules=CrossAttentionSubmodules(
- linear_q=ColumnParallelLinear,
- linear_kv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- )
- spec.submodules.cross_attn_bda = ModuleSpec(module=RetroDecoderBiasDropoutAdd)
- return spec
-
-
-def get_retro_decoder_block_spec(
- config: RetroConfig, use_transformer_engine: bool
-) -> TransformerBlockSubmodules:
-
- """Retro decoder block spec.
-
- Retro decoder block implementation details:
- - The retro decoder block consists of interleaved GPT layers and customized
- Retro decoder layers.
- - The Retro decoder layers are spaced three layers apart, and start on layer
- 6 or 9 (depending on the total number of layers).
- - The first decoder layer instantiates an encoder block, and it therefore
- passes in an encoder_block_spec.
-
-
- Arguments:
- config (RetroConfig): Retro config.
-
- use_transformer_engine (bool): If True, use Transformer Engine (instead
- of local modules.
- """
-
- # Num layers.
- assert (
- parallel_state.get_pipeline_model_parallel_world_size() == 1
- ), "retro does not currently support pipeline parallelism."
- assert (
- parallel_state.get_virtual_pipeline_model_parallel_world_size() is None
- ), "retro does not currently support virtual pipeline parallelism."
- num_layers = get_num_layers_to_build(config)
-
- # Retro layer numbers.
- retro_layer_start = 6 if num_layers <= 15 else 9
- retro_layer_numbers = list(range(retro_layer_start, num_layers + 1, 3))
-
- # Layer specs.
- gpt_layer_spec = (
- get_gpt_layer_with_transformer_engine_spec()
- if use_transformer_engine
- else get_gpt_layer_local_spec()
- )
- get_retro_decoder_layer_spec = (
- get_retro_decoder_layer_te_spec
- if use_transformer_engine
- else get_retro_decoder_layer_local_spec
- )
- retro_layer_spec = get_retro_decoder_layer_spec()
- retro_layer_spec_with_retriever = get_retro_decoder_layer_spec(
- get_retro_encoder_block_spec(config, use_transformer_engine)
- )
-
- layer_specs = []
- for layer_number in range(1, num_layers + 1):
- if layer_number == retro_layer_numbers[0]:
- layer_specs.append(retro_layer_spec_with_retriever)
- elif layer_number in retro_layer_numbers:
- layer_specs.append(retro_layer_spec)
- else:
- layer_specs.append(gpt_layer_spec)
-
- # Block spec.
- block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
-
- return block_spec
diff --git a/megatron/core/models/retro/encoder_attention.py b/megatron/core/models/retro/encoder_attention.py
deleted file mode 100644
index 5840e3e3017150956b0997bb6d697e1eda85d5ff..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/encoder_attention.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Retro's cross attention modules for the encoder block."""
-
-from functools import partial
-from typing import Callable, Optional, Tuple, Type
-
-import torch
-from torch import Tensor
-
-from megatron.core import InferenceParams
-from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
-from megatron.core.models.retro.base_attention import BaseRetroCrossAttention
-from megatron.core.models.retro.config import RetroConfig
-from megatron.core.transformer.module import MegatronModule
-
-
-class RetroEncoderCrossAttention(BaseRetroCrossAttention):
-
- """Retro encoder's cross attention operator.
-
- See this paper for more details: https://arxiv.org/abs/2112.04426.
- Neighboring chunks are retrieved from the chunk database, encoded, and
- used by the decoder layers for chunked cross attention.
-
- Arguments:
- config (RetroConfig): Retro config.
-
- submodules (CrossAttentionSubmodules): Cross attention submodules.
-
- layer_number (int): Layer number within transformer block.
-
- attn_mask_type (AttnMaskType): Mask type ('causal' or 'padding').
- """
-
- def forward(
- self,
- hidden_states: Tensor,
- attention_mask: Tensor,
- key_value_states: Tensor = None,
- inference_params: InferenceParams = None,
- # rotary_pos_emb: Tensor = None, # unsupported for retro.
- ) -> Tensor:
- """Cross attention for Retro encoder.
-
- Notation:
- ns : Sequence length.
- bs : Batch size.
- d : Hidden size.
- l : Number of chunks per sample (i.e., seq_length/chunk_length).
- k : Number of neighbors.
- r : Number of retrieved tokens (neighbors + continuation).
-
- Arguments:
- hidden_states (Tensor): Transformer layer hidden states.
-
- attention_mask (Tensor): Attention mask.
-
- key_value_states (Tensor): Neighbor embeddings.
-
- inference_params (InferenceParams): Inference params.
- """
-
- # Input shape. [ r, bs*l*k, d ]
- ns, bs, d = hidden_states.shape
-
- # Reshape sequence into neighboring chunks.
- # - hidden_states: [ r, bs*l*k, d ]
- # - chunked_outputs: [ r, bs*l, k, d ]
- chunked_outputs = hidden_states.reshape(
- self.retro_retrieved_length, -1, self.retro_num_neighbors, d
- )
-
- # Per-chunk attention.
- attention_output_tuples = []
- for k in range(self.retro_num_neighbors):
-
- # Attend to current neighboring chunks.
- # - chunked_output: [ r, bs*l, d ]
- # - key_value_states: [ m, bs*l, d ]
- # - attention_output: [ r, bs*l, d ]
- # - attention_bias: [ d ]
- chunked_output = chunked_outputs[:, :, k].contiguous()
- attention_output, attention_bias = self.attn(
- hidden_states=chunked_output, # Q (neighbor embedding)
- attention_mask=None,
- key_value_states=key_value_states, # K, V (hidden act)
- )
-
- # Residual connection. [ r, bs*l, d ]
- residual = chunked_output
-
- # Collect tensors.
- attention_output_tuples.append((attention_output, attention_bias, residual,))
-
- # Output. (List[Tuple[( [ r, bs*l, d ], [ d ] )]])
- return attention_output_tuples
-
-
-class RetroEncoderBiasDropoutAdd(MegatronModule):
-
- """Retro encoder's bias-dropout-add operator.
-
- This operator applies bias-dropout-add individually on each neighboring
- chunk that is retrieved from the chunk database.
-
- Arguments:
- config (RetroConfig): Retro config.
- """
-
- def __init__(
- self, config: RetroConfig,
- ):
- super().__init__(config=config)
- self.retro_num_neighbors = config.retro_num_neighbors
-
- @classmethod
- def _forward(
- cls,
- x_with_bias: Tuple[Tensor, Optional[Tensor]],
- residual: Tensor,
- prob: float,
- retro_num_neighbors: int,
- bias_dropout_add: Callable,
- ) -> Tensor:
- """Per-chunk bias-dropout-add.
-
- Arguments:
- x_with_bias (dict): Attention output and bias tuple.
-
- residual (Tensor): Transformer layer residual.
-
- prob (float): Dropout probability.
-
- retro_num_neighbors (int): Number of retrieved neighbor chunks (e.g., 2).
-
- bias_dropout_add (Callable): Bias-dropout-add function.
- """
-
- # Re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
-
- # Per-neighbor bias-dropout-add.
- # - attention_output: [ r, bs*l, d ]
- # - attention_bias: [ d ]
- # - residual: [ r, bs*l, d ]
- # - output: [ r, bs*l, d ]
- outputs = [
- bias_dropout_add(
- (
- attention_output,
- None if attention_bias is None else attention_bias.expand_as(residual),
- ),
- residual,
- prob,
- )
- for attention_output, attention_bias, residual in x_with_bias
- ]
-
- # Concatenate outputs (to shape [r, k*bs*l, d]; see notation above).
- r, _, d = outputs[0].shape
- output = torch.stack(outputs, dim=1).reshape(r, -1, d)
-
- # Output. [ r, k*bs*l, d ]
- return output
-
- def forward(self, training: bool, fused: bool) -> Tensor:
- """Retro decoder bias-dropout-add.
-
- Arguments:
- training (bool): If training, then apply dropout.
-
- fused (bool): Fuse bias-dropout-add.
- """
- return partial(
- self._forward,
- retro_num_neighbors=self.retro_num_neighbors,
- bias_dropout_add=get_bias_dropout_add(training, fused),
- )
-
-
-class RetroEncoderLayerNorm(MegatronModule):
-
- """Retro encoder's layernorm operator.
-
- This operator applies layernorm individually on each neighboring chunk that
- is retrieved from the chunk database, and then concatenates the chunks into
- a single tensor.
-
- Arguments:
- config (RetroConfig): Retro config.
- """
-
- def __init__(
- self, config: RetroConfig, submodules: Type, **kwargs,
- ):
- super().__init__(config=config)
- norm_class = submodules
- self.norm = norm_class(config=config, **kwargs)
- self.retro_num_neighbors = config.retro_num_neighbors
-
- def forward(self, input: Tensor) -> Tensor:
- """Per-chunk layer norm.
-
- Arguments:
- input (Tensor): Input chunks, concatenated into a single tensor.
- """
-
- # Input shape: [ r, k*bs*l, d ]. (see notation above in attention module)
-
- # Split input into 'num_neighbors' tensors.
- chunk_size = input.shape[1] // self.retro_num_neighbors
- inputs = torch.split(input, chunk_size, dim=1)
-
- # Norm.
- outputs = [self.norm(inp.contiguous()) for inp in inputs]
-
- # Concatenate layer norms (to shape [r, k*bs*l, d]; see notation above).
- r, _, d = inputs[0].shape
- output = torch.stack(outputs, dim=1).reshape(r, -1, d)
-
- # Output. [ r, k*bs*l, d ]
- return output
diff --git a/megatron/core/models/retro/encoder_spec.py b/megatron/core/models/retro/encoder_spec.py
deleted file mode 100644
index 63efadedd884f075e35c52b00f216013ede2ab21..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/encoder_spec.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.models.gpt.gpt_layer_specs import (
- get_gpt_layer_local_spec,
- get_gpt_layer_with_transformer_engine_spec,
-)
-from megatron.core.models.retro.config import RetroConfig
-from megatron.core.models.retro.encoder_attention import (
- RetroEncoderBiasDropoutAdd,
- RetroEncoderCrossAttention,
- RetroEncoderLayerNorm,
-)
-from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
-from megatron.core.transformer import ModuleSpec
-from megatron.core.transformer.attention import CrossAttentionSubmodules
-from megatron.core.transformer.custom_layers.transformer_engine import (
- TEColumnParallelLinear,
- TEDotProductAttention,
- TENorm,
- TERowParallelLinear,
-)
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.mlp import MLP, MLPSubmodules
-from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
-
-
-def get_retro_encoder_layer_te_spec() -> ModuleSpec:
- """Retro encoder TE spec (uses Transformer Engine components).
-
- A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
- operators to encode neighboring chunks that are retrieved from the chunk
- database. Each operator is responsible for iterating the retrieved chunks
- and processing them individually.
- """
- spec = get_gpt_layer_with_transformer_engine_spec()
- spec.submodules.pre_cross_attn_layernorm = TENorm
- spec.submodules.cross_attention = ModuleSpec(
- module=RetroEncoderCrossAttention,
- params={"attn_mask_type": AttnMaskType.padding,},
- submodules=CrossAttentionSubmodules(
- linear_q=TEColumnParallelLinear,
- linear_kv=TEColumnParallelLinear,
- core_attention=TEDotProductAttention,
- linear_proj=TERowParallelLinear,
- ),
- )
- spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
- spec.submodules.pre_mlp_layernorm = ModuleSpec(module=RetroEncoderLayerNorm, submodules=TENorm,)
- spec.submodules.mlp = ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(
- linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear,
- ),
- )
- return spec
-
-
-def get_retro_encoder_layer_local_spec() -> ModuleSpec:
- """Retro encoder local spec (uses Megatron-Core components).
-
- A Retro encoder layer uses custom attention, bias-dropout-add, and layernorm
- operators to encode neighboring chunks that are retrieved from the chunk
- database. Each operator is responsible for iterating the retrieved chunks
- and processing them individually.
- """
- spec = get_gpt_layer_local_spec()
- spec.submodules.pre_cross_attn_layernorm = FusedLayerNorm
- spec.submodules.cross_attention = ModuleSpec(
- module=RetroEncoderCrossAttention,
- params={"attn_mask_type": AttnMaskType.padding,},
- submodules=CrossAttentionSubmodules(
- linear_q=ColumnParallelLinear,
- linear_kv=ColumnParallelLinear,
- core_attention=DotProductAttention,
- linear_proj=RowParallelLinear,
- ),
- )
- spec.submodules.cross_attn_bda = ModuleSpec(module=RetroEncoderBiasDropoutAdd)
- spec.submodules.pre_mlp_layernorm = ModuleSpec(
- module=RetroEncoderLayerNorm, submodules=FusedLayerNorm,
- )
- spec.submodules.mlp = ModuleSpec(
- module=MLP,
- submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear,),
- )
- return spec
-
-
-def get_retro_encoder_block_spec(
- config: RetroConfig, use_transformer_engine: bool
-) -> TransformerBlockSubmodules:
-
- """Retro encoder block spec.
-
- The retro encoder block consists of one customized Retro encoder layer
- (layer 1), and all of the following layers are standard GPT layers.
-
- Arguments:
- config (RetroConfig): Retro config.
-
- use_transformer_engine (bool): If True, use Transformer Engine (instead
- of local modules.
- """
-
- # Num layers.
- num_layers = config.retro_encoder_num_layers
- retro_layer_numbers = [1]
-
- # Layer specs.
- gpt_layer_spec = (
- get_gpt_layer_with_transformer_engine_spec()
- if use_transformer_engine
- else get_gpt_layer_local_spec()
- )
- get_retro_encoder_layer_spec = (
- get_retro_encoder_layer_te_spec
- if use_transformer_engine
- else get_retro_encoder_layer_local_spec
- )
- retro_layer_spec = get_retro_encoder_layer_spec()
- for spec in (gpt_layer_spec, retro_layer_spec):
- spec.params["hidden_dropout"] = config.retro_encoder_hidden_dropout
- spec.submodules.self_attention.params["attn_mask_type"] = AttnMaskType.padding
- spec.submodules.self_attention.submodules.core_attention = ModuleSpec(
- module=TEDotProductAttention if use_transformer_engine else DotProductAttention,
- params={"attention_dropout": config.retro_encoder_attention_dropout,},
- )
-
- layer_specs = []
- for layer_number in range(1, num_layers + 1):
- if layer_number in retro_layer_numbers:
- layer_specs.append(retro_layer_spec)
- else:
- layer_specs.append(gpt_layer_spec)
-
- # Block spec.
- block_spec = TransformerBlockSubmodules(layer_specs=layer_specs)
-
- return block_spec
diff --git a/megatron/core/models/retro/model.py b/megatron/core/models/retro/model.py
deleted file mode 100644
index d47c08fb52788b23ba4c3301bd24401d42e6d2a6..0000000000000000000000000000000000000000
--- a/megatron/core/models/retro/model.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Retro Model."""
-
-from torch import Tensor
-
-from megatron.core import InferenceParams
-from megatron.core.models.gpt import GPTModel
-
-
-class RetroModel(GPTModel):
-
- """Retro Model.
-
- A Retro model mostly re-uses the GPTModel interface, with the only difference
- being the embedding of the 'context' this is used by Retro for processing
- neighbor tokens. This embedded context is then forwarded to the Transformer
- Block.
- """
-
- def forward(
- self,
- input_ids: Tensor,
- position_ids: Tensor,
- attention_mask: Tensor,
- context_input_ids: Tensor = None,
- context_position_ids: Tensor = None,
- context_mask: Tensor = None,
- decoder_input: Tensor = None,
- labels: Tensor = None,
- inference_params: InferenceParams = None,
- ) -> Tensor:
- """RetroModel forward method.
-
- Foward input tokens & mask, along with neighbor tokens & mask, through
- the Retro model..
-
- Arguments:
- input_ids (Tensor): Input token IDs.
-
- position_ids (Tensor): Input position IDs.
-
- attention_mask (Tensor): Input attention mask.
-
- context_input_ids (Tensor): Context (i.e., neighbor) token IDs.
-
- context_position_ids (Tensor): Context (i.e., neighbor) position IDs.
-
- context_mask (Tensor): Context (i.e., neighbor) attention mask.
-
- decoder_input (Tensor): When using pipeline parallelism, input_ids and
- position_ids will only be used on the first stage, and for all other
- stages decoder_input will be provided via communication from the
- previous stage.
-
- labels (Tensor): The labels of dimension [batch size, seq length].
-
- inference_params (InferenceParams): Parameters for inference.
- """
-
- # Argument shapes:
- # Notation:
- # ns : Sequence length.
- # bs : Batch size.
- # d : Hidden size.
- # l : Number of chunks per sample (i.e., seq_length/chunk_length).
- # k : Number of neighbors.
- # r : Number of retrieved tokens (neighbors + continuation).
- # - input_ids: [ bs, ns ]
- # - context_ids: [ k*bs*l, r ]
- # - context: [ r, k*bs*l, d ]
- # - output: [ ns, bs, d ]
-
- # Context embedding (e.g., for Retro neighbor tokens).
- if context_input_ids is not None:
- context = self.embedding(context_input_ids, context_position_ids)
- else:
- context = None
-
- # Call GPTModel.forward, and pass in embedded context.
- return super().forward(
- input_ids=input_ids,
- position_ids=position_ids,
- attention_mask=attention_mask,
- decoder_input=decoder_input,
- labels=labels,
- inference_params=inference_params,
- extra_block_kwargs={"context": context, "context_mask": context_mask,},
- )
diff --git a/megatron/core/package_info.py b/megatron/core/package_info.py
deleted file mode 100644
index 55c49b1785b0cdfa94854ca73431d963150b8c7f..0000000000000000000000000000000000000000
--- a/megatron/core/package_info.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-
-MAJOR = 0
-MINOR = 4
-PATCH = 0
-PRE_RELEASE = 'rc0'
-
-# Use the following formatting: (major, minor, patch, pre-release)
-VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
-
-__shortversion__ = '.'.join(map(str, VERSION[:3]))
-__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
-
-__package_name__ = 'megatron_core'
-__contact_names__ = 'NVIDIA'
-__contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email
-__homepage__ = (
- 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage
-)
-__repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core'
-__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
-__description__ = (
- 'Megatron Core - a library for efficient and scalable training of transformer based models'
-)
-__license__ = 'BSD-3'
-__keywords__ = (
- 'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch'
-)
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
deleted file mode 100644
index 5652b208468537cf2b0e97283d8c719995e1112d..0000000000000000000000000000000000000000
--- a/megatron/core/parallel_state.py
+++ /dev/null
@@ -1,980 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Model and data parallel groups."""
-
-import os
-from typing import Optional
-
-import torch
-
-from .utils import GlobalMemoryBuffer
-
-# Intra-layer model parallel group that the current rank belongs to.
-_TENSOR_MODEL_PARALLEL_GROUP = None
-# Inter-layer model parallel group that the current rank belongs to.
-_PIPELINE_MODEL_PARALLEL_GROUP = None
-# Model parallel group (both intra- and pipeline) that the current rank belongs to.
-_MODEL_PARALLEL_GROUP = None
-# Embedding group.
-_EMBEDDING_GROUP = None
-# Position embedding group.
-_POSITION_EMBEDDING_GROUP = None
-# Data parallel group that the current rank belongs to.
-_DATA_PARALLEL_GROUP = None
-_DATA_PARALLEL_GROUP_GLOO = None
-# tensor model parallel group and data parallel group combined
-# used for fp8 and moe training
-_TENSOR_AND_DATA_PARALLEL_GROUP = None
-# Expert parallel group that the current rank belongs to.
-_TENSOR_AND_EXPERT_PARALLEL_GROUP = None
-_DATA_MODULO_EXPERT_PARALLEL_GROUP = None
-
-
-_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
-_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
-_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
-
-# These values enable us to change the mpu sizes on the fly.
-_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
-_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
-_MPU_TENSOR_MODEL_PARALLEL_RANK = None
-_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
-
-# A list of ranks that have a copy of the embedding.
-_EMBEDDING_GLOBAL_RANKS = None
-
-# A list of ranks that have a copy of the position embedding.
-_POSITION_EMBEDDING_GLOBAL_RANKS = None
-
-# A list of global ranks for each pipeline group to ease calculation of the source
-# rank when broadcasting from the first or last pipeline stage.
-_PIPELINE_GLOBAL_RANKS = None
-
-# A list of global ranks for each data parallel group to ease calculation of the source
-# rank when broadcasting weights from src to all other data parallel ranks
-_DATA_PARALLEL_GLOBAL_RANKS = None
-
-# Context parallel group that the current rank belongs to
-_CONTEXT_PARALLEL_GROUP = None
-# A list of global ranks for each context parallel group to ease calculation of the
-# destination rank when exchanging KV/dKV between context parallel_ranks
-_CONTEXT_PARALLEL_GLOBAL_RANKS = None
-
-# Data parallel group information with context parallel combined.
-_DATA_PARALLEL_GROUP_WITH_CP = None
-_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
-_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
-
-# combined parallel group of TP, DP, and CP used for fp8
-_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
-
-# Memory buffers to avoid dynamic memory allocation
-_GLOBAL_MEMORY_BUFFER = None
-
-
-def get_nccl_options(pg_name, nccl_comm_cfgs):
- """Set the NCCL process group options.
-
- Arguments:
- pg_name (str): process group name
- nccl_comm_cfgs (dict): nccl communicator configurations
-
- When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
- """
- if pg_name in nccl_comm_cfgs:
- nccl_options = torch.distributed.ProcessGroupNCCL.Options()
- nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
- nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
- nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
- return nccl_options
- else:
- return None
-
-
-def initialize_model_parallel(
- tensor_model_parallel_size: int = 1,
- pipeline_model_parallel_size: int = 1,
- virtual_pipeline_model_parallel_size: Optional[int] = None,
- pipeline_model_parallel_split_rank: Optional[int] = None,
- use_sharp: bool = False,
- context_parallel_size: int = 1,
- expert_model_parallel_size: int = 1,
- nccl_communicator_config_path: Optional[str] = None,
-) -> None:
- """Initialize model data parallel groups.
-
- Arguments:
- tensor_model_parallel_size (int, default = 1):
- The number of GPUs to split individual tensors across.
-
- pipeline_model_parallel_size (int, default = 1):
- The number of tensor parallel GPU groups to split the
- Transformer layers across. For example, if
- tensor_model_parallel_size is 4 and
- pipeline_model_parallel_size is 2, the model will be split
- into 2 groups of 4 GPUs.
-
- virtual_pipeline_model_parallel_size (int, optional):
- The number of stages that each pipeline group will have,
- interleaving as necessary. If None, no interleaving is
- performed. For example, if tensor_model_parallel_size is 1,
- pipeline_model_parallel_size is 4,
- virtual_pipeline_model_parallel_size is 2, and there are
- 16 transformer layers in the model, the model will be
- split into 8 stages with two layers each and each GPU
- would get 2 stages as such (layer number starting with 1):
-
- GPU 0: [1, 2] [9, 10]
- GPU 1: [3, 4] [11, 12]
- GPU 2: [5, 6] [13, 14]
- GPU 3: [7, 8] [15, 16]
-
- pipeline_model_parallel_split_rank (int, optional):
- For models with both an encoder and decoder, the rank in
- pipeline to switch between encoder and decoder (i.e. the
- first rank of the decoder). This allows the user to set
- the pipeline parallel size of the encoder and decoder
- independently. For example, if
- pipeline_model_parallel_size is 8 and
- pipeline_model_parallel_split_rank is 3, then ranks 0-2
- will be the encoder and ranks 3-7 will be the decoder.
-
- use_sharp (bool, default = False):
- Set the use of SHARP for the collective communications of
- data-parallel process groups. When `True`, run barrier
- within each data-parallel process group, which specifies
- the SHARP application target groups.
-
- context_parallel_size (int, default = 1):
- The number of tensor parallel GPU groups to split the
- network input sequence length across. Compute of attention
- module requires tokens of full sequence length, so GPUs
- in a context parallel group need to communicate with each
- other to exchange information of other sequence chunks.
- Each GPU and its counterparts in other tensor parallel
- groups compose a context parallel group.
-
- For example, assume we have 8 GPUs, if tensor model parallel
- size is 4 and context parallel size is 2, the network input
- will be split into two sequence chunks, which are processed
- by 2 different groups of 4 GPUs. One chunk is processed by
- GPU0-3, the other chunk is processed by GPU4-7. Four groups
- are build to do context parallel communications: [GPU0, GPU4],
- [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
-
- Context parallelism partitions sequence length, so it has no
- impact on weights, which means weights are duplicated among
- GPUs in a context parallel group. Hence, weight gradients
- all-reduce is required in backward. For simplicity, we piggyback
- GPUs of context parallelism on data parallel group for
- weight gradient all-reduce.
-
- nccl_communicator_config_path (str, default = None):
- Path to the yaml file of NCCL communicator configurations.
- `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
- for each communicator.
-
- Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
- use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
- the model pipeline. The present function will
- create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
- and 8 data-parallel groups as:
- 8 data_parallel groups:
- [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
- 8 tensor model-parallel groups:
- [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
- 4 pipeline model-parallel groups:
- [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
- Note that for efficiency, the caller should make sure adjacent ranks
- are on the same DGX box. For example if we are using 2 DGX-1 boxes
- with a total of 16 GPUs, rank 0 to 7 belong to the first box and
- ranks 8 to 15 belong to the second box.
-
- """
- # Get world size and rank. Ensure some consistencies.
- assert torch.distributed.is_initialized()
- world_size: int = torch.distributed.get_world_size()
-
- if (
- world_size
- % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
- != 0
- ):
- raise RuntimeError(
- f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
- f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) "
- f"x context_parallel_size ({context_parallel_size})"
- )
-
- data_parallel_size: int = world_size // (
- tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
- )
-
- if data_parallel_size % expert_model_parallel_size != 0:
- raise RuntimeError(
- f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size "
- )
-
- if expert_model_parallel_size > 1 and context_parallel_size > 1:
- raise RuntimeError(
- f"combination of expert model prallellism and context parallelism is not supported"
- )
-
- num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
- num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
-
- if virtual_pipeline_model_parallel_size is not None:
- if not pipeline_model_parallel_size > 2:
- raise RuntimeError(
- "pipeline-model-parallel size should be greater than 2 with interleaved schedule"
- )
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
-
- if pipeline_model_parallel_split_rank is not None:
- global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
- _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
-
- rank = torch.distributed.get_rank()
-
- nccl_comm_cfgs = {}
- if nccl_communicator_config_path is not None:
- try:
- import yaml
- except ImportError:
- raise RuntimeError(
- "Cannot import `yaml`. Setting custom nccl communicator configs "
- "requires the yaml package."
- )
-
- with open(nccl_communicator_config_path, "r") as stream:
- nccl_comm_cfgs = yaml.safe_load(stream)
-
- # Build the data-parallel groups.
- global _DATA_PARALLEL_GROUP
- global _DATA_PARALLEL_GROUP_GLOO
- global _DATA_PARALLEL_GLOBAL_RANKS
- global _DATA_PARALLEL_GROUP_WITH_CP
- global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
- global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
- assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
- all_data_parallel_group_ranks_with_cp = []
- for i in range(pipeline_model_parallel_size):
- start_rank = i * num_pipeline_model_parallel_groups
- end_rank = (i + 1) * num_pipeline_model_parallel_groups
- for j in range(context_parallel_size * tensor_model_parallel_size):
- ranks = range(
- start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size
- )
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
- )
- group_gloo = torch.distributed.new_group(ranks, backend="gloo")
- if rank in ranks:
- _DATA_PARALLEL_GROUP = group
- _DATA_PARALLEL_GROUP_GLOO = group_gloo
- _DATA_PARALLEL_GLOBAL_RANKS = ranks
- for j in range(tensor_model_parallel_size):
- ranks_with_cp = range(start_rank + j, end_rank, tensor_model_parallel_size)
- all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp))
- group_with_cp = torch.distributed.new_group(
- ranks_with_cp, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
- )
- group_with_cp_gloo = torch.distributed.new_group(ranks_with_cp, backend="gloo")
- if rank in ranks_with_cp:
- _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
- _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
- _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
-
- # Apply SHARP to DP process groups
- if use_sharp:
- if rank == 0:
- print(
- "The number of process groups to use SHARP with depends on the type "
- "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
- "process groups and QM2 supports up to 256 process groups. We apply "
- "SHARP to the communications of the data-parallel domain. If the "
- "number of data-parallel process groups is larger than the max "
- "process groups that the network switch supports, the communication "
- "will fall back to non-SHARP operators. To enable SHARP, "
- "`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
- )
- torch.distributed.barrier(
- group=get_data_parallel_group(with_context_parallel=context_parallel_size > 1),
- device_ids=[torch.cuda.current_device()],
- )
- # Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
- os.environ["NCCL_SHARP_DISABLE"] = "1"
-
- # Build the context-parallel groups.
- global _CONTEXT_PARALLEL_GROUP
- global _CONTEXT_PARALLEL_GLOBAL_RANKS
- assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
- for i in range(pipeline_model_parallel_size):
- for j in range(data_parallel_size):
- start_rank = (
- i * num_pipeline_model_parallel_groups
- + j * tensor_model_parallel_size * context_parallel_size
- )
- end_rank = (
- i * num_pipeline_model_parallel_groups
- + (j + 1) * tensor_model_parallel_size * context_parallel_size
- )
- for k in range(tensor_model_parallel_size):
- ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _CONTEXT_PARALLEL_GROUP = group
- _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
-
- # Build the model-parallel groups.
- global _MODEL_PARALLEL_GROUP
- assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
- for i in range(data_parallel_size * context_parallel_size):
- ranks = [
- data_parallel_group_ranks_with_cp[i]
- for data_parallel_group_ranks_with_cp in all_data_parallel_group_ranks_with_cp
- ]
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _MODEL_PARALLEL_GROUP = group
-
- # Build the tensor model-parallel groups.
- global _TENSOR_MODEL_PARALLEL_GROUP
- assert (
- _TENSOR_MODEL_PARALLEL_GROUP is None
- ), 'tensor model parallel group is already initialized'
- for i in range(num_tensor_model_parallel_groups):
- ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _TENSOR_MODEL_PARALLEL_GROUP = group
-
- # Build the pipeline model-parallel groups and embedding groups
- # (first and last rank in each pipeline model-parallel group).
- global _PIPELINE_MODEL_PARALLEL_GROUP
- global _PIPELINE_GLOBAL_RANKS
- assert (
- _PIPELINE_MODEL_PARALLEL_GROUP is None
- ), 'pipeline model parallel group is already initialized'
- global _EMBEDDING_GROUP
- global _EMBEDDING_GLOBAL_RANKS
- assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
- global _POSITION_EMBEDDING_GROUP
- global _POSITION_EMBEDDING_GLOBAL_RANKS
- assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
- for i in range(num_pipeline_model_parallel_groups):
- ranks = range(i, world_size, num_pipeline_model_parallel_groups)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('pp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _PIPELINE_MODEL_PARALLEL_GROUP = group
- _PIPELINE_GLOBAL_RANKS = ranks
- # Setup embedding group (to exchange gradients between
- # first and last stages).
- if len(ranks) > 1:
- embedding_ranks = [ranks[0], ranks[-1]]
- position_embedding_ranks = [ranks[0]]
- if pipeline_model_parallel_split_rank is not None:
- if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
- embedding_ranks = [
- ranks[0],
- ranks[pipeline_model_parallel_split_rank],
- ranks[-1],
- ]
- if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
- position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
- else:
- embedding_ranks = ranks
- position_embedding_ranks = ranks
-
- group = torch.distributed.new_group(
- embedding_ranks, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
- )
- if rank in embedding_ranks:
- _EMBEDDING_GROUP = group
- if rank in ranks:
- _EMBEDDING_GLOBAL_RANKS = embedding_ranks
-
- group = torch.distributed.new_group(
- position_embedding_ranks, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
- )
- if rank in position_embedding_ranks:
- _POSITION_EMBEDDING_GROUP = group
- if rank in ranks:
- _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
-
- # Build the tensor + data parallel groups.
- global _TENSOR_AND_DATA_PARALLEL_GROUP
- global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
- assert (
- _TENSOR_AND_DATA_PARALLEL_GROUP is None
- ), 'Tensor + data parallel group is already initialized'
- tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * data_parallel_size * context_parallel_size
- num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp
- for i in range(num_tensor_and_data_groups_with_cp):
- start_rank = i * tensor_and_data_group_size_with_cp
- end_rank = start_rank + tensor_and_data_group_size_with_cp
- ranks = range(start_rank, end_rank)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
-
- for j in range(context_parallel_size):
- ranks = []
- for k in range(data_parallel_size):
- start_rank = (
- i * tensor_and_data_group_size_with_cp
- + j * tensor_model_parallel_size
- + k * tensor_model_parallel_size * context_parallel_size
- )
- end_rank = start_rank + tensor_model_parallel_size
- ranks = ranks + list(range(start_rank, end_rank))
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _TENSOR_AND_DATA_PARALLEL_GROUP = group
-
- # Build the tensor + expert parallel groups
- global _TENSOR_AND_EXPERT_PARALLEL_GROUP
- assert (
- _TENSOR_AND_EXPERT_PARALLEL_GROUP is None
- ), 'Tensor + expert parallel group is already initialized'
- global _DATA_MODULO_EXPERT_PARALLEL_GROUP
- assert (
- _DATA_MODULO_EXPERT_PARALLEL_GROUP is None
- ), 'Data modulo expert group is already initialized'
- tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
- num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
- tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size
- num_expert_groups: int = data_parallel_size // expert_model_parallel_size
- for i in range(num_tensor_and_data_groups):
- for j in range(num_expert_groups):
- start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
- end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
- ranks = range(start_rank, end_rank)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _TENSOR_AND_EXPERT_PARALLEL_GROUP = group
-
- for i in range(num_tensor_and_data_groups):
- start_rank = i * tensor_and_data_group_size
- end_rank = (i + 1) * tensor_and_data_group_size
- for j in range(tensor_and_expert_group_size):
- ranks = range(start_rank + j, end_rank, tensor_and_expert_group_size)
- group = torch.distributed.new_group(
- ranks, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
- )
- if rank in ranks:
- _DATA_MODULO_EXPERT_PARALLEL_GROUP = group
-
- # Initialize global memory buffer
- # This isn't really "parallel state" but there isn't another good place to
- # put this. If we end up with a more generic initialization of megatron-core
- # we could stick it there
- _set_global_memory_buffer()
-
-
-def is_unitialized():
- """Useful for code segments that may be accessed with or without mpu initialization"""
- return _DATA_PARALLEL_GROUP is None
-
-
-def model_parallel_is_initialized():
- """Check if model and data parallel groups are initialized."""
- if (
- _TENSOR_MODEL_PARALLEL_GROUP is None
- or _PIPELINE_MODEL_PARALLEL_GROUP is None
- or _DATA_PARALLEL_GROUP is None
- ):
- return False
- return True
-
-
-def get_model_parallel_group():
- """Get the model parallel group the caller rank belongs to."""
- assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
- return _MODEL_PARALLEL_GROUP
-
-
-def get_tensor_model_parallel_group(check_initialized=True):
- """Get the tensor model parallel group the caller rank belongs to."""
- if check_initialized:
- assert (
- _TENSOR_MODEL_PARALLEL_GROUP is not None
- ), 'tensor model parallel group is not initialized'
- return _TENSOR_MODEL_PARALLEL_GROUP
-
-
-def get_pipeline_model_parallel_group():
- """Get the pipeline model parallel group the caller rank belongs to."""
- assert (
- _PIPELINE_MODEL_PARALLEL_GROUP is not None
- ), 'pipeline_model parallel group is not initialized'
- return _PIPELINE_MODEL_PARALLEL_GROUP
-
-
-def get_data_parallel_group(with_context_parallel=False):
- """Get the data parallel group the caller rank belongs to."""
- if with_context_parallel:
- assert (
- _DATA_PARALLEL_GROUP_WITH_CP is not None
- ), 'data parallel group with context parallel combined is not initialized'
- return _DATA_PARALLEL_GROUP_WITH_CP
- else:
- assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
- return _DATA_PARALLEL_GROUP
-
-
-def get_data_parallel_group_gloo(with_context_parallel=False):
- """Get the data parallel group-gloo the caller rank belongs to."""
- if with_context_parallel:
- assert (
- _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
- ), 'data parallel group-gloo with context parallel combined is not initialized'
- return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
- else:
- assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
- return _DATA_PARALLEL_GROUP_GLOO
-
-
-def get_context_parallel_group(check_initialized=True):
- """Get the context parallel group the caller rank belongs to."""
- if check_initialized:
- assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
- return _CONTEXT_PARALLEL_GROUP
-
-
-def get_context_parallel_global_ranks(check_initialized=True):
- """Get all global ranks of the context parallel group that the caller rank belongs to."""
- if check_initialized:
- assert (
- _CONTEXT_PARALLEL_GLOBAL_RANKS is not None
- ), 'context parallel group is not initialized'
- return _CONTEXT_PARALLEL_GLOBAL_RANKS
-
-
-def get_embedding_group():
- """Get the embedding group the caller rank belongs to."""
- assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
- return _EMBEDDING_GROUP
-
-
-def get_position_embedding_group():
- """Get the position embedding group the caller rank belongs to."""
- assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
- return _POSITION_EMBEDDING_GROUP
-
-
-def get_amax_reduction_group(with_context_parallel=False):
- """Get the FP8 amax reduction group the caller rank belongs to."""
- if with_context_parallel:
- assert (
- _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
- ), 'FP8 amax reduction group is not initialized'
- return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
- else:
- assert (
- _TENSOR_AND_DATA_PARALLEL_GROUP is not None
- ), 'FP8 amax reduction group is not initialized'
- return _TENSOR_AND_DATA_PARALLEL_GROUP
-
-
-def get_tensor_and_data_parallel_group(with_context_parallel=False):
- """Get the tensor and data parallel group the caller rank belongs to."""
- if with_context_parallel:
- assert (
- _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
- ), 'tensor and data parallel group is not initialized'
- return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
- else:
- assert (
- _TENSOR_AND_DATA_PARALLEL_GROUP is not None
- ), 'tensor and data parallel group is not initialized'
- return _TENSOR_AND_DATA_PARALLEL_GROUP
-
-
-def get_tensor_and_expert_parallel_group():
- assert (
- _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None
- ), 'tensor and expert parallel group is not initialized'
- return _TENSOR_AND_EXPERT_PARALLEL_GROUP
-
-
-def get_data_modulo_expert_parallel_group():
- assert (
- _DATA_MODULO_EXPERT_PARALLEL_GROUP is not None
- ), 'data modulo expert parallel group is not initialized'
- return _DATA_MODULO_EXPERT_PARALLEL_GROUP
-
-
-def set_tensor_model_parallel_world_size(world_size):
- """Set the tensor model parallel size"""
- global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
- _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
-
-
-def set_pipeline_model_parallel_world_size(world_size):
- """Set the pipeline model parallel size"""
- global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
-
-
-def set_virtual_pipeline_model_parallel_world_size(world_size):
- """Set the pipeline model parallel size"""
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
-
-
-def get_tensor_model_parallel_world_size():
- """Return world size for the tensor model parallel group."""
- global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
- if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
- return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
- return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
-
-
-def get_pipeline_model_parallel_world_size():
- """Return world size for the pipeline model parallel group."""
- global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
- return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
-
-
-def set_tensor_model_parallel_rank(rank):
- """Set tensor model parallel rank."""
- global _MPU_TENSOR_MODEL_PARALLEL_RANK
- _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
-
-
-def set_pipeline_model_parallel_rank(rank):
- """Set pipeline model parallel rank."""
- global _MPU_PIPELINE_MODEL_PARALLEL_RANK
- _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
-
-
-def set_pipeline_model_parallel_split_rank(rank):
- """Set pipeline model parallel split rank."""
- global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
- _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
-
-
-def get_tensor_model_parallel_rank():
- """Return my rank for the tensor model parallel group."""
- global _MPU_TENSOR_MODEL_PARALLEL_RANK
- if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
- return _MPU_TENSOR_MODEL_PARALLEL_RANK
- return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
-
-
-def get_pipeline_model_parallel_rank():
- """Return my rank for the pipeline model parallel group."""
- global _MPU_PIPELINE_MODEL_PARALLEL_RANK
- if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
- return _MPU_PIPELINE_MODEL_PARALLEL_RANK
- return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
-
-
-def get_pipeline_model_parallel_split_rank():
- """Return pipeline model parallel split rank."""
- global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
- return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
-
-
-def is_pipeline_first_stage(ignore_virtual=False):
- """Return True if in the first pipeline model-parallel stage, False otherwise."""
- if not ignore_virtual:
- if (
- get_virtual_pipeline_model_parallel_world_size() is not None
- and get_virtual_pipeline_model_parallel_rank() != 0
- ):
- return False
- return get_pipeline_model_parallel_rank() == 0
-
-
-def is_pipeline_last_stage(ignore_virtual=False):
- """Return True if in the last pipeline model-parallel stage, False otherwise."""
- if not ignore_virtual:
- virtual_pipeline_model_parallel_world_size = (
- get_virtual_pipeline_model_parallel_world_size()
- )
- if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
- virtual_pipeline_model_parallel_world_size - 1
- ):
- return False
- return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
-
-
-def is_rank_in_embedding_group(ignore_virtual=False):
- """Return true if current rank is in embedding group, False otherwise."""
- rank = torch.distributed.get_rank()
- global _EMBEDDING_GLOBAL_RANKS
- if ignore_virtual:
- return rank in _EMBEDDING_GLOBAL_RANKS
- if rank in _EMBEDDING_GLOBAL_RANKS:
- if rank == _EMBEDDING_GLOBAL_RANKS[0]:
- return is_pipeline_first_stage(ignore_virtual=False)
- elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
- return is_pipeline_last_stage(ignore_virtual=False)
- else:
- return True
- return False
-
-
-def is_rank_in_position_embedding_group():
- """Return true if current rank is in position embedding group, False otherwise."""
- rank = torch.distributed.get_rank()
- global _POSITION_EMBEDDING_GLOBAL_RANKS
- return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
-
-
-def is_pipeline_stage_before_split(rank=None):
- """Return True if pipeline stage executes encoder block for a model
- with both encoder and decoder."""
- if get_pipeline_model_parallel_world_size() == 1:
- return True
- if rank is None:
- rank = get_pipeline_model_parallel_rank()
- global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
- if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
- return True
- if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
- return True
- return False
-
-
-def is_pipeline_stage_after_split(rank=None):
- """Return True if pipeline stage executes decoder block for a model
- with both encoder and decoder."""
- if get_pipeline_model_parallel_world_size() == 1:
- return True
- if rank is None:
- rank = get_pipeline_model_parallel_rank()
- global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
- if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
- return True
- if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
- return True
- return False
-
-
-def is_pipeline_stage_at_split():
- """Return true if pipeline stage executes decoder block and next
- stage executes encoder block for a model with both encoder and
- decoder."""
- rank = get_pipeline_model_parallel_rank()
- return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
-
-
-def get_virtual_pipeline_model_parallel_rank():
- """Return the virtual pipeline-parallel rank."""
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
- return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
-
-
-def set_virtual_pipeline_model_parallel_rank(rank):
- """Set the virtual pipeline-parallel rank."""
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
-
-
-def get_virtual_pipeline_model_parallel_world_size():
- """Return the virtual pipeline-parallel world size."""
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
-
-
-def get_tensor_model_parallel_src_rank():
- """Calculate the global rank corresponding to the first local rank
- in the tensor model parallel group."""
- global_rank = torch.distributed.get_rank()
- local_world_size = get_tensor_model_parallel_world_size()
- return (global_rank // local_world_size) * local_world_size
-
-
-def get_data_parallel_src_rank(with_context_parallel=False):
- """Calculate the global rank corresponding to the first local rank
- in the data parallel group."""
- if with_context_parallel:
- assert (
- _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
- ), "Data parallel group with context parallel combined is not initialized"
- return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
- else:
- assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
- return _DATA_PARALLEL_GLOBAL_RANKS[0]
-
-
-def get_pipeline_model_parallel_first_rank():
- """Return the global rank of the first process in the pipeline for the
- current tensor parallel group"""
- assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
- return _PIPELINE_GLOBAL_RANKS[0]
-
-
-def get_pipeline_model_parallel_last_rank():
- """Return the global rank of the last process in the pipeline for the
- current tensor parallel group"""
- assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
- last_rank_local = get_pipeline_model_parallel_world_size() - 1
- return _PIPELINE_GLOBAL_RANKS[last_rank_local]
-
-
-def get_pipeline_model_parallel_next_rank():
- """Return the global rank that follows the caller in the pipeline"""
- assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
- rank_in_pipeline = get_pipeline_model_parallel_rank()
- world_size = get_pipeline_model_parallel_world_size()
- return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
-
-
-def get_pipeline_model_parallel_prev_rank():
- """Return the global rank that preceeds the caller in the pipeline"""
- assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
- rank_in_pipeline = get_pipeline_model_parallel_rank()
- world_size = get_pipeline_model_parallel_world_size()
- return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
-
-
-def get_data_parallel_world_size(with_context_parallel=False):
- """Return world size for the data parallel group."""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- return torch.distributed.get_world_size(
- group=get_data_parallel_group(with_context_parallel=with_context_parallel)
- )
- else:
- return 0
-
-
-def get_data_parallel_rank(with_context_parallel=False):
- """Return my rank for the data parallel group."""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- return torch.distributed.get_rank(
- group=get_data_parallel_group(with_context_parallel=with_context_parallel)
- )
- else:
- return 0
-
-
-def get_context_parallel_world_size():
- """Return world size for the context parallel group."""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- return torch.distributed.get_world_size(group=get_context_parallel_group())
- else:
- return 0
-
-
-def get_context_parallel_rank():
- """Return my rank for the context parallel group."""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- return torch.distributed.get_rank(group=get_context_parallel_group())
- else:
- return 0
-
-
-def get_expert_model_parallel_world_size():
- """Return my rank for the expert parallel group"""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
- group=get_tensor_and_expert_parallel_group()
- )
- return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size()
- else:
- return 0
-
-
-def get_expert_model_parallel_rank():
- """Return my rank for the expert parallel group"""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- tensor_and_expert_parallel_rank = torch.distributed.get_rank(
- group=get_tensor_and_expert_parallel_group()
- )
- return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size()
- else:
- return 0
-
-
-def get_data_modulo_expert_parallel_rank():
- """Return my rank for the context parallel group."""
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- return torch.distributed.get_rank(group=get_data_modulo_expert_parallel_group())
- else:
- return 0
-
-
-def _set_global_memory_buffer():
- """Initialize global buffer"""
- global _GLOBAL_MEMORY_BUFFER
- assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
- _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
-
-
-def get_global_memory_buffer():
- """Return the global GlobalMemoryBuffer object"""
- assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
- return _GLOBAL_MEMORY_BUFFER
-
-
-def destroy_global_memory_buffer():
- """Sets the global memory buffer to None"""
- global _GLOBAL_MEMORY_BUFFER
- _GLOBAL_MEMORY_BUFFER = None
-
-
-def destroy_model_parallel():
- """Set the groups to none."""
- global _MODEL_PARALLEL_GROUP
- _MODEL_PARALLEL_GROUP = None
- global _TENSOR_MODEL_PARALLEL_GROUP
- _TENSOR_MODEL_PARALLEL_GROUP = None
- global _PIPELINE_MODEL_PARALLEL_GROUP
- _PIPELINE_MODEL_PARALLEL_GROUP = None
- global _DATA_PARALLEL_GROUP
- _DATA_PARALLEL_GROUP = None
- global _DATA_PARALLEL_GROUP_WITH_CP
- _DATA_PARALLEL_GROUP_WITH_CP = None
- global _CONTEXT_PARALLEL_GROUP
- _CONTEXT_PARALLEL_GROUP = None
- global _CONTEXT_PARALLEL_GLOBAL_RANKS
- _CONTEXT_PARALLEL_GLOBAL_RANKS = None
- global _EMBEDDING_GROUP
- _EMBEDDING_GROUP = None
- global _POSITION_EMBEDDING_GROUP
- _POSITION_EMBEDDING_GROUP = None
- global _TENSOR_AND_DATA_PARALLEL_GROUP
- _TENSOR_AND_DATA_PARALLEL_GROUP = None
- global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
- _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
- global _TENSOR_AND_EXPERT_PARALLEL_GROUP
- _TENSOR_AND_EXPERT_PARALLEL_GROUP = None
- global _DATA_MODULO_EXPERT_PARALLEL_GROUP
- _DATA_MODULO_EXPERT_PARALLEL_GROUP = None
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
- global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
- global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
- _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
- global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
- _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
- global _MPU_TENSOR_MODEL_PARALLEL_RANK
- _MPU_TENSOR_MODEL_PARALLEL_RANK = None
- global _MPU_PIPELINE_MODEL_PARALLEL_RANK
- _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
- global _GLOBAL_MEMORY_BUFFER
- _GLOBAL_MEMORY_BUFFER = None
diff --git a/megatron/core/pipeline_parallel/__init__.py b/megatron/core/pipeline_parallel/__init__.py
deleted file mode 100644
index 00cd1ff3826564f9eef6cd9b023c0dd331b5d691..0000000000000000000000000000000000000000
--- a/megatron/core/pipeline_parallel/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .schedules import get_forward_backward_func
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
deleted file mode 100644
index 29ee34df8cff6a2f0613c0ad9bca175d71e30ba2..0000000000000000000000000000000000000000
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ /dev/null
@@ -1,571 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import operator
-from functools import reduce
-from typing import Callable, List, Optional, Tuple, Union
-
-import torch
-
-from megatron import core
-from megatron.core import ModelParallelConfig
-from megatron.core.parallel_state import (
- get_pipeline_model_parallel_group,
- get_pipeline_model_parallel_next_rank,
- get_pipeline_model_parallel_prev_rank,
- get_pipeline_model_parallel_rank,
-)
-
-# Types
-Shape = Union[List[int], torch.Size]
-
-
-def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
- """Communicate tensor shapes between stages. Used to communicate
- tensor shapes before the actual tensor communication happens.
- This is required when the sequence lengths across micro batches
- are not uniform.
-
- Takes the following arguments:
- tensor_send_next: tensor to send to next rank (no tensor sent if
- set to None).
- tensor_send_prev: tensor to send to prev rank (no tensor sent if
- set to None).
- recv_prev: boolean for whether tensor should be received from
- previous rank.
- recv_next: boolean for whether tensor should be received from
- next rank.
- Returns:
- (recv_prev_shape, recv_next_shape)
- """
-
- recv_prev_shape_tensor = None
- recv_next_shape_tensor = None
- send_prev_shape_tensor = None
- send_next_shape_tensor = None
- if recv_prev:
- recv_prev_shape_tensor = torch.empty(
- (3), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if recv_next:
- recv_next_shape_tensor = torch.empty(
- (3), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if tensor_send_prev is not None:
- send_prev_shape_tensor = torch.tensor(
- tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if tensor_send_next is not None:
- send_next_shape_tensor = torch.tensor(
- tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
- )
-
- if config.use_ring_exchange_p2p:
- torch.distributed.ring_exchange(
- tensor_send_prev=send_prev_shape_tensor,
- tensor_recv_prev=recv_prev_shape_tensor,
- tensor_send_next=send_next_shape_tensor,
- tensor_recv_next=recv_next_shape_tensor,
- group=get_pipeline_model_parallel_group(),
- )
- else:
- ops = []
- if send_prev_shape_tensor is not None:
- send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend,
- send_prev_shape_tensor,
- get_pipeline_model_parallel_prev_rank(),
- )
- ops.append(send_prev_op)
- if recv_prev_shape_tensor is not None:
- recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- recv_prev_shape_tensor,
- get_pipeline_model_parallel_prev_rank(),
- )
- ops.append(recv_prev_op)
- if send_next_shape_tensor is not None:
- send_next_op = torch.distributed.P2POp(
- torch.distributed.isend,
- send_next_shape_tensor,
- get_pipeline_model_parallel_next_rank(),
- )
- ops.append(send_next_op)
- if recv_next_shape_tensor is not None:
- recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- recv_next_shape_tensor,
- get_pipeline_model_parallel_next_rank(),
- )
- ops.append(recv_next_op)
- if len(ops) > 0:
- reqs = torch.distributed.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
-
- # To protect against race condition when using batch_isend_irecv().
- # should take this out once the bug with batch_isend_irecv is resolved.
- torch.cuda.synchronize()
-
- recv_prev_shape = [0, 0, 0]
- if recv_prev_shape_tensor is not None:
- recv_prev_shape = recv_prev_shape_tensor.tolist()
-
- recv_next_shape = [0, 0, 0]
- if recv_next_shape_tensor is not None:
- recv_next_shape = recv_next_shape_tensor.tolist()
-
- return recv_prev_shape, recv_next_shape
-
-
-def _batched_p2p_ops(
- *,
- tensor_send_prev: Optional[torch.Tensor],
- tensor_recv_prev: Optional[torch.Tensor],
- tensor_send_next: Optional[torch.Tensor],
- tensor_recv_next: Optional[torch.Tensor],
- group: torch.distributed.ProcessGroup
-):
- ops = []
- if tensor_send_prev is not None:
- send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend,
- tensor_send_prev,
- get_pipeline_model_parallel_prev_rank(),
- group,
- )
- ops.append(send_prev_op)
- if tensor_recv_prev is not None:
- recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- tensor_recv_prev,
- get_pipeline_model_parallel_prev_rank(),
- group,
- )
- ops.append(recv_prev_op)
- if tensor_send_next is not None:
- send_next_op = torch.distributed.P2POp(
- torch.distributed.isend,
- tensor_send_next,
- get_pipeline_model_parallel_next_rank(),
- group,
- )
- ops.append(send_next_op)
- if tensor_recv_next is not None:
- recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- tensor_recv_next,
- get_pipeline_model_parallel_next_rank(),
- group,
- )
- ops.append(recv_next_op)
- if len(ops) > 0:
- reqs = torch.distributed.batch_isend_irecv(ops)
- else:
- reqs = []
- return reqs
-
-
-def _p2p_ops(
- *,
- tensor_send_prev: Optional[torch.Tensor],
- tensor_recv_prev: Optional[torch.Tensor],
- tensor_send_next: Optional[torch.Tensor],
- tensor_recv_next: Optional[torch.Tensor],
- group: torch.distributed.ProcessGroup
-):
- reqs = []
- rank = get_pipeline_model_parallel_rank()
- if get_pipeline_model_parallel_rank() % 2 == 0:
- if tensor_send_next is not None:
- send_next_req = torch.distributed.isend(
- tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
- )
- reqs.append(send_next_req)
-
- if tensor_recv_prev is not None:
- recv_prev_req = torch.distributed.irecv(
- tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
- )
- reqs.append(recv_prev_req)
-
- if tensor_send_prev is not None:
- send_prev_req = torch.distributed.isend(
- tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
- )
- reqs.append(send_prev_req)
-
- if tensor_recv_next is not None:
- recv_next_req = torch.distributed.irecv(
- tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
- )
- reqs.append(recv_next_req)
-
- else:
- if tensor_recv_prev is not None:
- recv_prev_req = torch.distributed.irecv(
- tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
- )
- reqs.append(recv_prev_req)
-
- if tensor_send_next is not None:
- send_next_req = torch.distributed.isend(
- tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
- )
- reqs.append(send_next_req)
-
- if tensor_recv_next is not None:
- recv_next_req = torch.distributed.irecv(
- tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
- )
- reqs.append(recv_next_req)
-
- if tensor_send_prev is not None:
- send_prev_req = torch.distributed.isend(
- tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
- )
- reqs.append(send_prev_req)
- return reqs
-
-
-def _communicate(
- *,
- tensor_send_next: Optional[torch.Tensor],
- tensor_send_prev: Optional[torch.Tensor],
- recv_prev: bool,
- recv_next: bool,
- tensor_shape: Shape,
- config: ModelParallelConfig,
- wait_on_reqs: bool = True
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """Communicate tensors between stages. Used as helper method in other
- communication methods that are used in megatron/schedules.py.
-
- Arguments:
- tensor_send_next (torch.Tensor, optional):
- Tensor to send to next rank (no tensor sent if None)
-
- tensor_send_prev (torch.Tensor, optional):
- Tensor to send to prev rank (no tensor sent if None)
-
- recv_prev (boolean, required):
- whether tensor should be received from previous rank.
-
- recv_next (boolean, required):
- whether tensor should be received from next rank.
-
- tensor_shape (List[int] or torch.Size, required):
- shape of tensor to receive (this method assumes that all
- tensors sent and received in a single function call are
- the same shape).
-
- wait_on_reqs (boolean, optional, default=False):
- For non-batched p2p communication, wait on each request
- before returning.
-
- Returns:
- tuple containing
-
- - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
-
- """
-
- # Create placeholder tensors for receive in forward and backward directions
- # if needed.
- tensor_recv_prev = None
- tensor_recv_next = None
-
- if not config.variable_seq_lengths:
- recv_prev_shape = tensor_shape
- recv_next_shape = tensor_shape
- else:
- recv_prev_shape, recv_next_shape = _communicate_shapes(
- tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
- )
-
- if recv_prev:
- if config.pipeline_dtype is None:
- raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
- if tensor_shape is None:
- raise RuntimeError(
- "tensor_shape must be specified if recv_prev is True. "
- "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
- )
- tensor_recv_prev = torch.empty(
- recv_prev_shape,
- requires_grad=True,
- device=torch.cuda.current_device(),
- dtype=config.pipeline_dtype,
- )
- if recv_next:
- if config.pipeline_dtype is None:
- raise RuntimeError("dtype must be provided if recv_next is True")
- if tensor_shape is None:
- raise RuntimeError(
- "tensor_shape must be specified if recv_next is True. "
- "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
- )
- tensor_recv_next = torch.empty(
- recv_next_shape,
- requires_grad=True,
- device=torch.cuda.current_device(),
- dtype=config.pipeline_dtype,
- )
-
- # Send tensors in both the forward and backward directions as appropriate.
- if config.use_ring_exchange_p2p:
-
- def _ring_exchange_wrapper(**kwargs):
- torch.distributed.ring_exchange(**kwargs)
- return []
-
- p2p_func = _ring_exchange_wrapper
- elif config.batch_p2p_comm:
- assert wait_on_reqs
- p2p_func = _batched_p2p_ops
- else:
- p2p_func = _p2p_ops
-
- reqs = p2p_func(
- tensor_send_prev=tensor_send_prev,
- tensor_recv_prev=tensor_recv_prev,
- tensor_send_next=tensor_send_next,
- tensor_recv_next=tensor_recv_next,
- group=get_pipeline_model_parallel_group(),
- )
-
- if wait_on_reqs and len(reqs) > 0:
- for req in reqs:
- req.wait()
- reqs = None
-
- if config.batch_p2p_comm and config.batch_p2p_sync:
- # To protect against race condition when using batch_isend_irecv().
- # User should assert that we have a modern enough PyTorch to not need this
- torch.cuda.synchronize()
-
- return tensor_recv_prev, tensor_recv_next, reqs
-
-
-def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
- """ Receive tensor from previous rank in pipeline (forward receive).
-
-
- See _communicate for argument details.
- """
-
- if core.parallel_state.is_pipeline_first_stage():
- input_tensor = None
- else:
- if config.timers is not None:
- config.timers('forward-recv', log_level=2).start()
- input_tensor, _, _ = _communicate(
- tensor_send_next=None,
- tensor_send_prev=None,
- recv_prev=True,
- recv_next=False,
- tensor_shape=tensor_shape,
- config=config,
- )
- if config.timers is not None:
- config.timers('forward-recv').stop()
- return input_tensor
-
-
-def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
- """Receive tensor from next rank in pipeline (backward receive).
-
- See _communicate for argument details.
- """
- if core.parallel_state.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- if config.timers is not None:
- config.timers('backward-recv', log_level=2).start()
- _, output_tensor_grad, _ = _communicate(
- tensor_send_next=None,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=True,
- tensor_shape=tensor_shape,
- config=config,
- )
- if config.timers is not None:
- config.timers('backward-recv').stop()
- return output_tensor_grad
-
-
-def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
- """Send tensor to next rank in pipeline (forward send).
-
- See _communicate for argument details.
- """
-
- if not core.parallel_state.is_pipeline_last_stage():
- if config.timers is not None:
- config.timers('forward-send', log_level=2).start()
- _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=False,
- tensor_shape=None,
- config=config,
- )
- if config.timers is not None:
- config.timers('forward-send').stop()
-
-
-def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
- """Send tensor to previous rank in pipeline (backward send).
-
- See _communicate for argument details.
- """
- if not core.parallel_state.is_pipeline_first_stage():
- if config.timers is not None:
- config.timers('backward-send', log_level=2).start()
- _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=False,
- recv_next=False,
- tensor_shape=None,
- config=config,
- )
- if config.timers is not None:
- config.timers('backward-send').stop()
-
-
-def send_forward_recv_backward(
- output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
-) -> torch.Tensor:
- """Batched send and recv with next rank in pipeline.
-
- See _communicate for argument details.
- """
- if core.parallel_state.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- if config.timers is not None:
- config.timers('forward-send-backward-recv', log_level=2).start()
- _, output_tensor_grad, _ = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=True,
- tensor_shape=tensor_shape,
- config=config,
- )
- if config.timers is not None:
- config.timers('forward-send-backward-recv').stop()
- return output_tensor_grad
-
-
-def send_backward_recv_forward(
- input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
-) -> torch.Tensor:
- """Batched send and recv with previous rank in pipeline.
-
- See _communicate for argument details.
- """
- if core.parallel_state.is_pipeline_first_stage():
- input_tensor = None
- else:
- if config.timers is not None:
- config.timers('backward-send-forward-recv', log_level=2).start()
- input_tensor, _, _ = _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=True,
- recv_next=False,
- tensor_shape=tensor_shape,
- config=config,
- )
- if config.timers is not None:
- config.timers('backward-send-forward-recv').stop()
- return input_tensor
-
-
-def send_forward_recv_forward(
- output_tensor: torch.Tensor,
- recv_prev: bool,
- tensor_shape: Shape,
- config: ModelParallelConfig,
- overlap_p2p_comm: bool = False,
-) -> torch.Tensor:
- """Batched recv from previous rank and send to next rank in pipeline.
-
- See _communicate for argument details.
- """
- if config.timers is not None:
- config.timers('forward-send-forward-recv', log_level=2).start()
- input_tensor, _, wait_handles = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=recv_prev,
- recv_next=False,
- tensor_shape=tensor_shape,
- wait_on_reqs=(not overlap_p2p_comm),
- config=config,
- )
- if config.timers is not None:
- config.timers('forward-send-forward-recv').stop()
- if overlap_p2p_comm:
- return input_tensor, wait_handles
- return input_tensor
-
-
-def send_backward_recv_backward(
- input_tensor_grad: torch.Tensor,
- recv_next: bool,
- tensor_shape: Shape,
- config: ModelParallelConfig,
- overlap_p2p_comm: bool = False,
-) -> torch.Tensor:
- """Batched recv from next rank and send to previous rank in pipeline.
-
- See _communicate for argument details.
- """
- if config.timers is not None:
- config.timers('backward-send-backward-recv', log_level=2).start()
- _, output_tensor_grad, wait_handles = _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=False,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- wait_on_reqs=(not overlap_p2p_comm),
- config=config,
- )
- if config.timers is not None:
- config.timers('backward-send-backward-recv').stop()
- if overlap_p2p_comm:
- return output_tensor_grad, wait_handles
- return output_tensor_grad
-
-
-def send_forward_backward_recv_forward_backward(
- output_tensor: torch.Tensor,
- input_tensor_grad: torch.Tensor,
- recv_prev: bool,
- recv_next: bool,
- tensor_shape: Shape,
- config: ModelParallelConfig,
-) -> torch.Tensor:
- """Batched send and recv with previous and next ranks in pipeline.
-
- See _communicate for argument details.
- """
- if config.timers is not None:
- config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
- input_tensor, output_tensor_grad, _ = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=input_tensor_grad,
- recv_prev=recv_prev,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- config=config,
- )
- if config.timers is not None:
- config.timers('forward-backward-send-forward-backward-recv').stop()
- return input_tensor, output_tensor_grad
diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py
deleted file mode 100644
index 992da781271e22baefabe75c08a308aa9438f3b3..0000000000000000000000000000000000000000
--- a/megatron/core/pipeline_parallel/schedules.py
+++ /dev/null
@@ -1,1293 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import contextlib
-from typing import Callable, Iterator, List, Optional, Union
-
-import torch
-from torch.autograd.variable import Variable
-
-from megatron.core import parallel_state
-from megatron.core.enums import ModelType
-from megatron.core.pipeline_parallel import p2p_communication
-from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
-
-# Types
-Shape = Union[List[int], torch.Size]
-
-
-def get_forward_backward_func():
- """Retrieves the appropriate forward_backward function given the
- configuration of parallel_state.
-
- Returns a function that will perform all of the forward and
- backward passes of the model given the pipeline model parallel
- world size and virtual pipeline model parallel world size in the
- global parallel_state.
-
- Note that if using sequence parallelism, the sequence length component of
- the tensor shape is updated to original_sequence_length /
- tensor_model_parallel_world_size.
-
- The function returned takes the following arguments:
-
- forward_step_func (required): A function that takes a data
- iterator and a model as its arguments and return the model's
- forward output and the loss function. The loss function should
- take one torch.Tensor and return a torch.Tensor of loss and a
- dictionary of string -> torch.Tensor.
-
- A third argument, checkpoint_activations_microbatch, indicates
- that the activations for this microbatch should be
- checkpointed. A None value for this argument indicates that
- the default from the configuration should be used. This is
- used when the
- num_microbatches_with_partial_activation_checkpoints is used.
-
- For example:
-
- def loss_func(loss_mask, output_tensor):
- losses = output_tensor.float()
- loss_mask = loss_mask.view(-1).float()
- loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
-
- # Reduce loss for logging.
- averaged_loss = average_losses_across_data_parallel_group([loss])
-
- return loss, {'lm loss': averaged_loss[0]}
-
- def forward_step(data_iterator, model):
- data, loss_mask = next(data_iterator)
- output = model(data)
- return output, partial(loss_func, loss_mask)
-
-
- forward_backward_func(forward_step_func=forward_step, ...)
-
-
- data_iterator (required): an iterator over the data, will be
- passed as is to forward_step_func. Expected to be a list of
- iterators in the case of interleaved pipeline parallelism.
-
- model (required): the actual model. Expected to be a list of modules in the case of interleaved
- pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
-
- num_microbatches (int, required):
- The number of microbatches to go through
-
- seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
- transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
- in the config is True. Otherwise, each microbatch in the current global batch size must use
- this sequence length.
-
- micro_batch_size (int, required): The number of sequences in a microbatch.
-
- decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
- transformer. This is ignored for a single-stack transformer.
-
- forward_only (optional, default = False): Perform only the forward step
-
- collect_non_loss_data (optional, bool, default=False): TODO
-
- """
- pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
- if pipeline_model_parallel_size > 1:
- if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
- forward_backward_func = forward_backward_pipelining_with_interleaving
- else:
- forward_backward_func = forward_backward_pipelining_without_interleaving
- else:
- forward_backward_func = forward_backward_no_pipelining
- return forward_backward_func
-
-
-def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
- '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
-
- This method should be called right after the output tensor has been
- sent to the next pipeline stage. At this point, the output tensor is
- only useful for its '.grad_fn' field, and not its '.data'.
- '''
- if (out is None) or (not deallocate_pipeline_outputs):
- return
- assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
- assert out._base is None, "counter-productive to free a view of another tensor."
- out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
-
-
-def custom_backward(output, grad_output):
- '''Directly call C++ autograd engine.
-
- To make the 'deallocate_output_tensor' (above) optimization work, the C++
- autograd engine must be called directly, bypassing Pytorch's
- torch.autograd.backward. Pytorch's 'backward' checks that the output and
- grad have the same shape, while C++'s 'backward' does not.
- '''
-
- assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
- assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
- assert isinstance(grad_output, (torch.Tensor, type(None))), (
- "grad_output == '%s'." % type(grad_output).__name__
- )
-
- # Handle scalar output
- if grad_output is None:
- assert output.numel() == 1, "implicit grad requires scalar output."
- grad_output = torch.ones_like(output, memory_format=torch.preserve_format,)
-
- # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
- Variable._execution_engine.run_backward(
- tensors=(output,),
- grad_tensors=(grad_output,),
- keep_graph=False,
- create_graph=False,
- inputs=tuple(),
- allow_unreachable=True,
- accumulate_grad=True,
- )
-
-
-def forward_step(
- forward_step_func,
- data_iterator,
- model,
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data=False,
- checkpoint_activations_microbatch=None,
-):
- """Forward step for passed-in model.
-
- If first stage, input tensor is obtained from data_iterator, otherwise
- passed-in input_tensor is used.
-
- Returns output tensor."""
- if config.timers is not None:
- config.timers('forward-compute', log_level=2).start()
-
- unwrap_output_tensor = False
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
- unwrap_output_tensor = True
-
- set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
- set_input_tensor(input_tensor)
-
- if config.enable_autocast:
- context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
- else:
- context_manager = contextlib.nullcontext()
- with context_manager:
- if checkpoint_activations_microbatch is None:
- output_tensor, loss_func = forward_step_func(data_iterator, model)
- else:
- output_tensor, loss_func = forward_step_func(
- data_iterator, model, checkpoint_activations_microbatch
- )
-
- if parallel_state.is_pipeline_last_stage():
- if not collect_non_loss_data:
- output_tensor = loss_func(output_tensor)
- loss, loss_reduced = output_tensor
- output_tensor = loss / num_microbatches
- forward_data_store.append(loss_reduced)
- else:
- data = loss_func(output_tensor, non_loss_data=True)
- forward_data_store.append(data)
-
- if config.timers is not None:
- config.timers('forward-compute').stop()
-
- # If T5 model (or other model with encoder and decoder)
- # and in decoder stack, then send encoder_hidden_state
- # downstream as well.
- model_type = get_model_type(model)
- if (
- parallel_state.is_pipeline_stage_after_split()
- and model_type == ModelType.encoder_and_decoder
- ):
- return [output_tensor, input_tensor[-1]]
- if unwrap_output_tensor:
- return output_tensor
- return [output_tensor]
-
-
-def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
- """Backward step through passed-in output tensor.
-
- If last stage, output_tensor_grad is None, otherwise gradient of loss
- with respect to stage's output tensor.
-
- Returns gradient of loss with respect to input tensor (None if first
- stage)."""
-
- # NOTE: This code currently can handle at most one skip connection. It
- # needs to be modified slightly to support arbitrary numbers of skip
- # connections.
-
- if config.timers is not None:
- config.timers('backward-compute', log_level=2).start()
-
- # Retain the grad on the input_tensor.
- unwrap_input_tensor_grad = False
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
- unwrap_input_tensor_grad = True
- for x in input_tensor:
- if x is not None:
- x.retain_grad()
-
- if not isinstance(output_tensor, list):
- output_tensor = [output_tensor]
- if not isinstance(output_tensor_grad, list):
- output_tensor_grad = [output_tensor_grad]
-
- # Backward pass.
- if output_tensor_grad[0] is None and config.grad_scale_func is not None:
- output_tensor[0] = config.grad_scale_func(output_tensor[0])
-
- if config.deallocate_pipeline_outputs:
- custom_backward(output_tensor[0], output_tensor_grad[0])
- else:
- torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
-
- # Collect the grad of the input_tensor.
- input_tensor_grad = [None]
- if input_tensor is not None:
- input_tensor_grad = []
- for x in input_tensor:
- if x is None:
- input_tensor_grad.append(None)
- else:
- input_tensor_grad.append(x.grad)
-
- # Handle single skip connection if it exists (encoder_hidden_state in
- # model with encoder and decoder).
- if (
- parallel_state.get_pipeline_model_parallel_world_size() > 1
- and parallel_state.is_pipeline_stage_after_split()
- and model_type == ModelType.encoder_and_decoder
- ):
- if output_tensor_grad[1] is not None:
- input_tensor_grad[-1].add_(output_tensor_grad[1])
- if unwrap_input_tensor_grad:
- input_tensor_grad = input_tensor_grad[0]
-
- if config.timers is not None:
- config.timers('backward-compute').stop()
-
- return input_tensor_grad
-
-
-def forward_backward_no_pipelining(
- *,
- forward_step_func,
- data_iterator: Union[Iterator, List[Iterator]],
- model: Union[torch.nn.Module, List[torch.nn.Module]],
- num_microbatches: int,
- seq_length: int, # unused
- micro_batch_size: int, # unused
- decoder_seq_length: int = None, # unused
- forward_only: bool = False,
- collect_non_loss_data: bool = False,
-):
- """Run forward and backward passes with no pipeline parallelism
- (no inter-stage communication).
-
- Returns dictionary with losses.
-
-
- See get_forward_backward_func() for argument details
- """
-
- if isinstance(model, list):
- assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
- model = model[0]
- if isinstance(data_iterator, list):
- assert (
- len(data_iterator) == 1
- ), "non-pipeline-parallel schedule does not support model chunking"
- data_iterator = data_iterator[0]
-
- config = get_model_config(model)
- if config.timers is not None:
- config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
-
- no_sync_func = config.no_sync_func
- if no_sync_func is None:
- no_sync_func = contextlib.nullcontext
-
- model_type = get_model_type(model)
-
- forward_data_store = []
- input_tensor, output_tensor_grad = None, None
- with no_sync_func():
- for i in range(num_microbatches - 1):
- output_tensor = forward_step(
- forward_step_func,
- data_iterator,
- model,
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data,
- )
- if not forward_only:
- backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
-
- # Run computation for last microbatch out of context handler (want to
- # synchronize gradients).
- output_tensor = forward_step(
- forward_step_func,
- data_iterator,
- model,
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data,
- )
-
- if not forward_only:
- backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
-
- if config.timers is not None:
- config.timers('forward-backward').stop()
-
- if config.finalize_model_grads_func is not None and not forward_only:
- # Finalize model grads (perform full grad all-reduce / reduce-scatter for
- # data parallelism and layernorm all-reduce for sequence parallelism).
- config.finalize_model_grads_func([model])
-
- return forward_data_store
-
-
-def forward_backward_pipelining_with_interleaving(
- *,
- forward_step_func,
- data_iterator: Union[Iterator, List[Iterator]],
- model: Union[torch.nn.Module, List[torch.nn.Module]],
- num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
- decoder_seq_length: int = None,
- forward_only: bool = False,
- collect_non_loss_data: bool = False,
-):
- """Run interleaved 1F1B schedule (model split into model chunks), with
- communication between pipeline stages as needed.
-
- Returns dictionary with losses if the last stage, empty dict otherwise."""
- assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
- assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
- assert isinstance(
- data_iterator, list
- ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
-
- config = get_model_config(model[0])
- if config.overlap_p2p_comm and config.batch_p2p_comm:
- raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
-
- if config.timers is not None:
- config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
-
- # Disable async grad reductions
- no_sync_func = config.no_sync_func
- if isinstance(no_sync_func, list):
-
- def multi_no_sync():
- stack = contextlib.ExitStack()
- for model_chunk_no_sync_func in config.no_sync_func:
- stack.enter_context(model_chunk_no_sync_func())
- return stack
-
- no_sync_func = multi_no_sync
- if no_sync_func is None:
- no_sync_func = contextlib.nullcontext
- no_sync_context = None
-
- if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
- config.grad_sync_func = [config.grad_sync_func for _ in model]
-
- if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
- config.param_sync_func = [config.param_sync_func for _ in model]
-
- def disable_grad_sync():
- """Disable asynchronous grad reductions"""
- nonlocal no_sync_context
- if no_sync_context is None:
- no_sync_context = no_sync_func()
- no_sync_context.__enter__()
-
- def enable_grad_sync():
- """Enable asynchronous grad reductions"""
- nonlocal no_sync_context
- if no_sync_context is not None:
- no_sync_context.__exit__(None, None, None)
- no_sync_context = None
-
- disable_grad_sync()
-
- # Model chunk IDs with synchronized grads
- synchronized_model_chunks = set()
-
- input_tensors = [[] for _ in range(len(model))]
- output_tensors = [[] for _ in range(len(model))]
- forward_data_store = []
- if not forward_only:
- output_tensor_grads = [[] for _ in range(len(model))]
-
- pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
- pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
-
- if num_microbatches % pipeline_parallel_size != 0:
- msg = f'number of microbatches ({num_microbatches}) is not divisible by '
- msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
- msg += 'when using interleaved schedule'
- raise RuntimeError(msg)
-
- model_type = get_model_type(model[0])
- if model_type == ModelType.encoder_and_decoder:
- raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
-
- if decoder_seq_length is not None and decoder_seq_length != seq_length:
- raise RuntimeError(
- "Interleaving is not supported with a different decoder sequence length."
- )
-
- tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
- if config.sequence_parallel:
- tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
-
- # Compute number of warmup and remaining microbatches.
- num_model_chunks = len(model)
- total_num_microbatches = num_microbatches * num_model_chunks
- all_warmup_microbatches = False
- if forward_only:
- num_warmup_microbatches = total_num_microbatches
- else:
- # Run all forward passes and then all backward passes if number of
- # microbatches is just the number of pipeline stages.
- # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
- # all workers, followed by more microbatches after depending on
- # stage ID (more forward passes for earlier stages, later stages can
- # immediately start with 1F1B).
- if num_microbatches == pipeline_parallel_size:
- num_warmup_microbatches = total_num_microbatches
- all_warmup_microbatches = True
- else:
- num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
- num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
- num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
- num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
-
- # Checkpoint the activations of partial Transformer layers in a number of micro-batches
- # within the maximum outstanding micro-batch backpropagations.
- # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
- # checkpoint partial Transformer layers (or skip checkpointing) and
- # the rest of micro-batches within a window of micro-batches checkpoint
- # all Transformer layers. The window of micro-batches is set by the maximum
- # outstanding backpropagations and becomes smaller at later pipeline stages.
- # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
- max_outstanding_backprops = None
- if config.num_microbatches_with_partial_activation_checkpoints is not None:
- max_outstanding_backprops = num_warmup_microbatches + 1
-
- # Synchronize params for first two model chunks
- if config.param_sync_func is not None:
- config.param_sync_func[0](model[0].parameters())
- config.param_sync_func[1](model[1].parameters())
-
- def get_model_chunk_id(microbatch_id, forward):
- """Helper method to get the model chunk ID given the iteration number."""
- microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
- model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
- if not forward:
- model_chunk_id = num_model_chunks - model_chunk_id - 1
- return model_chunk_id
-
- def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
- """Check if an iteration is the first for a model chunk."""
- microbatch_group_size = pipeline_parallel_size * num_model_chunks
- num_microbatch_groups = total_num_microbatches // microbatch_group_size
- microbatch_group_id = microbatch_id // microbatch_group_size
- microbatch_id_in_group = microbatch_id % microbatch_group_size
- if microbatch_group_id == 0:
- return microbatch_id_in_group % pipeline_parallel_size == 0
- else:
- return False
-
- def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
- """Check if an iteration is the last for a model chunk."""
- microbatch_group_size = pipeline_parallel_size * num_model_chunks
- num_microbatch_groups = total_num_microbatches // microbatch_group_size
- microbatch_group_id = microbatch_id // microbatch_group_size
- microbatch_id_in_group = microbatch_id % microbatch_group_size
- if microbatch_group_id == num_microbatch_groups - 1:
- return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
- else:
- return False
-
- def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
- """Helper method to run forward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- forward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
- parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
-
- # launch param synchronization for next model chunk
- # Note: Asynchronous communication tends to slow down compute.
- # To reduce idling from mismatched microbatch times, we launch
- # asynchronous communication at the same time across the
- # pipeline-parallel group.
- if config.param_sync_func is not None:
- param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
- if (
- param_sync_microbatch_id < total_num_microbatches
- and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
- ):
- param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
- if 1 < param_sync_chunk_id < num_model_chunks:
- config.param_sync_func[param_sync_chunk_id](
- model[param_sync_chunk_id].parameters()
- )
-
- # forward step
- if parallel_state.is_pipeline_first_stage():
- if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
- input_tensors[model_chunk_id].append(None)
- input_tensor = input_tensors[model_chunk_id][-1]
- output_tensor = forward_step(
- forward_step_func,
- data_iterator[model_chunk_id],
- model[model_chunk_id],
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data,
- checkpoint_activations_microbatch,
- )
- output_tensors[model_chunk_id].append(output_tensor)
-
- # if forward-only, no need to save tensors for a backward pass
- if forward_only:
- input_tensors[model_chunk_id].pop()
- output_tensors[model_chunk_id].pop()
-
- return output_tensor
-
- def backward_step_helper(microbatch_id):
- """Helper method to run backward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- backward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
- parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
-
- # launch grad synchronization (default)
- if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
- enable_grad_sync()
- synchronized_model_chunks.add(model_chunk_id)
-
- if parallel_state.is_pipeline_last_stage():
- if len(output_tensor_grads[model_chunk_id]) == 0:
- output_tensor_grads[model_chunk_id].append(None)
- input_tensor = input_tensors[model_chunk_id].pop(0)
- output_tensor = output_tensors[model_chunk_id].pop(0)
- output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
- input_tensor_grad = backward_step(
- input_tensor, output_tensor, output_tensor_grad, model_type, config
- )
-
- # launch grad synchronization (custom grad sync)
- # Note: Asynchronous communication tends to slow down compute.
- # To reduce idling from mismatched microbatch times, we launch
- # asynchronous communication at the same time across the
- # pipeline-parallel group.
- if config.grad_sync_func is not None:
- grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
- if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
- grad_sync_microbatch_id
- ):
- grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
- enable_grad_sync()
- config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
- synchronized_model_chunks.add(grad_sync_chunk_id)
- disable_grad_sync()
-
- return input_tensor_grad
-
- # Run warmup forward passes.
- parallel_state.set_virtual_pipeline_model_parallel_rank(0)
- input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
-
- fwd_wait_handles = None
- bwd_wait_handles = None
-
- for k in range(num_warmup_microbatches):
-
- if fwd_wait_handles is not None:
- for req in fwd_wait_handles:
- req.wait()
-
- # Decide to checkpoint all layers' activations of the current micro-batch
- if max_outstanding_backprops is not None:
- checkpoint_activations_microbatch = (
- k % max_outstanding_backprops
- >= config.num_microbatches_with_partial_activation_checkpoints
- )
- else:
- checkpoint_activations_microbatch = None
-
- output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
-
- # Determine if tensor should be received from previous stage.
- next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
- recv_prev = True
- if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
- if next_forward_model_chunk_id == 0:
- recv_prev = False
- if k == (total_num_microbatches - 1):
- recv_prev = False
-
- # Don't send tensor downstream if on last stage.
- if parallel_state.is_pipeline_last_stage():
- output_tensor = None
-
- # Send and receive tensors as appropriate (send tensors computed
- # in this iteration; receive tensors for next iteration).
- if not config.overlap_p2p_comm:
- if (
- k == (num_warmup_microbatches - 1)
- and not forward_only
- and not all_warmup_microbatches
- ):
- input_tensor_grad = None
- recv_next = True
- if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- recv_next = False
- (
- input_tensor,
- output_tensor_grad,
- ) = p2p_communication.send_forward_backward_recv_forward_backward(
- output_tensor,
- input_tensor_grad,
- recv_prev=recv_prev,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- config=config,
- )
- output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
- else:
- input_tensor = p2p_communication.send_forward_recv_forward(
- output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
- )
- input_tensors[next_forward_model_chunk_id].append(input_tensor)
- else:
- input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
- output_tensor,
- recv_prev=recv_prev,
- tensor_shape=tensor_shape,
- config=config,
- overlap_p2p_comm=True,
- )
-
- if (
- k == (num_warmup_microbatches - 1)
- and not forward_only
- and not all_warmup_microbatches
- ):
- input_tensor_grad = None
- recv_next = True
- if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- recv_next = False
-
- (
- output_tensor_grad,
- bwd_wait_handles,
- ) = p2p_communication.send_backward_recv_backward(
- input_tensor_grad,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- config=config,
- overlap_p2p_comm=True,
- )
-
- output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
- input_tensors[next_forward_model_chunk_id].append(input_tensor)
-
- deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
-
- # Run 1F1B in steady state.
- for k in range(num_microbatches_remaining):
- # Forward pass.
- forward_k = k + num_warmup_microbatches
-
- # Decide to checkpoint all layers' activations of the current micro-batch
- if max_outstanding_backprops is not None:
- checkpoint_activations_microbatch = (
- forward_k % max_outstanding_backprops
- >= config.num_microbatches_with_partial_activation_checkpoints
- )
- else:
- checkpoint_activations_microbatch = None
-
- if config.overlap_p2p_comm:
- if fwd_wait_handles is not None:
- for req in fwd_wait_handles:
- req.wait()
-
- deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
-
- output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
-
- # Determine if current stage has anything to send in either direction,
- # otherwise set tensor to None.
- forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
- parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
-
- # Last virtual stage no activation tensor to send
- if parallel_state.is_pipeline_last_stage():
- output_tensor = None
-
- # Determine if peers are sending, and where in data structure to put
- # received tensors.
- recv_prev = True
- if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
- # First stage is ahead of last stage by (pipeline_parallel_size - 1).
- next_forward_model_chunk_id = get_model_chunk_id(
- forward_k - (pipeline_parallel_size - 1), forward=True
- )
- if next_forward_model_chunk_id == (num_model_chunks - 1):
- recv_prev = False
- next_forward_model_chunk_id += 1
- else:
- next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
-
- # If last iteration, don't receive; we already received one extra
- # before the start of the for loop.
- if k == (num_microbatches_remaining - 1):
- recv_prev = False
-
- # Send activation tensor to the next stage and receive activation tensor from the
- # previous stage
- input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
- output_tensor,
- recv_prev=recv_prev,
- tensor_shape=tensor_shape,
- config=config,
- overlap_p2p_comm=True,
- )
- # assert fwd_wait_handles is not None
-
- if bwd_wait_handles is not None:
- for req in bwd_wait_handles:
- req.wait()
-
- # Backward pass.
- backward_k = k
- input_tensor_grad = backward_step_helper(backward_k)
-
- backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
- parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
-
- # First virtual stage no activation gradient tensor to send
- if parallel_state.is_pipeline_first_stage():
- input_tensor_grad = None
-
- # Determine if the current virtual stage has an activation gradient tensor to receive
- recv_next = True
- if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
- next_backward_model_chunk_id = get_model_chunk_id(
- backward_k - (pipeline_parallel_size - 1), forward=False
- )
- if next_backward_model_chunk_id == 0:
- recv_next = False
- next_backward_model_chunk_id -= 1
- else:
- next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
-
- output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
- input_tensor_grad,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- config=config,
- overlap_p2p_comm=True,
- )
-
- else: # no p2p overlap
- output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
-
- # Backward pass.
- backward_k = k
- input_tensor_grad = backward_step_helper(backward_k)
-
- # Send output_tensor and input_tensor_grad, receive input_tensor
- # and output_tensor_grad.
-
- # Determine if current stage has anything to send in either direction,
- # otherwise set tensor to None.
- forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
- parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
- if parallel_state.is_pipeline_last_stage():
- output_tensor = None
-
- backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
- parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
- if parallel_state.is_pipeline_first_stage():
- input_tensor_grad = None
-
- # Determine if peers are sending, and where in data structure to put
- # received tensors.
- recv_prev = True
- if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
- # First stage is ahead of last stage by (pipeline_parallel_size - 1).
- next_forward_model_chunk_id = get_model_chunk_id(
- forward_k - (pipeline_parallel_size - 1), forward=True
- )
- if next_forward_model_chunk_id == (num_model_chunks - 1):
- recv_prev = False
- next_forward_model_chunk_id += 1
- else:
- next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
-
- recv_next = True
- if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
- next_backward_model_chunk_id = get_model_chunk_id(
- backward_k - (pipeline_parallel_size - 1), forward=False
- )
- if next_backward_model_chunk_id == 0:
- recv_next = False
- next_backward_model_chunk_id -= 1
- else:
- next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
-
- # If last iteration, don't receive; we already received one extra
- # before the start of the for loop.
- if k == (num_microbatches_remaining - 1):
- recv_prev = False
-
- # Communicate tensors.
- (
- input_tensor,
- output_tensor_grad,
- ) = p2p_communication.send_forward_backward_recv_forward_backward(
- output_tensor,
- input_tensor_grad,
- recv_prev=recv_prev,
- recv_next=recv_next,
- tensor_shape=tensor_shape,
- config=config,
- )
- deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
-
- # Put input_tensor and output_tensor_grad in data structures in the
- # right location.
- if recv_prev:
- input_tensors[next_forward_model_chunk_id].append(input_tensor)
- if recv_next:
- output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
-
- deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
-
- # Run cooldown backward passes (flush out pipeline).
- if not forward_only:
- if config.overlap_p2p_comm and bwd_wait_handles is not None:
- for wait_handle in bwd_wait_handles:
- wait_handle.wait()
-
- if all_warmup_microbatches:
- output_tensor_grads[num_model_chunks - 1].append(
- p2p_communication.recv_backward(tensor_shape, config=config)
- )
- for k in range(num_microbatches_remaining, total_num_microbatches):
- input_tensor_grad = backward_step_helper(k)
- next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
- recv_next = True
- if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
- if next_backward_model_chunk_id == (num_model_chunks - 1):
- recv_next = False
- if k == (total_num_microbatches - 1):
- recv_next = False
- output_tensor_grads[next_backward_model_chunk_id].append(
- p2p_communication.send_backward_recv_backward(
- input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
- )
- )
-
- # Launch any remaining grad reductions.
- enable_grad_sync()
- if config.grad_sync_func is not None:
- for model_chunk_id in range(num_model_chunks):
- if model_chunk_id not in synchronized_model_chunks:
- config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
- synchronized_model_chunks.add(model_chunk_id)
-
- if config.timers is not None:
- config.timers('forward-backward').stop()
-
- if config.finalize_model_grads_func is not None and not forward_only:
- # Finalize model grads (perform full grad all-reduce / reduce-scatter for
- # data parallelism, layernorm all-reduce for sequence parallelism, and
- # embedding all-reduce for pipeline parallelism).
- config.finalize_model_grads_func(model)
-
- return forward_data_store
-
-
-def get_tensor_shapes(
- *,
- rank: int,
- model_type: ModelType,
- seq_length: int,
- micro_batch_size: int,
- decoder_seq_length: int,
- config,
-):
- # Determine right tensor sizes (based on position of rank with respect to split
- # rank) and model size.
- # Send two tensors if model is T5 and rank is in decoder stage:
- # first tensor is decoder (pre-transpose),
- # second tensor is encoder (post-transpose).
- # If model is T5 and rank is at the boundary:
- # send one tensor (post-transpose from encoder).
- # Otherwise, send one tensor (pre-transpose).
- tensor_shapes = []
-
- if config.sequence_parallel:
- seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
- if model_type == ModelType.encoder_and_decoder:
- decoder_seq_length = (
- decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
- )
-
- if model_type == ModelType.encoder_and_decoder:
- if parallel_state.is_pipeline_stage_before_split(rank):
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
- else:
- tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
- else:
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
- return tensor_shapes
-
-
-def recv_forward(tensor_shapes, config):
- input_tensors = []
- for tensor_shape in tensor_shapes:
- if tensor_shape is None:
- input_tensors.append(None)
- else:
- input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
- return input_tensors
-
-
-def recv_backward(tensor_shapes, config):
- output_tensor_grads = []
- for tensor_shape in tensor_shapes:
- if tensor_shape is None:
- output_tensor_grads.append(None)
- else:
- output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
- return output_tensor_grads
-
-
-def send_forward(output_tensors, tensor_shapes, config):
- if not isinstance(output_tensors, list):
- output_tensors = [output_tensors]
- for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
- if tensor_shape is None:
- continue
- p2p_communication.send_forward(output_tensor, config)
-
-
-def send_backward(input_tensor_grads, tensor_shapes, config):
- if not isinstance(input_tensor_grads, list):
- input_tensor_grads = [input_tensor_grads]
- for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
- if tensor_shape is None:
- continue
- p2p_communication.send_backward(input_tensor_grad, config)
-
-
-def send_forward_recv_backward(output_tensors, tensor_shapes, config):
- if not isinstance(output_tensors, list):
- output_tensors = [output_tensors]
- output_tensor_grads = []
- for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
- if tensor_shape is None:
- output_tensor_grads.append(None)
- continue
- output_tensor_grad = p2p_communication.send_forward_recv_backward(
- output_tensor, tensor_shape, config
- )
- output_tensor_grads.append(output_tensor_grad)
- return output_tensor_grads
-
-
-def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
- if not isinstance(input_tensor_grads, list):
- input_tensor_grads = [input_tensor_grads]
- input_tensors = []
- for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
- if tensor_shape is None:
- input_tensors.append(None)
- continue
- input_tensor = p2p_communication.send_backward_recv_forward(
- input_tensor_grad, tensor_shape, config
- )
- input_tensors.append(input_tensor)
- return input_tensors
-
-
-def forward_backward_pipelining_without_interleaving(
- *,
- forward_step_func,
- data_iterator: Union[Iterator, List[Iterator]],
- model: Union[torch.nn.Module, List[torch.nn.Module]],
- num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
- decoder_seq_length: int = None,
- forward_only: bool = False,
- collect_non_loss_data: bool = False,
-):
- """Run non-interleaved 1F1B schedule, with communication between pipeline
- stages.
-
- Returns dictionary with losses if the last stage, empty dict otherwise."""
-
- if isinstance(model, list):
- assert (
- len(model) == 1
- ), "non-interleaved pipeline parallelism does not support model chunking"
- model = model[0]
- if isinstance(data_iterator, list):
- assert (
- len(data_iterator) == 1
- ), "non-pipeline-parallel schedule does not support model chunking"
- data_iterator = data_iterator[0]
-
- config = get_model_config(model)
- if config.overlap_p2p_comm:
- raise ValueError(
- "Non-interleaved pipeline parallelism does not support overlapping p2p communication"
- )
-
- if config.timers is not None:
- config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
-
- # Disable async grad reductions
- no_sync_func = config.no_sync_func
- if no_sync_func is None:
- no_sync_func = contextlib.nullcontext
- no_sync_context = None
-
- def disable_grad_sync():
- """Disable asynchronous grad reductions"""
- nonlocal no_sync_context
- if no_sync_context is None:
- no_sync_context = no_sync_func()
- no_sync_context.__enter__()
-
- def enable_grad_sync():
- """Enable asynchronous grad reductions"""
- nonlocal no_sync_context
- if no_sync_context is not None:
- no_sync_context.__exit__(None, None, None)
- no_sync_context = None
-
- disable_grad_sync()
-
- # Compute number of warmup microbatches.
- num_warmup_microbatches = (
- parallel_state.get_pipeline_model_parallel_world_size()
- - parallel_state.get_pipeline_model_parallel_rank()
- - 1
- )
- num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
- num_microbatches_remaining = num_microbatches - num_warmup_microbatches
-
- # Checkpoint the activations of partial Transformer layers in a number of micro-batches
- # within the maximum outstanding micro-batch backpropagations.
- # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
- # checkpoint partial Transformer layers (or skip checkpointing) and
- # the rest of micro-batches within a window of micro-batches checkpoint
- # all Transformer layers. The window of micro-batches is set by the maximum
- # outstanding backpropagations and becomes smaller at later pipeline stages.
- # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
- max_outstanding_backprops = None
- if config.num_microbatches_with_partial_activation_checkpoints is not None:
- max_outstanding_backprops = num_warmup_microbatches + 1
-
- model_type = get_model_type(model)
-
- rank = parallel_state.get_pipeline_model_parallel_rank()
- recv_tensor_shapes = get_tensor_shapes(
- rank=rank - 1,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
- send_tensor_shapes = get_tensor_shapes(
- rank=rank,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
-
- # Input, output tensors only need to be saved when doing backward passes
- input_tensors = None
- output_tensors = None
- if not forward_only:
- input_tensors = []
- output_tensors = []
- forward_data_store = []
-
- # Run warmup forward passes.
- for i in range(num_warmup_microbatches):
- # Decide to checkpoint all layers' activations of the current micro-batch
- if max_outstanding_backprops is not None:
- checkpoint_activations_microbatch = (
- i % max_outstanding_backprops
- >= config.num_microbatches_with_partial_activation_checkpoints
- )
- else:
- checkpoint_activations_microbatch = None
-
- input_tensor = recv_forward(recv_tensor_shapes, config)
- output_tensor = forward_step(
- forward_step_func,
- data_iterator,
- model,
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data,
- checkpoint_activations_microbatch,
- )
- send_forward(output_tensor, send_tensor_shapes, config)
-
- if not forward_only:
- input_tensors.append(input_tensor)
- output_tensors.append(output_tensor)
- deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
-
- # Before running 1F1B, need to receive first forward tensor.
- # If all microbatches are run in warmup / cooldown phase, then no need to
- # receive this tensor here.
- if num_microbatches_remaining > 0:
- input_tensor = recv_forward(recv_tensor_shapes, config)
-
- # Run 1F1B in steady state.
- for i in range(num_microbatches_remaining):
- last_iteration = i == (num_microbatches_remaining - 1)
-
- # Decide to checkpoint all layers' activations of the current micro-batch
- if max_outstanding_backprops is not None:
- checkpoint_activations_microbatch = (
- (i + num_warmup_microbatches) % max_outstanding_backprops
- ) >= config.num_microbatches_with_partial_activation_checkpoints
- else:
- checkpoint_activations_microbatch = None
-
- output_tensor = forward_step(
- forward_step_func,
- data_iterator,
- model,
- num_microbatches,
- input_tensor,
- forward_data_store,
- config,
- collect_non_loss_data,
- checkpoint_activations_microbatch,
- )
-
- if forward_only:
- send_forward(output_tensor, send_tensor_shapes, config)
-
- if not last_iteration:
- input_tensor = recv_forward(recv_tensor_shapes, config)
-
- else:
- output_tensor_grad = send_forward_recv_backward(
- output_tensor, send_tensor_shapes, config
- )
-
- # Add input_tensor and output_tensor to end of list.
- input_tensors.append(input_tensor)
- output_tensors.append(output_tensor)
- deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
-
- # Pop input_tensor and output_tensor from the start of the list for
- # the backward pass.
- input_tensor = input_tensors.pop(0)
- output_tensor = output_tensors.pop(0)
-
- # Enable grad sync for the last microbatch in the batch if the full
- # backward pass completes in the 1F1B stage.
- if num_warmup_microbatches == 0 and last_iteration:
- if config.grad_sync_func is None or rank == 0:
- enable_grad_sync()
-
- input_tensor_grad = backward_step(
- input_tensor, output_tensor, output_tensor_grad, model_type, config
- )
-
- if last_iteration:
- input_tensor = None
- send_backward(input_tensor_grad, recv_tensor_shapes, config)
- else:
- input_tensor = send_backward_recv_forward(
- input_tensor_grad, recv_tensor_shapes, config
- )
-
- # Run cooldown backward passes.
- if not forward_only:
- for i in range(num_warmup_microbatches):
-
- # Enable async grad reduction in the last backward pass
- # Note: If grad sync function is provided, only enable
- # async grad reduction in first pipeline stage. Other
- # pipeline stages do grad reduction during pipeline
- # bubble.
- if i == num_warmup_microbatches - 1:
- if config.grad_sync_func is None or rank == 0:
- enable_grad_sync()
-
- input_tensor = input_tensors.pop(0)
- output_tensor = output_tensors.pop(0)
-
- output_tensor_grad = recv_backward(send_tensor_shapes, config)
-
- input_tensor_grad = backward_step(
- input_tensor, output_tensor, output_tensor_grad, model_type, config
- )
-
- send_backward(input_tensor_grad, recv_tensor_shapes, config)
-
- # Launch any remaining grad reductions.
- if no_sync_context is not None:
- enable_grad_sync()
- if config.grad_sync_func is not None:
- config.grad_sync_func(model.parameters())
-
- if config.timers is not None:
- config.timers('forward-backward').stop()
-
- if config.finalize_model_grads_func is not None and not forward_only:
- # Finalize model grads (perform full grad all-reduce / reduce-scatter for
- # data parallelism, layernorm all-reduce for sequence parallelism, and
- # embedding all-reduce for pipeline parallelism).
- config.finalize_model_grads_func([model])
-
- return forward_data_store
diff --git a/megatron/core/requirements.txt b/megatron/core/requirements.txt
deleted file mode 100644
index 08ed5eeb4b9f080b780db7d3e0af6712866c0493..0000000000000000000000000000000000000000
--- a/megatron/core/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-torch
\ No newline at end of file
diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py
deleted file mode 100644
index c8040e9e84305bab5577befd73691a019130bd1c..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/__init__.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from .cross_entropy import vocab_parallel_cross_entropy
-from .data import broadcast_data
-from .layers import (
- ColumnParallelLinear,
- RowParallelLinear,
- VocabParallelEmbedding,
- copy_tensor_model_parallel_attributes,
- linear_with_grad_accumulation_and_async_allreduce,
- param_is_not_tensor_parallel_duplicate,
- set_defaults_if_not_set_tensor_model_parallel_attributes,
- set_tensor_model_parallel_attributes,
-)
-from .mappings import (
- copy_to_tensor_model_parallel_region,
- gather_from_sequence_parallel_region,
- gather_from_sequence_parallel_region_to_moe,
- gather_from_tensor_model_parallel_region,
- reduce_scatter_to_sequence_parallel_region_from_moe,
- scatter_to_sequence_parallel_region,
- scatter_to_tensor_model_parallel_region,
-)
-from .random import (
- checkpoint,
- get_cuda_rng_tracker,
- get_data_parallel_rng_tracker_name,
- model_parallel_cuda_manual_seed,
-)
-from .utils import (
- gather_split_1d_tensor,
- split_tensor_along_last_dim,
- split_tensor_into_1d_equal_chunks,
-)
-
-__all__ = [
- # cross_entropy.py
- "vocab_parallel_cross_entropy",
- # data.py
- "broadcast_data",
- # layers.py
- "ColumnParallelLinear",
- "RowParallelLinear",
- "VocabParallelEmbedding",
- "set_tensor_model_parallel_attributes",
- "set_defaults_if_not_set_tensor_model_parallel_attributes",
- "copy_tensor_model_parallel_attributes",
- "param_is_not_tensor_parallel_duplicate",
- "linear_with_grad_accumulation_and_async_allreduce",
- # mappings.py
- "copy_to_tensor_model_parallel_region",
- "gather_from_tensor_model_parallel_region",
- "gather_from_sequence_parallel_region",
- # "reduce_from_tensor_model_parallel_region",
- "scatter_to_tensor_model_parallel_region",
- "scatter_to_sequence_parallel_region",
- # random.py
- "checkpoint",
- "get_cuda_rng_tracker",
- "model_parallel_cuda_manual_seed",
- # utils.py
- "split_tensor_along_last_dim",
- "split_tensor_into_1d_equal_chunks",
- "gather_split_1d_tensor",
- "gather_from_sequence_parallel_region_to_moe",
- "reduce_scatter_to_sequence_parallel_region_from_moe",
-]
diff --git a/megatron/core/tensor_parallel/cross_entropy.py b/megatron/core/tensor_parallel/cross_entropy.py
deleted file mode 100644
index 645fd1ea0c3673299f354c610c36e27e3345c0bc..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/cross_entropy.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-from megatron.core.parallel_state import (
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
-)
-
-from .utils import VocabUtility
-
-
-class _VocabParallelCrossEntropy(torch.autograd.Function):
- @staticmethod
- def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
-
- # Maximum value along vocab dimension across all GPUs.
- logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
- torch.distributed.all_reduce(
- logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
- )
- # Subtract the maximum value.
- vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
-
- # Get the partition's vocab indecies
- get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
- partition_vocab_size = vocab_parallel_logits.size()[-1]
- rank = get_tensor_model_parallel_rank()
- world_size = get_tensor_model_parallel_world_size()
- vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
-
- # Create a mask of valid vocab ids (1 means it needs to be masked).
- target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
- masked_target = target.clone() - vocab_start_index
- masked_target[target_mask] = 0
-
- # Get predicted-logits = logits[target].
- # For Simplicity, we convert logits to a 2-D tensor with size
- # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
- logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
- masked_target_1d = masked_target.view(-1)
- arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
- predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
- predicted_logits_1d = predicted_logits_1d.clone().contiguous()
- predicted_logits = predicted_logits_1d.view_as(target)
- predicted_logits[target_mask] = 0.0
- # All reduce is needed to get the chunks from other GPUs.
- torch.distributed.all_reduce(
- predicted_logits,
- op=torch.distributed.ReduceOp.SUM,
- group=get_tensor_model_parallel_group(),
- )
-
- # Sum of exponential of logits along vocab dimension across all GPUs.
- exp_logits = vocab_parallel_logits
- torch.exp(vocab_parallel_logits, out=exp_logits)
- sum_exp_logits = exp_logits.sum(dim=-1)
- torch.distributed.all_reduce(
- sum_exp_logits,
- op=torch.distributed.ReduceOp.SUM,
- group=get_tensor_model_parallel_group(),
- )
-
- # Loss = log(sum(exp(logits))) - predicted-logit.
- loss = torch.log(sum_exp_logits) - predicted_logits
-
- # Normalize and optionally smooth logits
- exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
-
- vocab_size = exp_logits.size(-1)
- if label_smoothing > 0:
- """
- We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
- = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
- = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
- = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
- = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
- = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
- From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
- """
- assert 1.0 > label_smoothing > 0.0
- smoothing = label_smoothing * vocab_size / (vocab_size - 1)
-
- # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
- log_probs = torch.log(exp_logits)
- mean_log_probs = log_probs.mean(dim=-1)
- loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
-
- ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
-
- # Store softmax, target-mask and masked-target for backward pass.
- ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
-
- return loss
-
- @staticmethod
- def backward(ctx, grad_output):
-
- # Retreive tensors from the forward path.
- softmax, target_mask, masked_target_1d = ctx.saved_tensors
- label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
-
- # All the inputs have softmax as thier gradient.
- grad_input = softmax
- # For simplicity, work with the 2D gradient.
- partition_vocab_size = softmax.size()[-1]
- grad_2d = grad_input.view(-1, partition_vocab_size)
-
- # Add the gradient from matching classes.
- arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
-
- softmax_update = 1.0 - target_mask.view(-1).float()
-
- if label_smoothing > 0:
- smoothing = label_smoothing * vocab_size / (vocab_size - 1)
- grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
- average_grad = 1 / vocab_size
- grad_2d[arange_1d, :] -= smoothing * average_grad
- else:
- grad_2d[arange_1d, masked_target_1d] -= softmax_update
-
- # Finally elementwise multiplication with the output gradients.
- grad_input.mul_(grad_output.unsqueeze(dim=-1))
-
- return grad_input, None, None
-
-
-def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
- """
- Performs cross entropy loss when logits are split across tensor parallel ranks
-
- Arguments:
- vocab_parallel_logits: logits split across tensor parallel ranks
- dimension is [sequence_length, batch_size, hidden_size]
-
- target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
-
- lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
- default is no smoothing (=0.0)
- """
- return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
diff --git a/megatron/core/tensor_parallel/data.py b/megatron/core/tensor_parallel/data.py
deleted file mode 100644
index 45c4fe7eb0cc0b83bcf2407c5bb2c5f36ed4582e..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/data.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-from megatron.core.parallel_state import (
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_src_rank,
-)
-
-_MAX_DATA_DIM = 5
-
-
-def _check_data_types(keys, data, target_dtype):
- """Check that all the keys have the same target data type."""
- for key in keys:
- assert data[key].dtype == target_dtype, (
- '{} has data type {} which '
- 'is different than {}'.format(key, data[key].dtype, target_dtype)
- )
-
-
-def _build_key_size_numel_dictionaries(keys, data):
- """Build the size on rank 0 and broadcast."""
- max_dim = _MAX_DATA_DIM
- sizes = [0 for _ in range(max_dim) for _ in keys]
-
- # Pack the sizes on rank zero.
- if get_tensor_model_parallel_rank() == 0:
- offset = 0
- for key in keys:
- assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
- size = data[key].size()
- for i, s in enumerate(size):
- sizes[i + offset] = s
- offset += max_dim
-
- # Move to GPU and broadcast.
- sizes_cuda = torch.cuda.LongTensor(sizes)
- torch.distributed.broadcast(
- sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
- )
-
- # Move back to cpu and unpack.
- sizes_cpu = sizes_cuda.cpu()
- key_size = {}
- key_numel = {}
- total_numel = 0
- offset = 0
- for key in keys:
- i = 0
- size = []
- numel = 1
- while sizes_cpu[offset + i] > 0:
- this_size = sizes_cpu[offset + i]
- size.append(this_size)
- numel *= this_size
- i += 1
- key_size[key] = size
- key_numel[key] = numel
- total_numel += numel
- offset += max_dim
-
- return key_size, key_numel, total_numel
-
-
-def broadcast_data(keys, data, datatype):
- """Broadcast data from rank zero of each model parallel group to the
- members of the same model parallel group.
-
- Arguments:
- keys: list of keys in the data disctionary to be broadcasted
- data: data dictionary of string keys and cpu tensor values.
- datatype: torch data type of all tensors in data associated
- with keys.
- """
- # Build (key, size) and (key, number of elements) dictionaries along
- # with the total number of elements on all ranks.
- key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
-
- # Pack on rank zero.
- if get_tensor_model_parallel_rank() == 0:
- # Check that all keys have the same data type.
- _check_data_types(keys, data, datatype)
- # Flatten the data associated with the keys
- flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
- else:
- flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
-
- # Broadcast
- torch.distributed.broadcast(
- flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
- )
-
- # Unpack
- output = {}
- offset = 0
- for key in keys:
- size = key_size[key]
- numel = key_numel[key]
- output[key] = flatten_data.narrow(0, offset, numel).view(size)
- offset += numel
-
- return output
diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py
deleted file mode 100644
index f31ee42df623f4433af90570c5bf0fd0914aa4e0..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/layers.py
+++ /dev/null
@@ -1,949 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-# Parts of the code here are adapted from PyTorch
-# repo: https://github.com/pytorch/pytorch
-
-import math
-import os
-import warnings
-from typing import Callable, Optional
-
-import torch
-import torch.nn.functional as F
-import torch.nn.init as init
-from torch.cuda.amp import custom_bwd, custom_fwd
-from torch.nn.parameter import Parameter
-
-from megatron.core.model_parallel_config import ModelParallelConfig
-from megatron.core.parallel_state import (
- get_global_memory_buffer,
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
-)
-
-from .mappings import (
- copy_to_tensor_model_parallel_region,
- gather_from_sequence_parallel_region,
- gather_from_tensor_model_parallel_region,
- reduce_from_tensor_model_parallel_region,
- reduce_scatter_to_sequence_parallel_region,
- scatter_to_tensor_model_parallel_region,
-)
-from .random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
-from .utils import VocabUtility, divide, split_tensor_along_last_dim
-
-_grad_accum_fusion_available = True
-try:
- import fused_weight_gradient_mlp_cuda
-except ImportError:
- _grad_accum_fusion_available = False
-
-_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
- 'tensor_model_parallel': False,
- 'partition_dim': -1,
- 'partition_stride': 1,
-}
-
-
-def param_is_not_tensor_parallel_duplicate(param):
- return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
- get_tensor_model_parallel_rank() == 0
- )
-
-
-def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
- # Make sure the attributes are not set.
- for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
- assert not hasattr(tensor, attribute)
- # Set the attributes.
- setattr(tensor, 'tensor_model_parallel', is_parallel)
- setattr(tensor, 'partition_dim', dim)
- setattr(tensor, 'partition_stride', stride)
-
-
-def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
- def maybe_set(attribute, value):
- if not hasattr(tensor, attribute):
- setattr(tensor, attribute, value)
-
- for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
- maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
-
-
-def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
- def maybe_copy(attribute):
- if hasattr(source_tensor, attribute):
- setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
-
- for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
- maybe_copy(attribute)
-
-
-def _initialize_affine_weight_gpu(
- weight, init_method, partition_dim, stride=1, expert_parallel=False
-):
- """Initialize affine weight for model parallel on GPU."""
-
- set_tensor_model_parallel_attributes(
- tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
- )
-
- if not expert_parallel:
- with get_cuda_rng_tracker().fork():
- init_method(weight)
- else:
- with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()):
- init_method(weight)
-
-
-def _initialize_affine_weight_cpu(
- weight,
- output_size,
- input_size,
- per_partition_size,
- partition_dim,
- init_method,
- stride=1,
- return_master_weight=False,
- *,
- params_dtype=torch.float32,
-):
- """Initialize affine weight for model parallel.
-
- Build the master weight on all processes and scatter
- the relevant chunk."""
-
- set_tensor_model_parallel_attributes(
- tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
- )
-
- # Initialize master weight
- master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
- init_method(master_weight)
- master_weight = master_weight.to(dtype=params_dtype)
-
- # Split and copy
- per_partition_per_stride_size = divide(per_partition_size, stride)
- weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
- rank = get_tensor_model_parallel_rank()
- world_size = get_tensor_model_parallel_world_size()
- my_weight_list = weight_list[rank::world_size]
-
- with torch.no_grad():
- torch.cat(my_weight_list, dim=partition_dim, out=weight)
- if return_master_weight:
- return master_weight
- return None
-
-
-class VocabParallelEmbedding(torch.nn.Module):
- """Embedding parallelized in the vocabulary dimension.
-
- This is mainly adapted from torch.nn.Embedding and all the default
- values are kept.
- Arguments:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
-
- Keyword Arguments:
- config: A megatron.core.ModelParallelConfig object
- """
-
- def __init__(
- self,
- num_embeddings: int,
- embedding_dim: int,
- *,
- init_method: Callable,
- config: ModelParallelConfig,
- ):
- super(VocabParallelEmbedding, self).__init__()
- # Keep the input dimensions.
- self.num_embeddings = num_embeddings
- self.embedding_dim = embedding_dim
- self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
- # Divide the weight matrix along the vocaburaly dimension.
- (
- self.vocab_start_index,
- self.vocab_end_index,
- ) = VocabUtility.vocab_range_from_global_vocab_size(
- self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
- )
- self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
-
- # Allocate weights and initialize.
- if config.use_cpu_initialization:
- self.weight = Parameter(
- torch.empty(
- self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
- )
- )
- if config.perform_initialization:
- _initialize_affine_weight_cpu(
- self.weight,
- self.num_embeddings,
- self.embedding_dim,
- self.num_embeddings_per_partition,
- 0,
- init_method,
- params_dtype=config.params_dtype,
- )
- else:
- self.weight = Parameter(
- torch.empty(
- self.num_embeddings_per_partition,
- self.embedding_dim,
- device=torch.cuda.current_device(),
- dtype=config.params_dtype,
- )
- )
- if config.perform_initialization:
- _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
-
- def forward(self, input_):
- assert not torch.any(
- (input_ < 0) | (input_ >= self.num_embeddings)
- ), "An input token is out of bounds of the embedding table"
- if self.tensor_model_parallel_size > 1:
- # Build the mask.
- input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
- # Mask the input.
- masked_input = input_.clone() - self.vocab_start_index
- masked_input[input_mask] = 0
- else:
- masked_input = input_
- # Get the embeddings.
- output_parallel = self.weight[masked_input]
- # Mask the output embedding.
- if self.tensor_model_parallel_size > 1:
- output_parallel[input_mask, :] = 0.0
- # Reduce across all the model parallel GPUs.
- output = reduce_from_tensor_model_parallel_region(output_parallel)
- return output
-
-
-class LinearWithFrozenWeight(torch.autograd.Function):
- """Linear operator that does not calculate gradient for weight.
- This op and LinearWithGradAccumulationAndAsyncCommunication performs
- mathematically-identical forward and DGRAD.
-
- Conceptually this op is the same as torch.nn.functional.linear with
- weight.requires_grad==False, but in experiments they are not identical
- mathematically. """
-
- @staticmethod
- @custom_fwd
- def forward(
- ctx, input, weight, bias,
- ):
- ctx.save_for_backward(weight)
- output = torch.matmul(input, weight.t())
- if bias is not None:
- output = output + bias
- return output
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
- (weight,) = ctx.saved_tensors
- grad_input = grad_output.matmul(weight)
- return grad_input, None, None
-
-
-def linear_with_frozen_weight(
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: Optional[torch.Tensor],
- gradient_accumulation_fusion: bool,
- async_grad_allreduce: bool,
- sequence_parallel: bool,
-) -> torch.Tensor:
- """Linear layer execution with weight.requires_grad == False.
-
- This function handles linear layers with weight frozen (untrainable).
- In the forward, it only saves weight and does not save input activations.
- In the backward, it does not perform weight gradient calculation, or
- weight gradient allreduce.
-
- Arguments:
-
- input (torch.Tensor required): input like torch.nn.functional.linear
-
- weight (torch.Tensor required): weight like torch.nn.functional.linear
-
- bias (torch.Tensor optional): bias like torch.nn.functional.linear
-
- gradient_accumulation_fusion (bool required): dummy argument, used to
- keep the API unified between all forward implementation functions.
-
- async_grad_allreduce (bool required): dummy argument, used to
- keep the API unified between all forward implementation functions.
-
- sequence_parallel (bool required): Indicates that sequence
- parallelism is used and thus in the forward pass the input is
- all gathered, and the backward pass the input gradients are
- reduce scattered.
- """
-
- if sequence_parallel:
- input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True)
- else:
- input = input
-
- args = [
- input,
- weight,
- bias,
- ]
-
- return LinearWithFrozenWeight.apply(*args)
-
-
-class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
- """See linear_with_grad_accumulation_and_async_allreduce"""
-
- @staticmethod
- @custom_fwd
- def forward(
- ctx,
- input,
- weight,
- bias,
- gradient_accumulation_fusion,
- async_grad_allreduce,
- sequence_parallel,
- ):
- ctx.save_for_backward(input, weight)
- ctx.use_bias = bias is not None
- ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
- ctx.async_grad_allreduce = async_grad_allreduce
- ctx.sequence_parallel = sequence_parallel
-
- if sequence_parallel:
- world_size = get_tensor_model_parallel_world_size()
- dim_size = list(input.size())
- dim_size[0] = dim_size[0] * world_size
-
- all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
- torch.distributed._all_gather_base(
- all_gather_buffer, input, group=get_tensor_model_parallel_group()
- )
- total_input = all_gather_buffer
- else:
- total_input = input
-
- output = torch.matmul(total_input, weight.t())
- if bias is not None:
- output = output + bias
- return output
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
- input, weight = ctx.saved_tensors
- use_bias = ctx.use_bias
-
- if ctx.sequence_parallel:
- world_size = get_tensor_model_parallel_world_size()
- dim_size = list(input.size())
- dim_size[0] = dim_size[0] * world_size
-
- all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
- handle = torch.distributed._all_gather_base(
- all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
- )
-
- # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
- # gather is scheduled before the input gradient computation
- total_input = all_gather_buffer
- else:
- total_input = input
- grad_input = grad_output.matmul(weight)
-
- if ctx.sequence_parallel:
- handle.wait()
-
- # Doing gather + slicing during the NeMo forward pass can make this tensor
- # not be contiguous. PyTorch only checks if the tensor is contiguous, and only
- # clones it if it's not contiguous:
- # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
- grad_output = grad_output.contiguous()
- # Convert the tensor shapes to 2D for execution compatibility
- grad_output = grad_output.view(
- grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
- )
- total_input = total_input.view(
- total_input.shape[0] * total_input.shape[1], total_input.shape[2]
- )
-
- if ctx.async_grad_allreduce:
- # Asynchronous all-reduce
- handle = torch.distributed.all_reduce(
- grad_input, group=get_tensor_model_parallel_group(), async_op=True
- )
- # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
- # all-reduce is scheduled before the weight gradient computation
-
- if ctx.sequence_parallel:
- assert not ctx.async_grad_allreduce
- dim_size = list(input.size())
- sub_grad_input = torch.empty(
- dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
- )
- # reduce_scatter
- handle = torch.distributed._reduce_scatter_base(
- sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
- )
- # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
- # reduce scatter is scheduled before the weight gradient computation
-
- if ctx.gradient_accumulation_fusion:
- if weight.main_grad.dtype == torch.float32:
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
- total_input, grad_output, weight.main_grad
- )
- elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
- total_input, grad_output, weight.main_grad
- )
- else:
- raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
-
- if hasattr(weight, 'grad_added_to_main_grad'):
- # When overlap_grad_reduce is True, need to ensure that backward hooks
- # are all run on the main backprop thread to prevent deadlocks. Setup
- # dummy grad_weight tensor to prevent backward hooks from being run
- # in a background thread.
- grad_weight = torch.empty(
- weight.main_grad.shape,
- dtype=input.dtype,
- device=torch.cuda.current_device(),
- requires_grad=False,
- )
- weight.grad_added_to_main_grad = True
- else:
- grad_weight = None
- else:
- grad_weight = grad_output.t().matmul(total_input)
- grad_bias = grad_output.sum(dim=0) if use_bias else None
-
- if ctx.sequence_parallel:
- handle.wait()
- return sub_grad_input, grad_weight, grad_bias, None, None, None
-
- if ctx.async_grad_allreduce:
- handle.wait()
-
- return grad_input, grad_weight, grad_bias, None, None, None
-
-
-def linear_with_grad_accumulation_and_async_allreduce(
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: Optional[torch.Tensor],
- gradient_accumulation_fusion: bool,
- async_grad_allreduce: bool,
- sequence_parallel: bool,
-) -> torch.Tensor:
- """Linear layer execution with asynchronous communication and
- gradient accumulation fusion in backprop.
-
- This has the option to accumulate the result of backprop
- calculation into an existing gradient buffer, preventing the need
- to do an additional addition kernel after the gradient
- calculation.
-
- Additionally, the tensor parallel all reduce of the input
- gradients can be done asynchronously with the calculation of
- the weight gradients.
-
- In the case of sequence parallelism, the reduce scatter of the
- input gradients is done asynchronously with the calcluation of the
- weight gradients.
-
- Use of this module requires that the environment variable
- CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
- operations, noted in the code, that should be scheduled before
- compute kernels to overlap the communication with the computation,
- which is necessary for a speedup but not for correctness so that
- ordering isn't imposed by the scheduler. Setting
- CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
- in the order they are called.
-
- Arguments:
-
- input (torch.Tensor required): input like torch.nn.functional.linear
-
- weight (torch.Tensor required): weight like torch.nn.functional.linear
-
- bias (torch.Tensor optional): bias like torch.nn.functional.linear
-
- gradient_accumulation_fusion (bool required): Perform the gradient
- accumulation fusion, requires the custom CUDA extension
- fused_weight_gradient_mlp_cuda module. To use
- gradient_accumulation_fusion you must install APEX with
- --cpp_ext and --cuda_ext. For example: "pip install
- --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
- " Note that the extension requires CUDA>=11. Otherwise, you
- must turn off gradient accumulation fusion."
-
- async_grad_allreduce (bool required): Do the allreduce of input
- gradients asyncronously with the computation of weight
- gradients. If sequence_parallel is True, this must be
- False, as no all reduce is performed.
-
- sequence_parallel (bool required): Indicates that sequence
- parallelism is used and thus in the forward pass the input is
- all gathered, and the backward pass the input gradients are
- reduce scattered.
- """
- args = [
- input,
- weight,
- bias,
- gradient_accumulation_fusion,
- async_grad_allreduce,
- sequence_parallel,
- ]
-
- if not linear_with_grad_accumulation_and_async_allreduce.warned:
- if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
- if sequence_parallel:
- warnings.warn(
- "When using sequence parallelism it is recommended to set the "
- "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
- "maximum speedup"
- )
- linear_with_grad_accumulation_and_async_allreduce.warned = True
-
- if async_grad_allreduce:
- warnings.warn(
- "When using async grad allreduce it is recommended to set the "
- "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
- "maximum speedup"
- )
- linear_with_grad_accumulation_and_async_allreduce.warned = True
-
- return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
-
-
-linear_with_grad_accumulation_and_async_allreduce.warned = False
-
-
-class ColumnParallelLinear(torch.nn.Module):
- """Linear layer with column parallelism.
-
- The linear layer is defined as Y = XA + b. A is parallelized along
- its second dimension as A = [A_1, ..., A_p].
-
- Arguments:
- input_size: first dimension of matrix A.
- output_size: second dimension of matrix A.
-
- Keyword Arguments
- bias: If true, add bias
- gather_output: If true, call all-gather on output and make Y available
- to all GPUs, otherwise, every GPU will have its output
- which is Y_i = XA_i
- init_method: method to initialize weights. Note that bias is always set
- to zero.
- stride: For the strided linear layers.
- keep_master_weight_for_test: This was added for testing and should be
- set to False. It returns the master weights
- used for initialization.
- skip_bias_add: If True, do not add the bias term, instead
- return it to be added by the caller. This
- enables performance optimations where bias can
- be fused with other elementwise operations.
- skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
- as a keyword argument `weight` during the forward pass. Note
- that this does not affect bias, which will be allocated if
- bias is True. Defaults to False.
- is_expert: If True, the layer is treated as an MoE expert layer.
- config: ModelParallelConfig object
- tp_comm_buffer_name: Communication buffer name is not used in
- non-Transformer-Engine modules.
-
- """
-
- def __init__(
- self,
- input_size,
- output_size,
- *,
- config: ModelParallelConfig,
- init_method: Callable,
- bias=True,
- gather_output=False,
- stride=1,
- keep_master_weight_for_test=False,
- skip_bias_add=False,
- skip_weight_param_allocation: bool = False,
- is_expert: bool = False,
- tp_comm_buffer_name: str = None, # Not used
- ):
- super(ColumnParallelLinear, self).__init__()
-
- # Keep input parameters
- self.input_size = input_size
- self.output_size = output_size
- self.gather_output = gather_output
- # Divide the weight matrix along the last dimension.
- world_size = get_tensor_model_parallel_world_size()
- self.output_size_per_partition = divide(output_size, world_size)
- self.skip_bias_add = skip_bias_add
- self.is_expert = is_expert
- self.expert_parallel = config.expert_model_parallel_size > 1
- self.config = config
-
- # Parameters.
- # Note: torch.nn.functional.linear performs XA^T + b and as a result
- # we allocate the transpose.
- # Initialize weight.
- if not skip_weight_param_allocation:
- if config.use_cpu_initialization:
- self.weight = Parameter(
- torch.empty(
- self.output_size_per_partition, self.input_size, dtype=config.params_dtype
- )
- )
- if config.perform_initialization:
- self.master_weight = _initialize_affine_weight_cpu(
- self.weight,
- self.output_size,
- self.input_size,
- self.output_size_per_partition,
- 0,
- init_method,
- stride=stride,
- return_master_weight=keep_master_weight_for_test,
- )
- else:
- self.weight = Parameter(
- torch.empty(
- self.output_size_per_partition,
- self.input_size,
- device=torch.cuda.current_device(),
- dtype=config.params_dtype,
- )
- )
- if config.perform_initialization:
- _initialize_affine_weight_gpu(
- self.weight,
- init_method,
- partition_dim=0,
- stride=stride,
- expert_parallel=(self.is_expert and self.expert_parallel),
- )
-
- setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
- else:
- self.weight = None
-
- if bias:
- if config.use_cpu_initialization:
- self.bias = Parameter(
- torch.empty(self.output_size_per_partition, dtype=config.params_dtype)
- )
- else:
- self.bias = Parameter(
- torch.empty(
- self.output_size_per_partition,
- device=torch.cuda.current_device(),
- dtype=config.params_dtype,
- )
- )
- set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
- if config.perform_initialization:
- # Always initialize bias to zero.
- with torch.no_grad():
- self.bias.zero_()
- setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
- else:
- self.register_parameter('bias', None)
-
- self.async_tensor_model_parallel_allreduce = (
- config.async_tensor_model_parallel_allreduce and world_size > 1
- )
-
- self.sequence_parallel = config.sequence_parallel
- if self.sequence_parallel and world_size <= 1:
- warnings.warn(
- f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. "
- f"Disabling sequence parallel."
- )
- self.sequence_parallel = False
-
- if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
- raise RuntimeError(
- "ColumnParallelLinear was called with gradient_accumulation_fusion set "
- "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
- "module is not found. To use gradient_accumulation_fusion you must "
- "install APEX with --cpp_ext and --cuda_ext. For example: "
- "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
- "Note that the extension requires CUDA>=11. Otherwise, you must turn off "
- "gradient accumulation fusion."
- )
- self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
-
- if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
- raise RuntimeError(
- "`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
- "cannot be enabled at the same time."
- )
-
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
- self.explicit_expert_comm = self.is_expert and (
- self.sequence_parallel or self.expert_parallel
- )
-
- def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
- """Forward of ColumnParallelLinear
-
- Args:
- input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
-
- weight (optional): weight tensor to use, compulsory when
- skip_weight_param_allocation is True.
-
- Returns:
- - output
- - bias
-
- """
- if weight is None:
- if self.weight is None:
- raise RuntimeError(
- "weight was not supplied to ColumnParallelLinear forward pass "
- "and skip_weight_param_allocation is True."
- )
- weight = self.weight
- else:
- # Check the weight passed in is the correct shape
- expected_shape = (self.output_size_per_partition, self.input_size)
- if weight.shape != expected_shape:
- raise RuntimeError(
- f"supplied weight's shape is {tuple(weight.shape)}, "
- f"not {expected_shape} as expected"
- )
-
- bias = self.bias if not self.skip_bias_add else None
-
- if (
- self.async_tensor_model_parallel_allreduce
- or self.sequence_parallel
- or self.explicit_expert_comm
- ):
- input_parallel = input_
- else:
- input_parallel = copy_to_tensor_model_parallel_region(input_)
-
- # Matrix multiply.
- if not weight.requires_grad:
- self._forward_impl = linear_with_frozen_weight
- else:
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
- output_parallel = self._forward_impl(
- input=input_parallel,
- weight=weight,
- bias=bias,
- gradient_accumulation_fusion=self.gradient_accumulation_fusion,
- async_grad_allreduce=False
- if self.explicit_expert_comm
- else self.async_tensor_model_parallel_allreduce,
- sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
- )
- if self.gather_output:
- # All-gather across the partitions.
- assert not self.sequence_parallel
- output = gather_from_tensor_model_parallel_region(output_parallel)
- else:
- output = output_parallel
- output_bias = self.bias if self.skip_bias_add else None
- return output, output_bias
-
-
-class RowParallelLinear(torch.nn.Module):
- """Linear layer with row parallelism.
-
- The linear layer is defined as Y = XA + b. A is parallelized along
- its first dimension and X along its second dimension as:
- - -
- | A_1 |
- | . |
- A = | . | X = [X_1, ..., X_p]
- | . |
- | A_p |
- - -
- Arguments:
- input_size: first dimension of matrix A.
- output_size: second dimension of matrix A.
-
- Keyword Arguments:
- bias: If true, add bias. Note that bias is not parallelized.
- input_is_parallel: If true, we assume that the input is already
- split across the GPUs and we do not split
- again.
- init_method: method to initialize weights. Note that bias is always set
- to zero.
- stride: For the strided linear layers.
- keep_master_weight_for_test: This was added for testing and should be
- set to False. It returns the master weights
- used for initialization.
- skip_bias_add: If True, do not add the bias term, instead
- return it to be added by the caller. This
- enables performance optimations where bias can
- be fused with other elementwise operations.
- is_expert: If True, the layer is treated as an MoE expert layer
- tp_comm_buffer_name: Communication buffer name. Not used in
- non-Transformer-Engine modules.
- config: ModelParallelConfig object
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- *,
- config: ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- stride: int = 1,
- keep_master_weight_for_test: bool = False,
- is_expert: bool = False,
- tp_comm_buffer_name: str = None, # Not used
- ):
- super(RowParallelLinear, self).__init__()
-
- # Keep input parameters
- self.input_size = input_size
- self.output_size = output_size
- self.input_is_parallel = input_is_parallel
- # Divide the weight matrix along the last dimension.
- world_size = get_tensor_model_parallel_world_size()
- self.input_size_per_partition = divide(input_size, world_size)
- self.skip_bias_add = skip_bias_add
- self.config = config
- self.is_expert = is_expert
- self.expert_parallel = config.expert_model_parallel_size > 1
- self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
- self.sequence_parallel = config.sequence_parallel
- if self.sequence_parallel and not self.input_is_parallel:
- raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
-
- # Parameters.
- # Note: torch.nn.functional.linear performs XA^T + b and as a result
- # we allocate the transpose.
- # Initialize weight.
- if config.use_cpu_initialization:
- self.weight = Parameter(
- torch.empty(
- self.output_size, self.input_size_per_partition, dtype=config.params_dtype
- )
- )
- if config.perform_initialization:
- self.master_weight = _initialize_affine_weight_cpu(
- self.weight,
- self.output_size,
- self.input_size,
- self.input_size_per_partition,
- 1,
- init_method,
- stride=stride,
- return_master_weight=keep_master_weight_for_test,
- params_dtype=config.params_dtype,
- )
- else:
- self.weight = Parameter(
- torch.empty(
- self.output_size,
- self.input_size_per_partition,
- device=torch.cuda.current_device(),
- dtype=config.params_dtype,
- )
- )
- if config.perform_initialization:
- _initialize_affine_weight_gpu(
- self.weight,
- init_method,
- partition_dim=1,
- stride=stride,
- expert_parallel=(self.is_expert and self.expert_parallel),
- )
- setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
-
- if bias:
- if config.use_cpu_initialization:
- self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
- else:
- self.bias = Parameter(
- torch.empty(
- self.output_size,
- device=torch.cuda.current_device(),
- dtype=config.params_dtype,
- )
- )
-
- if config.perform_initialization:
- # Always initialize bias to zero.
- with torch.no_grad():
- self.bias.zero_()
- setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
- setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
- else:
- self.register_parameter('bias', None)
-
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
- self.explicit_expert_comm = self.is_expert and (
- self.sequence_parallel or self.expert_parallel
- )
-
- def forward(self, input_):
- """Forward of RowParallelLinear
-
- Args:
- input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
-
- Returns:
- - output
- - bias
- """
- # Set up backprop all-reduce.
- if self.input_is_parallel:
- input_parallel = input_
- else:
- assert not self.sequence_parallel
- input_parallel = scatter_to_tensor_model_parallel_region(input_)
- # Matrix multiply.
- if not self.weight.requires_grad:
- self._forward_impl = linear_with_frozen_weight
- else:
- self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
- output_parallel = self._forward_impl(
- input=input_parallel,
- weight=self.weight,
- bias=None,
- gradient_accumulation_fusion=self.gradient_accumulation_fusion,
- async_grad_allreduce=False,
- sequence_parallel=False,
- )
-
- # All-reduce across all the partitions.
- if self.explicit_expert_comm:
- assert self.skip_bias_add
- output_ = output_parallel
- elif self.sequence_parallel:
- output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
- else:
- output_ = reduce_from_tensor_model_parallel_region(output_parallel)
- if not self.skip_bias_add:
- output = (output_ + self.bias) if self.bias is not None else output_
- output_bias = None
- else:
- output = output_
- output_bias = self.bias
- return output, output_bias
diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py
deleted file mode 100644
index 95c8841be7c90e740baab285de7d9e03a90c98bb..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/mappings.py
+++ /dev/null
@@ -1,358 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-from megatron.core.parallel_state import (
- get_tensor_and_expert_parallel_group,
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
-)
-
-from .utils import split_tensor_along_last_dim
-
-
-def _reduce(input_):
- """All-reduce the input tensor across model parallel group."""
-
- # Bypass the function if we are using only 1 GPU.
- if get_tensor_model_parallel_world_size() == 1:
- return input_
-
- # All-reduce.
- torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
-
- return input_
-
-
-def _split_along_last_dim(input_):
- """Split the tensor along its last dimension and keep the
- corresponding slice."""
-
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- # Split along last dimension.
- input_list = split_tensor_along_last_dim(input_, world_size)
-
- # Note: torch.split does not create contiguous tensors by default.
- rank = get_tensor_model_parallel_rank()
- output = input_list[rank].contiguous()
-
- return output
-
-
-def _split_along_first_dim(input_):
- """Split the tensor along its first dimension and keep the
- corresponding slice."""
-
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- # Split along first dimension.
- dim_size = input_.size()[0]
- assert (
- dim_size % world_size == 0
- ), "First dimension of the tensor should be divisible by tensor parallel size"
- local_dim_size = dim_size // world_size
- rank = get_tensor_model_parallel_rank()
- dim_offset = rank * local_dim_size
-
- output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
-
- return output
-
-
-def _gather_along_last_dim(input_):
- """Gather tensors and concatinate along the last dimension."""
-
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- # Size and dimension.
- last_dim = input_.dim() - 1
- rank = get_tensor_model_parallel_rank()
-
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank] = input_
- torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
-
- # Note: torch.cat already creates a contiguous tensor.
- output = torch.cat(tensor_list, dim=last_dim).contiguous()
-
- return output
-
-
-def _gather_along_first_dim(input_):
- """Gather tensors and concatinate along the first dimension."""
-
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- dim_size = list(input_.size())
- dim_size[0] = dim_size[0] * world_size
-
- output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
- torch.distributed._all_gather_base(
- output, input_.contiguous(), group=get_tensor_model_parallel_group()
- )
-
- return output
-
-
-def _reduce_scatter_along_first_dim(input_):
- """Reduce-scatter the input tensor across model parallel group."""
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- dim_size = list(input_.size())
- assert (
- dim_size[0] % world_size == 0
- ), "First dimension of the tensor should be divisible by tensor parallel size"
-
- dim_size[0] = dim_size[0] // world_size
-
- output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
- torch.distributed._reduce_scatter_base(
- output, input_.contiguous(), group=get_tensor_model_parallel_group()
- )
- return output
-
-
-def _gather_along_first_dim_moe(input_):
- """Gather tensors and concatenate along the first dimension."""
- group = get_tensor_and_expert_parallel_group()
- world_size = torch.distributed.get_world_size(group=group)
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- dim_size = list(input_.size())
- dim_size[0] = dim_size[0] * world_size
-
- output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
- torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
-
- return output
-
-
-def _reduce_scatter_along_first_dim_moe(input_):
- """Reduce-scatter the input tensor across model parallel group."""
- group = get_tensor_and_expert_parallel_group()
- world_size = torch.distributed.get_world_size(group=group)
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
-
- dim_size = list(input_.size())
- assert dim_size[0] % world_size == 0
- dim_size[0] = dim_size[0] // world_size
-
- output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
- torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=group)
- return output
-
-
-class _CopyToModelParallelRegion(torch.autograd.Function):
- """Pass the input to the model parallel region."""
-
- @staticmethod
- def symbolic(graph, input_):
- return input_
-
- @staticmethod
- def forward(ctx, input_):
- return input_
-
- @staticmethod
- def backward(ctx, grad_output):
- return _reduce(grad_output)
-
-
-class _ReduceFromModelParallelRegion(torch.autograd.Function):
- """All-reduce the input from the model parallel region."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _reduce(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _reduce(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output
-
-
-class _ScatterToModelParallelRegion(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _split_along_last_dim(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _split_along_last_dim(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _gather_along_last_dim(grad_output)
-
-
-class _GatherFromModelParallelRegion(torch.autograd.Function):
- """Gather the input from model parallel region and concatinate."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _gather_along_last_dim(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _gather_along_last_dim(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _split_along_last_dim(grad_output)
-
-
-class _ScatterToSequenceParallelRegion(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _split_along_first_dim(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _split_along_first_dim(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _gather_along_first_dim(grad_output)
-
-
-class _GatherFromSequenceParallelRegion(torch.autograd.Function):
- """Gather the input from sequence parallel region and concatinate."""
-
- @staticmethod
- def symbolic(graph, input_, tensor_parallel_output_grad=True):
- return _gather_along_first_dim(input_)
-
- @staticmethod
- def forward(ctx, input_, tensor_parallel_output_grad=True):
- ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
- return _gather_along_first_dim(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
-
- # If the computation graph after the gather operation is
- # in the tensor parallel mode, output gradients need to reduce
- # scattered and whereas if the computation is duplicated,
- # output gradients need to be scattered.
- if tensor_parallel_output_grad:
- return _reduce_scatter_along_first_dim(grad_output), None
- else:
- return _split_along_first_dim(grad_output), None
-
-
-class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
- """Reduce scatter the input from the model parallel region."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _reduce_scatter_along_first_dim(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _reduce_scatter_along_first_dim(input_)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _gather_along_first_dim(grad_output)
-
-
-class _GatherFromSequenceParallelRegionToMOE(torch.autograd.Function):
- """Gather the input from model parallel region and concatenate.""" # TODO
-
- @staticmethod
- def symbolic(graph, input_):
- return _gather_along_first_dim_moe(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _gather_along_first_dim_moe(input_,)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _reduce_scatter_along_first_dim_moe(grad_output)
-
-
-class _ReduceScatterToSequenceParallelRegionFromMOE(torch.autograd.Function):
- """Reduce scatter the input from the model parallel region."""
-
- @staticmethod
- def symbolic(graph, input_):
- return _reduce_scatter_along_first_dim_moe(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _reduce_scatter_along_first_dim_moe(input_,)
-
- @staticmethod
- def backward(ctx, grad_output):
- return _gather_along_first_dim_moe(grad_output)
-
-
-# -----------------
-# Helper functions.
-# -----------------
-
-
-def copy_to_tensor_model_parallel_region(input_):
- return _CopyToModelParallelRegion.apply(input_)
-
-
-def reduce_from_tensor_model_parallel_region(input_):
- return _ReduceFromModelParallelRegion.apply(input_)
-
-
-def scatter_to_tensor_model_parallel_region(input_):
- return _ScatterToModelParallelRegion.apply(input_)
-
-
-def gather_from_tensor_model_parallel_region(input_):
- return _GatherFromModelParallelRegion.apply(input_)
-
-
-def scatter_to_sequence_parallel_region(input_):
- return _ScatterToSequenceParallelRegion.apply(input_)
-
-
-def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
- return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
-
-
-def reduce_scatter_to_sequence_parallel_region(input_):
- return _ReduceScatterToSequenceParallelRegion.apply(input_)
-
-
-def gather_from_sequence_parallel_region_to_moe(input_):
- return _GatherFromSequenceParallelRegionToMOE.apply(input_)
-
-
-def reduce_scatter_to_sequence_parallel_region_from_moe(input_):
- return _ReduceScatterToSequenceParallelRegionFromMOE.apply(input_)
diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py
deleted file mode 100644
index 9d51b09f7e3a37042b555f10afd889b0e5930df2..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/random.py
+++ /dev/null
@@ -1,269 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-# Parts of the code here are adapted from PyTorch
-# repo: https://github.com/pytorch/pytorch
-
-import contextlib
-
-import torch
-from torch import _C
-from torch.cuda import _lazy_call
-from torch.cuda import device as device_ctx_manager
-from torch.utils.checkpoint import detach_variable
-
-from megatron.core.parallel_state import (
- get_data_parallel_rank,
- get_expert_model_parallel_rank,
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
-)
-from megatron.core.utils import safely_set_viewless_tensor_data
-
-from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
-
-# Default name for the model parallel rng tracker.
-_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
-_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
-_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
-
-
-def _set_cuda_rng_state(new_state, device=-1):
- """Sets the random number generator state of the current GPU.
-
- Argumentss:
- new_state (torch.ByteTensor): The desired state
- This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
- with a single change: the input state is not cloned. Cloning caused
- major performance issues for +4 GPU cases.
- """
- if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
- # older PyTorch
- def cb():
- with device_ctx_manager(device):
- _C._cuda_setRNGState(new_state)
-
- else:
- # newer PyTorch
- if device == -1:
- device = torch.device('cuda')
- elif isinstance(device, str):
- device = torch.device(device)
- elif isinstance(device, int):
- device = torch.device('cuda', device)
-
- def cb():
- idx = device.index
- if idx is None:
- idx = torch.cuda.current_device()
- default_generator = torch.cuda.default_generators[idx]
- default_generator.set_state(new_state)
-
- _lazy_call(cb)
-
-
-def get_expert_parallel_rng_tracker_name():
- global _EXPERT_PARALLEL_RNG_TRACKER_NAME
- return _EXPERT_PARALLEL_RNG_TRACKER_NAME
-
-
-def get_data_parallel_rng_tracker_name():
- global _DATA_PARALLEL_RNG_TRACKER_NAME
- return _DATA_PARALLEL_RNG_TRACKER_NAME
-
-
-class CudaRNGStatesTracker:
- """Tracker for the cuda RNG states.
-
- Using the `add` method, a cuda rng state is initialized based on
- the input `seed` and is assigned to `name`. Later, by forking the
- rng state, we can perform operations and return to our starting
- cuda state.
- """
-
- def __init__(self):
- # Map from a string name to the cuda rng state.
- self.states_ = {}
- # Seeds are just for book keeping and ensure no seed is set twice.
- self.seeds_ = set()
-
- def reset(self):
- """Set to the initial state (no tracker)."""
- self.states_ = {}
- self.seeds_ = set()
-
- def get_states(self):
- """Get rng states. Copy the dictionary so we have direct
- pointers to the states, not just a pointer to the dictionary."""
- states = {}
- for name in self.states_:
- states[name] = self.states_[name]
- return states
-
- def set_states(self, states):
- """Set the rng states. For efficiency purposes, we do not check
- the size of seed for compatibility."""
- self.states_ = states
-
- def add(self, name, seed):
- """Track the rng state."""
- # Check seed is not already used.
- if seed in self.seeds_:
- raise Exception('seed {} already exists'.format(seed))
- self.seeds_.add(seed)
- # Check that state is not already defined.
- if name in self.states_:
- raise Exception('cuda rng state {} already exists'.format(name))
- # Get the current rng state.
- orig_rng_state = torch.cuda.get_rng_state()
- # Set the new state and store it.
- torch.cuda.manual_seed(seed)
- self.states_[name] = torch.cuda.get_rng_state()
- # Reset rng state to what it was.
- _set_cuda_rng_state(orig_rng_state)
-
- @contextlib.contextmanager
- def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
- """Fork the cuda rng state, perform operations, and exit with
- the original state."""
- # Check if we have added the state
- if name not in self.states_:
- raise Exception('cuda rng state {} is not added'.format(name))
- # Store current rng state.
- orig_cuda_rng_state = torch.cuda.get_rng_state()
- # Set rng state to the desired one
- _set_cuda_rng_state(self.states_[name])
- # Do the stuff we wanted to do.
- try:
- yield
- finally:
- # Update the current rng state for later use.
- self.states_[name] = torch.cuda.get_rng_state()
- # And set the state to the original state we started with.
- _set_cuda_rng_state(orig_cuda_rng_state)
-
-
-# RNG tracker object.
-_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
-
-
-def get_cuda_rng_tracker():
- """Get cuda rng tracker."""
- return _CUDA_RNG_STATE_TRACKER
-
-
-def model_parallel_cuda_manual_seed(seed):
- """Initialize model parallel cuda seed.
-
- This function should be called after the model parallel is
- initialized. Also, no torch.cuda.manual_seed should be called
- after this function. Basically, this is replacement for that
- function.
- Two set of RNG states are tracked:
- default state: This is for data parallelism and is the same among a
- set of model parallel GPUs but different across
- different model paralle groups. This is used for
- example for dropout in the non-tensor-model-parallel regions.
- tensor-model-parallel state: This state is different among a set of model
- parallel GPUs, but the same across data parallel
- groups. This is used for example for dropout in
- model parallel regions.
- """
- # 2718 is just for fun and any POSITIVE value will work.
- offset = seed + 2718
- tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
- # Data parallel gets the original seed.
- data_parallel_seed = seed
-
- _CUDA_RNG_STATE_TRACKER.reset()
- # Set the default state.
- torch.cuda.manual_seed(data_parallel_seed)
- _CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed)
-
- # and model parallel state.
- _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
-
- expert_parallel_seed = (
- seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank()
- )
- _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
-
-
-class CheckpointFunction(torch.autograd.Function):
- """This function is adapted from torch.utils.checkpoint with
- two main changes:
- 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
- 2) the states in the model parallel tracker are also properly
- tracked/set/reset.
- """
-
- @staticmethod
- def forward(ctx, run_function, distribute_saved_activations, *args):
- ctx.run_function = run_function
- ctx.distribute_saved_activations = distribute_saved_activations
-
- # Copy the rng states.
- ctx.fwd_cpu_rng_state = torch.get_rng_state()
- ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
- ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
-
- with torch.no_grad():
- outputs = run_function(*args)
-
- # Divide hidden states across model parallel group and only keep
- # the chunk corresponding to the current rank.
- if distribute_saved_activations:
- ctx.input_0_shape = args[0].data.shape
- safely_set_viewless_tensor_data(
- args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
- )
-
- # Store everything.
- ctx.save_for_backward(*args)
-
- return outputs
-
- @staticmethod
- def backward(ctx, *args):
- if not torch.autograd._is_checkpoint_valid():
- raise RuntimeError(
- "Checkpointing is not compatible with .grad(), "
- "please use .backward() if possible"
- )
- inputs = ctx.saved_tensors
- if ctx.distribute_saved_activations:
- safely_set_viewless_tensor_data(
- inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
- )
-
- # Store the current states.
- bwd_cpu_rng_state = torch.get_rng_state()
- bwd_cuda_rng_state = torch.cuda.get_rng_state()
- bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
-
- # Set the states to what it used to be before the forward pass.
- torch.set_rng_state(ctx.fwd_cpu_rng_state)
- _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
- get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
-
- # Compute the forward pass.
- detached_inputs = detach_variable(inputs)
- with torch.enable_grad():
- outputs = ctx.run_function(*detached_inputs)
-
- # Set the states back to what it was at the start of this function.
- torch.set_rng_state(bwd_cpu_rng_state)
- _set_cuda_rng_state(bwd_cuda_rng_state)
- get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
-
- if isinstance(outputs, torch.Tensor):
- outputs = (outputs,)
- torch.autograd.backward(outputs, args)
- grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
- return (None, None) + grads
-
-
-def checkpoint(function, distribute_saved_activations, *args):
- """Checkpoint a model or part of the model.
- This has been directly copied from torch.utils.checkpoint."""
- return CheckpointFunction.apply(function, distribute_saved_activations, *args)
diff --git a/megatron/core/tensor_parallel/utils.py b/megatron/core/tensor_parallel/utils.py
deleted file mode 100644
index a79ae1e87e33463ab6d54ffef8c73ad214cff3ef..0000000000000000000000000000000000000000
--- a/megatron/core/tensor_parallel/utils.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from typing import List, Sequence
-
-import torch
-
-from megatron.core import parallel_state
-from megatron.core.utils import divide
-
-
-def split_tensor_along_last_dim(
- tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False,
-) -> List[torch.Tensor]:
- """ Split a tensor along its last dimension.
-
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
-
- Returns:
- A list of Tensors
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = divide(tensor.size()[last_dim], num_partitions)
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
-
- return tensor_list
-
-
-def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
- """ Break a tensor into equal 1D chunks across tensor parallel ranks.
-
- Returns a Tensor or View with this rank's portion of the data.
-
- Arguments:
- tensor: The tensor to split
-
- Keyword Arguments:
- new_buffer (bool): If True, returns a new Tensor.
- If False, returns a view into the existing Tensor.
- Default is False
-
- """
- partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size()
- start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
- end_index = start_index + partition_size
- if new_buffer:
- data = torch.empty(
- partition_size,
- dtype=tensor.dtype,
- device=torch.cuda.current_device(),
- requires_grad=False,
- )
- data.copy_(tensor.view(-1)[start_index:end_index])
- else:
- data = tensor.view(-1)[start_index:end_index]
- return data
-
-
-def gather_split_1d_tensor(tensor):
- """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
- model parallel ranks.
-
- Returns a new Tensor with the gathered data.
-
- Arguments:
- tensor: A Tensor or view of this rank's portion of the data.
- """
- numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size()
- gathered = torch.empty(
- numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False
- )
- # TODO: This API is experimental in pytorch (as of Feb 2022) and
- # this might break in future pytorch releases. We chose this API
- # as opposed to torch.distributed.all_gather for efficiency reasons.
- # This API calls directly NCCL all-gather versus the former does
- # internal copies and can potentially cause slow down.
- torch.distributed._all_gather_base(
- gathered, tensor, group=parallel_state.get_tensor_model_parallel_group()
- )
- return gathered
-
-
-class VocabUtility:
- """ Split the vocabulary into `world_size` chunks and return the first
- and last index of the vocabulary belonging to the `rank`
- partition: Note that indices in [fist, last)
-
- """
-
- @staticmethod
- def vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size: int, rank, world_size: int
- ) -> Sequence[int]:
- index_f = rank * per_partition_vocab_size
- index_l = index_f + per_partition_vocab_size
- return index_f, index_l
-
- @staticmethod
- def vocab_range_from_global_vocab_size(
- global_vocab_size: int, rank: int, world_size: int
- ) -> Sequence[int]:
- per_partition_vocab_size = divide(global_vocab_size, world_size)
- return VocabUtility.vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size, rank, world_size
- )
diff --git a/megatron/core/transformer/__init__.py b/megatron/core/transformer/__init__.py
deleted file mode 100644
index 7cc10776b7459542eb35ecb5e768dbef9bd54d05..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from .module import MegatronModule
-from .spec_utils import ModuleSpec, build_module
-from .transformer_config import TransformerConfig
-from .transformer_layer import TransformerLayer, TransformerLayerSubmodules
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
deleted file mode 100644
index c725c7f3a20b9276b8d58987ea5c47cce136a8d0..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/attention.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from typing import Union
-
-import torch
-
-from megatron.core import parallel_state, tensor_parallel
-from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec, build_module
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import divide
-
-from .enums import AttnMaskType
-from .transformer_config import TransformerConfig
-from .utils import make_sharded_tensors_for_checkpoint
-
-
-@dataclass
-class SelfAttentionSubmodules:
- linear_qkv: Union[ModuleSpec, type] = None
- core_attention: Union[ModuleSpec, type] = None
- linear_proj: Union[ModuleSpec, type] = None
-
-
-@dataclass
-class CrossAttentionSubmodules:
- linear_q: Union[ModuleSpec, type] = None
- linear_kv: Union[ModuleSpec, type] = None
- core_attention: Union[ModuleSpec, type] = None
- linear_proj: Union[ModuleSpec, type] = None
-
-
-class Attention(MegatronModule, ABC):
- """Attention layer abstract class.
-
- This layer only contains common modules required for the "self attn" and
- "cross attn" specializations.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
- layer_number: int,
- attn_mask_type: AttnMaskType,
- attention_type: str,
- ):
- super().__init__(config=config)
-
- self.config = config
- self.layer_number = layer_number
- self.attn_mask_type = attn_mask_type
- self.attention_type = attention_type
-
- # For normal attention without groups, num_query_groups == num_attention_heads,
- # so these two will be the same
- self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
- self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
-
- # Per attention head and per partition values.
- world_size = parallel_state.get_tensor_model_parallel_world_size()
- self.hidden_size_per_attention_head = divide(
- self.query_projection_size, self.config.num_attention_heads
- )
- self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
- self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
-
- self.core_attention = build_module(
- submodules.core_attention,
- config=self.config,
- layer_number=self.layer_number,
- attn_mask_type=self.attn_mask_type,
- attention_type=self.attention_type,
- )
-
- self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
-
- # Output.
- self.linear_proj = build_module(
- submodules.linear_proj,
- self.query_projection_size,
- self.config.hidden_size,
- config=self.config,
- init_method=self.config.output_layer_init_method,
- bias=self.config.add_bias_linear,
- input_is_parallel=True,
- skip_bias_add=True,
- is_expert=False,
- tp_comm_buffer_name='proj',
- )
-
- def _checkpointed_attention_forward(
- self, query, key, value, attention_mask, rotary_pos_emb=None, attn_mask_type=None
- ):
- """Forward method with selective activation checkpointing."""
-
- def custom_forward(*inputs):
- query = inputs[0]
- key = inputs[1]
- value = inputs[2]
- attention_mask = inputs[3]
- attn_mask_type = inputs[5]
- attn_mask_type = AttnMaskType(attn_mask_type.item())
- output_ = self.core_attention(
- query, key, value, attention_mask, attn_mask_type=attn_mask_type
- )
- return output_
-
- if attn_mask_type is None:
- attn_mask_type = self.attn_mask_type
- attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
- hidden_states = tensor_parallel.checkpoint(
- custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
- )
-
- return hidden_states
-
- def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
- """Allocate memory to store kv cache during inference."""
-
- return torch.empty(
- inference_max_sequence_length,
- batch_size,
- self.num_query_groups_per_partition,
- self.hidden_size_per_attention_head,
- dtype=dtype,
- device=torch.cuda.current_device(),
- )
-
- def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb):
- """
- Saves the generated key and value tensors to the end of the buffers in inference_params.
- Returns the full size keys and values from the provided inference_params, as well as
- adjusted rotary_pos_emb.
-
- Returns a tuple: (key, value, rotary_pos_emb)
-
- """
- attn_mask_type = self.attn_mask_type
- if inference_params is None:
- return key, value, rotary_pos_emb, attn_mask_type
-
- # =================================================
- # Pre-allocate memory for key-values for inference.
- # =================================================
- is_first_step = False
- if self.layer_number not in inference_params.key_value_memory_dict:
- inf_max_seq_length = inference_params.max_sequence_length
- inf_max_batch_size = inference_params.max_batch_size
- inference_key_memory = self._allocate_memory(
- inf_max_seq_length, inf_max_batch_size, key.dtype
- )
- inference_value_memory = self._allocate_memory(
- inf_max_seq_length, inf_max_batch_size, value.dtype
- )
- inference_params.key_value_memory_dict[self.layer_number] = (
- inference_key_memory,
- inference_value_memory,
- )
- is_first_step = True
- else:
- # Get the pre-allocated buffers for this layer
- inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
- self.layer_number
- ]
- attn_mask_type = AttnMaskType.no_mask
-
- batch_start = inference_params.batch_size_offset
- batch_end = batch_start + key.size(1)
- assert batch_end <= inference_key_memory.size(1)
- sequence_start = inference_params.sequence_len_offset
- sequence_end = sequence_start + key.size(0)
- assert sequence_end <= inference_key_memory.size(0)
- # Copy key and values.
- inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
- inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
- key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
- value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
-
- # adjust the key rotary positional embedding
- if rotary_pos_emb is not None:
- q_pos_emb, k_pos_emb = rotary_pos_emb
- # need to cross check this condition during inference
- # if not set_inference_key_value_memory:
- if not is_first_step:
- # In inference, we compute one token at a time.
- # Select the correct positional embedding
- # (only the last token in the sequence)
- q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
- else:
- # In the first forward pass of inference,
- # we use the entire provided prefix.
- # q_pos_emb here has the rope embeddings of the entire
- # prefix + to-be-generated output so
- # we slice to just the prefix.
- q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
- k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
- rotary_pos_emb = (q_pos_emb, k_pos_emb)
-
- return key, value, rotary_pos_emb, attn_mask_type
-
- @abstractmethod
- def get_query_key_value_tensors(self, hidden_states, key_value_states):
- """
- This method needs to be implemented based on whether the derived class
- is "self-attn" or "cross-attn".
- """
-
- def forward(
- self,
- hidden_states,
- attention_mask,
- key_value_states=None,
- inference_params=None,
- rotary_pos_emb=None,
- ):
- # hidden_states: [sq, b, h]
-
- # For self attention we just duplicate the rotary_pos_emb if it isn't already
- if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
- rotary_pos_emb = (rotary_pos_emb,) * 2
-
- # =====================
- # Query, Key, and Value
- # =====================
- # Get the query, key and value tensors based on the type of attention -
- # self or cross attn.
- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
-
- # ===================================================
- # Adjust key, value, and rotary_pos_emb for inference
- # ===================================================
- key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
- inference_params, key, value, rotary_pos_emb
- )
-
- # ================================================
- # relative positional embedding (rotary embedding)
- # ================================================
- if rotary_pos_emb is not None:
- q_pos_emb, k_pos_emb = rotary_pos_emb
- query = apply_rotary_pos_emb(query, q_pos_emb)
- key = apply_rotary_pos_emb(key, k_pos_emb)
- # TODO, can apply positional embedding to value_layer so it has
- # absolute positional embedding.
- # otherwise, only relative positional embedding takes effect
- # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
-
- # ==================================
- # core attention computation
- # ==================================
-
- if self.checkpoint_core_attention:
- core_attn_out = self._checkpointed_attention_forward(
- query, key, value, attention_mask, attn_mask_type=attn_mask_type
- )
- else:
- core_attn_out = self.core_attention(
- query, key, value, attention_mask, attn_mask_type=attn_mask_type
- )
-
- # =================
- # Output. [sq, b, h]
- # =================
-
- output, bias = self.linear_proj(core_attn_out)
-
- return output, bias
-
-
-class SelfAttention(Attention):
- """Self-attention layer class
-
- Self-attention layer takes input with size [s, b, h]
- and returns output of the same size.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- submodules: SelfAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- ):
- super().__init__(
- config=config,
- submodules=submodules,
- layer_number=layer_number,
- attn_mask_type=attn_mask_type,
- attention_type="self",
- )
-
- self.linear_qkv = build_module(
- submodules.linear_qkv,
- self.config.hidden_size,
- self.query_projection_size + 2 * self.kv_projection_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear,
- skip_bias_add=False,
- is_expert=False,
- tp_comm_buffer_name='qkv',
- )
-
- def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
- """
- Derives `query`, `key` and `value` tensors from `hidden_states`.
- """
- # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
- mixed_qkv, _ = self.linear_qkv(hidden_states)
-
- # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
- new_tensor_shape = mixed_qkv.size()[:-1] + (
- self.num_query_groups_per_partition,
- (
- (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
- * self.hidden_size_per_attention_head
- ),
- )
- mixed_qkv = mixed_qkv.view(*new_tensor_shape)
-
- # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
- (query, key, value) = torch.split(
- mixed_qkv,
- [
- (
- self.num_attention_heads_per_partition
- // self.num_query_groups_per_partition
- * self.hidden_size_per_attention_head
- ),
- self.hidden_size_per_attention_head,
- self.hidden_size_per_attention_head,
- ],
- dim=3,
- )
- # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
- query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
-
- return query, key, value
-
- def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()):
- sharded_key_prefix = prefix if sharded_key_prefix is None else sharded_key_prefix
- sharded_state_dict = {}
- for name, module in (
- ('linear_qkv', self.linear_qkv),
- ('linear_proj', self.linear_proj),
- ):
- sub_sd = module.sharded_state_dict(
- prefix=f'{prefix}{name}.',
- sharded_key_prefix=f'{sharded_key_prefix}{name}.',
- sharded_offsets=sharded_offsets,
- )
- sharded_state_dict.update(sub_sd)
- return sharded_state_dict
-
-
-class CrossAttention(Attention):
- """Cross-attention layer class
-
- Cross-attention layer takes input with size [s, b, h] and context with size
- [s, b, h] and returns output of the same size.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- submodules: CrossAttentionSubmodules,
- layer_number: int,
- attn_mask_type=AttnMaskType.padding,
- ):
- super().__init__(
- config=config,
- submodules=submodules,
- layer_number=layer_number,
- attn_mask_type=attn_mask_type,
- attention_type="cross",
- )
-
- if self.config.num_query_groups != self.config.num_attention_heads:
- raise ValueError(
- f"Group query attention is not currently supported in cross attention."
- )
- assert self.query_projection_size == self.kv_projection_size
-
- self.linear_q = build_module(
- submodules.linear_q,
- self.config.hidden_size,
- self.query_projection_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear,
- skip_bias_add=False,
- is_expert=False,
- )
-
- self.linear_kv = build_module(
- submodules.linear_kv,
- self.config.hidden_size,
- 2 * self.kv_projection_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear,
- skip_bias_add=False,
- is_expert=False,
- )
-
- def get_query_key_value_tensors(self, hidden_states, key_value_states):
- """
- Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
- from `key_value_states`.
- """
- # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
- mixed_kv, _ = self.linear_kv(key_value_states)
-
- # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
- new_tensor_shape = mixed_kv.size()[:-1] + (
- self.num_attention_heads_per_partition,
- 2 * self.hidden_size_per_attention_head,
- )
- mixed_kv = mixed_kv.view(*new_tensor_shape)
-
- # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
- (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)
-
- # Attention head [sq, b, h] --> [sq, b, hp]
- query, _ = self.linear_q(hidden_states)
-
- # [sq, b, hp] --> [sq, b, np, hn]
- new_tensor_shape = query.size()[:-1] + (
- self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head,
- )
- query = query.view(*new_tensor_shape)
-
- return query, key, value
diff --git a/megatron/core/transformer/custom_layers/__init__.py b/megatron/core/transformer/custom_layers/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/core/transformer/custom_layers/transformer_engine.py b/megatron/core/transformer/custom_layers/transformer_engine.py
deleted file mode 100644
index d784184623a877480b4f945ebf6e44661e2fc1a3..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/custom_layers/transformer_engine.py
+++ /dev/null
@@ -1,431 +0,0 @@
-import os
-from importlib.metadata import version
-from typing import Callable
-
-import torch
-import transformer_engine as te
-from pkg_resources import packaging
-from torch import Tensor
-
-from megatron.core import ModelParallelConfig
-from megatron.core.parallel_state import (
- get_context_parallel_global_ranks,
- get_context_parallel_group,
- get_tensor_model_parallel_group,
-)
-from megatron.core.tensor_parallel import get_cuda_rng_tracker
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
-
-
-def _get_extra_te_kwargs(config: TransformerConfig):
- extra_transformer_engine_kwargs = {
- "params_dtype": config.params_dtype,
- }
-
- te_version = packaging.version.Version(version("transformer-engine"))
- if te_version >= packaging.version.Version("0.12.0"):
- if config.use_cpu_initialization:
- extra_transformer_engine_kwargs["device"] = 'cpu'
- else:
- extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
- return extra_transformer_engine_kwargs
-
-
-class TENorm:
- """
- A conditional wrapper to initialize an instance of Transformer-Engine's
- `LayerNorm` or `RMSNorm` based on input
- """
-
- # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
- def __new__(
- cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5,
- ):
- if config.normalization == "LayerNorm":
- instance = te.pytorch.LayerNorm(
- hidden_size=hidden_size,
- eps=eps,
- sequence_parallel=config.sequence_parallel,
- zero_centered_gamma=config.layernorm_zero_centered_gamma,
- **_get_extra_te_kwargs(config),
- )
- elif config.normalization == "RMSNorm":
- assert hasattr(
- te.pytorch, "RMSNorm"
- ), "Transformer-Engine >= v0.11 required to use this feature"
- instance = te.pytorch.RMSNorm(
- hidden_size=hidden_size,
- eps=eps,
- sequence_parallel=config.sequence_parallel,
- zero_centered_gamma=config.layernorm_zero_centered_gamma,
- **_get_extra_te_kwargs(config),
- )
- else:
- raise Exception('Only LayerNorm and RMSNorm are curently supported')
-
- return instance
-
-
-class TELinear(te.pytorch.Linear):
- """
- Wrapper for the Transformer-Engine's `Linear` layer.
-
- Note that if Megatron's parallel_state has not been initialized
- yet, the tp_group passed to TE will be None and must be set later
- via set_tensor_parallel_group().
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- *,
- parallel_mode: str,
- config: ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- skip_bias_add: bool,
- skip_weight_param_allocation: bool,
- tp_comm_buffer_name: str = None,
- ):
- self.config = config
-
- # TE returns a zero length Tensor when bias=False and
- # return_bias=True, but we prefer None. So in that case we
- # tell TE to not return the bias, and return None
- # ourselves. This way our forward always returns two values
- # and we don't have to deal with the zero length Tensor.
- self.te_return_bias = skip_bias_add and bias
-
- if skip_weight_param_allocation:
- raise ValueError(
- 'Transformer Engine linear layers do not support skip_weight_param_allocation'
- )
-
- extra_kwargs = _get_extra_te_kwargs(config)
-
- te_version = packaging.version.Version(version("transformer-engine"))
- if te_version >= packaging.version.Version("0.8.0"):
- if self.config.tp_comm_overlap:
- extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
- extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
- if te_version > packaging.version.Version("1.0.0"):
- assert (
- tp_comm_buffer_name is not None
- ), "Buffer name should be set to configure communication overlap settings"
- extra_kwargs["ub_name"] = tp_comm_buffer_name
-
- super().__init__(
- in_features=input_size,
- out_features=output_size,
- sequence_parallel=self.config.sequence_parallel,
- fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
- tp_group=get_tensor_model_parallel_group(check_initialized=False),
- tp_size=self.config.tensor_model_parallel_size,
- get_rng_state_tracker=get_cuda_rng_tracker,
- init_method=init_method,
- bias=bias,
- return_bias=self.te_return_bias,
- parallel_mode=parallel_mode,
- **extra_kwargs,
- )
-
- def forward(self, x):
- out = super().forward(x)
-
- # TE only returns a tuple when return_bias is True, otherwise
- # it returns a single Tensor, we always want to return two
- # values regardless of the arguments.
- if self.te_return_bias:
- return out
- return out, None
-
-
-class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
- """
- Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
- layernorm and linear layers
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- *,
- config: TransformerConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: str = None,
- ):
- self.config = config
-
- if gather_output:
- raise ValueError('Transformer Engine linear layers do not support gather_output = True')
-
- if is_expert:
- raise ValueError('Transformer Engine linear layers do not yet support MoE')
-
- if skip_weight_param_allocation:
- raise ValueError(
- 'Transformer Engine linear layers do not support skip_weight_param_allocation'
- )
-
- # TE returns a zero length Tensor when bias=False and
- # return_bias=True, but we prefer None. So in that case we
- # tell TE to not return the bias, and return None
- # ourselves. This way our forward always returns two values
- # and we don't have to deal with the zero length Tensor.
- self.te_return_bias = skip_bias_add and bias
-
- extra_kwargs = _get_extra_te_kwargs(config)
-
- # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
- te_version = packaging.version.Version(version("transformer-engine"))
- if te_version >= packaging.version.Version("0.11.0"):
- extra_kwargs["normalization"] = self.config.normalization
- elif self.config.normalization != "LayerNorm":
- raise ValueError(
- f"Transformer Engine v{te_version} does not support {self.config.normalization}."
- )
-
- if te_version >= packaging.version.Version("0.8.0"):
- if self.config.tp_comm_overlap:
- extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
- extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
- extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
- if te_version > packaging.version.Version("1.0.0"):
- assert (
- tp_comm_buffer_name is not None
- ), "Buffer name should be set to configure communication overlap settings"
- extra_kwargs["ub_name"] = tp_comm_buffer_name
-
- super().__init__(
- in_features=input_size,
- out_features=output_size,
- eps=self.config.layernorm_epsilon,
- sequence_parallel=self.config.sequence_parallel,
- fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
- tp_group=get_tensor_model_parallel_group(check_initialized=False),
- tp_size=self.config.tensor_model_parallel_size,
- get_rng_state_tracker=get_cuda_rng_tracker,
- init_method=init_method,
- bias=bias,
- return_bias=self.te_return_bias,
- parallel_mode="column",
- return_layernorm_output=False,
- zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
- **extra_kwargs,
- )
-
- def forward(self, x):
- out = super().forward(x)
-
- # TE only returns a tuple when return_bias is True, otherwise
- # it returns a single Tensor, we always want to return two
- # values regardless of the arguments.
- if self.te_return_bias:
- return out
- return out, None
-
- def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()):
- """ Sharding along axis 0, bias sharded """
- state_dict = self.state_dict(prefix='', keep_vars=True)
- return make_sharded_tensors_for_checkpoint(
- state_dict, prefix, sharded_key_prefix, {'weight': 0, 'bias': 0}, sharded_offsets
- )
-
-
-class TEColumnParallelLinear(TELinear):
- """
- Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
- to megatron's `ColumnParallelLinear` layer.
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- *,
- config: ModelParallelConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: str = None,
- ):
- if gather_output:
- raise ValueError('Transformer Engine linear layers do not support gather_output = True')
-
- if is_expert:
- raise ValueError('Transformer Engine linear layers do not yet support MoE')
-
- super().__init__(
- input_size=input_size,
- output_size=output_size,
- parallel_mode="column",
- config=config,
- init_method=init_method,
- bias=bias,
- skip_bias_add=skip_bias_add,
- skip_weight_param_allocation=skip_weight_param_allocation,
- tp_comm_buffer_name=tp_comm_buffer_name,
- )
-
- def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()):
- """ Sharding along axis 0, bias sharded """
- state_dict = self.state_dict(prefix='', keep_vars=True)
- return make_sharded_tensors_for_checkpoint(
- state_dict, prefix, sharded_key_prefix, {'weight': 0, 'bias': 0}, sharded_offsets
- )
-
-
-class TERowParallelLinear(TELinear):
- """
- Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
- to megatron's `RowParallelLinear` layer.
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- *,
- config: ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: str = None,
- ):
- if not input_is_parallel:
- raise ValueError(
- "Transformer Engine linear layers do not support input_is_parallel = False"
- )
-
- if is_expert:
- raise ValueError('Transformer Engine linear layers do not yet support MoE')
-
- super().__init__(
- input_size=input_size,
- output_size=output_size,
- parallel_mode="row",
- config=config,
- init_method=init_method,
- bias=bias,
- skip_bias_add=skip_bias_add,
- skip_weight_param_allocation=False, # We don't currently use this for row parallel layers
- tp_comm_buffer_name=tp_comm_buffer_name,
- )
-
- def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()):
- """ Sharding along axis 1, bias not sharded """
- state_dict = self.state_dict(prefix='', keep_vars=True)
- return make_sharded_tensors_for_checkpoint(
- state_dict, prefix, sharded_key_prefix, {'weight': 1}, sharded_offsets
- )
-
-
-class TEDotProductAttention(te.pytorch.DotProductAttention):
- """
- Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
- has "flash attention" enabled.
-
- Note that if Megatron's parallel_state has not been initialized yet, the
- tp_group and cp_group passed to TE will be None and must be set later
- via set_tensor_parallel_group() and set_context_parallel_group().
- """
-
- cp_stream: torch.cuda.Stream = None
-
- def __init__(
- self,
- config: TransformerConfig,
- layer_number: int,
- attn_mask_type: AttnMaskType,
- attention_type: str,
- attention_dropout: float = None,
- ):
- self.config = config
- self.te_forward_mask_type = False
-
- if self.config.apply_query_key_layer_scaling != bool(
- int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
- ):
- raise ValueError(
- f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
- f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
- f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
- f"setting query key layer scaling via argument, so these two must match."
- )
-
- extra_kwargs = {}
- te_version = packaging.version.Version(version("transformer-engine"))
- if te_version >= packaging.version.Version("0.11.0"):
- extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
- elif self.config.num_query_groups != self.config.num_attention_heads:
- raise ValueError(
- f"Transformer Engine v{te_version} does not support Grouped Query Attention, "
- f"use a newer version of Transformer Engine. "
- f"(num_query_groups ({self.config.num_query_groups}) != "
- f"num_attention_heads ({self.config.num_attention_heads}))"
- )
-
- if te_version >= packaging.version.Version("0.10.0"):
- extra_kwargs["attention_type"] = attention_type
- # older version don't need attention_type
-
- if te_version > packaging.version.Version("0.12.0"):
- self.te_forward_mask_type = True
-
- # Only Transformer-Engine version >= 1.0.0 supports context parallelism
- if te_version >= packaging.version.Version("1.0.0"):
- if getattr(TEDotProductAttention, "cp_stream") is None:
- TEDotProductAttention.cp_stream = torch.cuda.Stream()
- extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
- extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
- check_initialized=False
- )
- extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
- else:
- assert (
- self.config.context_parallel_size == 1
- ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
-
- super().__init__(
- num_attention_heads=self.config.num_attention_heads,
- kv_channels=self.config.kv_channels,
- attention_dropout=self.config.attention_dropout
- if attention_dropout is None
- else attention_dropout,
- attn_mask_type=attn_mask_type.name,
- sequence_parallel=self.config.sequence_parallel,
- tp_size=self.config.tensor_model_parallel_size,
- get_rng_state_tracker=get_cuda_rng_tracker,
- tp_group=get_tensor_model_parallel_group(check_initialized=False),
- layer_number=layer_number,
- **extra_kwargs,
- )
-
- def forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- attention_mask: Tensor,
- attn_mask_type: AttnMaskType,
- ):
- if self.te_forward_mask_type:
- return super().forward(
- query, key, value, attention_mask, attn_mask_type=attn_mask_type.name
- )
- else:
- return super().forward(query, key, value, attention_mask)
diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py
deleted file mode 100644
index 7eab478bd00a1c4507e52f80465c31f214a47fd1..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/dot_product_attention.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-
-import math
-
-import torch
-from torch import Tensor
-
-from megatron.core import parallel_state, tensor_parallel
-from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import attention_mask_func
-from megatron.core.utils import divide
-
-
-class DotProductAttention(MegatronModule):
- """
- Region where selective activation recomputation is applied.
- This region is memory intensive but less compute intensive which
- makes activation checkpointing more efficient for LLMs (20B+).
- See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
-
- We use the following notation:
- h: hidden size
- n: number of attention heads
- p: number of tensor model parallel partitions
- b: batch size
- s: sequence length
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- layer_number: int,
- attn_mask_type: AttnMaskType,
- attention_type: str,
- attention_dropout: float = None,
- ):
- super().__init__(config=config)
-
- self.config: TransformerConfig = config
-
- assert (
- self.config.context_parallel_size == 1
- ), "Context parallelism is only supported by TEDotProductAttention!"
-
- self.layer_number = max(1, layer_number)
- self.attn_mask_type = attn_mask_type
- self.attention_type = attention_type # unused for now
-
- projection_size = self.config.kv_channels * self.config.num_attention_heads
-
- # Per attention head and per partition values.
- world_size = parallel_state.get_tensor_model_parallel_world_size()
- self.hidden_size_per_partition = divide(projection_size, world_size)
- self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
- self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
- self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
-
- coeff = None
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- if self.config.apply_query_key_layer_scaling:
- coeff = self.layer_number
- self.norm_factor *= coeff
-
- self.scale_mask_softmax = FusedScaleMaskSoftmax(
- input_in_fp16=self.config.fp16,
- input_in_bf16=self.config.bf16,
- attn_mask_type=self.attn_mask_type,
- scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
- mask_func=attention_mask_func,
- softmax_in_fp32=self.config.attention_softmax_in_fp32,
- scale=coeff,
- )
-
- # Dropout. Note that for a single iteration, this layer will generate
- # different outputs on different number of parallel partitions but
- # on average it should not be partition dependent.
- self.attention_dropout = torch.nn.Dropout(
- self.config.attention_dropout if attention_dropout is None else attention_dropout
- )
-
- def forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- attention_mask: Tensor,
- attn_mask_type: AttnMaskType = None,
- ):
-
- # ===================================
- # Raw attention scores. [b, n/p, s, s]
- # ===================================
-
- # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
- # This is a noop for normal attention where ng == np. When using group query attention this
- # creates a view that has the keys and values virtually repeated along their dimension to
- # match the number of queries.
-
- # attn_mask_type is not used.
- if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
- key = key.repeat_interleave(
- self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
- )
- value = value.repeat_interleave(
- self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
- )
-
- # [b, np, sq, sk]
- output_size = (
- query.size(1),
- query.size(2),
- query.size(0),
- key.size(0),
- )
-
- # [sq, b, np, hn] -> [sq, b * np, hn]
- # This will be a simple view when doing normal attention, but in group query attention
- # the key and value tensors are repeated to match the queries so you can't use simple strides
- # to extract the queries.
- query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key = key.view(output_size[3], output_size[0] * output_size[1], -1)
-
- # preallocting input tensor: [b * np, sq, sk]
- matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
- (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
- )
-
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_input_buffer,
- query.transpose(0, 1), # [b * np, sq, hn]
- key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0,
- alpha=(1.0 / self.norm_factor),
- )
-
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
-
- # ===========================
- # Attention probs and dropout
- # ===========================
-
- # attention scores and attention mask [b, np, sq, sk]
- attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
-
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
-
- if not self.config.sequence_parallel:
- with tensor_parallel.get_cuda_rng_tracker().fork():
- attention_probs = self.attention_dropout(attention_probs)
- else:
- attention_probs = self.attention_dropout(attention_probs)
-
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
-
- # value -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
-
- # context layer shape: [b, np, sq, hn]
- output_size = (
- value.size(1),
- value.size(2),
- query.size(0),
- value.size(3),
- )
-
- # change view [sk, b * np, hn]
- value = value.view(value.size(0), output_size[0] * output_size[1], -1)
-
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
-
- # matmul: [b * np, sq, hn]
- context = torch.bmm(attention_probs, value.transpose(0, 1))
-
- # change view [b, np, sq, hn]
- context = context.view(*output_size)
-
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context = context.permute(2, 0, 1, 3).contiguous()
-
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
- context = context.view(*new_context_shape)
-
- return context
diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py
deleted file mode 100644
index ab72f3536854413443eb56455fe96171aef5a72e..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/enums.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import enum
-
-
-# can we get rid of this?
-# it's being used in pipeline schedules
-class ModelType(enum.Enum):
- encoder_or_decoder = 1
- encoder_and_decoder = 2
-
-
-# class LayerType(enum.Enum):
-# encoder = 1
-# decoder = 2
-
-
-class AttnType(enum.Enum):
- self_attn = 1
- cross_attn = 2
-
-
-class AttnMaskType(enum.Enum):
- padding = 1
- causal = 2
- no_mask = 3 # only used for TE
diff --git a/megatron/core/transformer/identity_op.py b/megatron/core/transformer/identity_op.py
deleted file mode 100644
index 5d9388ffcc628bdd0f04dd5969b9e669153446a8..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/identity_op.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-import torch
-
-
-class IdentityOp(torch.nn.Module):
- """
- This is a placeholder for IdentityOp(x) -> x
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- def forward(self, x, *args, **kwargs):
- return x
-
-
-class IdentityFuncOp(IdentityOp):
- """
- This is a placeholder for IdentityFuncOp(...)(x) -> IdentityOp(x) -> x.
- Such a func is handy for ops like `bias_dropout_fusion` which themselves
- return a function at runtime based on passed arguments
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- def forward(self, *args, **kwargs):
- return super().forward
diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py
deleted file mode 100644
index 8f5575b72467ed92752fc06ff5eeb015cb97db92..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/mlp.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from dataclasses import dataclass
-from typing import Tuple, Union
-
-import torch
-import torch.nn.functional as F
-
-from megatron.core import parallel_state
-from megatron.core.dist_checkpointing import ShardedTensor
-from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory
-from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec, build_module
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
-
-
-@dataclass
-class MLPSubmodules:
- linear_fc1: Union[ModuleSpec, type] = None
- linear_fc2: Union[ModuleSpec, type] = None
-
-
-class MLP(MegatronModule):
- """
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension.
-
-
- Returns an output and a bias to be added to the output.
- If config.add_bias_linear is False, the bias returned is None.
-
- We use the following notation:
- h: hidden size
- p: number of tensor model parallel partitions
- b: batch size
- s: sequence length
- """
-
- def __init__(
- self, config: TransformerConfig, submodules: MLPSubmodules, is_expert: bool = False
- ):
- super().__init__(config=config)
-
- self.config: TransformerConfig = config
-
- # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
- ffn_hidden_size = self.config.ffn_hidden_size
- if self.config.gated_linear_unit:
- ffn_hidden_size *= 2
-
- self.linear_fc1 = build_module(
- submodules.linear_fc1,
- self.config.hidden_size,
- ffn_hidden_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear,
- skip_bias_add=True,
- is_expert=is_expert,
- tp_comm_buffer_name='fc1',
- )
-
- if self.config.gated_linear_unit:
-
- def glu(x):
- x = torch.chunk(x, 2, dim=-1)
- return self.config.activation_func(x[0]) * x[1]
-
- self.activation_func = glu
- else:
- self.activation_func = self.config.activation_func
-
- self.linear_fc2 = build_module(
- submodules.linear_fc2,
- self.config.ffn_hidden_size,
- self.config.hidden_size,
- config=self.config,
- init_method=self.config.output_layer_init_method,
- bias=self.config.add_bias_linear,
- input_is_parallel=True,
- skip_bias_add=True,
- is_expert=is_expert,
- tp_comm_buffer_name='fc2',
- )
-
- def forward(self, hidden_states):
-
- # [s, b, 4 * h/p]
- intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
-
- if self.config.bias_gelu_fusion:
- assert self.config.add_bias_linear is True
- assert self.activation_func == F.gelu
- intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
- else:
- if bias_parallel is not None:
- intermediate_parallel = intermediate_parallel + bias_parallel
- intermediate_parallel = self.activation_func(intermediate_parallel)
-
- # [s, b, h]
- output, output_bias = self.linear_fc2(intermediate_parallel)
-
- return output, output_bias
-
- def sharded_state_dict(self, prefix='', sharded_key_prefix=None, sharded_offsets=()):
- sharded_key_prefix = prefix if sharded_key_prefix is None else sharded_key_prefix
- sharded_state_dict = {}
- for name, module in self._modules.items():
- if name == 'linear_fc1' and self.config.gated_linear_unit:
- sub_sd = self._sharded_state_dict_for_glu(
- name, module, prefix, sharded_key_prefix, sharded_offsets
- )
- else:
- sub_sd = module.sharded_state_dict(
- prefix=f'{prefix}{name}.',
- sharded_key_prefix=f'{sharded_key_prefix}{name}.',
- sharded_offsets=sharded_offsets,
- )
- sharded_state_dict.update(sub_sd)
- return sharded_state_dict
-
- def _sharded_state_dict_for_glu(
- self,
- module_name: str,
- module: torch.nn.Module,
- prefix: str,
- sharded_key_prefix: str,
- sharded_offsets: Tuple[Tuple[int, int, int]],
- ):
- assert module_name == 'linear_fc1', module_name
- sharded_state_dict = module.sharded_state_dict(
- prefix=f'{prefix}{module_name}.',
- sharded_key_prefix=f'{sharded_key_prefix}{module_name}.',
- sharded_offsets=sharded_offsets,
- )
- weight_key = f'{prefix}{module_name}.weight'
- prev_sh_ten = sharded_state_dict[weight_key]
-
- # We must split the tensor into 2 parts, each sharded separately.
- # This requires a ShardedTensorFactory which `chunk`s during saving
- # and `cat`s during loading
- tp_rank = parallel_state.get_tensor_model_parallel_rank()
- tp_size = parallel_state.get_tensor_model_parallel_world_size()
-
- tp_shard_axis = 0
- replica_id = prev_sh_ten.replica_id
- prepend_axis_num = len(sharded_offsets)
-
- def sh_ten_build_fn(key: str, t: torch.Tensor):
- offset_w = (tp_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
- offset_v = (tp_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
- with torch.no_grad():
- tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_shard_axis)
- return [
- ShardedTensor.from_rank_offsets(
- key,
- tensor_w,
- *sharded_offsets,
- offset_w,
- replica_id=replica_id,
- prepend_axis_num=1,
- ),
- ShardedTensor.from_rank_offsets(
- key,
- tensor_v,
- *sharded_offsets,
- offset_v,
- replica_id=replica_id,
- prepend_axis_num=1,
- ),
- ]
-
- def sh_ten_merge_fn(sub_state_dict):
- with torch.no_grad():
- return torch.cat(sub_state_dict)
-
- sharded_state_dict[weight_key] = ShardedTensorFactory(
- prev_sh_ten.key, prev_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn
- )
- return sharded_state_dict
diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py
deleted file mode 100644
index d20074aa07644dd8c556b69646eb3d0ffaf6012e..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/module.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-"""Megatron Module."""
-
-import torch
-from torch.autograd import Variable
-from torch.nn.parameter import Parameter
-
-from megatron.core import parallel_state
-from megatron.core.transformer.transformer_config import TransformerConfig
-
-_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
-_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
-_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
-
-
-def param_is_not_shared(param):
- return not hasattr(param, 'shared') or not param.shared
-
-
-class MegatronModule(torch.nn.Module):
- """Base Megatron module inhertied by all Models.
-
- Megatron specific extensions of torch Module with support
- for pipelining
-
- Args:
- config (TransformerConfig): Transformer config
- """
-
- # def __init__(self, config: TransformerConfig, share_word_embeddings=True):
- def __init__(self, config: TransformerConfig):
- super().__init__()
- self.config = config
-
- def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
- """Override state dict for saving checkpoints Use this function to override the
- state dict for saving checkpoints.
-
- Args:
- prefix (str, optional): _description_. Defaults to ''.
- keep_vars (bool, optional): _description_. Defaults to False.
-
- Returns:
- _type_: _description_
- """
-
- return self.state_dict(prefix=prefix, keep_vars=keep_vars)
-
- def sharded_state_dict(self, prefix: str = ''):
- """Override sharded state dict with Dist Checkpointing.
-
- Override sharded_state_dict when using distributed checkpointing. keep_vars must always be set to True so that optimizer states can be sharded.
-
- Args:
- prefix (str, optional): _description_. Defaults to ''.
-
- Returns:
- _type_: _description_
- """
- return self.state_dict(prefix=prefix, keep_vars=True)
-
-
-def conversion_helper(val, conversion):
- if not isinstance(val, (tuple, list)):
- return conversion(val)
- rtn = [conversion_helper(v, conversion) for v in val]
- if isinstance(val, tuple):
- rtn = tuple(rtn)
- return rtn
-
-
-def fp32_to_float16(val, float16_convertor):
- def half_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, _FLOAT_TYPES):
- val = float16_convertor(val)
- return val
-
- return conversion_helper(val, half_conversion)
-
-
-def float16_to_fp32(val):
- def float_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
- val = val.float()
- return val
-
- return conversion_helper(val, float_conversion)
-
-
-class Float16Module(MegatronModule):
- """Float 16 Module.
-
- Attributes:
- config (TransformerConfig): Transformer config
- fp16 (bool) : Specifies if the model runs in fp16 mode
- bf16 (bool) : Specifies if the model runs in bf16 mode
-
- Args:
- config (TransformerConfig): The transformer config used to initalize the model
- """
-
- def __init__(self, config: TransformerConfig, module: torch.nn.Module):
- super(Float16Module, self).__init__(config)
- self.config = config
- self.fp16 = config.fp16
- self.bf16 = config.bf16
-
- if self.fp16:
- self.add_module('module', module.half())
-
- def float16_convertor(val):
- return val.half()
-
- elif self.bf16:
- self.add_module('module', module.bfloat16())
-
- def float16_convertor(val):
- return val.bfloat16()
-
- else:
- raise Exception('Either config.fp16 or config.bf16 should be True.')
-
- self.float16_convertor = float16_convertor
-
- def set_input_tensor(self, input_tensor):
- return self.module.set_input_tensor(input_tensor)
-
- def forward(self, *inputs, **kwargs):
- if parallel_state.is_pipeline_first_stage():
- inputs = fp32_to_float16(inputs, self.float16_convertor)
- outputs = self.module(*inputs, **kwargs)
- if parallel_state.is_pipeline_last_stage():
- outputs = float16_to_fp32(outputs)
- return outputs
-
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """Retrieve state_dict from the module being wrapped."""
- return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
-
- def sharded_state_dict(self, prefix=''):
- """Retrieve state_dict from the module being wrapped.
-
- When using distributed checkpointing, keep_vars must always be set to True.
- """
- return self.module.sharded_state_dict(prefix=prefix)
-
- def load_state_dict(self, state_dict, strict=True):
- self.module.load_state_dict(state_dict, strict=strict)
diff --git a/megatron/core/transformer/spec_utils.py b/megatron/core/transformer/spec_utils.py
deleted file mode 100644
index 473933e45297903a76f539db0e1c5990ff2a946d..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/spec_utils.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import types
-from dataclasses import dataclass, field
-from typing import Tuple, Union
-
-import torch
-
-
-@dataclass
-class ModuleSpec:
- """This is a Module Specification dataclass.
-
- Specification defines the location of the module (to import dynamically)
- or the imported module itself. It also defines the params that need to be
- passed to initialize the module.
-
- Args:
- module (Union[Tuple, type]): A tuple describing the location of the
- module class e.g. `(module.location, ModuleClass)` or the imported
- module class itself e.g. `ModuleClass` (which is already imported
- using `from module.location import ModuleClass`).
- params (dict): A dictionary of params that need to be passed while init.
-
- """
-
- module: Union[Tuple, type]
- params: dict = field(default_factory=lambda: {})
- submodules: type = None
-
-
-def import_module(module_path: Tuple[str]):
- """Import a named object from a module in the context of this function.
-
- TODO: make this importer module more robust, at least make sure there
- are no side effects of using this as is
- """
- base_path, name = module_path
- try:
- module = __import__(base_path, globals(), locals(), [name])
- except ImportError as e:
- print(f"couldn't import module due to {e}")
- return None
- return vars(module)[name]
-
-
-def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs):
- # If a module clas is already provided return it as is
- if isinstance(spec_or_module, (type, types.FunctionType)):
- return spec_or_module
-
- # If the module is provided instead of module path, then return it as is
- if isinstance(spec_or_module.module, (type, types.FunctionType)):
- return spec_or_module.module
-
- # Otherwise, return the dynamically imported module from the module path
- return import_module(spec_or_module.module)
-
-
-def build_module(spec_or_module: Union[ModuleSpec, type], *args, **kwargs):
- # If the passed `spec_or_module` is
- # a `Function`, then return it as it is
- # NOTE: to support an already initialized module add the following condition
- # `or isinstance(spec_or_module, torch.nn.Module)` to the following if check
- if isinstance(spec_or_module, types.FunctionType):
- return spec_or_module
-
- # If the passed `spec_or_module` is actually a spec (instance of
- # `ModuleSpec`) and it specifies a `Function` using its `module`
- # field, return the `Function` as it is
- if isinstance(spec_or_module, ModuleSpec) and isinstance(
- spec_or_module.module, types.FunctionType
- ):
- return spec_or_module.module
-
- # Check if a module class is provided as a spec or if the module path
- # itself is a class
- if isinstance(spec_or_module, type):
- module = spec_or_module
- elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
- module = spec_or_module.module
- else:
- # Otherwise, dynamically import the module from the module path
- module = import_module(spec_or_module.module)
-
- # If the imported module is actually a `Function` return it as it is
- if isinstance(module, types.FunctionType):
- return module
-
- # Finally return the initialized module with params from the spec as well
- # as those passed as **kwargs from the code
-
- # Add the `submodules` argument to the module init call if it exists in the
- # spec.
- if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
- kwargs["submodules"] = spec_or_module.submodules
-
- try:
- return module(
- *args, **spec_or_module.params if hasattr(spec_or_module, "params") else {}, **kwargs
- )
- except Exception as e:
- # improve the error message since we hide the module name in the line above
- import sys
-
- tb = sys.exc_info()[2]
- raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
- sys.exc_info()[2]
- )
diff --git a/megatron/core/transformer/switch_mlp.py b/megatron/core/transformer/switch_mlp.py
deleted file mode 100644
index 092c6c6402bad1bf1689b18f5c2aeed53930b17b..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/switch_mlp.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-from megatron.core import parallel_state, tensor_parallel
-from megatron.core.parallel_state import (
- get_tensor_and_expert_parallel_group,
- get_tensor_model_parallel_group,
-)
-from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.transformer_config import TransformerConfig
-
-from .mlp import MLP, MLPSubmodules
-
-
-def sinkhorn(cost, tol=0.0001):
- "Sinkhorn based MoE routing function"
- cost = torch.exp(cost)
- d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
- d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
-
- eps = 0.00000001
- error = 1e9
- d1_old = d1
- while error > tol:
- d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
- d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
- error = torch.mean(torch.abs(d1_old - d1))
- d1_old = d1
- return d1 * cost * d0.unsqueeze(1)
-
-
-def get_router_linear_layer(config):
- router = torch.nn.Linear(config.hidden_size, config.num_moe_experts, bias=False)
- with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
- config.init_method(router.weight)
- setattr(router.weight, 'sequence_parallel', config.sequence_parallel)
- return router
-
-
-class SwitchMLP(MegatronModule):
- """
- Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts"
- Curently supports Sinkhorn based expert routing.
- """
-
- def __init__(self, config: TransformerConfig, submodules: MLPSubmodules):
- super().__init__(config=config)
-
- self.config: TransformerConfig = config
-
- self.router = get_router_linear_layer(self.config)
- self.add_bias = config.add_bias_linear
- self.sequence_parallel = config.sequence_parallel
- self.route_algo = sinkhorn
- self.router_activation = torch.sigmoid
- self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
-
- assert self.config.num_moe_experts % self.expert_parallel_size == 0
- self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
- local_expert_indices_offset = (
- parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
- )
- self.local_expert_indices = [
- local_expert_indices_offset + i for i in range(self.num_local_experts)
- ]
-
- self.local_experts = torch.nn.ModuleList()
- for _ in range(self.num_local_experts):
- expert = MLP(self.config, submodules, is_expert=True)
- self.local_experts.append(expert)
-
- def gather_indices(self, local_indices):
- """ Gather tensors and concatenate along the first dimension."""
- group = get_tensor_and_expert_parallel_group()
- world_size = torch.distributed.get_world_size(group=group)
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return local_indices
-
- dim_size = list(local_indices.size())
- dim_size[0] = dim_size[0] * world_size
-
- # TODO pre allocate memory
- output = torch.empty(
- dim_size, dtype=local_indices.dtype, device=torch.cuda.current_device()
- )
- torch.distributed._all_gather_base(output, local_indices.contiguous(), group=group)
- return output
-
- def forward(self, hidden_states):
- hidden_shape = hidden_states.shape
- route = self.router(hidden_states)
- route = route.view(-1, self.config.num_moe_experts)
-
- if self.training:
- with torch.no_grad():
- norm_route = self.route_algo(
- route.detach().to(dtype=torch.float32)
- ) # explicit fp32 conversion for stability
- _, max_ind = torch.max(norm_route, dim=1)
- route = self.router_activation(route)
- max_prob = route[torch.arange(route.size(0)), max_ind]
- else:
- route = self.router_activation(route)
- max_prob, max_ind = torch.max(route, dim=1)
-
- max_prob = torch.unsqueeze(max_prob, 1)
- hidden_states = hidden_states.view(-1, hidden_shape[-1])
-
- if self.sequence_parallel or (self.expert_parallel_size > 1):
- global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
- hidden_states
- )
- global_indices = self.gather_indices(max_ind)
- else:
- global_hidden_states = hidden_states
- global_indices = max_ind
-
- output_total = torch.zeros_like(global_hidden_states)
- if self.add_bias:
- output_bias_total = torch.zeros_like(global_hidden_states)
-
- for expert_num, expert in enumerate(self.local_experts):
- local_expert_index = self.local_expert_indices[expert_num]
- local_indices = (global_indices == local_expert_index).nonzero()
- hidden = global_hidden_states[local_indices, :]
- output, output_bias = expert(hidden)
-
- output_total[local_indices, :] = output
- if self.add_bias:
- output_bias = output_bias.expand_as(output)
- output_bias_total[local_indices, :] = output_bias
-
- if self.sequence_parallel or (self.expert_parallel_size > 1):
- output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
- output_total
- )
- if self.add_bias:
- output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
- output_bias_total
- )
- # bias is duplicated across tensor parallelism ranks;
- # reduce scatter reduces bias across tensor parallel_ranks
- output_bias_total = (
- output_bias_total / parallel_state.get_tensor_model_parallel_world_size()
- )
-
- output_total = output_total * max_prob
- output_total = output_total.view(hidden_shape)
- if self.add_bias:
- output_bias_total = output_bias_total * max_prob
- output_bias_total = output_bias_total.view(hidden_shape)
- else:
- output_bias_total = None
-
- return output_total, output_bias_total
diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py
deleted file mode 100644
index 74bf29c859dba98cd6c0ce6aac1444d6d53f81f6..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/transformer_block.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import re
-from contextlib import nullcontext
-from dataclasses import dataclass
-from typing import List, Union
-
-import torch
-from torch import Tensor
-
-from megatron.core import InferenceParams, parallel_state, tensor_parallel
-from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
-from megatron.core.transformer.custom_layers.transformer_engine import TENorm
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec, build_module
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.transformer.transformer_layer import TransformerLayer
-from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
-
-
-def get_num_layers_to_build(config: TransformerConfig) -> int:
-
- num_layers_per_pipeline_rank = (
- config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
- )
-
- if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
- # Interleaved pipeline parallelism:
- # Number of layers in each model chunk is the number of layers in the stage,
- # divided by the number of model chunks in a stage.
- # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0] [2] [4] [6]
- # Stage 1: [1] [3] [5] [7]
- # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0, 1] [4, 5]
- # Stage 1: [2, 3] [6, 7]
-
- vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
-
- num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
-
- num_layers_to_build = num_layers_per_virtual_rank
-
- else:
- # Non-interleaved pipeline parallelism:
- # Each stage gets a contiguous set of layers.
-
- num_layers_to_build = num_layers_per_pipeline_rank
-
- return num_layers_to_build
-
-
-@dataclass
-class TransformerBlockSubmodules:
- layer_specs: List[ModuleSpec] = None
-
-
-def _get_block_submodules(
- config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec],
-) -> TransformerBlockSubmodules:
-
- # Transformer block submodules.
- if isinstance(spec, TransformerBlockSubmodules):
- return spec
-
- # ModuleSpec here is generally assumed to be for a transformer layer.
- elif isinstance(spec, ModuleSpec):
- if issubclass(spec.module, TransformerBlock):
- return spec.submodules
- elif issubclass(spec.module, TransformerLayer):
- num_layers = get_num_layers_to_build(config)
- return TransformerBlockSubmodules(layer_specs=[spec] * num_layers)
- else:
- raise Exception(f"specialize for {spec.module.__name__}.")
- else:
- raise Exception(f"specialize for {type(spec).__name__}.")
-
-
-class TransformerBlock(MegatronModule):
- """Transformer class."""
-
- def __init__(
- self,
- config: TransformerConfig,
- spec: Union[TransformerBlockSubmodules, ModuleSpec],
- post_layer_norm: bool = True,
- pre_process: bool = True,
- post_process: bool = True,
- ):
- super().__init__(config=config)
-
- self.submodules = _get_block_submodules(config, spec)
- self.post_layer_norm = post_layer_norm
- self.pre_process = pre_process
- self.post_process = post_process
-
- # required for pipeline parallel schedules
- self.input_tensor = None
-
- self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
-
- self._build_layers()
- self.num_layers_per_pipeline_rank = len(self.layers)
-
- def _build_layers(self):
- # Transformer layers.
- # @jcasper can we improve how we deal with layer_number?
- # currently it's only used in CoreAttention?
- # if self.apply_query_key_layer_scaling:
- # coeff = self.layer_number
- # self.norm_factor *= coeff
- def build_layer(layer_spec, layer_number):
- return build_module(layer_spec, config=self.config, layer_number=layer_number,)
-
- # offset is implicit in TransformerLayer
- self.layers = torch.nn.ModuleList(
- [
- build_layer(layer_spec, i + 1)
- for i, layer_spec in enumerate(self.submodules.layer_specs)
- ]
- )
-
- # # TODO: add back standalone_embedding_stage
- # if self.num_layers == 0:
- # # When a standalone embedding stage is used (e.g.,
- # # args.standalone_embedding_stage == True), virtual pipeline ranks
- # # on pipeline rank 0 will have zero transformer layers assigned to
- # # them. This results in the model's input and output tensors to be
- # # the same, which will cause failure for certain output tensor
- # # optimizations (e.g., pipeline output deallocation). To remedy
- # # this, we assign a 'no-op' layer on these ranks, which will
- # # disconnect the input tensor from the output tensor.
- # self.num_layers = 1
- # self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
- # else:
- # self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
-
- if self.post_process and self.post_layer_norm:
- # Final layer norm before output.
- self.final_layernorm = TENorm(
- config=self.config,
- hidden_size=self.config.hidden_size,
- eps=self.config.layernorm_epsilon,
- )
-
- def _get_layer(self, layer_number: int):
- return self.layers[layer_number]
-
- def _checkpointed_forward(
- self,
- hidden_states: Tensor,
- attention_mask: Tensor,
- context: Tensor,
- context_mask: Tensor,
- rotary_pos_emb: Tensor,
- ):
- """Forward method with activation checkpointing."""
-
- def custom(start: int, end: int):
- def custom_forward(
- hidden_states, attention_mask, context, context_mask, rotary_pos_emb,
- ):
- for index in range(start, end):
- layer = self._get_layer(index)
- hidden_states, context = layer(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- context=context,
- context_mask=context_mask,
- rotary_pos_emb=rotary_pos_emb,
- inference_params=None,
- )
- return hidden_states, context
-
- return custom_forward
-
- if self.config.recompute_method == 'uniform':
- # Uniformly divide the total number of Transformer layers and checkpoint
- # the input activation of each divided chunk.
- # A method to further reduce memory usage reducing checkpoints.
- l = 0
- while l < self.num_layers_per_pipeline_rank:
- hidden_states, context = tensor_parallel.checkpoint(
- custom(l, l + self.config.recompute_num_layers),
- self.config.distribute_saved_activations,
- hidden_states,
- attention_mask,
- context,
- context_mask,
- rotary_pos_emb,
- )
-
- l += self.config.recompute_num_layers
-
- elif self.config.recompute_method == 'block':
- # Checkpoint the input activation of only a set number of individual
- # Transformer layers and skip the rest.
- # A method fully use the device memory removing redundant re-computation.
- for l in range(self.num_layers_per_pipeline_rank):
- if l < self.config.recompute_num_layers:
- hidden_states, context = tensor_parallel.checkpoint(
- custom(l, l + 1),
- self.config.distribute_saved_activations,
- hidden_states,
- attention_mask,
- context,
- context_mask,
- rotary_pos_emb,
- )
- else:
- hidden_states, context = custom(l, l + 1)(
- hidden_states, attention_mask, context, context_mask, rotary_pos_emb,
- )
- else:
- raise ValueError("Invalid activation recompute method.")
-
- return hidden_states
-
- def set_input_tensor(self, input_tensor: Tensor):
- """Set input tensor to be used instead of forward()'s input.
-
- When doing pipeline parallelism the input from the previous
- stage comes from communication, not from the input, so the
- model's forward_step_func won't have it. This function is thus
- used by internal code to bypass the input provided by the
- forward_step_func"""
- self.input_tensor = input_tensor
-
- def forward(
- self,
- hidden_states: Tensor,
- attention_mask: Tensor,
- context: Tensor = None,
- context_mask: Tensor = None,
- rotary_pos_emb: Tensor = None,
- inference_params: InferenceParams = None,
- ):
- # hidden_states (float): [s, b, h]
- # attention_mask (bool): [1, 1, s, s]
-
- if not self.pre_process:
- # See set_input_tensor()
- hidden_states = self.input_tensor
-
- # Viewless tensor.
- # - We only need to create a viewless tensor in the case of micro batch
- # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
- # above creates a view tensor, and '.contiguous()' is a pass-through.
- # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
- # the need to make it viewless.
- #
- # However, we don't explicitly check mbs == 1 here because
- # make_viewless_tensor() has negligible overhead when its input
- # is already viewless.
- #
- # - For the 'else' case above, calling make_viewless_tensor() here is
- # likely redundant, since p2p_communication.py (likely originator)
- # already creates viewless tensors. That said, make_viewless_tensor()
- # is called here to be future-proof and corner-case-proof.
- hidden_states = make_viewless_tensor(
- inp=hidden_states, requires_grad=True, keep_graph=True,
- )
-
- if self.config.sequence_parallel:
- rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
- else:
- rng_context = nullcontext()
-
- if self.config.fp8:
- import transformer_engine # To keep out TE dependency when not training in fp8
-
- if self.config.fp8 == "e4m3":
- fp8_format = transformer_engine.common.recipe.Format.E4M3
- elif self.config.fp8 == "hybrid":
- fp8_format = transformer_engine.common.recipe.Format.HYBRID
- else:
- raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
-
- fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
- margin=self.config.fp8_margin,
- interval=self.config.fp8_interval,
- fp8_format=fp8_format,
- amax_compute_algo=self.config.fp8_amax_compute_algo,
- amax_history_len=self.config.fp8_amax_history_len,
- override_linear_precision=(False, False, not self.config.fp8_wgrad),
- )
- fp8_group = None
- if parallel_state.model_parallel_is_initialized():
- fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
- fp8_context = transformer_engine.pytorch.fp8_autocast(
- enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
- )
- else:
- fp8_context = nullcontext()
-
- with rng_context and fp8_context:
- # Forward pass.
- if self.config.recompute_granularity == 'full':
- hidden_states = self._checkpointed_forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- context=context,
- context_mask=context_mask,
- rotary_pos_emb=rotary_pos_emb,
- )
- else:
- for layer in self.layers:
- hidden_states, context = layer(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- context=context,
- context_mask=context_mask,
- rotary_pos_emb=rotary_pos_emb,
- inference_params=inference_params,
- )
-
- # Final layer norm.
- if self.post_process and self.post_layer_norm:
- hidden_states = self.final_layernorm(hidden_states)
-
- return hidden_states
-
- def sharded_state_dict(self, prefix: str = ''):
-
- sharded_state_dict = {}
-
- layer_prefix = f'{prefix}layers.'
- for layer in self.layers:
- sharded_state_dict.update(layer.sharded_state_dict(prefix=layer_prefix))
-
- if self.post_process and self.post_layer_norm:
- state_dict = self.state_dict(keep_vars=True)
-
- tensor = state_dict['final_layernorm.weight']
- layer_name = f'{prefix}final_layernorm.weight'
- sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(tensor, layer_name)
-
- # RMSNorm doesn't have bias.
- if 'final_layernorm.bias' in state_dict.keys():
- tensor = state_dict['final_layernorm.bias']
- layer_name = f'{prefix}final_layernorm.bias'
- sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(
- tensor, layer_name
- )
-
- return sharded_state_dict
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
deleted file mode 100644
index adccd4409bd148cb394683ad3d8e8f69a1c43064..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/transformer_config.py
+++ /dev/null
@@ -1,288 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import types
-from dataclasses import dataclass
-from typing import Callable
-
-import torch
-import torch.nn.functional as F
-
-from ..model_parallel_config import ModelParallelConfig
-from ..utils import init_method_normal, scaled_init_method_normal
-
-
-@dataclass
-class TransformerConfig(ModelParallelConfig):
- """Configuration object for megatron-core transformers.
-
- Attributes:
-
- # model architecture
- num_layers (int): Number of transformer layers in a transformer block.
- hidden_size (int): Transformer hidden size.
- ffn_hidden_size (int): Transformer Feed-Forward Network hidden size.
- This is set to 4*hidden_size if not provided. Defaults to None.')
- num_attention_heads (int): Number of transformer attention heads.
- kv_channels (int): Projection weights dimension in multi-head attention.
- This is set to hidden_size // num_attention_heads if not provided.
- Defaults to None.
- num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used.
-
- hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1.
- attention_dropout (float): Post attention dropout probability. Defaults to 0.1.
- fp32_residual_connection (bool): If true, move residual connections to fp32.
- apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering.
- Defaults to False.
- layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5.
-
- layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values
- around 0. This improves numerical stability. Defaults to False.
-
- add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two
- in MLP layer). Default is True.
-
- gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False.
-
- activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
-
- num_moe_experts (int): Number of experts to use for Mixture of Experts.
- When set, it replaces MLP with Switch MLP. Defaults to None (no MoE).
-
- # initialization
- init_method (Callable): Method to initialize weights. Note that bias is always set to
- zero. Should be a function that takes a single Tensor and
- initializes it. Defaults to
- megatron.core.utils.init_method_normal(init_method_std) which is
- torch.nn.init.normal_ with mean=0.0 and std=init_method_Std.
-
- output_layer_init_method (Callable): Method to initialize weights of the output layer of
- both attention and MLP blocks. Defaults to
- megatron.core.utils.scaled_init_method_normal(init_method_std)
- which is torch.nn.init.normal_ with mean=0.0 and
- std=init_method_std / math.sqrt(2.0 * num_layers).
-
- init_method_std (float): Standard deviation of the zero mean normal for the default
- initialization method, not used if init_method and
- output_layer_init_method are provided. Defaults to 0.02.
-
- # mixed-precision
- apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True.
- attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32.
- This should be true if apply_query_key_layer_scaling is true.
-
- # fusion
- bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False.
- masked_softmax_fusion (bool): If true, uses softmax fusion.
- persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel.
- This kernel only supports a fixed set of hidden sizes.
- Defaults to False.
- bias_dropout_fusion (bool): If true, uses bias dropout fusion.
-
- # activation recomputation
-
- recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory
- intensive part of attention is checkpointed. These memory intensive activations
- are also less compute intensive which makes activation checkpointing more efficient
- for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer
- Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint
- the entire transformer layer. Must be 'selective' or 'full'. 'selective' always uses all layers.
- Defaults to None.
-
- recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer
- block and recompute the input activation of each divided chunk at the specified
- granularity. block will recompute the input activations for only a set number of
- transformer layers per pipeline stage. The rest of the layers in the pipeline stage
- will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to
- None.
-
- recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer
- layers in each uniformly divided recompute unit. When recompute_method is block,
- recompute_num_layers is the number of transformer layers to recompute within each
- pipeline stage. Must be None for 'selective' activation checkpointing. Defaults to None.
-
- distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel
- group. Defaults to None.
-
- # fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at
- # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html
-
- fp8 (str): If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) 'e4m3'
- uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and
- e5m2 for all FP8 output activation gradient tensors. Defaults to None.
-
- fp8_margin (int): Margin for the scaling factor computation.
-
- fp8_interval (int): Controls how often the scaling factor is recomputed.
-
- fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation.
-
- fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation.
- There are 2 predefined choices: `max` chooses the largest `amax` in the history
- window, while `most_recent` always chooses the most recently seen value.
-
- fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision.
- Defaults to True.
-
- # Miscellaneous
- clone_scatter_output_in_embedding (bool): When set to true, clone the output of scatter_to_sequence_parallel_region
- in embedding layer to facilitate garbage collection of input.
-
- # Experimental
- normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily
- used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`.
-
-
- """
-
- # model architecture
- num_layers: int = 0
- hidden_size: int = 0
- num_attention_heads: int = 0
- num_query_groups: int = None
-
- ffn_hidden_size: int = None
- kv_channels: int = None
- hidden_dropout: float = 0.1
- attention_dropout: float = 0.1
- fp32_residual_connection: bool = False
- # @jcasper should we keep this option?
- apply_residual_connection_post_layernorm: bool = False
- layernorm_epsilon: float = 1e-5
- layernorm_zero_centered_gamma: bool = False
- add_bias_linear: bool = True
- gated_linear_unit: bool = False
- activation_func: Callable = F.gelu
- num_moe_experts: int = None
-
- # initialization
- init_method: Callable = None
- output_layer_init_method: Callable = None
- init_method_std: float = 0.02
-
- # mixed-precision
- apply_query_key_layer_scaling: bool = False
- attention_softmax_in_fp32: bool = True
-
- # communication
-
- # fusion
- bias_gelu_fusion: bool = False # TODO: this should be bias_activation_fusion ?
- masked_softmax_fusion: bool = False
- persist_layer_norm: bool = False
- bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
-
- # activation recomputation
- recompute_granularity: str = None
- recompute_method: str = None
- recompute_num_layers: int = None
- distribute_saved_activations: bool = None
-
- # fp8 related
- fp8: str = None
- fp8_margin: int = 0
- fp8_interval: int = 1
- fp8_amax_history_len: int = 1
- fp8_amax_compute_algo: str = "most_recent"
- fp8_wgrad: bool = True
-
- # miscellaneous
- clone_scatter_output_in_embedding: bool = True
-
- # experimental section (TODO: move to apt. section above once stable)
- normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm"
-
- def __post_init__(self):
- """ Python dataclass method that is used to modify attributes after initialization.
- See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
- """
- super().__post_init__()
- if self.fp16 and self.bf16:
- raise ValueError(
- f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
- )
-
- if self.num_attention_heads % self.tensor_model_parallel_size != 0:
- raise ValueError(
- f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
- f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
- )
-
- if self.ffn_hidden_size is None:
- self.ffn_hidden_size = 4 * self.hidden_size
-
- if self.kv_channels is None:
- self.kv_channels = self.hidden_size // self.num_attention_heads
-
- if self.num_query_groups is None:
- self.num_query_groups = self.num_attention_heads
-
- if self.num_query_groups % self.tensor_model_parallel_size != 0:
- raise ValueError(
- f"num_query_groups ({self.num_query_groups}) must be a multiple of "
- f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
- )
-
- if self.apply_query_key_layer_scaling:
- self.attention_softmax_in_fp32 = True
-
- if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
- raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')
-
- if self.recompute_granularity is not None:
- if not self.recompute_granularity in ['full', 'selective']:
- raise ValueError(
- f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
- )
-
- if self.recompute_method is not None:
- if not self.recompute_method in ['block', 'uniform']:
- raise ValueError(
- f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
- )
- elif self.recompute_granularity != 'selective':
- raise ValueError(
- f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
- )
-
- if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
- raise ValueError(
- f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
- f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
- )
- elif (
- self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
- ):
- raise ValueError(
- f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
- )
-
- if self.distribute_saved_activations and self.sequence_parallel:
- raise ValueError(
- f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
- )
-
- if self.virtual_pipeline_model_parallel_size is not None:
- if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
- raise ValueError(
- f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
- )
-
- if self.apply_query_key_layer_scaling:
- self.attention_softmax_in_fp32 = True
-
- if self.bias_gelu_fusion:
- if not self.add_bias_linear:
- raise ValueError(
- "When bias_gelu_fusion is True, add_bias_linear must also be True."
- )
-
- if self.activation_func != F.gelu:
- raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
-
- if self.init_method is None:
- self.init_method = init_method_normal(self.init_method_std)
-
- if self.output_layer_init_method is None:
- self.output_layer_init_method = scaled_init_method_normal(
- self.init_method_std, self.num_layers
- )
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
deleted file mode 100644
index b9951d4347005f262f255d1d49c531114e9dc554..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/transformer_layer.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from dataclasses import dataclass
-from typing import Union
-
-import torch
-
-from megatron.core import parallel_state
-from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec, build_module
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import make_viewless_tensor
-
-
-@dataclass
-class TransformerLayerSubmodules:
- input_layernorm: Union[ModuleSpec, type] = IdentityOp
- self_attention: Union[ModuleSpec, type] = IdentityOp
- self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
-
- pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
- cross_attention: Union[ModuleSpec, type] = IdentityOp
- cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
-
- pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
- mlp: Union[ModuleSpec, type] = IdentityOp
- mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
-
-
-class TransformerLayer(MegatronModule):
- """A single transformer layer.
-
- Transformer layer takes input with size [s, b, h] and returns an
- output of the same size.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- submodules: TransformerLayerSubmodules,
- layer_number: int = 1,
- hidden_dropout: float = None,
- ):
- super().__init__(config=config)
-
- self.layer_number = layer_number + self._get_layer_offset()
- self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
-
- ## [Module 1: Input Layernorm] Optional Layernorm on the input data
- # TODO: add pytorch only layernorm
- self.input_layernorm = build_module(
- submodules.input_layernorm,
- config=self.config,
- hidden_size=self.config.hidden_size,
- eps=self.config.layernorm_epsilon,
- )
-
- ## [Module 2: SelfAttention]
- self.self_attention = build_module(
- submodules.self_attention, config=self.config, layer_number=layer_number,
- )
-
- ## [Module 3: BiasDropoutFusion]
- self.self_attn_bda = build_module(submodules.self_attn_bda)
-
- ## [Module 4: Post SelfAttention] Optional Layernorm after self-attn
- self.pre_cross_attn_layernorm = build_module(
- submodules.pre_cross_attn_layernorm,
- config=self.config,
- hidden_size=self.config.hidden_size,
- eps=self.config.layernorm_epsilon,
- )
-
- ## [Module 5: CrossAttention]
- self.cross_attention = build_module(
- submodules.cross_attention, config=self.config, layer_number=layer_number,
- )
-
- ## [Module 6: BiasDropoutFusion]
- self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config,)
-
- ## [Module 7: Pre MLP] Optional Layernorm before MLP
- self.pre_mlp_layernorm = build_module(
- submodules.pre_mlp_layernorm,
- config=self.config,
- hidden_size=self.config.hidden_size,
- eps=self.config.layernorm_epsilon,
- )
-
- ## [Module 8: MLP block]
- # TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
- # where MLP and SwitchMLP both appear alternately?
- self.mlp = build_module(submodules.mlp, config=self.config)
-
- ## [Module 9: BiasDropoutFusion]
- self.mlp_bda = build_module(submodules.mlp_bda)
-
- # @jcasper how should we handle nvfuser?
- # Set bias+dropout+add fusion grad_enable execution handler.
- # TORCH_MAJOR = int(torch.__version__.split('.')[0])
- # TORCH_MINOR = int(torch.__version__.split('.')[1])
- # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
- # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
- self.bias_dropout_add_exec_handler = torch.enable_grad
-
- def _get_layer_offset(self):
-
- pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
-
- num_layers_per_pipeline_rank = (
- self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
- )
-
- if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
- vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
- vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
-
- total_num_layers = self.config.num_layers
- num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
- total_virtual_chunks = total_num_layers // vp_size
- offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
-
- else:
- # Each stage gets a contiguous set of layers.
- if parallel_state.get_pipeline_model_parallel_world_size() > 1:
- offset = pipeline_rank * num_layers_per_pipeline_rank
- else:
- offset = 0
-
- return offset
-
- def forward(
- self,
- hidden_states,
- attention_mask,
- context=None,
- context_mask=None,
- rotary_pos_emb=None,
- inference_params=None,
- ):
- # hidden_states: [s, b, h]
-
- # Residual connection.
- residual = hidden_states
-
- # Optional Input Layer norm
- input_layernorm_output = self.input_layernorm(hidden_states)
-
- # Self attention.
- attention_output_with_bias = self.self_attention(
- input_layernorm_output,
- attention_mask=attention_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb,
- )
-
- # TODO: could we move `bias_dropout_add_exec_handler` itself
- # inside the module provided in the `bias_dropout_add_spec` module?
- with self.bias_dropout_add_exec_handler():
- hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
- attention_output_with_bias, residual, self.hidden_dropout
- )
-
- # Residual connection.
- residual = hidden_states
-
- # Optional Layer norm after self-attention
- pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
-
- # Cross attention.
- attention_output_with_bias = self.cross_attention(
- pre_cross_attn_layernorm_output,
- attention_mask=context_mask,
- key_value_states=context,
- inference_params=inference_params,
- )
-
- if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
- context = attention_output_with_bias["context"]
-
- # TODO: could we move `bias_dropout_add_exec_handler` itself
- # inside the module provided in the `bias_dropout_add_spec` module?
- with self.bias_dropout_add_exec_handler():
- hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
- attention_output_with_bias, residual, self.hidden_dropout
- )
-
- # Residual connection.
- residual = hidden_states
-
- # Optional Layer norm post the cross-attention.
- pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
-
- # MLP.
- mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
-
- # TODO: could we move `bias_dropout_add_exec_handler` itself
- # inside the module provided in the `bias_dropout_add_spec` module?
- with self.bias_dropout_add_exec_handler():
- hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
- mlp_output_with_bias, residual, self.hidden_dropout
- )
-
- # Jit compiled function creates 'view' tensor. This tensor
- # potentially gets saved in the MPU checkpoint function context,
- # which rejects view tensors. While making a viewless tensor here
- # won't result in memory savings (like the data loader, or
- # p2p_communication), it serves to document the origin of this
- # 'view' tensor.
- output = make_viewless_tensor(
- inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
- )
-
- return output, context
-
- def sharded_state_dict(self, prefix=''):
- offset = self._get_layer_offset()
- num_layers = self.config.num_layers
-
- global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1
- state_dict_prefix = (
- f'{prefix}{global_layer_offset - offset}.' # module list index in TransformerBlock
- )
- sharded_pp_offset = [
- (0, global_layer_offset, num_layers)
- ] # PP sharding offset for ShardedTensors
-
- attn_state_dict = self.self_attention.sharded_state_dict(
- prefix=f'{state_dict_prefix}self_attention.',
- sharded_key_prefix=f'{prefix}self_attention.',
- sharded_offsets=sharded_pp_offset,
- )
-
- mlp_state_dict = self.mlp.sharded_state_dict(
- prefix=f'{state_dict_prefix}mlp.',
- sharded_key_prefix=f'{prefix}mlp.',
- sharded_offsets=sharded_pp_offset,
- )
-
- sharded_state_dict = {**mlp_state_dict, **attn_state_dict}
-
- return sharded_state_dict
diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py
deleted file mode 100644
index d7d002734f7f8c9fc93767604c40cf76a4fb82d4..0000000000000000000000000000000000000000
--- a/megatron/core/transformer/utils.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Utilities for transformer layers."""
-from operator import itemgetter
-from typing import Any, Dict, Iterable, Optional, Tuple, Union
-
-import torch
-
-from megatron.core import parallel_state
-from megatron.core.dist_checkpointing.mapping import ShardedObject, StateDict
-from megatron.core.utils import (
- make_sharded_tensor_for_checkpoint,
- make_tp_sharded_tensor_for_checkpoint,
-)
-
-
-def get_linear_layer(rows, columns, init_method, perform_initialization=True):
- """Simple linear layer with weight initialization."""
- layer = torch.nn.Linear(rows, columns)
- if perform_initialization: # Take from modelparallel config
- init_method(layer.weight)
- with torch.no_grad():
- layer.bias.zero_()
- return layer
-
-
-def attention_mask_func(attention_scores, attention_mask):
- attention_scores.masked_fill_(attention_mask, -10000.0)
- return attention_scores
-
-
-@torch.jit.script
-def gelu_impl(x):
- """OpenAI's gelu implementation."""
- return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
-
-
-def openai_gelu(x):
- return gelu_impl(x)
-
-
-# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
-@torch.jit.script
-def erf_gelu(x):
- return (
- x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
- )
-
-
-def make_sharded_tensors_for_checkpoint(
- state_dict: StateDict,
- state_dict_prefix: str,
- sharded_key_prefix: Optional[str] = None,
- tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None,
- sharded_offsets: Iterable[Tuple[int, int, int]] = (),
- extra_state_suffix: str = '_extra_state',
-):
- """Wraps tensors from transformer layers with ShardedTensor or ShardedObject.
-
- For a given `state_dict`, wraps:
- - all _extra_states with ShardedObject
- - all tensors specified in tensor_parallel_layers_axis_map with TP and DP sharded ShardedTensor
- - other values with DP sharded ShardedTensor
-
- Args:
- state_dict (StateDict): state_dict to convert
- state_dict_prefix (str): prefix appended to keys in final state dict
- sharded_key_prefix (str, optional): prefix appended to ShardedTensor keys
- tensor_parallel_layers_axis_map (Dict[str, int], optional): dict mapping layer
- names to the axis for TP sharding
- sharded_offsets (Iterable[Tuple[int, int, int]], optional): sharding already
- applied (e.g. PP related), passed along to ShardedTensor
- extra_state_suffix (str, default = '_extra_state'): layers with this
- suffix will be wrapped with ShardedObject instead of ShardedTensor.
-
- """
- if sharded_key_prefix is None:
- sharded_key_prefix = state_dict_prefix
-
- if tensor_parallel_layers_axis_map is None:
- tensor_parallel_layers_axis_map = {}
-
- sharded_state_dict = {}
- for layer_name in state_dict.keys():
- tensor = state_dict[layer_name]
- layer_key = f'{state_dict_prefix}{layer_name}'
- sharded_key = f'{sharded_key_prefix}{layer_name}'
-
- if layer_name.endswith(extra_state_suffix):
- sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint(
- tensor, sharded_key, sharded_offsets
- )
-
- elif layer_name in tensor_parallel_layers_axis_map:
- tp_axis = tensor_parallel_layers_axis_map[layer_name]
- sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint(
- tensor, sharded_key, tp_axis, prepend_offsets=sharded_offsets,
- )
-
- else:
- sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
- tensor, sharded_key, prepend_offsets=sharded_offsets,
- )
-
- return sharded_state_dict
-
-
-def make_sharded_object_for_checkpoint(
- obj: Any,
- key: str,
- sharded_offsets: Iterable[Tuple[int, int, int]] = (),
- replica_id: Union[None, int, Tuple[int, ...]] = None,
- **kwargs,
-):
- """ Helper for instantiating a non-sharded ShardedObject (replicated across TP and DP group).
-
- Arguments:
- obj (object): any object to be sharded
- key (str): unique identifier of the object
- sharded_offsets (Iterable[Tuple[int, int, int]]): offsets normally
- prepended to ShardedTensors, will be used as global offsets for
- ShardedObject
- replica_id (Union[None, int, Tuple[int, ...]]): replica id
- """
- if replica_id is None:
- replica_id = (
- 0,
- parallel_state.get_tensor_model_parallel_rank(),
- parallel_state.get_data_parallel_rank(),
- )
-
- return ShardedObject(key, obj, *_get_extra_state_offsets(sharded_offsets), replica_id, **kwargs)
-
-
-def _get_extra_state_offsets(
- sharded_offsets: Iterable[Tuple[int, int, int]]
-) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
- """ Turns ShardedTensor offsets into offsets suitable for ShardedObject. """
- if sharded_offsets:
- sharded_offsets = sorted(sharded_offsets, key=itemgetter(0)) # sort by axis
- axis, extra_state_offset, extra_state_shape = zip(*sharded_offsets)
- assert list(axis) == list(
- range(len(axis))
- ), f'Expected contiguous axis for offsets: {sharded_offsets}'
- else:
- extra_state_shape = (1,)
- extra_state_offset = (0,)
- return extra_state_shape, extra_state_offset
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
deleted file mode 100644
index d4e042b2d403022b2728ac4b76a7abadc3ec8117..0000000000000000000000000000000000000000
--- a/megatron/core/utils.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Utility functions used throughout Megatron core"""
-import math
-import operator
-from functools import reduce
-
-import torch
-
-from megatron.core import parallel_state
-from megatron.core.dist_checkpointing.mapping import ShardedTensor
-
-
-def ensure_divisibility(numerator, denominator):
- """Ensure that numerator is divisible by the denominator."""
- assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
-
-
-def divide(numerator, denominator):
- """Ensure that numerator is divisible by the denominator and return
- the division value."""
- ensure_divisibility(numerator, denominator)
- return numerator // denominator
-
-
-def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
- """Get an attribute from a wrapped model.
- If return_model_obj is true, return the object that has the 'attr' attribute;
- otherwise, return the attribute directly."""
- if isinstance(model, list):
- raise RuntimeError("_get_attr_wrapped_model given a list of models")
-
- if allow_none:
-
- def condition(model, attr):
- return not hasattr(model, attr)
-
- else:
-
- def condition(model, attr):
- return getattr(model, attr, None) is None
-
- while condition(model, attr):
- if not hasattr(model, "module"):
- raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
-
- model = model.module
-
- if return_model_obj:
- return model
- return getattr(model, attr)
-
-
-def get_model_type(model):
- return get_attr_wrapped_model(model, 'model_type')
-
-
-def get_model_config(model):
- return get_attr_wrapped_model(model, 'config', allow_none=False)
-
-
-class GlobalMemoryBuffer:
- """Global buffer to avoid dynamic memory allocations.
- Caller should ensure that buffers of the same name
- are not used concurrently."""
-
- def __init__(self):
- self.buffer = {}
-
- def get_tensor(self, tensor_shape, dtype, name):
- required_len = reduce(operator.mul, tensor_shape, 1)
- if (
- self.buffer.get((name, dtype), None) is None
- or self.buffer[(name, dtype)].numel() < required_len
- ):
- self.buffer[(name, dtype)] = torch.empty(
- required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
- )
-
- return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
-
-
-def _kernel_make_viewless_tensor(inp, requires_grad):
- '''Make a viewless tensor.
-
- View tensors have the undesirable side-affect of retaining a reference
- to the originally-viewed tensor, even after manually setting the '.data'
- field. This method creates a new tensor that links to the old tensor's
- data, without linking the viewed tensor, referenced via the '._base'
- field.
- '''
- out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
- out.data = inp.data
- return out
-
-
-class MakeViewlessTensor(torch.autograd.Function):
- '''
- Autograd function to make a viewless tensor.
-
- This function should be used in cases where the computation graph needs
- to be propagated, but we only want a viewless tensor (e.g.,
- ParallelTransformer's hidden_states). Call this function by passing
- 'keep_graph = True' to 'make_viewless_tensor()'.
- '''
-
- @staticmethod
- def forward(ctx, inp, requires_grad):
- return _kernel_make_viewless_tensor(inp, requires_grad)
-
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output, None
-
-
-def make_viewless_tensor(inp, requires_grad, keep_graph):
- '''
- Entry-point for creating viewless tensors.
-
- This method should be used, rather than calling 'MakeViewlessTensor'
- or '_kernel_make_viewless_tensor' directly. This method acts as a
- switch for determining if an autograd function or a regular method
- should be used to create the tensor.
- '''
-
- # return tensor as-is, if not a 'view'
- if inp._base is None:
- return inp
-
- # create viewless tensor
- if keep_graph:
- return MakeViewlessTensor.apply(inp, requires_grad)
- else:
- return _kernel_make_viewless_tensor(inp, requires_grad)
-
-
-def assert_viewless_tensor(tensor, extra_msg=None):
- '''Assert that a tensor is not a view (i.e., its '._base' field is
- not set).'''
- if isinstance(tensor, list):
- [assert_viewless_tensor(t) for t in tensor]
- return tensor
- if not isinstance(tensor, torch.Tensor):
- return tensor
- assert tensor._base is None, (
- "Ensure tensor._base is None before setting tensor.data or storing "
- "tensor to memory buffer. Otherwise, a memory leak will occur (and "
- "likely accumulate over iterations). %s"
- ) % extra_msg
- return tensor
-
-
-def safely_set_viewless_tensor_data(tensor, new_data_tensor):
- '''Safely set tensor's '.data' field.
-
- Check first that the tensor is viewless (i.e., '._base' not set). If not,
- raise an exception.
- '''
- assert_viewless_tensor(
- tensor,
- extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
- % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
- )
- tensor.data = new_data_tensor
-
-
-def init_method_normal(sigma):
- """Init method based on N(0, sigma)."""
-
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
-
- return init_
-
-
-def scaled_init_method_normal(sigma, num_layers):
- """Init method based on N(0, sigma/sqrt(2*num_layers)."""
- std = sigma / math.sqrt(2.0 * num_layers)
-
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=std)
-
- return init_
-
-
-def make_tp_sharded_tensor_for_checkpoint(
- tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
-):
- """ Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group.
-
- Optionally, can provide offsets which prepend new dimensions to the tensor.
- """
-
- prepend_axis_num = len(prepend_offsets)
-
- if replica_id is None:
- replica_id = (0, 0, parallel_state.get_data_parallel_rank())
-
- return ShardedTensor.from_rank_offsets(
- key,
- tensor,
- *prepend_offsets,
- (
- tp_axis + prepend_axis_num,
- parallel_state.get_tensor_model_parallel_rank(),
- parallel_state.get_tensor_model_parallel_world_size(),
- ),
- replica_id=replica_id,
- prepend_axis_num=prepend_axis_num,
- **kwargs,
- )
-
-
-def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs):
- """ Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group).
-
- Optionally, can provide offsets which prepend new dimensions to the tensor.
- """
-
- prepend_axis_num = len(prepend_offsets)
-
- if replica_id is None:
- replica_id = (
- 0,
- parallel_state.get_tensor_model_parallel_rank(),
- parallel_state.get_data_parallel_rank(),
- )
-
- return ShardedTensor.from_rank_offsets(
- key,
- tensor,
- *prepend_offsets,
- replica_id=replica_id,
- prepend_axis_num=prepend_axis_num,
- **kwargs,
- )
diff --git a/megatron/data/__init__.py b/megatron/data/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/data/autoaugment.py b/megatron/data/autoaugment.py
deleted file mode 100644
index 7f988c5f0411707a8988e63898a49fabb932fbb5..0000000000000000000000000000000000000000
--- a/megatron/data/autoaugment.py
+++ /dev/null
@@ -1,320 +0,0 @@
-"""AutoAugment data augmentation policy for ImageNet.
-
--- Begin license text.
-
-MIT License
-
-Copyright (c) 2018 Philip Popien
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
--- End license text.
-
-Code adapted from https://github.com/DeepVoltaire/AutoAugment.
-
-This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
-Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
-policies.
-
-Reference:
-[1] https://arxiv.org/abs/1805.09501
-"""
-
-import random
-
-import numpy as np
-from PIL import Image
-from PIL import ImageEnhance
-from PIL import ImageOps
-
-_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
-
-
-class ImageNetPolicy:
- """Definition of an ImageNetPolicy.
-
- Implements a fixed AutoAugment data augmentation policy targeted at
- ImageNet training by randomly applying at runtime one of the 25 pre-defined
- data augmentation sub-policies provided in Reference [1].
-
- Usage example as a Pytorch Transform:
- >>> transform=transforms.Compose([transforms.Resize(256),
- >>> ImageNetPolicy(),
- >>> transforms.ToTensor()])
- """
-
- def __init__(self, fillcolor=(128, 128, 128)):
- """Initialize an ImageNetPolicy.
-
- Args:
- fillcolor (tuple): RGB color components of the color to be used for
- filling when needed (default: (128, 128, 128), which
- corresponds to gray).
- """
- # Instantiate a list of sub-policies.
- # Each entry of the list is a SubPolicy which consists of
- # two augmentation operations,
- # each of those parametrized as operation, probability, magnitude.
- # Those two operations are applied sequentially on the image upon call.
- self.policies = [
- SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
- SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
- SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
- SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
- SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
- SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
- SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
- SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
- SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
- SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
- SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
- SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
- SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
- SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
- SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
- SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
- SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
- SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
- SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
- SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
- SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
- SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
- SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
- SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
- SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
- ]
-
- def __call__(self, img):
- """Define call method for ImageNetPolicy class."""
- policy_idx = random.randint(0, len(self.policies) - 1)
- return self.policies[policy_idx](img)
-
- def __repr__(self):
- """Define repr method for ImageNetPolicy class."""
- return "ImageNetPolicy"
-
-
-class SubPolicy:
- """Definition of a SubPolicy.
-
- A SubPolicy consists of two augmentation operations,
- each of those parametrized as operation, probability, magnitude.
- The two operations are applied sequentially on the image upon call.
- """
-
- def __init__(
- self,
- operation1,
- probability1,
- magnitude_idx1,
- operation2,
- probability2,
- magnitude_idx2,
- fillcolor,
- ):
- """Initialize a SubPolicy.
-
- Args:
- operation1 (str): Key specifying the first augmentation operation.
- There are fourteen key values altogether (see supported_ops below
- listing supported operations). probability1 (float): Probability
- within [0., 1.] of applying the first augmentation operation.
- magnitude_idx1 (int): Integer specifiying the strength of the first
- operation as an index further used to derive the magnitude from a
- range of possible values.
- operation2 (str): Key specifying the second augmentation operation.
- probability2 (float): Probability within [0., 1.] of applying the
- second augmentation operation.
- magnitude_idx2 (int): Integer specifiying the strength of the
- second operation as an index further used to derive the magnitude
- from a range of possible values.
- fillcolor (tuple): RGB color components of the color to be used for
- filling.
- Returns:
- """
- # List of supported operations for operation1 and operation2.
- supported_ops = [
- "shearX",
- "shearY",
- "translateX",
- "translateY",
- "rotate",
- "color",
- "posterize",
- "solarize",
- "contrast",
- "sharpness",
- "brightness",
- "autocontrast",
- "equalize",
- "invert",
- ]
- assert (operation1 in supported_ops) and (
- operation2 in supported_ops
- ), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
-
- assert (
- 0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
- ), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
-
- assert (
- isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
- ), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
-
- assert (
- isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
- ), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
-
- # Define a dictionary where each key refers to a specific type of
- # augmentation and the corresponding value is a range of ten possible
- # magnitude values for that augmentation.
- num_levels = _MAX_LEVEL + 1
- ranges = {
- "shearX": np.linspace(0, 0.3, num_levels),
- "shearY": np.linspace(0, 0.3, num_levels),
- "translateX": np.linspace(0, 150 / 331, num_levels),
- "translateY": np.linspace(0, 150 / 331, num_levels),
- "rotate": np.linspace(0, 30, num_levels),
- "color": np.linspace(0.0, 0.9, num_levels),
- "posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
- np.int32
- ),
- "solarize": np.linspace(256, 0, num_levels), # range [0, 256]
- "contrast": np.linspace(0.0, 0.9, num_levels),
- "sharpness": np.linspace(0.0, 0.9, num_levels),
- "brightness": np.linspace(0.0, 0.9, num_levels),
- "autocontrast": [0]
- * num_levels, # This augmentation doesn't use magnitude parameter.
- "equalize": [0]
- * num_levels, # This augmentation doesn't use magnitude parameter.
- "invert": [0]
- * num_levels, # This augmentation doesn't use magnitude parameter.
- }
-
- def rotate_with_fill(img, magnitude):
- """Define rotation transformation with fill.
-
- The input image is first rotated, then it is blended together with
- a gray mask of the same size. Note that fillcolor as defined
- elsewhere in this module doesn't apply here.
-
- Args:
- magnitude (float): rotation angle in degrees.
- Returns:
- rotated_filled (PIL Image): rotated image with gray filling for
- disoccluded areas unveiled by the rotation.
- """
- rotated = img.convert("RGBA").rotate(magnitude)
- rotated_filled = Image.composite(
- rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
- )
- return rotated_filled.convert(img.mode)
-
- # Define a dictionary of augmentation functions where each key refers
- # to a specific type of augmentation and the corresponding value defines
- # the augmentation itself using a lambda function.
- # pylint: disable=unnecessary-lambda
- func_dict = {
- "shearX": lambda img, magnitude: img.transform(
- img.size,
- Image.AFFINE,
- (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
- Image.BICUBIC,
- fillcolor=fillcolor,
- ),
- "shearY": lambda img, magnitude: img.transform(
- img.size,
- Image.AFFINE,
- (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
- Image.BICUBIC,
- fillcolor=fillcolor,
- ),
- "translateX": lambda img, magnitude: img.transform(
- img.size,
- Image.AFFINE,
- (
- 1,
- 0,
- magnitude * img.size[0] * random.choice([-1, 1]),
- 0,
- 1,
- 0,
- ),
- fillcolor=fillcolor,
- ),
- "translateY": lambda img, magnitude: img.transform(
- img.size,
- Image.AFFINE,
- (
- 1,
- 0,
- 0,
- 0,
- 1,
- magnitude * img.size[1] * random.choice([-1, 1]),
- ),
- fillcolor=fillcolor,
- ),
- "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
- "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
- 1 + magnitude * random.choice([-1, 1])
- ),
- "posterize": lambda img, magnitude: ImageOps.posterize(
- img, magnitude
- ),
- "solarize": lambda img, magnitude: ImageOps.solarize(
- img, magnitude
- ),
- "contrast": lambda img, magnitude: ImageEnhance.Contrast(
- img
- ).enhance(1 + magnitude * random.choice([-1, 1])),
- "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
- img
- ).enhance(1 + magnitude * random.choice([-1, 1])),
- "brightness": lambda img, magnitude: ImageEnhance.Brightness(
- img
- ).enhance(1 + magnitude * random.choice([-1, 1])),
- "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
- "equalize": lambda img, magnitude: ImageOps.equalize(img),
- "invert": lambda img, magnitude: ImageOps.invert(img),
- }
-
- # Store probability, function and magnitude of the first augmentation
- # for the sub-policy.
- self.probability1 = probability1
- self.operation1 = func_dict[operation1]
- self.magnitude1 = ranges[operation1][magnitude_idx1]
-
- # Store probability, function and magnitude of the second augmentation
- # for the sub-policy.
- self.probability2 = probability2
- self.operation2 = func_dict[operation2]
- self.magnitude2 = ranges[operation2][magnitude_idx2]
-
- def __call__(self, img):
- """Define call method for SubPolicy class."""
- # Randomly apply operation 1.
- if random.random() < self.probability1:
- img = self.operation1(img, self.magnitude1)
-
- # Randomly apply operation 2.
- if random.random() < self.probability2:
- img = self.operation2(img, self.magnitude2)
-
- return img
diff --git a/megatron/data/bert_dataset.py b/megatron/data/bert_dataset.py
deleted file mode 100644
index 036e6bccc9166214ed82f836dab4dbff974686ed..0000000000000000000000000000000000000000
--- a/megatron/data/bert_dataset.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""BERT Style dataset."""
-
-import numpy as np
-import torch
-
-from megatron import (
- get_args,
- get_tokenizer,
- mpu,
- print_rank_0
-)
-from megatron.data.dataset_utils import (
- get_samples_mapping,
- get_a_and_b_segments,
- truncate_segments,
- create_tokens_and_tokentypes,
- create_masked_lm_predictions
-)
-
-class BertDataset(torch.utils.data.Dataset):
-
- def __init__(self, name, indexed_dataset, data_prefix,
- num_epochs, max_num_samples, masked_lm_prob,
- max_seq_length, short_seq_prob, seed, binary_head):
-
- # Params to store.
- self.name = name
- self.seed = seed
- self.masked_lm_prob = masked_lm_prob
- self.max_seq_length = max_seq_length
- self.binary_head = binary_head
-
- # Dataset.
- self.indexed_dataset = indexed_dataset
-
- # Build the samples mapping.
- self.samples_mapping = get_samples_mapping(self.indexed_dataset,
- data_prefix,
- num_epochs,
- max_num_samples,
- self.max_seq_length - 3, # account for added tokens
- short_seq_prob,
- self.seed,
- self.name,
- self.binary_head)
-
- # Vocab stuff.
- tokenizer = get_tokenizer()
- self.vocab_id_list = list(tokenizer.inv_vocab.keys())
- self.vocab_id_to_token_dict = tokenizer.inv_vocab
- self.cls_id = tokenizer.cls
- self.sep_id = tokenizer.sep
- self.mask_id = tokenizer.mask
- self.pad_id = tokenizer.pad
-
- def __len__(self):
- return self.samples_mapping.shape[0]
-
- def __getitem__(self, idx):
- start_idx, end_idx, seq_length = self.samples_mapping[idx]
- sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
- # Note that this rng state should be numpy and not python since
- # python randint is inclusive whereas the numpy one is exclusive.
- # We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
- np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
- return build_training_sample(sample, seq_length,
- self.max_seq_length, # needed for padding
- self.vocab_id_list,
- self.vocab_id_to_token_dict,
- self.cls_id, self.sep_id,
- self.mask_id, self.pad_id,
- self.masked_lm_prob, np_rng,
- self.binary_head)
-
-
-
-
-def build_training_sample(sample,
- target_seq_length, max_seq_length,
- vocab_id_list, vocab_id_to_token_dict,
- cls_id, sep_id, mask_id, pad_id,
- masked_lm_prob, np_rng, binary_head):
- """Biuld training sample.
-
- Arguments:
- sample: A list of sentences in which each sentence is a list token ids.
- target_seq_length: Desired sequence length.
- max_seq_length: Maximum length of the sequence. All values are padded to
- this length.
- vocab_id_list: List of vocabulary ids. Used to pick a random id.
- vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
- cls_id: Start of example id.
- sep_id: Separator id.
- mask_id: Mask token id.
- pad_id: Padding token id.
- masked_lm_prob: Probability to mask tokens.
- np_rng: Random number genenrator. Note that this rng state should be
- numpy and not python since python randint is inclusive for
- the opper bound whereas the numpy one is exclusive.
- """
-
- if binary_head:
- # We assume that we have at least two sentences in the sample
- assert len(sample) > 1
- assert target_seq_length <= max_seq_length
-
- # Divide sample into two segments (A and B).
- if binary_head:
- tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
- np_rng)
- else:
- tokens_a = []
- for j in range(len(sample)):
- tokens_a.extend(sample[j])
- tokens_b = []
- is_next_random = False
-
- # Truncate to `target_sequence_length`.
- max_num_tokens = target_seq_length
- truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
- len(tokens_b), max_num_tokens, np_rng)
-
- # Build tokens and toketypes.
- tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
- cls_id, sep_id)
-
- # Masking.
- max_predictions_per_seq = masked_lm_prob * max_num_tokens
- (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
- tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
- cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
-
- # Padding.
- tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
- = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
- masked_labels, pad_id, max_seq_length)
-
- train_sample = {
- 'text': tokens_np,
- 'types': tokentypes_np,
- 'labels': labels_np,
- 'is_random': int(is_next_random),
- 'loss_mask': loss_mask_np,
- 'padding_mask': padding_mask_np,
- 'truncated': int(truncated)}
- return train_sample
-
-
-def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
- masked_labels, pad_id, max_seq_length):
- """Pad sequences and convert them to numpy."""
-
- # Some checks.
- num_tokens = len(tokens)
- padding_length = max_seq_length - num_tokens
- assert padding_length >= 0, \
- f"num_tokens ({num_tokens}) is greater than " \
- "max_seq_length ({max_seq_length})."
- assert len(tokentypes) == num_tokens
- assert len(masked_positions) == len(masked_labels)
-
- # Tokens and token types.
- filler = [pad_id] * padding_length
- tokens_np = np.array(tokens + filler, dtype=np.int64)
- tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
-
- # Padding mask.
- padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
- dtype=np.int64)
-
- # Lables and loss mask.
- labels = [-1] * max_seq_length
- loss_mask = [0] * max_seq_length
- for i in range(len(masked_positions)):
- assert masked_positions[i] < num_tokens
- labels[masked_positions[i]] = masked_labels[i]
- loss_mask[masked_positions[i]] = 1
- labels_np = np.array(labels, dtype=np.int64)
- loss_mask_np = np.array(loss_mask, dtype=np.int64)
-
- return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py
deleted file mode 100644
index f137528adaef2af2d87567db9c73ef709b02b0f0..0000000000000000000000000000000000000000
--- a/megatron/data/biencoder_dataset_utils.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import os
-import time
-
-import numpy as np
-import torch
-
-from megatron import get_args, get_tokenizer, print_rank_0
-from megatron.core import mpu, tensor_parallel
-from megatron.data.dataset_utils import create_masked_lm_predictions, \
- pad_and_convert_to_numpy
-from megatron.data.data_samplers import MegatronPretrainingSampler
-
-def make_attention_mask(source_block, target_block):
- """
- Returns a 2-dimensional (2-D) attention mask
- :param source_block: 1-D array
- :param target_block: 1-D array
- """
- mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
- mask = mask.astype(np.int64)
- # (source_length, target_length)
- return mask
-
-def get_one_epoch_dataloader(dataset, micro_batch_size=None):
- """Specifically one epoch to be used in an indexing job."""
- args = get_args()
-
- if micro_batch_size is None:
- micro_batch_size = args.micro_batch_size
- num_workers = args.num_workers
-
- # Use megatron's sampler with consumed samples set to 0 as
- # this is only for evaluation and don't intend to resume half way.
- # Also, set the drop last to false as don't intend to remove
- # the last batch
- batch_sampler = MegatronPretrainingSampler(
- total_samples=len(dataset),
- consumed_samples=0,
- micro_batch_size=args.micro_batch_size,
- data_parallel_rank=mpu.get_data_parallel_rank(),
- data_parallel_size=mpu.get_data_parallel_world_size(),
- drop_last=False)
-
- return torch.utils.data.DataLoader(dataset,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- pin_memory=True)
-
-
-def get_ict_batch(data_iterator):
- # Items and their type.
- keys = ['query_tokens', 'query_mask',
- 'context_tokens', 'context_mask', 'block_data']
- datatype = torch.int64
-
- # Broadcast data.
- if data_iterator is None:
- data = None
- else:
- data = next(data_iterator)
- data_b = tensor_parallel.broadcast_data(keys, data, datatype)
-
- # Unpack.
- query_tokens = data_b['query_tokens'].long()
- query_mask = data_b['query_mask'] < 0.5
- context_tokens = data_b['context_tokens'].long()
- context_mask = data_b['context_mask'] < 0.5
- block_indices = data_b['block_data'].long()
-
- return query_tokens, query_mask,\
- context_tokens, context_mask, block_indices
-
-
-def join_str_list(str_list):
- """Join a list of strings, handling spaces appropriately"""
- result = ""
- for s in str_list:
- if s.startswith("##"):
- result += s[2:]
- else:
- result += " " + s
- return result
-
-
-class BlockSampleData(object):
- """A struct for fully describing a fixed-size block of data as used in REALM
-
- :param start_idx: for first sentence of the block
- :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
- :param doc_idx: the index of the document from which the block comes in the original indexed dataset
- :param block_idx: a unique integer identifier given to every block.
- """
- def __init__(self, start_idx, end_idx, doc_idx, block_idx):
- self.start_idx = start_idx
- self.end_idx = end_idx
- self.doc_idx = doc_idx
- self.block_idx = block_idx
-
- def as_array(self):
- return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
-
- def as_tuple(self):
- return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
-
-
-class BlockSamplesMapping(object):
- def __init__(self, mapping_array):
- # make sure that the array is compatible with BlockSampleData
- assert mapping_array.shape[1] == 4
- self.mapping_array = mapping_array
-
- def __len__(self):
- return self.mapping_array.shape[0]
-
- def __getitem__(self, idx):
- """Get the data associated with an indexed sample."""
- sample_data = BlockSampleData(*self.mapping_array[idx])
- return sample_data
-
-
-def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
- max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
- """Get samples mapping for a dataset over fixed size blocks. This function also requires
- a dataset of the titles for the source documents since their lengths must be taken into account.
-
- :return: samples_mapping (BlockSamplesMapping)
- """
-
- if not num_epochs:
- if not max_num_samples:
- raise ValueError("Need to specify either max_num_samples "
- "or num_epochs")
- num_epochs = np.iinfo(np.int32).max - 1
- if not max_num_samples:
- max_num_samples = np.iinfo(np.int64).max - 1
-
- # Filename of the index mapping
- indexmap_filename = data_prefix
- indexmap_filename += '_{}_indexmap'.format(name)
- if num_epochs != (np.iinfo(np.int32).max - 1):
- indexmap_filename += '_{}ep'.format(num_epochs)
- if max_num_samples != (np.iinfo(np.int64).max - 1):
- indexmap_filename += '_{}mns'.format(max_num_samples)
- indexmap_filename += '_{}msl'.format(max_seq_length)
- indexmap_filename += '_{}s'.format(seed)
- if use_one_sent_docs:
- indexmap_filename += '_1sentok'
- indexmap_filename += '.npy'
-
- # Build the indexed mapping if not exist.
- if mpu.get_data_parallel_rank() == 0 and \
- not os.path.isfile(indexmap_filename):
- print(' > WARNING: could not find index map file {}, building '
- 'the indices on rank 0 ...'.format(indexmap_filename))
-
- # Make sure the types match the helpers input types.
- assert block_dataset.document_indices.dtype == np.int64
- assert block_dataset.sequence_lengths.dtype == np.int32
-
- # Build samples mapping
- verbose = torch.distributed.get_rank() == 0
- start_time = time.time()
- print_rank_0(' > building samples index mapping for {} ...'.format(
- name))
-
- from megatron.core.datasets import helpers
- mapping_array = helpers.build_blocks_mapping(
- block_dataset.document_indices,
- block_dataset.sequence_lengths,
- title_dataset.sequence_lengths,
- num_epochs,
- max_num_samples,
- max_seq_length - 3, # account for added tokens
- seed,
- verbose,
- use_one_sent_docs)
-
-
- print_rank_0(' > done building samples index mapping')
- np.save(indexmap_filename, mapping_array, allow_pickle=True)
- print_rank_0(' > saved the index mapping in {}'.format(
- indexmap_filename))
- # Make sure all the ranks have built the mapping
- print_rank_0(' > elapsed time to build and save samples mapping '
- '(seconds): {:4f}'.format(
- time.time() - start_time))
-
- # This should be a barrier but nccl barrier assumes
- # device_index=rank which is not the case for model
- # parallel case
- counts = torch.cuda.LongTensor([1])
- torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
- assert counts[0].item() == torch.distributed.get_world_size(
- group=mpu.get_data_parallel_group())
-
- # Load indexed dataset.
- print_rank_0(' > loading indexed mapping from {}'.format(
- indexmap_filename))
- start_time = time.time()
-
- mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
- samples_mapping = BlockSamplesMapping(mapping_array)
-
- print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
- time.time() - start_time))
- print_rank_0(' total number of samples: {}'.format(
- mapping_array.shape[0]))
-
- return samples_mapping
diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py
deleted file mode 100644
index 8dec2c192236f0c0af4b5534e963a1e65cf455ad..0000000000000000000000000000000000000000
--- a/megatron/data/data_samplers.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Dataloaders."""
-
-
-import random
-import torch
-import numpy as np
-from torch.utils.data import Dataset
-from megatron import get_args
-from megatron.core import mpu
-
-
-def build_pretraining_data_loader(dataset, consumed_samples):
- """Buld dataloader given an input dataset."""
-
- if dataset is None:
- return None
- args = get_args()
-
- # Megatron sampler
- if args.dataloader_type == 'single':
- batch_sampler = MegatronPretrainingSampler(
- total_samples=len(dataset),
- consumed_samples=consumed_samples,
- micro_batch_size=args.micro_batch_size,
- data_parallel_rank=mpu.get_data_parallel_rank(),
- data_parallel_size=mpu.get_data_parallel_world_size())
- elif args.dataloader_type == 'cyclic':
- batch_sampler = MegatronPretrainingRandomSampler(
- dataset,
- total_samples=len(dataset),
- consumed_samples=consumed_samples,
- micro_batch_size=args.micro_batch_size,
- data_parallel_rank=mpu.get_data_parallel_rank(),
- data_parallel_size=mpu.get_data_parallel_world_size(),
- data_sharding=args.data_sharding)
- else:
- raise Exception('{} dataloader type is not supported.'.format(
- args.dataloader_type))
-
- # Torch dataloader.
- return torch.utils.data.DataLoader(dataset,
- batch_sampler=batch_sampler,
- num_workers=args.num_workers,
- pin_memory=True)
-
-class MegatronPretrainingSampler:
-
- def __init__(self, total_samples, consumed_samples, micro_batch_size,
- data_parallel_rank, data_parallel_size, drop_last=True):
- # Keep a copy of input params for later use.
- self.total_samples = total_samples
- self.consumed_samples = consumed_samples
- self.micro_batch_size = micro_batch_size
- self.data_parallel_rank = data_parallel_rank
- self.micro_batch_times_data_parallel_size = \
- self.micro_batch_size * data_parallel_size
- self.drop_last = drop_last
-
- # Sanity checks.
- assert self.total_samples > 0, \
- 'no sample to consume: {}'.format(self.total_samples)
- assert self.consumed_samples < self.total_samples, \
- 'no samples left to consume: {}, {}'.format(self.consumed_samples,
- self.total_samples)
- assert self.micro_batch_size > 0
- assert data_parallel_size > 0
- assert self.data_parallel_rank < data_parallel_size, \
- 'data_parallel_rank should be smaller than data size: {}, ' \
- '{}'.format(self.data_parallel_rank, data_parallel_size)
-
- def __len__(self):
- return self.total_samples
-
- def get_start_end_idx(self):
- start_idx = self.data_parallel_rank * self.micro_batch_size
- end_idx = start_idx + self.micro_batch_size
- return start_idx, end_idx
-
- def __iter__(self):
- batch = []
- # Last batch will be dropped if drop_last is not set False
- for idx in range(self.consumed_samples, self.total_samples):
- batch.append(idx)
- if len(batch) == self.micro_batch_times_data_parallel_size:
- start_idx, end_idx = self.get_start_end_idx()
- yield batch[start_idx:end_idx]
- batch = []
-
- # Check the last partial batch and see drop_last is set
- if len(batch) > 0 and not self.drop_last:
- start_idx, end_idx = self.get_start_end_idx()
- yield batch[start_idx:end_idx]
-
-
-class RandomSeedDataset(Dataset):
-
- def __init__(self, dataset):
- args = get_args()
- self.base_seed = args.seed
- self.curr_seed = args.seed
- self.dataset = dataset
-
- def __len__(self):
- return len(self.dataset)
-
- def set_epoch(self, epoch):
- self.curr_seed = self.base_seed + epoch
-
- def __getitem__(self, idx):
- seed = idx + self.curr_seed
- torch.manual_seed(seed)
- random.seed(seed)
- np.random.seed(seed)
- return self.dataset[idx]
-
-
-class MegatronPretrainingRandomSampler:
-
- def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
- data_parallel_rank, data_parallel_size, data_sharding):
- # Keep a copy of input params for later use.
- self.dataset = dataset
- self.total_samples = total_samples
- self.consumed_samples = consumed_samples
- self.micro_batch_size = micro_batch_size
- self.data_parallel_rank = data_parallel_rank
- self.data_parallel_size = data_parallel_size
- self.data_sharding = data_sharding
- self.micro_batch_times_data_parallel_size = \
- self.micro_batch_size * data_parallel_size
- self.last_batch_size = \
- self.total_samples % self.micro_batch_times_data_parallel_size
-
- # Sanity checks.
- assert self.total_samples > 0, \
- 'no sample to consume: {}'.format(self.total_samples)
- assert self.micro_batch_size > 0
- assert data_parallel_size > 0
- assert self.data_parallel_rank < data_parallel_size, \
- 'data_parallel_rank should be smaller than data size: {}, ' \
- '{}'.format(self.data_parallel_rank, data_parallel_size)
-
- def __len__(self):
- return self.total_samples
-
- def __iter__(self):
- active_total_samples = self.total_samples - self.last_batch_size
- self.epoch = self.consumed_samples // active_total_samples
- current_epoch_samples = self.consumed_samples % active_total_samples
- assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
-
- if isinstance(self.dataset, RandomSeedDataset):
- self.dataset.set_epoch(self.epoch)
-
- # data sharding and random sampling
- if self.data_sharding:
- bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
- * self.micro_batch_size
- bucket_offset = current_epoch_samples // self.data_parallel_size
- start_idx = self.data_parallel_rank * bucket_size
-
- g = torch.Generator()
- g.manual_seed(self.epoch)
- random_idx = torch.randperm(bucket_size, generator=g).tolist()
- idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
- else:
- full_bucket_size = (self.total_samples // self.micro_batch_size) \
- * self.micro_batch_size
- full_bucket_offset = current_epoch_samples
- g = torch.Generator()
- g.manual_seed(self.epoch)
- idx_range_total = \
- torch.randperm(full_bucket_size, generator=g).tolist()
- idx_range_active = idx_range_total[full_bucket_offset:]
- idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]
-
- batch = []
- # Last batch if not complete will be dropped.
- for idx in idx_range:
- batch.append(idx)
- if len(batch) == self.micro_batch_size:
- self.consumed_samples += self.micro_batch_times_data_parallel_size
- yield batch
- batch = []
diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py
deleted file mode 100644
index 561129c865d14f2c5578dfbdc6b2a8bbfac88b78..0000000000000000000000000000000000000000
--- a/megatron/data/dataset_utils.py
+++ /dev/null
@@ -1,743 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Most of the code here has been copied from:
-# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
-# with some modifications.
-
-import math
-import os
-import time
-import collections
-
-import numpy as np
-import torch
-
-from megatron import (
- get_args,
- print_rank_0
-)
-from megatron.core import mpu
-from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
-
-
-DSET_TYPE_BERT = 'standard_bert'
-DSET_TYPE_ICT = 'ict'
-DSET_TYPE_T5 = 't5'
-DSET_TYPE_MULTIMODAL = 'multimodal'
-
-DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_MULTIMODAL]
-
-
-def get_datasets_weights_and_num_samples(data_prefix,
- train_valid_test_num_samples):
-
- # The data prefix should be in the format of:
- # weight-1, data-prefix-1, weight-2, data-prefix-2, ..
- assert len(data_prefix) % 2 == 0
- num_datasets = len(data_prefix) // 2
- weights = [0]*num_datasets
- prefixes = [0]*num_datasets
- for i in range(num_datasets):
- weights[i] = float(data_prefix[2*i])
- prefixes[i] = (data_prefix[2*i+1]).strip()
- # Normalize weights
- weight_sum = 0.0
- for weight in weights:
- weight_sum += weight
- assert weight_sum > 0.0
- weights = [weight / weight_sum for weight in weights]
-
- # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
- # not uniformly distribute the number of samples, we still have
- # samples left to feed to the network.
- if isinstance(train_valid_test_num_samples, list):
- datasets_train_valid_test_num_samples = []
- for weight in weights:
- datasets_train_valid_test_num_samples.append(
- [int(math.ceil(val * weight * 1.005))
- for val in train_valid_test_num_samples])
- else:
- # Used when separate dataset files are provided for train,
- # valid and test
- datasets_train_valid_test_num_samples = [
- int(math.ceil(train_valid_test_num_samples * weight * 1.005))
- for weight in weights]
-
- return prefixes, weights, datasets_train_valid_test_num_samples
-
-
-def get_a_and_b_segments(sample, np_rng):
- """Divide sample into a and b segments."""
-
- # Number of sentences in the sample.
- n_sentences = len(sample)
- # Make sure we always have two sentences.
- assert n_sentences > 1, 'make sure each sample has at least two sentences.'
-
- # First part:
- # `a_end` is how many sentences go into the `A`.
- a_end = 1
- if n_sentences >= 3:
- # Note that randin in numpy is exclusive.
- a_end = np_rng.randint(1, n_sentences)
- tokens_a = []
- for j in range(a_end):
- tokens_a.extend(sample[j])
-
- # Second part:
- tokens_b = []
- for j in range(a_end, n_sentences):
- tokens_b.extend(sample[j])
-
- # Random next:
- is_next_random = False
- if np_rng.random() < 0.5:
- is_next_random = True
- tokens_a, tokens_b = tokens_b, tokens_a
-
- return tokens_a, tokens_b, is_next_random
-
-
-def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
- """Truncates a pair of sequences to a maximum sequence length."""
- #print(len_a, len_b, max_num_tokens)
- assert len_a > 0
- if len_a + len_b <= max_num_tokens:
- return False
- while len_a + len_b > max_num_tokens:
- if len_a > len_b:
- len_a -= 1
- tokens = tokens_a
- else:
- len_b -= 1
- tokens = tokens_b
- if np_rng.random() < 0.5:
- del tokens[0]
- else:
- tokens.pop()
- return True
-
-
-def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
- """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
-
- tokens = []
- tokentypes = []
- # [CLS].
- tokens.append(cls_id)
- tokentypes.append(0)
- # Segment A.
- for token in tokens_a:
- tokens.append(token)
- tokentypes.append(0)
- # [SEP].
- tokens.append(sep_id)
- tokentypes.append(0)
- # Segment B.
- for token in tokens_b:
- tokens.append(token)
- tokentypes.append(1)
- if tokens_b:
- # [SEP].
- tokens.append(sep_id)
- tokentypes.append(1)
-
- return tokens, tokentypes
-
-
-MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
- ["index", "label"])
-
-
-def is_start_piece(piece):
- """Check if the current word piece is the starting piece (BERT)."""
- # When a word has been split into
- # WordPieces, the first token does not have any marker and any subsequence
- # tokens are prefixed with ##. So whenever we see the ## token, we
- # append it to the previous set of word indexes.
- return not piece.startswith("##")
-
-
-def create_masked_lm_predictions(tokens,
- vocab_id_list, vocab_id_to_token_dict,
- masked_lm_prob,
- cls_id, sep_id, mask_id,
- max_predictions_per_seq,
- np_rng,
- max_ngrams=3,
- do_whole_word_mask=True,
- favor_longer_ngram=False,
- do_permutation=False,
- geometric_dist=False,
- masking_style="bert"):
- """Creates the predictions for the masked LM objective.
- Note: Tokens here are vocab ids and not text tokens."""
-
- cand_indexes = []
- # Note(mingdachen): We create a list for recording if the piece is
- # the starting piece of current token, where 1 means true, so that
- # on-the-fly whole word masking is possible.
- token_boundary = [0] * len(tokens)
-
- for (i, token) in enumerate(tokens):
- if token == cls_id or token == sep_id:
- token_boundary[i] = 1
- continue
- # Whole Word Masking means that if we mask all of the wordpieces
- # corresponding to an original word.
- #
- # Note that Whole Word Masking does *not* change the training code
- # at all -- we still predict each WordPiece independently, softmaxed
- # over the entire vocabulary.
- if (do_whole_word_mask and len(cand_indexes) >= 1 and
- not is_start_piece(vocab_id_to_token_dict[token])):
- cand_indexes[-1].append(i)
- else:
- cand_indexes.append([i])
- if is_start_piece(vocab_id_to_token_dict[token]):
- token_boundary[i] = 1
-
- output_tokens = list(tokens)
-
- masked_lm_positions = []
- masked_lm_labels = []
-
- if masked_lm_prob == 0:
- return (output_tokens, masked_lm_positions,
- masked_lm_labels, token_boundary)
-
- num_to_predict = min(max_predictions_per_seq,
- max(1, int(round(len(tokens) * masked_lm_prob))))
-
- ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
- if not geometric_dist:
- # Note(mingdachen):
- # By default, we set the probilities to favor shorter ngram sequences.
- pvals = 1. / np.arange(1, max_ngrams + 1)
- pvals /= pvals.sum(keepdims=True)
- if favor_longer_ngram:
- pvals = pvals[::-1]
-
- ngram_indexes = []
- for idx in range(len(cand_indexes)):
- ngram_index = []
- for n in ngrams:
- ngram_index.append(cand_indexes[idx:idx + n])
- ngram_indexes.append(ngram_index)
-
- np_rng.shuffle(ngram_indexes)
-
- (masked_lms, masked_spans) = ([], [])
- covered_indexes = set()
- for cand_index_set in ngram_indexes:
- if len(masked_lms) >= num_to_predict:
- break
- if not cand_index_set:
- continue
- # Note(mingdachen):
- # Skip current piece if they are covered in lm masking or previous ngrams.
- for index_set in cand_index_set[0]:
- for index in index_set:
- if index in covered_indexes:
- continue
-
- if not geometric_dist:
- n = np_rng.choice(ngrams[:len(cand_index_set)],
- p=pvals[:len(cand_index_set)] /
- pvals[:len(cand_index_set)].sum(keepdims=True))
- else:
- # Sampling "n" from the geometric distribution and clipping it to
- # the max_ngrams. Using p=0.2 default from the SpanBERT paper
- # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
- n = min(np_rng.geometric(0.2), max_ngrams)
-
- index_set = sum(cand_index_set[n - 1], [])
- n -= 1
- # Note(mingdachen):
- # Repeatedly looking for a candidate that does not exceed the
- # maximum number of predictions by trying shorter ngrams.
- while len(masked_lms) + len(index_set) > num_to_predict:
- if n == 0:
- break
- index_set = sum(cand_index_set[n - 1], [])
- n -= 1
- # If adding a whole-word mask would exceed the maximum number of
- # predictions, then just skip this candidate.
- if len(masked_lms) + len(index_set) > num_to_predict:
- continue
- is_any_index_covered = False
- for index in index_set:
- if index in covered_indexes:
- is_any_index_covered = True
- break
- if is_any_index_covered:
- continue
- for index in index_set:
- covered_indexes.add(index)
- masked_token = None
- if masking_style == "bert":
- # 80% of the time, replace with [MASK]
- if np_rng.random() < 0.8:
- masked_token = mask_id
- else:
- # 10% of the time, keep original
- if np_rng.random() < 0.5:
- masked_token = tokens[index]
- # 10% of the time, replace with random word
- else:
- masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
- elif masking_style == "t5":
- masked_token = mask_id
- else:
- raise ValueError("invalid value of masking style")
-
- output_tokens[index] = masked_token
- masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
-
- masked_spans.append(MaskedLmInstance(
- index=index_set,
- label=[tokens[index] for index in index_set]))
-
- assert len(masked_lms) <= num_to_predict
- np_rng.shuffle(ngram_indexes)
-
- select_indexes = set()
- if do_permutation:
- for cand_index_set in ngram_indexes:
- if len(select_indexes) >= num_to_predict:
- break
- if not cand_index_set:
- continue
- # Note(mingdachen):
- # Skip current piece if they are covered in lm masking or previous ngrams.
- for index_set in cand_index_set[0]:
- for index in index_set:
- if index in covered_indexes or index in select_indexes:
- continue
-
- n = np.random.choice(ngrams[:len(cand_index_set)],
- p=pvals[:len(cand_index_set)] /
- pvals[:len(cand_index_set)].sum(keepdims=True))
- index_set = sum(cand_index_set[n - 1], [])
- n -= 1
-
- while len(select_indexes) + len(index_set) > num_to_predict:
- if n == 0:
- break
- index_set = sum(cand_index_set[n - 1], [])
- n -= 1
- # If adding a whole-word mask would exceed the maximum number of
- # predictions, then just skip this candidate.
- if len(select_indexes) + len(index_set) > num_to_predict:
- continue
- is_any_index_covered = False
- for index in index_set:
- if index in covered_indexes or index in select_indexes:
- is_any_index_covered = True
- break
- if is_any_index_covered:
- continue
- for index in index_set:
- select_indexes.add(index)
- assert len(select_indexes) <= num_to_predict
-
- select_indexes = sorted(select_indexes)
- permute_indexes = list(select_indexes)
- np_rng.shuffle(permute_indexes)
- orig_token = list(output_tokens)
-
- for src_i, tgt_i in zip(select_indexes, permute_indexes):
- output_tokens[src_i] = orig_token[tgt_i]
- masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
-
- masked_lms = sorted(masked_lms, key=lambda x: x.index)
- # Sort the spans by the index of the first span
- masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
-
- for p in masked_lms:
- masked_lm_positions.append(p.index)
- masked_lm_labels.append(p.label)
- return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
-
-
-def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
- masked_labels, pad_id, max_seq_length):
- """Pad sequences and convert them to numpy."""
-
- # Some checks.
- num_tokens = len(tokens)
- padding_length = max_seq_length - num_tokens
- assert padding_length >= 0
- assert len(tokentypes) == num_tokens
- assert len(masked_positions) == len(masked_labels)
-
- # Tokens and token types.
- filler = [pad_id] * padding_length
- tokens_np = np.array(tokens + filler, dtype=np.int64)
- tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
-
- # Padding mask.
- padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
- dtype=np.int64)
-
- # Lables and loss mask.
- labels = [-1] * max_seq_length
- loss_mask = [0] * max_seq_length
- for i in range(len(masked_positions)):
- assert masked_positions[i] < num_tokens
- labels[masked_positions[i]] = masked_labels[i]
- loss_mask[masked_positions[i]] = 1
- labels_np = np.array(labels, dtype=np.int64)
- loss_mask_np = np.array(loss_mask, dtype=np.int64)
-
- return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
-
-
-def build_train_valid_test_datasets_with_prefixes(train_valid_test_num_samples,
- max_seq_length,
- seed,
- train_data_prefix=None,
- valid_data_prefix=None,
- test_data_prefix=None,
- binary_head=False,
- max_seq_length_dec=None,
- dataset_type='standard_bert'):
- print_rank_0("Separate data paths provided for train, valid & test.")
-
- train_dataset, valid_dataset, test_dataset = None, None, None
- # Single dataset.
- if train_data_prefix is not None:
- train_dataset = build_dataset("train", train_data_prefix,
- train_valid_test_num_samples[0],
- max_seq_length, seed,
- binary_head, max_seq_length_dec,
- dataset_type=dataset_type)
-
- if valid_data_prefix is not None:
- valid_dataset = build_dataset("valid", valid_data_prefix,
- train_valid_test_num_samples[1],
- max_seq_length, seed, False,
- binary_head, max_seq_length_dec,
- dataset_type=dataset_type)
-
- if test_data_prefix is not None:
- test_dataset = build_dataset("test", test_data_prefix,
- train_valid_test_num_samples[2],
- max_seq_length, seed, False,
- binary_head, max_seq_length_dec,
- dataset_type=dataset_type)
-
- return (train_dataset, valid_dataset, test_dataset)
-
-
-def build_train_valid_test_datasets(data_prefix, splits_string,
- train_valid_test_num_samples,
- max_seq_length, seed,
- binary_head=False,
- max_seq_length_dec=None,
- dataset_type='standard_bert'):
-
- if len(data_prefix) == 1:
- return _build_train_valid_test_datasets(data_prefix[0],
- splits_string,
- train_valid_test_num_samples,
- max_seq_length, seed,
- binary_head,
- max_seq_length_dec,
- dataset_type=dataset_type)
-
- raise NotImplementedError("Blending currently unsupported for non-GPT dataset instances")
-
-
-def _build_train_valid_test_datasets(data_prefix, splits_string,
- train_valid_test_num_samples,
- max_seq_length, seed,
- binary_head,
- max_seq_length_dec,
- dataset_type='standard_bert'):
-
- # Indexed dataset.
- indexed_dataset = get_indexed_dataset_(data_prefix,
- dataset_type)
-
- # Get start and end indices of train/valid/train into doc-idx
- # Note that doc-idx is desinged to be num-docs + 1 so we can
- # easily iterate over it.
- total_num_of_documents = indexed_dataset.document_indices.shape[0] - 1
- splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
-
- # Print stats about the splits.
- print_rank_0(' > dataset split:')
-
- def print_split_stats(name, index):
- print_rank_0(' {}:'.format(name))
- print_rank_0(' document indices in [{}, {}) total of {} '
- 'documents'.format(splits[index], splits[index + 1],
- splits[index + 1] - splits[index]))
- start_index = indexed_dataset.document_indices[splits[index]]
- end_index = indexed_dataset.document_indices[splits[index + 1]]
- print_rank_0(' sentence indices in [{}, {}) total of {} '
- 'sentences'.format(start_index, end_index,
- end_index - start_index))
- print_split_stats('train', 0)
- print_split_stats('validation', 1)
- print_split_stats('test', 2)
-
- def build_split_dataset(index, name):
- dataset = None
- if splits[index + 1] > splits[index]:
- # Get the pointer to the original doc-idx so we can set it later.
- doc_idx_ptr = indexed_dataset.get_document_indices()
- # Slice the doc-idx
- start_index = splits[index]
- # Add +1 so we can index into the dataset to get the upper bound.
- end_index = splits[index + 1] + 1
- # New doc_idx view.
- indexed_dataset.set_document_indices(doc_idx_ptr[start_index:end_index])
-
- dataset = build_dataset(
- name, data_prefix,
- train_valid_test_num_samples[index], max_seq_length,
- seed, binary_head, max_seq_length_dec,
- dataset_type, indexed_dataset)
-
- # Set the original pointer so dataset remains the main dataset.
- indexed_dataset.set_document_indices(doc_idx_ptr)
- # Checks.
- assert indexed_dataset.document_indices[0] == 0
- assert indexed_dataset.document_indices.shape[0] == \
- (total_num_of_documents + 1)
- return dataset
-
- train_dataset = build_split_dataset(0, 'train')
- valid_dataset = build_split_dataset(1, 'valid')
- test_dataset = build_split_dataset(2, 'test')
-
- return (train_dataset, valid_dataset, test_dataset)
-
-
-def build_dataset(name, data_prefix, max_num_samples,
- max_seq_length, seed, binary_head,
- max_seq_length_dec, dataset_type='standard_bert',
- indexed_dataset=None):
-
- from megatron.data.bert_dataset import BertDataset
- from megatron.data.ict_dataset import ICTDataset
- from megatron.data.t5_dataset import T5Dataset
- from megatron.data.multimodal_dataset import MultiModalDataset
-
- if dataset_type not in DSET_TYPES:
- raise ValueError("Invalid dataset_type: ", dataset_type)
-
- if indexed_dataset is None:
- indexed_dataset = get_indexed_dataset_(data_prefix,
- dataset_type)
-
- kwargs = dict(
- name=name,
- data_prefix=data_prefix,
- num_epochs=None,
- max_num_samples=max_num_samples,
- max_seq_length=max_seq_length,
- seed=seed,
- )
-
- if dataset_type == DSET_TYPE_ICT:
- args = get_args()
-
- title_dataset = get_indexed_dataset_(
- args.titles_data_path,
- dataset_type)
-
- dataset = ICTDataset(
- block_dataset=indexed_dataset,
- title_dataset=title_dataset,
- query_in_block_prob=args.query_in_block_prob,
- use_one_sent_docs=args.use_one_sent_docs,
- binary_head=binary_head,
- **kwargs
- )
- elif dataset_type == DSET_TYPE_T5:
- args = get_args()
- dataset = T5Dataset(
- indexed_dataset=indexed_dataset,
- masked_lm_prob=args.mask_prob,
- max_seq_length_dec=max_seq_length_dec,
- short_seq_prob=args.short_seq_prob,
- **kwargs
- )
- elif dataset_type == DSET_TYPE_BERT:
- args = get_args()
- dataset = BertDataset(
- indexed_dataset=indexed_dataset,
- masked_lm_prob=args.mask_prob,
- short_seq_prob=args.short_seq_prob,
- binary_head=binary_head,
- **kwargs
- )
- elif dataset_type == DSET_TYPE_MULTIMODAL:
- args = get_args()
- dataset = MultiModalDataset(
- name=name,
- data_prefix=data_prefix,
- indexed_dataset=indexed_dataset,
- num_samples=max_num_samples,
- seq_length=max_seq_length,
- seed=seed,
- img_h=args.img_h,
- img_w=args.img_w,
- )
- else:
- raise NotImplementedError("Dataset type not fully implemented.")
-
- return dataset
-
-
-def get_indexed_dataset_(data_prefix, dataset_type):
-
- print_rank_0(' > building dataset index ...')
-
- start_time = time.time()
- multimodal = dataset_type == DSET_TYPE_MULTIMODAL
- indexed_dataset = MMapIndexedDataset(data_prefix, multimodal)
- assert indexed_dataset.sequence_lengths.shape[0] == indexed_dataset.document_indices[-1]
- print_rank_0(' > finished creating indexed dataset in {:4f} '
- 'seconds'.format(time.time() - start_time))
-
- print_rank_0(' > indexed dataset stats:')
- print_rank_0(' number of documents: {}'.format(
- indexed_dataset.document_indices.shape[0] - 1))
- print_rank_0(' number of sentences: {}'.format(
- indexed_dataset.sequence_lengths.shape[0]))
-
- return indexed_dataset
-
-
-def get_train_valid_test_split_(splits_string, size):
- """ Get dataset splits from comma or '/' separated string list."""
-
- splits = []
- if splits_string.find(',') != -1:
- splits = [float(s) for s in splits_string.split(',')]
- elif splits_string.find('/') != -1:
- splits = [float(s) for s in splits_string.split('/')]
- else:
- splits = [float(splits_string)]
- while len(splits) < 3:
- splits.append(0.)
- splits = splits[:3]
- splits_sum = sum(splits)
- assert splits_sum > 0.0
- splits = [split / splits_sum for split in splits]
- splits_index = [0]
- for index, split in enumerate(splits):
- splits_index.append(splits_index[index] +
- int(round(split * float(size))))
- diff = splits_index[-1] - size
- for index in range(1, len(splits_index)):
- splits_index[index] -= diff
- assert len(splits_index) == 4
- assert splits_index[-1] == size
- return splits_index
-
-def get_samples_mapping(indexed_dataset,
- data_prefix,
- num_epochs,
- max_num_samples,
- max_seq_length,
- short_seq_prob,
- seed,
- name,
- binary_head):
- """Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
-
- if not num_epochs:
- if not max_num_samples:
- raise ValueError("Need to specify either max_num_samples "
- "or num_epochs")
- num_epochs = np.iinfo(np.int32).max - 1
- if not max_num_samples:
- max_num_samples = np.iinfo(np.int64).max - 1
-
- # Filename of the index mapping
- indexmap_filename = data_prefix
- indexmap_filename += '_{}_indexmap'.format(name)
- if num_epochs != (np.iinfo(np.int32).max - 1):
- indexmap_filename += '_{}ep'.format(num_epochs)
- if max_num_samples != (np.iinfo(np.int64).max - 1):
- indexmap_filename += '_{}mns'.format(max_num_samples)
- indexmap_filename += '_{}msl'.format(max_seq_length)
- indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
- indexmap_filename += '_{}s'.format(seed)
- indexmap_filename += '.npy'
-
- # Build the indexed mapping if not exist.
- if torch.distributed.get_rank() == 0 and \
- not os.path.isfile(indexmap_filename):
- print(' > WARNING: could not find index map file {}, building '
- 'the indices on rank 0 ...'.format(indexmap_filename))
-
- # Make sure the types match the helpers input types.
- assert indexed_dataset.document_indices.dtype == np.int64
- assert indexed_dataset.sequence_lengths.dtype == np.int32
-
- # Build samples mapping
- verbose = torch.distributed.get_rank() == 0
- start_time = time.time()
- print_rank_0(' > building samples index mapping for {} ...'.format(
- name))
- # First compile and then import.
- from megatron.core.datasets import helpers
- samples_mapping = helpers.build_mapping(
- indexed_dataset.document_indices,
- indexed_dataset.sequence_lengths,
- num_epochs,
- max_num_samples,
- max_seq_length,
- short_seq_prob,
- seed,
- verbose,
- 2 if binary_head else 1)
- print_rank_0(' > done building samples index maping')
- np.save(indexmap_filename, samples_mapping, allow_pickle=True)
- print_rank_0(' > saved the index mapping in {}'.format(
- indexmap_filename))
- # Make sure all the ranks have built the mapping
- print_rank_0(' > elasped time to build and save samples mapping '
- '(seconds): {:4f}'.format(
- time.time() - start_time))
- # This should be a barrier but nccl barrier assumes
- # device_index=rank which is not the case for model
- # parallel case
- counts = torch.cuda.LongTensor([1])
- torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
- torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
- assert counts[0].item() == (
- torch.distributed.get_world_size() //
- torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
-
- # Load indexed dataset.
- print_rank_0(' > loading indexed mapping from {}'.format(
- indexmap_filename))
- start_time = time.time()
- samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
- print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
- time.time() - start_time))
- print_rank_0(' total number of samples: {}'.format(
- samples_mapping.shape[0]))
-
- return samples_mapping
diff --git a/megatron/data/ict_dataset.py b/megatron/data/ict_dataset.py
deleted file mode 100644
index 6dac35ff9d413898146df5c1cc8553719e142105..0000000000000000000000000000000000000000
--- a/megatron/data/ict_dataset.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import itertools
-import random
-
-import numpy as np
-from torch.utils.data import Dataset
-
-from megatron import get_tokenizer
-from megatron import get_args
-from megatron.data.dataset_utils import get_indexed_dataset_
-from megatron.data.realm_dataset_utils import get_block_samples_mapping
-
-def make_attention_mask(source_block, target_block):
- """
- Returns a 2-dimensional (2-D) attention mask
- :param source_block: 1-D array
- :param target_block: 1-D array
- """
- mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
- mask = mask.astype(np.int64)
- # (source_length, target_length)
- return mask
-
-def get_ict_dataset(use_titles=True, query_in_block_prob=1):
- """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
- rather than for training, since it is only built with a single epoch sample mapping.
- """
- args = get_args()
- block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
- titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
-
- kwargs = dict(
- name='full',
- block_dataset=block_dataset,
- title_dataset=titles_dataset,
- data_prefix=args.data_path,
- num_epochs=1,
- max_num_samples=None,
- max_seq_length=args.seq_length,
- seed=1,
- query_in_block_prob=query_in_block_prob,
- use_titles=use_titles,
- use_one_sent_docs=args.use_one_sent_docs
- )
- dataset = ICTDataset(**kwargs)
- return dataset
-
-
-class ICTDataset(Dataset):
- """Dataset containing sentences and their blocks for an inverse cloze task."""
- def __init__(self, name, block_dataset, title_dataset, data_prefix,
- num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
- seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
- self.name = name
- self.seed = seed
- self.max_seq_length = max_seq_length
- self.query_in_block_prob = query_in_block_prob
- self.block_dataset = block_dataset
- self.title_dataset = title_dataset
- self.rng = random.Random(self.seed)
- self.use_titles = use_titles
- self.use_one_sent_docs = use_one_sent_docs
-
- self.samples_mapping = get_block_samples_mapping(
- block_dataset, title_dataset, data_prefix, num_epochs,
- max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
- self.tokenizer = get_tokenizer()
- self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
- self.vocab_id_to_token_list = self.tokenizer.inv_vocab
- self.cls_id = self.tokenizer.cls
- self.sep_id = self.tokenizer.sep
- self.mask_id = self.tokenizer.mask
- self.pad_id = self.tokenizer.pad
-
- def __len__(self):
- return len(self.samples_mapping)
-
- def __getitem__(self, idx):
- """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
- sample_data = self.samples_mapping[idx]
- start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
-
- if self.use_titles:
- title = self.title_dataset[int(doc_idx)]
- title_pad_offset = 3 + len(title)
- else:
- title = None
- title_pad_offset = 2
- block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
- assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
-
- # randint() is inclusive for Python rng
- rand_sent_idx = self.rng.randint(0, len(block) - 1)
-
- # keep the query in the context query_in_block_prob fraction of the time.
- if self.rng.random() < self.query_in_block_prob:
- query = block[rand_sent_idx].copy()
- else:
- query = block.pop(rand_sent_idx)
-
- # still need to truncate because blocks are concluded when
- # the sentence lengths have exceeded max_seq_length.
- query = query[:self.max_seq_length - 2]
- block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
-
- query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
- context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
-
- query_mask = make_attention_mask(query_tokens, query_tokens)
- context_mask = make_attention_mask(context_tokens, context_tokens)
-
- block_data = sample_data.as_array()
-
- sample = {
- 'query_tokens': query_tokens,
- 'query_mask': query_mask,
- 'query_pad_mask': query_pad_mask,
- 'context_tokens': context_tokens,
- 'context_mask': context_mask,
- 'context_pad_mask': context_pad_mask,
- 'block_data': block_data,
- }
-
- return sample
-
- def get_block(self, start_idx, end_idx, doc_idx):
- """Get the IDs for an evidence block plus the title of the corresponding document"""
- block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
- title = self.title_dataset[int(doc_idx)]
-
- block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
- block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
-
- return block_tokens, block_pad_mask
-
- def get_null_block(self):
- """Get empty block and title - used in REALM pretraining"""
- block, title = [], []
- block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
-
- return block_tokens, block_pad_mask
-
- def concat_and_pad_tokens(self, tokens, title=None):
- """Concat with special tokens and pad sequence to self.max_seq_length"""
- tokens = list(tokens)
- if title is None:
- tokens = [self.cls_id] + tokens + [self.sep_id]
- else:
- title = list(title)
- tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
- assert len(tokens) <= self.max_seq_length
-
- num_pad = self.max_seq_length - len(tokens)
- pad_mask = [1] * len(tokens) + [0] * num_pad
- tokens += [self.pad_id] * num_pad
-
- return np.array(tokens), np.array(pad_mask)
diff --git a/megatron/data/image_folder.py b/megatron/data/image_folder.py
deleted file mode 100644
index de15b29bf0665562a00bfcab8b106ff2d4ca26f2..0000000000000000000000000000000000000000
--- a/megatron/data/image_folder.py
+++ /dev/null
@@ -1,302 +0,0 @@
-# BSD 3-Clause License
-#
-# Copyright (c) Soumith Chintala 2016,
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# code taken from
-# https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py
-# added support for classes_fraction and data_per_class_fraction
-
-from torchvision.datasets import VisionDataset
-from PIL import Image
-
-import os
-import os.path
-from typing import Any, Callable, cast, Dict, List, Optional, Tuple
-import numpy as np
-
-def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
- """Checks if a file is an allowed extension.
- Args:
- filename (string): path to a file
- extensions (tuple of strings): extensions to consider (lowercase)
- Returns:
- bool: True if the filename ends with one of given extensions
- """
- return filename.lower().endswith(extensions)
-
-
-def is_image_file(filename: str) -> bool:
- """Checks if a file is an allowed image extension.
- Args:
- filename (string): path to a file
- Returns:
- bool: True if the filename ends with a known image extension
- """
- return has_file_allowed_extension(filename, IMG_EXTENSIONS)
-
-
-def make_dataset(
- directory: str,
- class_to_idx: Dict[str, int],
- data_per_class_fraction: float,
- extensions: Optional[Tuple[str, ...]] = None,
- is_valid_file: Optional[Callable[[str], bool]] = None,
-) -> List[Tuple[str, int]]:
- """Generates a list of samples of a form (path_to_sample, class).
- Args:
- directory (str): root dataset directory
- class_to_idx (Dict[str, int]): dictionary mapping class name to class index
- extensions (optional): A list of allowed extensions.
- Either extensions or is_valid_file should be passed. Defaults to None.
- is_valid_file (optional): A function that takes path of a file
- and checks if the file is a valid file
- (used to check of corrupt files) both extensions and
- is_valid_file should not be passed. Defaults to None.
- Raises:
- ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
- Returns:
- List[Tuple[str, int]]: samples of a form (path_to_sample, class)
- """
- instances = []
- directory = os.path.expanduser(directory)
- both_none = extensions is None and is_valid_file is None
- both_something = extensions is not None and is_valid_file is not None
- if both_none or both_something:
- raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
- if extensions is not None:
- def is_valid_file(x: str) -> bool:
- return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
- is_valid_file = cast(Callable[[str], bool], is_valid_file)
- for target_class in sorted(class_to_idx.keys()):
- class_index = class_to_idx[target_class]
- target_dir = os.path.join(directory, target_class)
- if not os.path.isdir(target_dir):
- continue
- local_instances = []
- for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
- for fname in sorted(fnames):
- path = os.path.join(root, fname)
- if is_valid_file(path):
- item = path, class_index
- local_instances.append(item)
-
- instances.extend(local_instances[0:int(len(local_instances) * data_per_class_fraction)])
-
- return instances
-
-
-class DatasetFolder(VisionDataset):
- """A generic data loader where the samples are arranged in this way: ::
- root/class_x/xxx.ext
- root/class_x/xxy.ext
- root/class_x/[...]/xxz.ext
- root/class_y/123.ext
- root/class_y/nsdf3.ext
- root/class_y/[...]/asd932_.ext
- Args:
- root (string): Root directory path.
- loader (callable): A function to load a sample given its path.
- extensions (tuple[string]): A list of allowed extensions.
- both extensions and is_valid_file should not be passed.
- transform (callable, optional): A function/transform that takes in
- a sample and returns a transformed version.
- E.g, ``transforms.RandomCrop`` for images.
- target_transform (callable, optional): A function/transform that takes
- in the target and transforms it.
- is_valid_file (callable, optional): A function that takes path of a file
- and check if the file is a valid file (used to check of corrupt files)
- both extensions and is_valid_file should not be passed.
- Attributes:
- classes (list): List of the class names sorted alphabetically.
- class_to_idx (dict): Dict with items (class_name, class_index).
- samples (list): List of (sample path, class_index) tuples
- targets (list): The class_index value for each image in the dataset
- """
-
- def __init__(
- self,
- root: str,
- loader: Callable[[str], Any],
- extensions: Optional[Tuple[str, ...]] = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- classes_fraction=1.0,
- data_per_class_fraction=1.0,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ) -> None:
- super(DatasetFolder, self).__init__(root, transform=transform,
- target_transform=target_transform)
- self.classes_fraction = classes_fraction
- self.data_per_class_fraction = data_per_class_fraction
- classes, class_to_idx = self._find_classes(self.root)
- samples = self.make_dataset(self.root,
- class_to_idx,
- self.data_per_class_fraction,
- extensions,
- is_valid_file)
- if len(samples) == 0:
- msg = "Found 0 files in subfolders of: {}\n".format(self.root)
- if extensions is not None:
- msg += "Supported extensions are: {}".format(",".join(extensions))
- raise RuntimeError(msg)
-
- self.loader = loader
- self.extensions = extensions
- self.total = len(samples)
- self.classes = classes
- self.class_to_idx = class_to_idx
- self.samples = samples
- self.targets = [s[1] for s in samples]
-
- @staticmethod
- def make_dataset(
- directory: str,
- class_to_idx: Dict[str, int],
- data_per_class_fraction: float,
- extensions: Optional[Tuple[str, ...]] = None,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ) -> List[Tuple[str, int]]:
- return make_dataset(directory,
- class_to_idx,
- data_per_class_fraction,
- extensions=extensions,
- is_valid_file=is_valid_file)
-
- def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
- """
- Finds the class folders in a dataset.
- Args:
- dir (string): Root directory path.
- Returns:
- tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
- Ensures:
- No class is a subdirectory of another.
- """
- all_classes = [d.name for d in os.scandir(dir) if d.is_dir()]
- classes = all_classes[0:int(len(all_classes) * self.classes_fraction)]
- classes.sort()
- class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
- return classes, class_to_idx
-
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """
- Args:
- index (int): Index
- Returns:
- tuple: (sample, target) where target is class_index of the target class.
- """
- curr_index = index
- for x in range(self.total):
- try:
- path, target = self.samples[curr_index]
- sample = self.loader(path)
- break
- except Exception as e:
- curr_index = np.random.randint(0, self.total)
-
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return sample, target
-
- def __len__(self) -> int:
- return len(self.samples)
-
-
-IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
-
-
-def pil_loader(path: str) -> Image.Image:
- # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('RGB')
-
-
-# TODO: specify the return type
-def accimage_loader(path: str) -> Any:
- import accimage
- try:
- return accimage.Image(path)
- except IOError:
- # Potentially a decoding problem, fall back to PIL.Image
- return pil_loader(path)
-
-
-def default_loader(path: str) -> Any:
- from torchvision import get_image_backend
- if get_image_backend() == 'accimage':
- return accimage_loader(path)
- else:
- return pil_loader(path)
-
-
-class ImageFolder(DatasetFolder):
- """A generic data loader where the images are arranged in this way: ::
- root/dog/xxx.png
- root/dog/xxy.png
- root/dog/[...]/xxz.png
- root/cat/123.png
- root/cat/nsdf3.png
- root/cat/[...]/asd932_.png
- Args:
- root (string): Root directory path.
- transform (callable, optional): A function/transform that takes in an PIL image
- and returns a transformed version. E.g, ``transforms.RandomCrop``
- target_transform (callable, optional): A function/transform that takes in the
- target and transforms it.
- loader (callable, optional): A function to load an image given its path.
- is_valid_file (callable, optional): A function that takes path of an Image file
- and check if the file is a valid file (used to check of corrupt files)
- Attributes:
- classes (list): List of the class names sorted alphabetically.
- class_to_idx (dict): Dict with items (class_name, class_index).
- imgs (list): List of (image path, class_index) tuples
- """
-
- def __init__(
- self,
- root: str,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- classes_fraction=1.0,
- data_per_class_fraction=1.0,
- loader: Callable[[str], Any] = default_loader,
- is_valid_file: Optional[Callable[[str], bool]] = None,
- ):
- super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
- transform=transform,
- target_transform=target_transform,
- classes_fraction=classes_fraction,
- data_per_class_fraction=data_per_class_fraction,
- is_valid_file=is_valid_file)
- self.imgs = self.samples
-
diff --git a/megatron/data/multimodal_dataset.py b/megatron/data/multimodal_dataset.py
deleted file mode 100644
index 93ea790329c7aa58f2e63745c4aaca464567c99f..0000000000000000000000000000000000000000
--- a/megatron/data/multimodal_dataset.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from PIL import Image, UnidentifiedImageError
-import numpy as np
-import io
-import torch
-
-try:
- from torchvision.transforms import InterpolationMode
- BICUBIC = InterpolationMode.BICUBIC
-except ImportError:
- BICUBIC = Image.BICUBIC
-
-from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, RandomResizedCrop, Resize
-
-def _convert_image_to_rgb(image):
- return image.convert("RGB")
-
-def _transform(img_h, img_w):
- return Compose([
- ToPILImage(),
- RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0), interpolation=BICUBIC),
- _convert_image_to_rgb,
- ToTensor(),
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
- ])
-
-class MultiModalDataset(torch.utils.data.Dataset):
-
- def __init__(self, name, data_prefix, indexed_dataset,
- num_samples, seq_length, seed, img_h, img_w):
-
- self.name = name
- self.indexed_dataset = indexed_dataset
- self.doc_idx = indexed_dataset.get_document_indices()
- self.visual_transform = _transform(img_h, img_w)
-
- def __len__(self):
- return self.indexed_dataset.sequence_lengths.shape[0]
-
- def __getitem__(self, idx):
- text_sample, mode = self.indexed_dataset.get(self.doc_idx[idx])
- assert mode == 0
- img_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]+1)
- assert mode == 1
- img_pad = img_sample[0].item()
- xs = img_sample[1:].tobytes(order='C')
- xs = xs[:len(xs)-img_pad]
-
- img_sample = np.array(Image.open(io.BytesIO(xs)))
- img_sample = self.visual_transform(img_sample).reshape(-1)
-
- return {'text': np.array(text_sample, dtype=np.int64),
- 'img': np.array(img_sample, dtype=np.float32)}
diff --git a/megatron/data/orqa_wiki_dataset.py b/megatron/data/orqa_wiki_dataset.py
deleted file mode 100644
index 4019cd764c204b34de0df28bfcc7969c3c19d937..0000000000000000000000000000000000000000
--- a/megatron/data/orqa_wiki_dataset.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Wikipedia dataset from DPR code for ORQA."""
-
-from abc import ABC
-import csv
-import numpy as np
-import random
-import torch
-from torch.utils.data import Dataset
-
-from megatron import print_rank_0, get_args, get_tokenizer
-from megatron.core import tensor_parallel
-from megatron.data.biencoder_dataset_utils import make_attention_mask
-
-def get_open_retrieval_wiki_dataset():
- args = get_args()
- tokenizer = get_tokenizer()
-
- dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
- 'evidence',
- args.evidence_data_path,
- tokenizer,
- args.retriever_seq_length)
- return dataset
-
-
-def get_open_retrieval_batch(data_iterator):
- # Items and their type.
- keys = ['row_id', 'context', 'context_mask', 'context_types',
- 'context_pad_mask']
- datatype = torch.int64
-
- # Broadcast data.
- data = None if data_iterator is None else next(data_iterator)
- data_b = tensor_parallel.broadcast_data(keys, data, datatype)
-
- # Unpack.
- row_id = data_b['row_id'].long()
- context = data_b['context'].long()
-
- # TODO: make the context mask a binary one
- context_mask = (data_b['context_mask'] < 0.5)
-
- context_types = data_b['context_types'].long()
- context_pad_mask = data_b['context_pad_mask'].long()
-
- return row_id, context, context_mask, context_types, context_pad_mask
-
-
-def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
- """Build token types and paddings, trim if needed, and pad if needed."""
-
- title_ids = tokenizer.tokenize(row['title'])
- context_ids = tokenizer.tokenize(row['text'])
-
- # Appending the title of the context at front
- extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
-
- context_ids, context_types, context_pad_mask = \
- build_tokens_types_paddings_from_ids(extended_context_ids,
- max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
-
- return context_ids, context_types, context_pad_mask
-
-
-# noinspection DuplicatedCode
-def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
- cls_id, sep_id, pad_id):
- """Build token types and paddings, trim if needed, and pad if needed."""
- enc_ids = []
- tokentypes_enc = []
-
- # [CLS].
- enc_ids.append(cls_id)
- tokentypes_enc.append(0)
-
- # A.
- len_src = len(text_ids)
- enc_ids.extend(text_ids)
- tokentypes_enc.extend([0] * len_src)
-
- # Cap the size.
- if len(enc_ids) > max_seq_length - 1:
- enc_ids = enc_ids[0: max_seq_length - 1]
- tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
-
- # [SEP].
- enc_ids.append(sep_id)
- tokentypes_enc.append(0)
-
- num_tokens_enc = len(enc_ids)
- # Padding.
- padding_length = max_seq_length - len(enc_ids)
- if padding_length > 0:
- enc_ids.extend([pad_id] * padding_length)
- tokentypes_enc.extend([pad_id] * padding_length)
-
- pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
- pad_mask = np.array(pad_mask, dtype=np.int64)
-
- return enc_ids, tokentypes_enc, pad_mask
-
-
-def build_sample(row_id, context_ids, context_types, context_pad_mask):
- """Convert to numpy and return a sample consumed by the batch producer."""
-
- context_ids = np.array(context_ids, dtype=np.int64)
- context_types = np.array(context_types, dtype=np.int64)
- context_mask = make_attention_mask(context_ids, context_ids)
-
- sample = ({
- 'row_id': row_id,
- 'context': context_ids,
- 'context_mask': context_mask,
- 'context_types': context_types,
- 'context_pad_mask': context_pad_mask
- })
- return sample
-
-
-class OpenRetrievalEvidenceDataset(ABC, Dataset):
- """Open Retrieval Evidence dataset class."""
-
- def __init__(self, task_name, dataset_name, datapath, tokenizer,
- max_seq_length):
- # Store inputs.
- self.task_name = task_name
- self.dataset_name = dataset_name
- self.tokenizer = tokenizer
- self.max_seq_length = max_seq_length
- print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
- self.dataset_name))
- # Process the files.
- print_rank_0(datapath)
- self.samples, self.id2text = self.process_samples_from_single_path(
- datapath)
-
- args = get_args()
- if args.sample_rate < 1: # subsample
- k = int(len(self.samples) * args.sample_rate)
- self.samples = random.sample(self.samples, k)
-
- print_rank_0(' >> total number of samples: {}'.format(
- len(self.samples)))
-
- def __len__(self):
- return len(self.samples)
-
- def __getitem__(self, idx):
- row = self.samples[idx]
-
- context_ids, context_types, context_pad_mask = \
- build_tokens_types_paddings_from_text(row, self.tokenizer,
- self.max_seq_length)
-
- sample = build_sample(row['doc_id'],
- context_ids,
- context_types,
- context_pad_mask)
- return sample
-
- @staticmethod
- def process_samples_from_single_path(filename):
- print_rank_0(' > Processing {} ...'.format(filename))
- total = 0
-
- rows = []
- id2text = {}
-
- with open(filename) as tsvfile:
- reader = csv.reader(tsvfile, delimiter='\t')
- next(reader, None) # skip the headers
- for row in reader:
- # file format: doc_id, doc_text, title
- doc_id = int(row[0])
- text = row[1]
- title = row[2]
-
- rows.append({'doc_id': doc_id,
- 'text': text,
- 'title': title})
-
- assert doc_id not in id2text
- id2text[doc_id] = (text, title)
-
- total += 1
- if total % 100000 == 0:
- print_rank_0(' > processed {} rows so far ...'.format(
- total))
-
- print_rank_0(' >> processed {} samples.'.format(len(rows)))
- return rows, id2text
diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py
deleted file mode 100644
index 3c8672bb583aa0c839fc055405ae3761dda48911..0000000000000000000000000000000000000000
--- a/megatron/data/realm_dataset_utils.py
+++ /dev/null
@@ -1,199 +0,0 @@
-import os
-import time
-
-import numpy as np
-import torch
-
-from megatron import print_rank_0
-from megatron.core import mpu, tensor_parallel
-from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
-from megatron import get_args, get_tokenizer, print_rank_0
-
-
-def get_one_epoch_dataloader(dataset, micro_batch_size=None):
- """Specifically one epoch to be used in an indexing job."""
- args = get_args()
-
- world_size = mpu.get_data_parallel_world_size()
- rank = mpu.get_data_parallel_rank()
- if micro_batch_size is None:
- micro_batch_size = args.micro_batch_size
- global_batch_size = micro_batch_size * world_size
- num_workers = args.num_workers
-
- sampler = torch.utils.data.SequentialSampler(dataset)
- # importantly, drop_last must be False to get all the data.
- assert False, 'DistributedBatchSampler deprecated, change the implementation'
- from megatron.data.samplers import DistributedBatchSampler
- batch_sampler = DistributedBatchSampler(sampler,
- batch_size=global_batch_size,
- drop_last=False,
- rank=rank,
- world_size=world_size)
-
- return torch.utils.data.DataLoader(dataset,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- pin_memory=True)
-
-
-def get_ict_batch(data_iterator):
- # Items and their type.
- keys = ['query_tokens', 'query_pad_mask',
- 'block_tokens', 'block_pad_mask', 'block_data']
- datatype = torch.int64
-
- # Broadcast data.
- if data_iterator is None:
- data = None
- else:
- data = next(data_iterator)
- data_b = tensor_parallel.broadcast_data(keys, data, datatype)
-
- # Unpack.
- query_tokens = data_b['query_tokens'].long()
- query_pad_mask = data_b['query_pad_mask'].long()
- block_tokens = data_b['block_tokens'].long()
- block_pad_mask = data_b['block_pad_mask'].long()
- block_indices = data_b['block_data'].long()
-
- return query_tokens, query_pad_mask,\
- block_tokens, block_pad_mask, block_indices
-
-
-def join_str_list(str_list):
- """Join a list of strings, handling spaces appropriately"""
- result = ""
- for s in str_list:
- if s.startswith("##"):
- result += s[2:]
- else:
- result += " " + s
- return result
-
-
-class BlockSampleData(object):
- """A struct for fully describing a fixed-size block of data as used in REALM
-
- :param start_idx: for first sentence of the block
- :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
- :param doc_idx: the index of the document from which the block comes in the original indexed dataset
- :param block_idx: a unique integer identifier given to every block.
- """
- def __init__(self, start_idx, end_idx, doc_idx, block_idx):
- self.start_idx = start_idx
- self.end_idx = end_idx
- self.doc_idx = doc_idx
- self.block_idx = block_idx
-
- def as_array(self):
- return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
-
- def as_tuple(self):
- return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
-
-
-class BlockSamplesMapping(object):
- def __init__(self, mapping_array):
- # make sure that the array is compatible with BlockSampleData
- assert mapping_array.shape[1] == 4
- self.mapping_array = mapping_array
-
- def __len__(self):
- return self.mapping_array.shape[0]
-
- def __getitem__(self, idx):
- """Get the data associated with an indexed sample."""
- sample_data = BlockSampleData(*self.mapping_array[idx])
- return sample_data
-
-
-def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
- max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
- """Get samples mapping for a dataset over fixed size blocks. This function also requires
- a dataset of the titles for the source documents since their lengths must be taken into account.
-
- :return: samples_mapping (BlockSamplesMapping)
- """
-
- if not num_epochs:
- if not max_num_samples:
- raise ValueError("Need to specify either max_num_samples "
- "or num_epochs")
- num_epochs = np.iinfo(np.int32).max - 1
- if not max_num_samples:
- max_num_samples = np.iinfo(np.int64).max - 1
-
- # Filename of the index mapping
- indexmap_filename = data_prefix
- indexmap_filename += '_{}_indexmap'.format(name)
- if num_epochs != (np.iinfo(np.int32).max - 1):
- indexmap_filename += '_{}ep'.format(num_epochs)
- if max_num_samples != (np.iinfo(np.int64).max - 1):
- indexmap_filename += '_{}mns'.format(max_num_samples)
- indexmap_filename += '_{}msl'.format(max_seq_length)
- indexmap_filename += '_{}s'.format(seed)
- if use_one_sent_docs:
- indexmap_filename += '_1sentok'
- indexmap_filename += '.npy'
-
- # Build the indexed mapping if not exist.
- if mpu.get_data_parallel_rank() == 0 and \
- not os.path.isfile(indexmap_filename):
- print(' > WARNING: could not find index map file {}, building '
- 'the indices on rank 0 ...'.format(indexmap_filename))
-
- # Make sure the types match the helpers input types.
- assert block_dataset.document_indices.dtype == np.int64
- assert block_dataset.sequence_lengths.dtype == np.int32
-
- # Build samples mapping
- verbose = torch.distributed.get_rank() == 0
- start_time = time.time()
- print_rank_0(' > building samples index mapping for {} ...'.format(
- name))
-
- from megatron.core.datasets import helpers
- mapping_array = helpers.build_blocks_mapping(
- block_dataset.document_indices,
- block_dataset.sequence_lengths,
- title_dataset.sequence_lengths,
- num_epochs,
- max_num_samples,
- max_seq_length - 3, # account for added tokens
- seed,
- verbose,
- use_one_sent_docs)
-
-
- print_rank_0(' > done building samples index mapping')
- np.save(indexmap_filename, mapping_array, allow_pickle=True)
- print_rank_0(' > saved the index mapping in {}'.format(
- indexmap_filename))
- # Make sure all the ranks have built the mapping
- print_rank_0(' > elapsed time to build and save samples mapping '
- '(seconds): {:4f}'.format(
- time.time() - start_time))
-
- # This should be a barrier but nccl barrier assumes
- # device_index=rank which is not the case for model
- # parallel case
- counts = torch.cuda.LongTensor([1])
- torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
- assert counts[0].item() == torch.distributed.get_world_size(
- group=mpu.get_data_parallel_group())
-
- # Load indexed dataset.
- print_rank_0(' > loading indexed mapping from {}'.format(
- indexmap_filename))
- start_time = time.time()
-
- mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
- samples_mapping = BlockSamplesMapping(mapping_array)
-
- print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
- time.time() - start_time))
- print_rank_0(' total number of samples: {}'.format(
- mapping_array.shape[0]))
-
- return samples_mapping
diff --git a/megatron/data/realm_index.py b/megatron/data/realm_index.py
deleted file mode 100644
index 1fa4a309edcd5a761b3a87b973b799c30ac73458..0000000000000000000000000000000000000000
--- a/megatron/data/realm_index.py
+++ /dev/null
@@ -1,224 +0,0 @@
-import itertools
-import os
-import pickle
-import shutil
-
-import numpy as np
-import torch
-
-from megatron import get_args
-from megatron.core import mpu
-
-
-def detach(tensor):
- return tensor.detach().cpu().numpy()
-
-
-class OpenRetreivalDataStore(object):
- """
- Serializable data structure for holding data for blocks --
- embeddings and necessary metadata for Retriever
- """
- def __init__(self, embedding_path=None, load_from_path=True, rank=None):
- self.embed_data = dict()
- if embedding_path is None:
- args = get_args()
- embedding_path = args.embedding_path
- rank = args.rank
- self.embedding_path = embedding_path
- self.rank = rank
-
- if load_from_path:
- self.load_from_file()
-
- block_data_name = os.path.splitext(self.embedding_path)[0]
- self.temp_dir_name = block_data_name + '_tmp'
-
- def state(self):
- return {
- 'embed_data': self.embed_data,
- }
-
- def clear(self):
- """
- Clear the embedding data structures to save memory.
- The metadata ends up getting used, and is also much smaller in
- dimensionality so it isn't really worth clearing.
- """
- self.embed_data = dict()
-
- def load_from_file(self):
- """Populate members from instance saved to file"""
-
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print("\n> Unpickling BlockData", flush=True)
- state_dict = pickle.load(open(self.embedding_path, 'rb'))
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Finished unpickling BlockData\n", flush=True)
-
- self.embed_data = state_dict['embed_data']
-
- def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
- """
- Add data for set of blocks
- :param row_id: 1D array of unique int ids for the blocks
- :param block_embeds: 2D array of embeddings of the blocks
- In the case of retriever this will be [start_idx, end_idx, doc_idx]
- """
- for idx, embed in zip(row_id, block_embeds):
- if not allow_overwrite and idx in self.embed_data:
- raise ValueError("Unexpectedly tried to overwrite block data")
-
- self.embed_data[idx] = np.float16(embed)
-
- def save_shard(self):
- """
- Save the block data that was created this in this process
- """
- if not os.path.isdir(self.temp_dir_name):
- os.makedirs(self.temp_dir_name, exist_ok=True)
-
- # save the data for each shard
- with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
- as writer:
- pickle.dump(self.state(), writer)
-
- def merge_shards_and_save(self):
- #Combine all the shards made using save_shard
- shard_names = os.listdir(self.temp_dir_name)
- seen_own_shard = False
-
- for fname in os.listdir(self.temp_dir_name):
- shard_rank = int(os.path.splitext(fname)[0])
- if shard_rank == self.rank:
- seen_own_shard = True
- continue
-
- with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
- data = pickle.load(f)
- old_size = len(self.embed_data)
- shard_size = len(data['embed_data'])
-
- # add the shard's data and check to make sure there
- # is no overlap
- self.embed_data.update(data['embed_data'])
- assert len(self.embed_data) == old_size + shard_size
-
- assert seen_own_shard
-
- # save the consolidated shards and remove temporary directory
- with open(self.embedding_path, 'wb') as final_file:
- pickle.dump(self.state(), final_file)
- shutil.rmtree(self.temp_dir_name, ignore_errors=True)
-
- print("Finished merging {} shards for a total of {} embeds".format(
- len(shard_names), len(self.embed_data)), flush=True)
-
-
-class FaissMIPSIndex(object):
- """
- Wrapper object for a BlockData which similarity search via FAISS under the hood
- """
- def __init__(self, embed_size, embed_data=None, use_gpu=False):
- self.embed_size = embed_size
- self.embed_data = embed_data
- self.use_gpu = use_gpu
-
- self.mips_index = None
- self._set_mips_index()
-
- def _set_mips_index(self):
- """
- Create a Faiss Flat index with inner product as the metric
- to search against
- """
- try:
- import faiss
- except ImportError:
- raise Exception("Error: Please install faiss to use FaissMIPSIndex")
-
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print("\n> Building index", flush=True)
-
- cpu_index = faiss.IndexFlatIP(self.embed_size)
-
- if self.use_gpu:
- # create resources and config for GpuIndex
- config = faiss.GpuMultipleClonerOptions()
- config.shard = True
- config.useFloat16 = True
- gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
- self.mips_index = faiss.IndexIDMap(gpu_index)
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Initialized index on GPU", flush=True)
- else:
- # CPU index supports IDs so wrap with IDMap
- self.mips_index = faiss.IndexIDMap(cpu_index)
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Initialized index on CPU", flush=True)
-
- # if we were constructed with a BlockData, then automatically load it
- # when the FAISS structure is built
- if self.embed_data is not None:
- self.add_embed_data(self.embed_data)
-
- def reset_index(self):
- """Delete existing index and create a new"""
- del self.mips_index
-
- # reset the block data so that _set_block_index will reload it as well
- if self.embed_data is not None:
- embed_data_path = self.embed_data.embedding_path
- del self.embed_data
- self.embed_data = OpenRetreivalDataStore(embed_data_path)
-
- self._set_mips_index()
-
- def update_index(self):
- """Delete existing index and create a new"""
- del self.mips_index
-
- # reset the block data so that _set_mips_index will reload it as well
- if self.embed_data is not None:
- self.embed_data.load_from_file()
- self._set_mips_index()
-
- def add_embed_data(self, all_embed_data):
- """Add the embedding of each block to the underlying FAISS index"""
-
- # this assumes the embed_data is a dict : {int: np.array}
- block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
-
- # the embeddings have to be entered in as float32 even though the math
- # internally is done with float16.
- embeds_arr = np.float32(np.array(block_embeds))
- indices_arr = np.array(block_indices)
-
- # we no longer need the embedding data since it's in the index now
- all_embed_data.clear()
-
- self.mips_index.add_with_ids(embeds_arr, indices_arr)
-
- if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
- print(">>> Finished adding block data to index", flush=True)
-
- def search_mips_index(self, query_embeds, top_k, reconstruct=True):
- """
- Get the top-k blocks by the index distance metric.
-
- :param reconstruct: if True: return a [num_queries x k x embed_dim]
- array of blocks
- if False: return [num_queries x k] array of
- distances, and another for indices
- """
- query_embeds = np.float32(detach(query_embeds))
-
- if reconstruct:
- # get the vectors themselves
- top_k_block_embeds = self.mips_index.search_and_reconstruct(\
- query_embeds, top_k)
- return top_k_block_embeds
- else:
- # get distances and indices of closest vectors
- distances, block_indices = self.mips_index.search(query_embeds, top_k)
- return distances, block_indices
diff --git a/megatron/data/t5_dataset.py b/megatron/data/t5_dataset.py
deleted file mode 100644
index 075b089f8e39ba9d236f5b39d62861dc2f17608d..0000000000000000000000000000000000000000
--- a/megatron/data/t5_dataset.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""T5 Style dataset."""
-
-import collections
-
-import numpy as np
-import torch
-
-from megatron import get_tokenizer
-from megatron.data.dataset_utils import (
- create_masked_lm_predictions,
- get_samples_mapping
-)
-
-class T5Dataset(torch.utils.data.Dataset):
-
- def __init__(self, name, indexed_dataset, data_prefix,
- num_epochs, max_num_samples, masked_lm_prob,
- max_seq_length, max_seq_length_dec,
- short_seq_prob, seed):
-
- # Params to store.
- self.name = name
- self.desc = name
- self.seed = seed
- self.masked_lm_prob = masked_lm_prob
- self.max_seq_length = max_seq_length
- self.max_seq_length_dec = max_seq_length_dec
-
- # Dataset.
- self.indexed_dataset = indexed_dataset
-
- # Build the samples mapping.
- self.samples_mapping = get_samples_mapping(self.indexed_dataset,
- data_prefix,
- num_epochs,
- max_num_samples,
- self.max_seq_length - 2, # account for added tokens
- short_seq_prob,
- self.seed,
- self.name,
- False)
-
- # Vocab stuff.
- tokenizer = get_tokenizer()
- self.vocab_id_list = list(tokenizer.inv_vocab.keys())
- self.vocab_id_to_token_dict = tokenizer.inv_vocab
- self.cls_id = tokenizer.cls
- self.sep_id = tokenizer.sep
- self.mask_id = tokenizer.mask
- self.pad_id = tokenizer.pad
- self.bos_id = tokenizer.bos_token_id
- self.eos_id = tokenizer.eos_token_id
- self.sentinel_tokens = tokenizer.additional_special_tokens_ids
- assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
-
- def __len__(self):
- return self.samples_mapping.shape[0]
-
- def __getitem__(self, idx):
-
- start_index, end_index, seq_length = self.samples_mapping[idx]
- sample = []
- for index in range(start_index, end_index):
- sample.append(self.indexed_dataset[index])
- # Note that this rng state should be numpy and not python since
- # python randint is inclusive whereas the numpy one is exclusive.
- np_rng = np.random.RandomState(seed=(self.seed + idx))
- return build_training_sample(sample, seq_length,
- self.max_seq_length, # needed for padding
- self.max_seq_length_dec,
- self.vocab_id_list,
- self.vocab_id_to_token_dict,
- self.cls_id, self.sep_id,
- self.mask_id, self.pad_id,
- self.masked_lm_prob, np_rng,
- self.bos_id, self.eos_id,
- self.sentinel_tokens)
-
-
-def build_training_sample(sample, target_seq_length,
- max_seq_length, max_seq_length_dec,
- vocab_id_list, vocab_id_to_token_dict,
- cls_id, sep_id, mask_id, pad_id,
- masked_lm_prob, np_rng, bos_id=None,
- eos_id=None, sentinel_tokens=None):
- """Build training sample.
-
- Arguments:
- sample: A list of sentences in which each sentence is a list token ids.
- target_seq_length: Desired sequence length.
- max_seq_length: Maximum length of the sequence. All values are padded to
- this length.
- vocab_id_list: List of vocabulary ids. Used to pick a random id.
- vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
- cls_id: Start of example id.
- sep_id: Separator id.
- mask_id: Mask token id.
- pad_id: Padding token id.
- masked_lm_prob: Probability to mask tokens.
- np_rng: Random number genenrator. Note that this rng state should be
- numpy and not python since python randint is inclusive for
- the opper bound whereas the numpy one is exclusive.
- bos_id: start of decoder example id
- eos_id: end of generation id
- sentinel_tokens: unique value to be substituted for every replaced span
- """
-
- assert target_seq_length <= max_seq_length
-
- # flatten sentences into one list
- tokens = [token for sentence in sample for token in sentence]
-
- # Truncate to `target_sequence_length`.
- max_num_tokens = target_seq_length
- truncated = len(tokens) > max_num_tokens
- tokens = tokens[:max_num_tokens]
-
- # Masking.
- max_predictions_per_seq = masked_lm_prob * max_num_tokens
- (tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
- tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
- cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
- max_ngrams=10, geometric_dist=True, masking_style="t5")
-
- # Padding.
- tokens_enc, tokens_dec_in, labels, enc_mask, \
- dec_mask, enc_dec_mask, loss_mask \
- = pad_and_convert_to_numpy(tokens, masked_positions,
- masked_labels, pad_id, max_seq_length,
- max_seq_length_dec, masked_spans,
- bos_id, eos_id, sentinel_tokens)
-
- train_sample = {
- 'text_enc': tokens_enc,
- 'text_dec': tokens_dec_in,
- 'labels': labels,
- 'loss_mask': loss_mask,
- 'truncated': int(truncated),
- 'enc_mask': enc_mask,
- 'dec_mask': dec_mask,
- 'enc_dec_mask': enc_dec_mask,
- }
- return train_sample
-
-
-def pad_and_convert_to_numpy(tokens, masked_positions,
- masked_labels, pad_id,
- max_seq_length, max_seq_length_dec,
- masked_spans=None, bos_id=None,
- eos_id=None, sentinel_tokens=None):
- """Pad sequences and convert them to numpy."""
-
- sentinel_tokens = collections.deque(sentinel_tokens)
- t5_input = []
- (t5_decoder_in, t5_decoder_out) = ([bos_id], [])
- (start_index, end_index) = (0, None)
- for span in masked_spans:
- flag = sentinel_tokens.popleft()
-
- # Append the same tokens in decoder input and output
- t5_decoder_in.append(flag)
- t5_decoder_in.extend(span.label)
- t5_decoder_out.append(flag)
- t5_decoder_out.extend(span.label)
-
- end_index = span.index[0]
- t5_input.extend(tokens[start_index: end_index])
- t5_input.append(flag)
-
- # the next start index is the token after the last span token
- start_index = span.index[-1] + 1
-
- # Add token to the t5_decoder_out
- t5_decoder_out.append(eos_id)
-
- # Add the remaining tokens to the t5 input
- t5_input.extend(tokens[start_index:])
-
- # assert (len(t5_input) - len(masked_spans)) + \
- # (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
-
- # Some checks.
-
- # Encoder-side padding mask.
- num_tokens = len(t5_input)
- padding_length = max_seq_length - num_tokens
- assert padding_length >= 0
- assert len(masked_positions) == len(masked_labels)
-
- # Tokens..
- filler = [pad_id] * padding_length
- tokens_enc = np.array(t5_input + filler, dtype=np.int64)
-
- # Decoder-side padding mask.
- num_tokens_dec = len(t5_decoder_in)
- padding_length_dec = max_seq_length_dec - num_tokens_dec
- assert padding_length_dec >= 0
- filler_dec = [pad_id] * padding_length_dec
- tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
-
- # Create attention masks
- enc_mask = make_attention_mask(tokens_enc, tokens_enc)
- enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
- dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
- dec_mask = dec_mask * make_history_mask(tokens_dec_in)
-
- # Labels mask.
- labels = t5_decoder_out + ([-1] * padding_length_dec)
- labels = np.array(labels, dtype=np.int64)
-
- # Loss mask
- loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
- loss_mask = np.array(loss_mask, dtype=np.int64)
-
- return tokens_enc, tokens_dec_in, labels, enc_mask, \
- dec_mask, enc_dec_mask, loss_mask
-
-
-def make_attention_mask(source_block, target_block):
- """
- Returns a 2-dimensional (2-D) attention mask
- :param source_block: 1-D array
- :param target_block: 1-D array
- """
- mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
- mask = mask.astype(np.int64)
- # (source_length, target_length)
- return mask
-
-
-def make_attention_mask_3d(source_block, target_block):
- """
- Returns a 3-dimensional (3-D) attention mask
- :param source_block: 1-D array
- :param target_block: 1-D array
- """
- mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
- # (batch, source_length, target_length)
- # mask = mask.astype(np.int64)
- return mask
-
-
-def make_history_mask(block):
- length = block.shape[0]
- arange = np.arange(length)
- history_mask = (arange[None, ] <= arange[:, None])
- history_mask = history_mask.astype(np.int64)
- return history_mask
-
-
-def make_history_mask_3d(block):
- batch, length = block.shape
- arange = torch.arange(length, device=block.device)
- history_mask = (arange[None, ] <= arange[:, None])[None, ]
- history_mask = history_mask.expand(batch, length, length)
- return history_mask
diff --git a/megatron/data/vit_dataset.py b/megatron/data/vit_dataset.py
deleted file mode 100644
index 82391e9157e1a989dd2e13ea69d4b146284b6f4b..0000000000000000000000000000000000000000
--- a/megatron/data/vit_dataset.py
+++ /dev/null
@@ -1,249 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-import os
-import random
-import numpy as np
-import torch
-import torchvision.transforms as T
-from torchvision import datasets
-from megatron import get_args
-from megatron.data.image_folder import ImageFolder
-from megatron.data.autoaugment import ImageNetPolicy
-from megatron.data.data_samplers import RandomSeedDataset
-from PIL import Image, ImageFilter, ImageOps
-
-
-class GaussianBlur(object):
- """
- Apply Gaussian Blur to the PIL image.
- """
- def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
- self.prob = p
- self.radius_min = radius_min
- self.radius_max = radius_max
-
- def __call__(self, img):
- do_it = random.random() <= self.prob
- if not do_it:
- return img
-
- return img.filter(
- ImageFilter.GaussianBlur(
- radius=random.uniform(self.radius_min, self.radius_max)
- )
- )
-
-
-class Solarization(object):
- """
- Apply Solarization to the PIL image.
- """
- def __init__(self, p):
- self.p = p
-
- def __call__(self, img):
- if random.random() < self.p:
- return ImageOps.solarize(img)
- else:
- return img
-
-
-class ClassificationTransform():
- def __init__(self, image_size, train=True):
- args = get_args()
- assert args.fp16 or args.bf16
- self.data_type = torch.half if args.fp16 else torch.bfloat16
- if train:
- self.transform = T.Compose([
- T.RandomResizedCrop(image_size),
- T.RandomHorizontalFlip(),
- T.ColorJitter(0.4, 0.4, 0.4, 0.1),
- ImageNetPolicy(),
- T.ToTensor(),
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- T.ConvertImageDtype(self.data_type)
- ])
- else:
- self.transform = T.Compose([
- T.Resize(image_size),
- T.CenterCrop(image_size),
- T.ToTensor(),
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- T.ConvertImageDtype(self.data_type)
- ])
-
- def __call__(self, input):
- output = self.transform(input)
- return output
-
-
-class InpaintingTransform():
- def __init__(self, image_size, train=True):
-
- args = get_args()
- self.mask_factor = args.mask_factor
- self.mask_type = args.mask_type
- self.image_size = image_size
- self.patch_size = args.patch_dim
- self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
- self.train = train
- assert args.fp16 or args.bf16
- self.data_type = torch.half if args.fp16 else torch.bfloat16
-
- if self.train:
- self.transform = T.Compose([
- T.RandomResizedCrop(self.image_size),
- T.RandomHorizontalFlip(),
- T.ColorJitter(0.4, 0.4, 0.4, 0.1),
- ImageNetPolicy(),
- T.ToTensor(),
- T.ConvertImageDtype(self.data_type)
- ])
- else:
- self.transform = T.Compose([
- T.Resize(self.image_size, interpolation=2),
- T.CenterCrop(self.image_size),
- T.ToTensor(),
- T.ConvertImageDtype(self.data_type)
- ])
-
- def gen_mask(self, image_size, mask_size, mask_type, patch_size):
- # output: mask as a list with indices for missing patches
- action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
- assert image_size[0] == image_size[1]
- img_size_patch = image_size[0] // patch_size
-
- # drop masked patches
- mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
-
- if mask_type == 'random':
- x = torch.randint(0, img_size_patch, ())
- y = torch.randint(0, img_size_patch, ())
- for i in range(mask_size):
- r = torch.randint(0, len(action_list), ())
- x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
- y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
- x_offset = x * patch_size
- y_offset = y * patch_size
- mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
- else:
- assert mask_type == 'row'
- count = 0
- for x in reversed(range(img_size_patch)):
- for y in reversed(range(img_size_patch)):
- if (count < mask_size):
- count += 1
- x_offset = x * patch_size
- y_offset = y * patch_size
- mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
- return mask
-
- def __call__(self, input):
- trans_input = self.transform(input)
- mask = self.gen_mask(self.image_size, self.mask_size,
- self.mask_type, self.patch_size)
- mask = mask.unsqueeze(dim=0)
- return trans_input, mask
-
-
-class DinoTransform(object):
- def __init__(self, image_size, train=True):
- args = get_args()
- self.data_type = torch.half if args.fp16 else torch.bfloat16
-
- flip_and_color_jitter = T.Compose([
- T.RandomHorizontalFlip(p=0.5),
- T.RandomApply(
- [T.ColorJitter(brightness=0.4, contrast=0.4,
- saturation=0.2, hue=0.1)],
- p=0.8
- ),
- T.RandomGrayscale(p=0.2),
- ])
-
- if args.fp16 or args.bf16:
- normalize = T.Compose([
- T.ToTensor(),
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- T.ConvertImageDtype(self.data_type)
- ])
- else:
- normalize = T.Compose([
- T.ToTensor(),
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- ])
-
- # first global crop
- scale_const = 0.4
- self.global_transform1 = T.Compose([
- T.RandomResizedCrop(image_size,
- scale=(scale_const, 1),
- interpolation=Image.BICUBIC),
- flip_and_color_jitter,
- GaussianBlur(1.0),
- normalize
- ])
- # second global crop
- self.global_transform2 = T.Compose([
- T.RandomResizedCrop(image_size,
- scale=(scale_const, 1),
- interpolation=Image.BICUBIC),
- flip_and_color_jitter,
- GaussianBlur(0.1),
- Solarization(0.2),
- normalize
- ])
- # transformation for the local small crops
- self.local_crops_number = args.dino_local_crops_number
- self.local_transform = T.Compose([
- T.RandomResizedCrop(args.dino_local_img_size,
- scale=(0.05, scale_const),
- interpolation=Image.BICUBIC),
- flip_and_color_jitter,
- GaussianBlur(p=0.5),
- normalize
- ])
-
- def __call__(self, image):
- crops = []
- crops.append(self.global_transform1(image))
- crops.append(self.global_transform2(image))
- for _ in range(self.local_crops_number):
- crops.append(self.local_transform(image))
- return crops
-
-
-def build_train_valid_datasets(data_path, image_size=224):
- args = get_args()
-
- if args.vision_pretraining_type == 'classify':
- train_transform = ClassificationTransform(image_size)
- val_transform = ClassificationTransform(image_size, train=False)
- elif args.vision_pretraining_type == 'inpaint':
- train_transform = InpaintingTransform(image_size, train=False)
- val_transform = InpaintingTransform(image_size, train=False)
- elif args.vision_pretraining_type == 'dino':
- train_transform = DinoTransform(image_size, train=True)
- val_transform = ClassificationTransform(image_size, train=False)
- else:
- raise Exception('{} vit pretraining type is not supported.'.format(
- args.vit_pretraining_type))
-
- # training dataset
- train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
- train_data = ImageFolder(
- root=train_data_path,
- transform=train_transform,
- classes_fraction=args.classes_fraction,
- data_per_class_fraction=args.data_per_class_fraction
- )
- train_data = RandomSeedDataset(train_data)
-
- # validation dataset
- val_data_path = data_path[1]
- val_data = ImageFolder(
- root=val_data_path,
- transform=val_transform
- )
- val_data = RandomSeedDataset(val_data)
-
- return train_data, val_data
diff --git a/megatron/dist_signal_handler.py b/megatron/dist_signal_handler.py
deleted file mode 100644
index a60204f004a3f149da0bc059cd875b6ec390c0c4..0000000000000000000000000000000000000000
--- a/megatron/dist_signal_handler.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import signal
-
-import torch
-
-
-def get_world_size():
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- world_size = torch.distributed.get_world_size()
- else:
- world_size = 1
- return world_size
-
-
-def get_device(local_rank=None):
- backend = torch.distributed.get_backend()
- if backend == 'nccl':
- if local_rank is None:
- device = torch.device('cuda')
- else:
- device = torch.device(f'cuda:{local_rank}')
- elif backend == 'gloo':
- device = torch.device('cpu')
- else:
- raise RuntimeError
- return device
-
-
-def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None):
- if not torch.distributed.is_available() or \
- not torch.distributed.is_initialized():
- return [item]
-
- device = get_device(local_rank)
-
- if group is not None:
- group_size = group.size()
- else:
- group_size = get_world_size()
-
- tensor = torch.tensor([item], device=device, dtype=dtype)
- output_tensors = [
- torch.zeros(1, dtype=tensor.dtype, device=tensor.device)
- for _ in range(group_size)
- ]
- torch.distributed.all_gather(output_tensors, tensor, group, async_op)
- output = [elem.item() for elem in output_tensors]
- return output
-
-
-class DistributedSignalHandler:
- def __init__(self, sig=signal.SIGTERM):
- self.sig = sig
-
- def signals_received(self):
- all_received = all_gather_item(
- self._signal_received, dtype=torch.int32
- )
- return all_received
-
- def __enter__(self):
- self._signal_received = False
- self.released = False
- self.original_handler = signal.getsignal(self.sig)
-
- def handler(signum, frame):
- self._signal_received = True
-
- signal.signal(self.sig, handler)
-
- return self
-
- def __exit__(self, type, value, tb):
- self.release()
-
- def release(self):
- if self.released:
- return False
-
- signal.signal(self.sig, self.original_handler)
- self.released = True
- return True
diff --git a/megatron/fp16_deprecated/__init__.py b/megatron/fp16_deprecated/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/fp16_deprecated/loss_scaler.py b/megatron/fp16_deprecated/loss_scaler.py
deleted file mode 100644
index e31d00ad3215b443b6cae8e97da9b03dcdcc3ad4..0000000000000000000000000000000000000000
--- a/megatron/fp16_deprecated/loss_scaler.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""For backward compatibility, we need the class definitions to deserialize."""
-
-class LossScaler:
- def __init__(self, scale=1):
- self.cur_scale = scale
-
-class DynamicLossScaler:
- def __init__(self,
- init_scale=2**32,
- scale_factor=2.,
- scale_window=1000,
- min_scale=1,
- delayed_shift=1,
- consecutive_hysteresis=False):
- self.cur_scale = init_scale
- self.cur_iter = 0
- self.last_overflow_iter = -1
- self.scale_factor = scale_factor
- self.scale_window = scale_window
- self.min_scale = min_scale
- self.delayed_shift = delayed_shift
- self.cur_hysteresis = delayed_shift
- self.consecutive_hysteresis = consecutive_hysteresis
-
diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py
deleted file mode 100644
index 87cceac3e35f983cf9f2264ff651a1067069f9e2..0000000000000000000000000000000000000000
--- a/megatron/fused_kernels/__init__.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import os
-import pathlib
-import subprocess
-
-from torch.utils import cpp_extension
-
-# Setting this param to a list has a problem of generating different
-# compilation commands (with diferent order of architectures) and
-# leading to recompilation of fused kernels. Set it to empty string
-# to avoid recompilation and assign arch flags explicity in
-# extra_cuda_cflags below
-os.environ["TORCH_CUDA_ARCH_LIST"] = ""
-
-
-def load(args):
-
- # Check if cuda 11 is installed for compute capability 8.0
- cc_flag = []
- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
- cpp_extension.CUDA_HOME
- )
- if int(bare_metal_major) >= 11:
- cc_flag.append('-gencode')
- cc_flag.append('arch=compute_80,code=sm_80')
- if int(bare_metal_minor) >= 8:
- cc_flag.append('-gencode')
- cc_flag.append('arch=compute_90,code=sm_90')
-
- # Build path
- srcpath = pathlib.Path(__file__).parent.absolute()
- buildpath = srcpath / "build"
- _create_build_dir(buildpath)
-
- # Helper function to build the kernels.
- def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
- return cpp_extension.load(
- name=name,
- sources=sources,
- build_directory=buildpath,
- extra_cflags=[
- "-O3",
- ],
- extra_cuda_cflags=[
- "-O3",
- "-gencode",
- "arch=compute_70,code=sm_70",
- "--use_fast_math",
- ]
- + extra_cuda_flags
- + cc_flag,
- verbose=(args.rank == 0),
- )
-
-
-def _get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output(
- [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
- )
- output = raw_output.split()
- release_idx = output.index("release") + 1
- release = output[release_idx].split(".")
- bare_metal_major = release[0]
- bare_metal_minor = release[1][0]
-
- return raw_output, bare_metal_major, bare_metal_minor
-
-
-def _create_build_dir(buildpath):
- try:
- os.mkdir(buildpath)
- except OSError:
- if not os.path.isdir(buildpath):
- print(f"Creation of the build directory {buildpath} failed")
diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h
deleted file mode 100644
index 5495d7807762d8b4e3dbc11b28dba15f85bd8108..0000000000000000000000000000000000000000
--- a/megatron/fused_kernels/compat.h
+++ /dev/null
@@ -1,17 +0,0 @@
-/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
-
-/*This code is copied fron NVIDIA apex:
- * https://github.com/NVIDIA/apex
- * with minor changes. */
-
-
-
-#ifndef TORCH_CHECK
-#define TORCH_CHECK AT_CHECK
-#endif
-
-#ifdef VERSION_GE_1_3
-#define DATA_PTR data_ptr
-#else
-#define DATA_PTR data
-#endif
diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/fused_kernels/tests/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py
deleted file mode 100644
index 74024c5020f45ea3e80607fad38a45c4dab3453b..0000000000000000000000000000000000000000
--- a/megatron/fused_kernels/tests/test_fused_kernels.py
+++ /dev/null
@@ -1,388 +0,0 @@
-import math
-
-import torch
-from torch.nn import LayerNorm
-
-from megatron.model.enums import AttnMaskType
-from megatron.model.fused_layer_norm import MixedFusedLayerNorm
-from megatron.model.fused_softmax import FusedScaleMaskSoftmax
-from megatron.model.utils import attention_mask_func
-from megatron.fused_kernels import load
-
-def test_load_fused_kernels():
- try:
- import fused_layer_norm_cuda
- import scaled_masked_softmax_cuda
- import scaled_upper_triang_masked_softmax_cuda
- import torch
-
- print("[Success] load_fused_kernels")
- except ImportError as e:
- print("[Fail] load_fused_kernels")
- raise e
-
-def test_fused_softmax():
- bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
- tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
- test_text = (
- "Hello. How are you? I am fine thank you and you? yes Good. "
- "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
- )
-
- tokens = tokenizer(
- [test_text] * 4,
- return_tensors="pt",
- )
-
- embedding_output = bert.embeddings(
- input_ids=tokens["input_ids"].cuda(),
- position_ids=None,
- token_type_ids=tokens["token_type_ids"].cuda(),
- inputs_embeds=None,
- past_key_values_length=0,
- )
-
- # (bsz, 1, 1, seq_len)
- mask = bert.get_extended_attention_mask(
- attention_mask=tokens["attention_mask"].cuda(),
- input_shape=tokens["input_ids"].shape,
- device=bert.device,
- )
- # (bsz, 1, seq_len, seq_len)
- mask = mask.repeat(1, 1, mask.size()[-1], 1)
-
- attention = bert.encoder.layer[0].attention.self
- key_layer = attention.transpose_for_scores(attention.key(embedding_output))
- query_layer = attention.transpose_for_scores(attention.query(embedding_output))
-
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores /= math.sqrt(key_layer.size()[-1])
-
- fused_softmax = (
- FusedScaleMaskSoftmax(
- input_in_fp16=True,
- input_in_bf16=False,
- mask_func=attention_mask_func,
- scale=None,
- softmax_in_fp32=False,
- attn_mask_type=AttnMaskType.padding,
- scaled_masked_softmax_fusion=True,
- )
- .cuda()
- .half()
- )
-
- fused_softmax_output = fused_softmax(
- attention_scores,
- (mask != 0),
- )
-
- torch_softmax = (
- FusedScaleMaskSoftmax(
- input_in_fp16=True,
- input_in_bf16=False,
- mask_func=attention_mask_func,
- scale=None,
- softmax_in_fp32=False,
- attn_mask_type=AttnMaskType.padding,
- scaled_masked_softmax_fusion=False,
- )
- .cuda()
- .half()
- )
-
- torch_softmax_output = torch_softmax(
- attention_scores,
- (mask != 0),
- )
-
- test_result = (fused_softmax_output - torch_softmax_output).abs()
-
- while test_result.dim() != 1:
- test_result = test_result.mean(dim=-1)
-
- diff = test_result.mean(dim=-1)
-
- if diff <= 1e-3:
- print(
- f"\n[Success] test_fused_softmax"
- f"\n > mean_difference={diff}"
- f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
- f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
- )
- else:
- print(
- f"\n[Fail] test_fused_softmax"
- f"\n > mean_difference={diff}, "
- f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
- f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
- )
-
-
-def test_fused_upper_triangle_mask_softmax():
- gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- test_text = (
- "Hello. How are you? I am fine thank you and you? yes Good. "
- "hi hi hi hi hi hi hi" # 24
- )
-
- tokens = tokenizer(
- [test_text] * 4,
- return_tensors="pt",
- )
-
- attention_mask = tokens["attention_mask"].cuda()
- attention_mask = attention_mask.view(attention_mask.size(0), -1)
- attention_mask = attention_mask[:, None, None, :]
- attention_mask = (1.0 - attention_mask) * -10000.0
- attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
- attn = gpt.h[0]
-
- hidden_states = gpt.wte(tokens["input_ids"].cuda())
- q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
- q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
- k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
- attn_weights = torch.matmul(q, k.transpose(-1, -2))
-
- sq, sk = q.size(-2), k.size(-2)
- causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
- total_mask = ~(causal_mask & (attention_mask == 0))
- """
- tensor([[[[False, True, True, ..., True, True, True],
- [False, False, True, ..., True, True, True],
- [False, False, False, ..., True, True, True],
- ...,
- [False, False, False, ..., False, True, True],
- [False, False, False, ..., False, False, True],
- [False, False, False, ..., False, False, False]]]
- """
-
- fused_softmax = (
- FusedScaleMaskSoftmax(
- input_in_fp16=True,
- input_in_bf16=False,
- mask_func=attention_mask_func,
- scale=None,
- softmax_in_fp32=False,
- attn_mask_type=AttnMaskType.causal,
- scaled_masked_softmax_fusion=True,
- )
- .cuda()
- .half()
- )
-
- fused_softmax_output = fused_softmax(
- attn_weights,
- total_mask,
- )
-
- torch_softmax = (
- FusedScaleMaskSoftmax(
- input_in_fp16=True,
- input_in_bf16=False,
- mask_func=attention_mask_func,
- scale=None,
- softmax_in_fp32=False,
- attn_mask_type=AttnMaskType.causal,
- scaled_masked_softmax_fusion=False,
- )
- .cuda()
- .half()
- )
-
- torch_softmax_output = torch_softmax(
- attn_weights,
- total_mask,
- )
-
- test_result = (fused_softmax_output - torch_softmax_output).abs()
-
- while test_result.dim() != 1:
- test_result = test_result.mean(dim=-1)
-
- diff = test_result.mean(dim=-1)
-
- if diff <= 1e-3:
- print(
- f"\n[Success] test_fused_upper_triangle_mask_softmax"
- f"\n > mean_difference={diff}"
- f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
- f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
- )
- else:
- print(
- f"\n[Fail] test_fused_upper_triangle_mask_softmax"
- f"\n > mean_difference={diff}, "
- f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
- f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
- )
-
-
-def test_layer_norm():
- bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
- tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
- test_text = (
- "Hello. How are you? I am fine thank you and you? yes Good. "
- "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
- )
-
- tokens = tokenizer(
- [test_text] * 4,
- return_tensors="pt",
- )
-
- # [bsz, seq_len, d_model]
- embedding_output = (
- bert.embeddings(
- input_ids=tokens["input_ids"].cuda(),
- position_ids=None,
- token_type_ids=tokens["token_type_ids"].cuda(),
- inputs_embeds=None,
- past_key_values_length=0,
- )
- .cuda()
- .half()
- )
-
- fused_layernorm_layer = (
- MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
- )
-
- torch_layernorm_layer = (
- LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
- )
-
- fused_output = fused_layernorm_layer(embedding_output)
- torch_output = torch_layernorm_layer(embedding_output)
- test_result = (fused_output - torch_output).abs()
-
- while test_result.dim() != 1:
- test_result = test_result.mean(dim=-1)
-
- diff = test_result.mean(dim=-1)
-
- if diff <= 1e-3:
- print(
- f"\n[Success] test_layer_norm"
- f"\n > mean_difference={diff}"
- f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
- f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
- )
- else:
- print(
- f"\n[Fail] test_layer_norm"
- f"\n > mean_difference={diff}, "
- f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
- f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
- )
-
-
-def attention_mask_func(attention_scores, attention_mask):
- attention_scores.masked_fill_(attention_mask, -10000.0)
- return attention_scores
-
-
-def forward_torch_softmax(input, mask, scale):
- input = input * scale
- mask_output = attention_mask_func(input, mask) if mask is not None else input
- probs = torch.nn.Softmax(dim=-1)(mask_output)
- return probs
-
-
-def test_masked_softmax_forward():
- import scaled_masked_softmax_cuda
-
- batch = 2
- attn = 16
- scale_t = torch.tensor([1.0])
- for qlen in [128, 256, 1024, 2048, 4096]:
- for klen in [128, 256, 1024, 2048]:
- inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
- masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
- softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
- error = (softmax_results_torch - softmax_results).abs().max()
- assert error < 1e-3
-
-def test_masked_softmax_backward():
- import scaled_masked_softmax_cuda
-
- batch = 2
- attn = 16
- scale_t = torch.tensor([1.0])
- for qlen in [128, 256, 1024, 2048, 4096]:
- for klen in [128, 256, 1024, 2048]:
- inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
- backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
- masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
- back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
-
- inputs.requires_grad = True
- softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
- softmax_results_torch.backward(backward)
- error = (back_grad - inputs.grad).abs().max()
- assert error < 1e-3
-
-
-def test_allmasked_softmax_forward():
- import scaled_masked_softmax_cuda
-
- batch = 2
- attn = 16
- scale_t = torch.tensor([1.0])
- for qlen in [128, 256, 1024, 2048, 4096]:
- for klen in [128, 256, 1024, 2048]:
- inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
- masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
- softmax_results_torch = torch.zeros_like(inputs)
- error = (softmax_results_torch - softmax_results).abs().max()
- assert error == 0.0
-
-
-def test_allmasked_softmax_backward():
- import scaled_masked_softmax_cuda
-
- batch = 2
- attn = 16
- scale_t = torch.tensor([1.0])
- for qlen in [128, 256, 1024, 2048, 4096]:
- for klen in [128, 256, 1024, 2048]:
- inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0')
- backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0')
- masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0')
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item())
- back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item())
- inputs.requires_grad = True
- softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item())
- softmax_results_torch.backward(backward)
- error = (back_grad - inputs.grad).abs().max()
- assert error < 1e-3
-
-
-if __name__ == "__main__":
- try:
- from transformers import BertTokenizer, GPT2Tokenizer
- from transformers.models.bert.modeling_bert import BertModel
- from transformers.models.gpt2.modeling_gpt2 import GPT2Model
- import transformers
-
- transformers.logging.set_verbosity(
- transformers.logging.FATAL,
- )
-
- except:
- print("\n[Fail] Please install `transformers` package to test fused kernels\n")
- exit(-1)
-
- load()
- test_masked_softmax_forward()
- test_masked_softmax_backward()
- test_allmasked_softmax_forward()
- test_allmasked_softmax_backward()
- test_load_fused_kernels()
- test_fused_softmax()
- test_fused_upper_triangle_mask_softmax()
- test_layer_norm()
diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h
deleted file mode 100644
index d60a6f8c6fb50e241f9ddcc852adec71e963e1b2..0000000000000000000000000000000000000000
--- a/megatron/fused_kernels/type_shim.h
+++ /dev/null
@@ -1,103 +0,0 @@
-/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
-
-
-#include
-#include "compat.h"
-
-
-#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
- switch(TYPE) \
- { \
- case at::ScalarType::Half: \
- { \
- using scalar_t = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
-
-
-#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
- switch(TYPE) \
- { \
- case at::ScalarType::Half: \
- { \
- using scalar_t = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Float: \
- { \
- using scalar_t = float; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
-
-
-
-#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
- switch(TYPEIN) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_in = float; \
- switch(TYPEOUT) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_out = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_out = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t_out = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
- } \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_in = at::Half; \
- using scalar_t_out = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t_in = at::BFloat16; \
- using scalar_t_out = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
- }
-
diff --git a/megatron/global_vars.py b/megatron/global_vars.py
deleted file mode 100644
index 04b1448ee80b2707a882d482edce3731356e99eb..0000000000000000000000000000000000000000
--- a/megatron/global_vars.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron global variables."""
-
-import os
-import sys
-import torch
-
-from megatron import dist_signal_handler
-from megatron.tokenizer import build_tokenizer
-from .microbatches import build_num_microbatches_calculator
-from .timers import Timers
-
-_GLOBAL_ARGS = None
-_GLOBAL_RETRO_ARGS = None
-_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
-_GLOBAL_TOKENIZER = None
-_GLOBAL_TENSORBOARD_WRITER = None
-_GLOBAL_WANDB_WRITER = None
-_GLOBAL_ADLR_AUTORESUME = None
-_GLOBAL_TIMERS = None
-_GLOBAL_SIGNAL_HANDLER = None
-
-def get_args():
- """Return arguments."""
- _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
- return _GLOBAL_ARGS
-
-
-def get_retro_args():
- """Return retro arguments."""
- return _GLOBAL_RETRO_ARGS
-
-
-def get_num_microbatches():
- return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
-
-
-def get_current_global_batch_size():
- return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
-
-
-def update_num_microbatches(consumed_samples, consistency_check=True):
- _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
- consistency_check)
-
-
-def get_tokenizer():
- """Return tokenizer."""
- _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
- return _GLOBAL_TOKENIZER
-
-
-def get_tensorboard_writer():
- """Return tensorboard writer. It can be None so no need
- to check if it is initialized."""
- return _GLOBAL_TENSORBOARD_WRITER
-
-
-def get_wandb_writer():
- """Return tensorboard writer. It can be None so no need
- to check if it is initialized."""
- return _GLOBAL_WANDB_WRITER
-
-
-def get_adlr_autoresume():
- """ADLR autoresume object. It can be None so no need
- to check if it is initialized."""
- return _GLOBAL_ADLR_AUTORESUME
-
-
-def get_timers():
- """Return timers."""
- _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
- return _GLOBAL_TIMERS
-
-
-def get_signal_handler():
- _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
- return _GLOBAL_SIGNAL_HANDLER
-
-
-def _set_signal_handler():
- global _GLOBAL_SIGNAL_HANDLER
- _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
- _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
-
-
-
-def set_global_variables(args, build_tokenizer=True):
- """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
-
- assert args is not None
-
- _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
- set_args(args)
-
- _build_num_microbatches_calculator(args)
- if build_tokenizer:
- _ = _build_tokenizer(args)
- _set_tensorboard_writer(args)
- _set_wandb_writer(args)
- _set_adlr_autoresume(args)
- _set_timers(args)
-
- if args.exit_signal_handler:
- _set_signal_handler()
-
-
-def set_args(args):
- global _GLOBAL_ARGS
- _GLOBAL_ARGS = args
-
-
-def set_retro_args(retro_args):
- global _GLOBAL_RETRO_ARGS
- _GLOBAL_RETRO_ARGS = retro_args
-
-
-def _build_num_microbatches_calculator(args):
-
- global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
- _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
- 'num microbatches calculator')
-
- _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
- args)
-
-
-def _build_tokenizer(args):
- """Initialize tokenizer."""
- global _GLOBAL_TOKENIZER
- _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
- _GLOBAL_TOKENIZER = build_tokenizer(args)
- return _GLOBAL_TOKENIZER
-
-
-def rebuild_tokenizer(args):
- global _GLOBAL_TOKENIZER
- _GLOBAL_TOKENIZER = None
- return _build_tokenizer(args)
-
-
-def _set_tensorboard_writer(args):
- """Set tensorboard writer."""
- global _GLOBAL_TENSORBOARD_WRITER
- _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
- 'tensorboard writer')
-
- if hasattr(args, 'tensorboard_dir') and \
- args.tensorboard_dir and args.rank == (args.world_size - 1):
- try:
- from torch.utils.tensorboard import SummaryWriter
- print('> setting tensorboard ...')
- _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
- log_dir=args.tensorboard_dir,
- max_queue=args.tensorboard_queue_size)
- except ModuleNotFoundError:
- print('WARNING: TensorBoard writing requested but is not '
- 'available (are you using PyTorch 1.1.0 or later?), '
- 'no TensorBoard logs will be written.', flush=True)
-
-
-def _set_wandb_writer(args):
- global _GLOBAL_WANDB_WRITER
- _ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER,
- 'wandb writer')
- if getattr(args, 'wandb_project', '') and args.rank == (args.world_size - 1):
- if args.wandb_exp_name == '':
- raise ValueError("Please specify the wandb experiment name!")
-
- import wandb
- if args.wandb_save_dir:
- save_dir = args.wandb_save_dir
- else:
- # Defaults to the save dir.
- save_dir = os.path.join(args.save, 'wandb')
- wandb_kwargs = {
- 'dir': save_dir,
- 'name': args.wandb_exp_name,
- 'project': args.wandb_project,
- 'config': vars(args)}
- os.makedirs(wandb_kwargs['dir'], exist_ok=True)
- wandb.init(**wandb_kwargs)
- _GLOBAL_WANDB_WRITER = wandb
-
-
-def _set_adlr_autoresume(args):
- """Initialize ADLR autoresume."""
- global _GLOBAL_ADLR_AUTORESUME
- _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
-
- if args.adlr_autoresume:
- if args.rank == 0:
- print('enabling autoresume ...', flush=True)
- sys.path.append(os.environ.get('SUBMIT_SCRIPTS', ''))
- try:
- from userlib.auto_resume import AutoResume
- except BaseException:
- print('ADLR autoresume is not available, exiting ...')
- sys.exit()
-
- _GLOBAL_ADLR_AUTORESUME = AutoResume
-
-
-def _set_timers(args):
- """Initialize timers."""
- global _GLOBAL_TIMERS
- _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
- _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
-
-
-def _ensure_var_is_initialized(var, name):
- """Make sure the input variable is not None."""
- assert var is not None, '{} is not initialized.'.format(name)
-
-
-def _ensure_var_is_not_initialized(var, name):
- """Make sure the input variable is not None."""
- assert var is None, '{} is already initialized.'.format(name)
-
-
-
diff --git a/megatron/indexer.py b/megatron/indexer.py
deleted file mode 100644
index 45f530a7d4d7d72bfeb1f2d9a3098b3812a332fe..0000000000000000000000000000000000000000
--- a/megatron/indexer.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import sys
-import time
-import torch
-import torch.distributed as dist
-
-from megatron import get_args, print_rank_0
-from megatron.core import mpu
-from megatron.checkpointing import load_biencoder_checkpoint
-from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
-from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
-from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
-from megatron.data.realm_index import detach, OpenRetreivalDataStore
-from megatron.model.biencoder_model import get_model_provider
-from megatron.training import get_model
-
-
-class IndexBuilder(object):
- """
- Object for taking one pass over a dataset and creating a BlockData of its
- embeddings
- """
- def __init__(self):
- args = get_args()
- self.model = None
- self.dataloader = None
- self.evidence_embedder_obj = None
- self.biencoder_shared_query_context_model = \
- args.biencoder_shared_query_context_model
-
- # need to know whether we're using a REALM checkpoint (args.load)
- # or ICT checkpoint
- assert not (args.load and args.ict_load)
-
- self.log_interval = args.indexer_log_interval
- self.batch_size = args.indexer_batch_size
-
- self.load_attributes()
- self.is_main_builder = mpu.get_data_parallel_rank() == 0
- self.num_total_builders = mpu.get_data_parallel_world_size()
- self.iteration = self.total_processed = 0
-
- def load_attributes(self):
- """
- Load the necessary attributes: model, dataloader and empty BlockData
- """
- only_context_model = True
- if self.biencoder_shared_query_context_model:
- only_context_model = False
-
- model = get_model(get_model_provider(only_context_model=\
- only_context_model, biencoder_shared_query_context_model=\
- self.biencoder_shared_query_context_model))
-
- self.model = load_biencoder_checkpoint(model,
- only_context_model=only_context_model)
-
- assert len(self.model) == 1
- self.model[0].eval()
-
- self.dataset = get_open_retrieval_wiki_dataset()
- self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
- self.batch_size))
-
- self.evidence_embedder_obj = OpenRetreivalDataStore( \
- load_from_path=False)
-
- def track_and_report_progress(self, batch_size):
- """
- Utility function for tracking progress
- """
- self.iteration += 1
- self.total_processed += batch_size * self.num_total_builders
- if self.is_main_builder and self.iteration % self.log_interval == 0:
- print('Batch {:10d} | Total {:10d}'.format(self.iteration,
- self.total_processed), flush=True)
-
- def build_and_save_index(self):
- """
- Goes through one epoch of the dataloader and adds all data to this
- instance's BlockData.
-
- The copy of BlockData is saved as a shard, which when run in a
- distributed setting will be consolidated by the rank 0 process
- and saved as a final pickled BlockData.
- """
- assert len(self.model) == 1
- unwrapped_model = self.model[0]
-
- while not hasattr(unwrapped_model, 'embed_text'):
- unwrapped_model = unwrapped_model.module
-
- while True:
- try:
- # batch also has query_tokens and query_pad_data
- row_id, context_tokens, context_mask, context_types, \
- context_pad_mask = get_open_retrieval_batch( \
- self.dataloader)
- except (StopIteration, IndexError):
- break
-
- # TODO: can we add with torch.no_grad() to reduce memory usage
- # detach, separate fields and add to BlockData
- assert context_mask.dtype == torch.bool
- context_logits = unwrapped_model.embed_text(
- unwrapped_model.context_model, context_tokens, context_mask,
- context_types)
-
- context_logits = detach(context_logits)
- row_id = detach(row_id)
-
- self.evidence_embedder_obj.add_block_data(row_id, context_logits)
- self.track_and_report_progress(batch_size=len(row_id))
-
- # This process signals to finalize its shard and then synchronize with
- # the other processes
- self.evidence_embedder_obj.save_shard()
- torch.distributed.barrier()
- del self.model
-
- # rank 0 process builds the final copy
- if self.is_main_builder:
- self.evidence_embedder_obj.merge_shards_and_save()
- # make sure that every single piece of data was embedded
- assert len(self.evidence_embedder_obj.embed_data) == \
- len(self.dataset)
- self.evidence_embedder_obj.clear()
-
- # complete building the final copy
- torch.distributed.barrier()
diff --git a/megatron/initialize.py b/megatron/initialize.py
deleted file mode 100644
index fb7866ab03510701afb351022f566e8f1e7ba00b..0000000000000000000000000000000000000000
--- a/megatron/initialize.py
+++ /dev/null
@@ -1,387 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron initialization."""
-
-import random
-import os
-import time
-
-import numpy as np
-import torch
-from datetime import timedelta
-
-from megatron import fused_kernels
-from megatron import get_adlr_autoresume
-from megatron import get_args
-from megatron import get_tensorboard_writer
-from megatron.core import mpu, tensor_parallel
-from megatron.arguments import parse_args, validate_args
-from megatron.checkpointing import load_args_from_checkpoint
-from megatron.global_vars import set_global_variables
-from megatron.model.transformer import bias_dropout_add_fused_train
-from megatron.model.fused_bias_gelu import bias_gelu
-
-def initialize_megatron(
- extra_args_provider=None,
- args_defaults={},
- ignore_unknown_args=False,
- allow_no_cuda=False,
- skip_mpu_initialization=False,
-):
- """Set global variables, initialize distributed, and
- set autoresume and random seeds.
- `allow_no_cuda` should not be set unless using megatron for cpu only
- data processing. In general this arg should not be set unless you know
- what you are doing.
- Returns a function to finalize distributed env initialization
- (optionally, only when args.lazy_mpu_init == True)
- """
- if not allow_no_cuda:
- # Make sure cuda is available.
- assert torch.cuda.is_available(), "Megatron requires CUDA."
-
- # Parse arguments
- args = parse_args(extra_args_provider, ignore_unknown_args)
-
- if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
- assert args.load is not None, "--use-checkpoints-args requires --load argument"
- load_args_from_checkpoint(args)
-
- validate_args(args, args_defaults)
-
- # set global args, build tokenizer, and set adlr-autoresume,
- # tensorboard-writer, and timers.
- set_global_variables(args)
-
- # torch.distributed initialization
- def finish_mpu_init():
- args = get_args()
- # Pytorch distributed.
- _initialize_distributed()
-
- # Random seeds for reproducibility.
- if args.rank == 0:
- print("> setting random seeds to {} ...".format(args.seed))
- _set_random_seed(args.seed, args.data_parallel_random_init)
-
- if skip_mpu_initialization:
- return None
-
- args = get_args()
- if args.lazy_mpu_init:
- # TODO is this still a necessary option?
- args.use_cpu_initialization = True
- # delayed initialization of DDP-related stuff
- # We only set basic DDP globals
- mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
- # and return function for external DDP manager
- # to call when it has DDP initialized
- mpu.set_tensor_model_parallel_rank(args.rank)
- return finish_mpu_init
- else:
- # Megatron's MPU is the master. Complete initialization right away.
- finish_mpu_init()
-
- # Autoresume.
- _init_autoresume()
-
- # Compile dependencies.
- _compile_dependencies()
-
- if args.tp_comm_overlap:
- _initialize_tp_communicators()
-
- # No continuation function
- return None
-
-
-def _compile_dependencies():
-
- args = get_args()
-
- # =========================
- # Compile dataset C++ code.
- # =========================
- # TODO: move this to ninja
- if torch.distributed.get_rank() == 0:
- start_time = time.time()
- print("> compiling dataset index builder ...")
- from megatron.core.datasets.utils import compile_helpers
-
- compile_helpers()
- print(
- ">>> done with dataset index builder. Compilation time: {:.3f} "
- "seconds".format(time.time() - start_time),
- flush=True,
- )
-
- # ==================
- # Load fused kernels
- # ==================
-
- # Custom kernel constraints check.
- seq_len = args.seq_length
- attn_batch_size = (
- args.num_attention_heads / args.tensor_model_parallel_size
- ) * args.micro_batch_size
- # Constraints on sequence length and attn_batch_size to enable warp based
- # optimization and upper triangular optimization (for causal mask)
- custom_kernel_constraint = (
- seq_len > 16
- and seq_len <= 16384
- and seq_len % 4 == 0
- and attn_batch_size % 4 == 0
- )
- # Print a warning.
- if not (
- (args.fp16 or args.bf16)
- and custom_kernel_constraint
- and args.masked_softmax_fusion
- ):
- if args.rank == 0:
- print(
- "WARNING: constraints for invoking optimized"
- " fused softmax kernel are not met. We default"
- " back to unfused kernel invocations.",
- flush=True,
- )
-
- # Always build on rank zero first.
- if torch.distributed.get_rank() == 0:
- start_time = time.time()
- print("> compiling and loading fused kernels ...", flush=True)
- fused_kernels.load(args)
- torch.distributed.barrier()
- else:
- torch.distributed.barrier()
- fused_kernels.load(args)
- # Simple barrier to make sure all ranks have passed the
- # compilation phase successfully before moving on to the
- # rest of the program. We think this might ensure that
- # the lock is released.
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(
- ">>> done with compiling and loading fused kernels. "
- "Compilation time: {:.3f} seconds".format(time.time() - start_time),
- flush=True,
- )
-
-def _initialize_tp_communicators():
- """ initializing the communicators with user buffers for high-performance tensor-model-parallel
- communication overlap """
-
- try:
- import yaml
-
- import transformer_engine
- from transformer_engine.pytorch import module as te_module
-
- except ImportError:
- raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
- "'transformer_engine' packages")
-
- args = get_args()
-
- if args.tp_comm_overlap_cfg is not None:
- with open(args.tp_comm_overlap_cfg,"r") as stream:
- ub_cfgs = yaml.safe_load(stream)
- else:
- ub_cfgs = {}
-
- input_shape = [args.seq_length * args.micro_batch_size , args.hidden_size]
-
- #We create a MPI process group, which is needed to bootstrap the pipelined
- #tensor-model-parallel communication overlap
- torch.distributed.new_group(backend='mpi')
-
- te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
- use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,)
-
-def _initialize_distributed():
- """Initialize torch.distributed and core model parallel."""
- args = get_args()
-
- device_count = torch.cuda.device_count()
- if torch.distributed.is_initialized():
-
- if args.rank == 0:
- print(
- "torch distributed is already initialized, "
- "skipping initialization ...",
- flush=True,
- )
- args.rank = torch.distributed.get_rank()
- args.world_size = torch.distributed.get_world_size()
-
- else:
-
- if args.rank == 0:
- print("> initializing torch distributed ...", flush=True)
- # Manually set the device ids.
- if device_count > 0:
- device = args.rank % device_count
- if args.local_rank is not None:
- assert (
- args.local_rank == device
- ), "expected local-rank to be the same as rank % device-count."
- else:
- args.local_rank = device
- torch.cuda.set_device(device)
- # Call the init process
- torch.distributed.init_process_group(
- backend=args.distributed_backend,
- world_size=args.world_size,
- rank=args.rank,
- timeout=timedelta(minutes=args.distributed_timeout_minutes),
- )
-
- # Set the tensor model-parallel, pipeline model-parallel, and
- # data-parallel communicators.
- if device_count > 0:
- if mpu.model_parallel_is_initialized():
- print("model parallel is already initialized")
- else:
- mpu.initialize_model_parallel(
- args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size,
- args.virtual_pipeline_model_parallel_size,
- args.pipeline_model_parallel_split_rank,
- context_parallel_size=args.context_parallel_size,
- expert_model_parallel_size=args.expert_model_parallel_size,
- nccl_communicator_config_path=args.nccl_communicator_config_path,
- )
- if args.rank == 0:
- print(
- f"> initialized tensor model parallel with size "
- f"{mpu.get_tensor_model_parallel_world_size()}"
- )
- print(
- f"> initialized pipeline model parallel with size "
- f"{mpu.get_pipeline_model_parallel_world_size()}"
- )
-
-
-def _init_autoresume():
- """Set autoresume start time."""
- autoresume = get_adlr_autoresume()
- if autoresume:
- torch.distributed.barrier()
- autoresume.init()
- torch.distributed.barrier()
-
-
-def _set_random_seed(seed_, data_parallel_random_init=False):
- """Set random seed for reproducability."""
- if seed_ is not None and seed_ > 0:
- # Ensure that different pipeline MP stages get different seeds.
- seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
- # Ensure different data parallel ranks get different seeds
- if data_parallel_random_init:
- seed = seed + (10 * mpu.get_data_parallel_rank())
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.device_count() > 0:
- tensor_parallel.model_parallel_cuda_manual_seed(seed)
- else:
- raise ValueError("Seed ({}) should be a positive integer.".format(seed))
-
-
-def write_args_to_tensorboard():
- """Write arguments to tensorboard."""
- args = get_args()
- writer = get_tensorboard_writer()
- if writer:
- for arg in vars(args):
- writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
-
-
-def set_jit_fusion_options():
- """Set PyTorch JIT layer fusion options."""
- # flags required to enable jit fusion kernels
- TORCH_MAJOR = int(torch.__version__.split(".")[0])
- TORCH_MINOR = int(torch.__version__.split(".")[1])
- if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
- # nvfuser
- torch._C._jit_set_profiling_executor(True)
- torch._C._jit_set_profiling_mode(True)
- torch._C._jit_override_can_fuse_on_cpu(False)
- torch._C._jit_override_can_fuse_on_gpu(False)
- torch._C._jit_set_texpr_fuser_enabled(False)
- torch._C._jit_set_nvfuser_enabled(True)
- torch._C._debug_set_autodiff_subgraph_inlining(False)
- else:
- # legacy pytorch fuser
- torch._C._jit_set_profiling_mode(False)
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
-
- _warmup_jit_function()
-
-
-def _warmup_jit_function():
- """Compilie JIT functions before the main training steps"""
- args = get_args()
- if args.bf16:
- dtype = torch.bfloat16
- elif args.fp16:
- dtype = torch.float16
- else:
- dtype = torch.float32
-
- # Warmup fused bias+gelu
- bias = torch.rand(
- args.ffn_hidden_size // args.tensor_model_parallel_size,
- dtype=dtype,
- device="cuda",
- )
- input = torch.rand(
- (
- args.seq_length,
- args.micro_batch_size,
- args.ffn_hidden_size // args.tensor_model_parallel_size,
- ),
- dtype=dtype,
- device="cuda",
- )
- # Warmup JIT fusions with the input grad_enable state of both forward
- # prop and recomputation
- for bias_grad, input_grad in zip([True, True], [False, True]):
- bias.requires_grad, input.requires_grad = bias_grad, input_grad
- for _ in range(5):
- output = bias_gelu(bias, input)
- del bias, input, output
-
- # Warmup fused bias+dropout+add
- if args.sequence_parallel:
- seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
- else:
- seq_length = args.seq_length
- input = torch.rand(
- (seq_length, args.micro_batch_size, args.hidden_size),
- dtype=dtype,
- device="cuda",
- )
- residual = torch.rand(
- (seq_length, args.micro_batch_size, args.hidden_size),
- dtype=dtype,
- device="cuda",
- )
- bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
- residual
- )
- dropout_rate = 0.1
- # Warmup JIT fusions with the input grad_enable state of both forward
- # prop and recomputation
- for input_grad, bias_grad, residual_grad in zip(
- [False, True], [True, True], [True, True]
- ):
- input.requires_grad = input_grad
- bias.requires_grad = bias_grad
- residual.requires_grad = residual_grad
- for _ in range(5):
- output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
- del bias, input, residual, output
- torch.cuda.empty_cache()
diff --git a/megatron/log_handler.py b/megatron/log_handler.py
deleted file mode 100644
index 06f5d1842d1d8bb89ca78633854ce4d910761f1a..0000000000000000000000000000000000000000
--- a/megatron/log_handler.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import sys
-from logging import LogRecord, StreamHandler
-
-BLACKLISTED_MODULES = ["torch.distributed"]
-
-
-class CustomHandler(StreamHandler):
- """
- Custom handler to filter out logging from code outside of
- Megatron Core, and dump to stdout.
- """
-
- def __init__(self):
- super().__init__(stream=sys.stdout)
-
- def filter(self, record: LogRecord) -> bool:
- # Prevent log entries that come from the blacklisted modules
- # through (e.g., PyTorch Distributed).
- for blacklisted_module in BLACKLISTED_MODULES:
- if record.name.startswith(blacklisted_module):
- return False
- return True
diff --git a/megatron/memory.py b/megatron/memory.py
deleted file mode 100644
index a5fef75baa749d557da227bbccf706501ffdd10f..0000000000000000000000000000000000000000
--- a/megatron/memory.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-
-import torch
-
-
-# A dictionary of all the memory buffers allocated.
-_MEM_BUFFS = dict()
-
-
-def allocate_mem_buff(name, numel, dtype, track_usage):
- """Allocate a memory buffer."""
- assert name not in _MEM_BUFFS, \
- 'memory buffer {} already allocated.'.format(name)
- _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
- return _MEM_BUFFS[name]
-
-
-def get_mem_buff(name):
- """Get the memory buffer."""
- return _MEM_BUFFS[name]
-
-
-class MemoryBuffer:
- """Contiguous memory buffer.
- Allocate a contiguous memory of type `dtype` and size `numel`. It is
- used to reduce memory fragmentation.
-
- Usage: After the allocation, the `_start` index is set tot the first
- index of the memory. A memory chunk starting from `_start` index
- can be `allocated` for an input tensor, with the elements of the
- tensor being coppied. The buffer can be reused by resetting the
- `_start` index.
-
- """
- def __init__(self, name, numel, dtype, track_usage):
- if torch.distributed.get_rank() == 0:
- element_size = torch.tensor([], dtype=dtype).element_size()
- print('> building the {} memory buffer with {} num elements '
- 'and {} dtype ({:.1f} MB)...'.format(
- name, numel, dtype, numel*element_size/1024/1024),
- flush=True)
- self.name = name
- self.numel = numel
- self.dtype = dtype
- self.data = torch.empty(self.numel,
- dtype=self.dtype,
- device=torch.cuda.current_device(),
- requires_grad=False)
-
- # Index tracking the start of the free memory.
- self._start = 0
-
- # Values used for tracking usage.
- self.track_usage = track_usage
- if self.track_usage:
- self.in_use_value = 0.0
- self.total_value = 0.0
-
-
- def reset(self):
- """Reset the buffer start index to the beginning of the buffer."""
- self._start = 0
-
-
- def is_in_use(self):
- """Whether the current buffer hold on to any memory."""
- return self._start > 0
-
-
- def numel_in_use(self):
- """Return number of elements in use."""
- return self._start
-
-
- def add(self, tensor):
- """Allocate a chunk of memory from the buffer to tensor and copy
- the values."""
- assert tensor.dtype == self.dtype, \
- 'Input tensor type {} different from buffer type {}'.format(
- tensor.dtype, self.dtype)
- # Number of elements of the input tensor.
- tensor_numel = torch.numel(tensor)
- new_start = self._start + tensor_numel
- assert new_start <= self.numel, \
- 'Not enough memory left in the buffer ({} > {})'.format(
- tensor_numel, self.numel - self._start)
- # New tensor is a view into the memory.
- new_tensor = self.data[self._start:new_start]
- self._start = new_start
- new_tensor = new_tensor.view(tensor.shape)
- new_tensor.copy_(tensor)
- # Return a pointer to the new tensor.
- return new_tensor
-
-
- def get_data(self):
- """Return the data currently in use."""
- if self.track_usage:
- self.in_use_value += float(self._start)
- self.total_value += float(self.numel)
- return self.data[:self._start]
-
-
- def print_average_usage(self):
- """Print memory usage average over time. We would like this value
- to be as high as possible."""
- assert self.track_usage, 'You need to enable track usage.'
- if torch.distributed.get_rank() == 0:
- print(' > usage of {} memory buffer: {:.2f} %'.format(
- self.name, self.in_use_value * 100.0 / self.total_value),
- flush=True)
-
-
-
-class RingMemBuffer:
- """A ring of memory buffers."""
-
- def __init__(self, name, num_buffers, numel, dtype, track_usage):
- self.num_buffers = num_buffers
- self.buffers = [
- allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage)
- for i in range(num_buffers)]
- self._index = -1
-
-
- def get_next_buffer(self):
- self._index += 1
- self._index = self._index % self.num_buffers
- buff = self.buffers[self._index]
- assert not buff.is_in_use(), 'buffer is already in use.'
- return buff
diff --git a/megatron/microbatches.py b/megatron/microbatches.py
deleted file mode 100644
index 6449d7479c9c983b4813889ee8f1beec9e027cc3..0000000000000000000000000000000000000000
--- a/megatron/microbatches.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron number of micro-batches calculators."""
-
-from abc import ABC
-from abc import abstractmethod
-
-
-def build_num_microbatches_calculator(args):
-
- # Constant num micro-batches.
- if args.rampup_batch_size is None:
- num_microbatches_calculator = ConstantNumMicroBatches(
- args.global_batch_size, args.micro_batch_size,
- args.data_parallel_size)
- if args.rank == 0:
- print('setting number of micro-batches to constant {}'.format(
- num_microbatches_calculator.get()), flush=True)
-
- else:
- assert len(args.rampup_batch_size) == 3, 'expected the following ' \
- 'format: --rampup-batch-size ' \
- ''
- start_batch_size = int(args.rampup_batch_size[0])
- batch_size_increment = int(args.rampup_batch_size[1])
- ramup_samples = int(args.rampup_batch_size[2])
- if args.rank == 0:
- print('will use batch size rampup starting from global batch '
- 'size {} to global batch size {} with batch size increments '
- '{} over {} samples.'.format(start_batch_size,
- args.global_batch_size,
- batch_size_increment,
- ramup_samples), flush=True)
- num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
- start_batch_size, batch_size_increment, ramup_samples,
- args.global_batch_size, args.micro_batch_size,
- args.data_parallel_size)
-
- return num_microbatches_calculator
-
-
-class NumMicroBatchesCalculator(ABC):
-
- def __init__(self):
- self.num_micro_batches = None
- self.current_global_batch_size = None
-
- def get(self):
- return self.num_micro_batches
-
- def get_current_global_batch_size(self):
- return self.current_global_batch_size
-
- @abstractmethod
- def update(self, consumed_samples, consistency_check):
- pass
-
-
-class ConstantNumMicroBatches(NumMicroBatchesCalculator):
-
- def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
- micro_batch_times_data_parallel = micro_batch_size * \
- data_parallel_size
- assert global_batch_size % micro_batch_times_data_parallel == 0, \
- 'global batch size ({}) is not divisible by micro batch size ({})' \
- ' times data parallel size ({})'.format(global_batch_size,
- micro_batch_size,
- data_parallel_size)
- self.num_micro_batches = global_batch_size // \
- micro_batch_times_data_parallel
- assert self.num_micro_batches >= 1
- self.current_global_batch_size = global_batch_size
-
- def update(self, consumed_samples, consistency_check):
- pass
-
-
-class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
-
- def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
- global_batch_size, micro_batch_size, data_parallel_size):
- """Batch size ramp up.
- Over
- steps = (global-batch-size - start-batch-size) / batch_size_increment
- increment batch size from start-batch-size to global-batch-size using
- rampup-samples / steps
- samples.
- Arguments:
- start_batch_size: global batch size to start with
- batch_size_increment: global batch size increments
- ramup_samples: number of samples to use ramp up global
- batch size from `start_batch_size` to `global_batch_size`
- global_batch_size: global batch size post rampup
- micro_batch_size: micro batch size
- data_parallel_size: data parallel size.
- """
-
- self.micro_batch_size = micro_batch_size
- self.data_parallel_size = data_parallel_size
- self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
- self.data_parallel_size
- assert self.micro_batch_times_data_parallel_size > 0
-
- assert start_batch_size > 0
- self.start_batch_size = start_batch_size
-
- assert global_batch_size > 0
- self.global_batch_size = global_batch_size
- diff_batch_size = self.global_batch_size - self.start_batch_size
- assert diff_batch_size >= 0
- assert batch_size_increment > 0
- self.batch_size_increment = batch_size_increment
- assert diff_batch_size % batch_size_increment == 0, 'expected ' \
- 'global batch size interval ({}) to be divisible by global batch ' \
- 'size increment ({})'.format(diff_batch_size, batch_size_increment)
-
- num_increments = diff_batch_size // self.batch_size_increment
- self.ramup_samples = ramup_samples
- assert self.ramup_samples >= 0
- self.rampup_samples_per_increment = self.ramup_samples / num_increments
-
- # Initialize number of microbatches.
- self.update(0, False)
-
-
- def update(self, consumed_samples, consistency_check):
-
- if consumed_samples > self.ramup_samples:
- self.current_global_batch_size = self.global_batch_size
- else:
- steps = int(consumed_samples / self.rampup_samples_per_increment)
- self.current_global_batch_size = self.start_batch_size + \
- steps * self.batch_size_increment
- assert self.current_global_batch_size <= self.global_batch_size
-
- if consistency_check:
- assert self.current_global_batch_size % \
- self.micro_batch_times_data_parallel_size == 0, 'current global ' \
- 'batch size ({}) is not divisible by micro-batch-size ({}) times' \
- 'data parallel size ({})'.format(self.current_global_batch_size,
- self.micro_batch_size,
- self.data_parallel_size)
- self.num_micro_batches = self.current_global_batch_size // \
- self.micro_batch_times_data_parallel_size
diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py
deleted file mode 100644
index cb010e5fb6c318ae849ad647d8f6d4ee4e309931..0000000000000000000000000000000000000000
--- a/megatron/model/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
-from .rms_norm import RMSNorm
-
-from .bert_model import BertModel
-from .gpt_model import GPTModel
-from .t5_model import T5Model
-from .language_model import get_language_model
-from .module import Float16Module
diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py
deleted file mode 100644
index cd4bb35db725faebb2304305f207b1afd777d412..0000000000000000000000000000000000000000
--- a/megatron/model/bert_model.py
+++ /dev/null
@@ -1,257 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""BERT model."""
-
-import torch
-
-from megatron import get_args
-from megatron.core import tensor_parallel
-from megatron.model.enums import AttnMaskType
-from megatron.model.language_model import parallel_lm_logits
-from megatron.model.language_model import get_language_model
-from megatron.model.utils import get_norm
-from megatron.model.utils import openai_gelu, erf_gelu
-from megatron.model.utils import get_linear_layer
-from megatron.model.utils import init_method_normal
-from megatron.model.utils import scaled_init_method_normal
-from .module import MegatronModule
-
-
-def bert_extended_attention_mask(attention_mask):
- # We create a 3D attention mask from a 2D tensor mask.
- # [b, 1, s]
- attention_mask_b1s = attention_mask.unsqueeze(1)
- # [b, s, 1]
- attention_mask_bs1 = attention_mask.unsqueeze(2)
- # [b, s, s]
- attention_mask_bss = attention_mask_b1s * attention_mask_bs1
- # [b, 1, s, s]
- extended_attention_mask = attention_mask_bss.unsqueeze(1)
-
- # Convert attention mask to binary:
- extended_attention_mask = (extended_attention_mask < 0.5)
-
- return extended_attention_mask
-
-def bert_position_ids(token_ids):
- # Create position ids
- seq_length = token_ids.size(1)
- position_ids = torch.arange(seq_length, dtype=torch.long,
- device=token_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
-
- return position_ids
-
-
-class BertLMHead(MegatronModule):
- """Masked LM head for Bert
-
- Arguments:
- config: TransformerConfig object
- mpu_vocab_size: model parallel size of vocabulary.
- parallel_output: whether output logits being distributed or not.
- """
-
- def __init__(self, mpu_vocab_size, config, parallel_output):
- super().__init__(config=config)
-
- args = get_args()
- self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
- tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
- self.parallel_output = parallel_output
-
- self.dense = get_linear_layer(config.hidden_size, config.hidden_size, config.init_method)
- setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
- setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
-
- self.norm = get_norm(config)
- self.gelu = torch.nn.functional.gelu
- if args.openai_gelu:
- self.gelu = openai_gelu
- elif args.onnx_safe:
- self.gelu = erf_gelu
-
- def forward(self, hidden_states, word_embeddings_weight):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.gelu(hidden_states)
- hidden_states = self.norm(hidden_states)
- output = parallel_lm_logits(hidden_states,
- word_embeddings_weight,
- self.parallel_output,
- bias=self.bias)
- return output
-
- def load_state_dict(self, state_dict, strict=True):
- """Customize load."""
-
- # Handle renaming layernorm -> norm in component names
- state_dict_ = {}
- for key in state_dict.keys():
- newkey = key.replace("layernorm", "norm")
- state_dict_[newkey] = state_dict[key]
-
- super().load_state_dict(state_dict_, strict)
-
-
-def post_language_model_processing(lm_output, pooled_output,
- lm_head, binary_head,
- lm_labels,
- logit_weights,
- fp16_lm_cross_entropy):
- # Output.
- lm_logits = lm_head(
- lm_output, logit_weights)
-
- binary_logits = None
- if binary_head is not None:
- binary_logits = binary_head(pooled_output)
-
- if lm_labels is None:
- # [s b h] => [b s h]
- return lm_logits.transpose(0,1).contiguous(), binary_logits
- else:
- # [b s] => [s b]
- lm_labels = lm_labels.transpose(0,1).contiguous()
- # lm_logits : [s, b, h] and lm_labels: [s, b]
- if fp16_lm_cross_entropy:
- assert lm_logits.dtype == torch.half
- lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
- else:
- lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
- lm_labels)
- # [s, b] => [b s]
- lm_loss = lm_loss.transpose(0,1).contiguous()
- return lm_loss, binary_logits
-
-
-class BertModel(MegatronModule):
- """Bert Language model."""
-
- def __init__(self,
- config,
- num_tokentypes=2,
- add_binary_head=True,
- parallel_output=True,
- pre_process=True,
- post_process=True):
- super().__init__(config=config)
- args = get_args()
-
- # TODO this option is not yet implemented in BERT
- assert args.untie_embeddings_and_output_weights is False
-
- self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
- self.add_binary_head = add_binary_head
- self.parallel_output = parallel_output
- self.pre_process = pre_process
- self.post_process = post_process
-
- self.return_embeddings = args.output_bert_embeddings
- if self.return_embeddings:
- assert self.post_process and self.add_binary_head
-
- self.language_model, self._language_model_key = get_language_model(
- config=config,
- num_tokentypes=num_tokentypes,
- add_pooler=self.add_binary_head,
- encoder_attn_mask_type=AttnMaskType.padding,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- self.initialize_word_embeddings()
- if self.post_process:
- self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config, parallel_output)
- self._lm_head_key = 'lm_head'
- self.binary_head = None
- if self.add_binary_head:
- self.binary_head = get_linear_layer(config.hidden_size, 2,
- config.init_method)
- self._binary_head_key = 'binary_head'
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.language_model.set_input_tensor(input_tensor)
-
- def forward(self, bert_model_input, attention_mask,
- tokentype_ids=None, lm_labels=None):
-
- extended_attention_mask = bert_extended_attention_mask(attention_mask)
- input_ids = bert_model_input
- position_ids = bert_position_ids(input_ids)
-
- lm_output = self.language_model(
- input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids
- )
-
- if self.post_process and self.add_binary_head:
- lm_output, pooled_output = lm_output
-
- # Return pooled output (e.g., when computing Bert embeddings).
- if self.return_embeddings:
-
- # Sum attention mask.
- embeddings = torch.transpose(lm_output, 0, 1)
- masks = torch.sum(attention_mask, dim=1)
-
- # Collect masked embeddings.
- output = torch.zeros(
- size=(embeddings.shape[0], embeddings.shape[2]),
- dtype=torch.float32,
- device=torch.cuda.current_device())
- for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
- output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)
-
- return output
-
- else:
- pooled_output = None
-
- if self.post_process:
- return post_language_model_processing(lm_output, pooled_output,
- self.lm_head, self.binary_head,
- lm_labels,
- self.shared_embedding_or_output_weight(),
- self.fp16_lm_cross_entropy)
- else:
- return lm_output
-
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process:
- state_dict_[self._lm_head_key] \
- = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process and self.add_binary_head:
- state_dict_[self._binary_head_key] \
- = self.binary_head.state_dict(prefix=prefix, keep_vars=keep_vars)
- # Save word_embeddings.
- if self.post_process and not self.pre_process:
- state_dict_[self._word_embeddings_for_head_key] \
- = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- if self.post_process:
- self.lm_head.load_state_dict(
- state_dict[self._lm_head_key], strict=strict)
- if self.post_process and self.add_binary_head:
- self.binary_head.load_state_dict(
- state_dict[self._binary_head_key], strict=strict)
- # Load word_embeddings.
- if self.post_process and not self.pre_process:
- self.word_embeddings.load_state_dict(
- state_dict[self._word_embeddings_for_head_key], strict=strict)
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
deleted file mode 100644
index c910879dc8dc48b96273683996f608453dc0729c..0000000000000000000000000000000000000000
--- a/megatron/model/biencoder_model.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import os
-import torch
-import sys
-
-from megatron import get_args, print_rank_0, get_tokenizer
-from megatron.core import mpu
-from megatron.checkpointing import fix_query_key_value_ordering
-from megatron.checkpointing import get_checkpoint_tracker_filename
-from megatron.checkpointing import get_checkpoint_name
-from megatron.model.bert_model import bert_position_ids
-from megatron.model.enums import AttnMaskType
-from megatron.model.language_model import get_language_model
-from megatron.model.utils import get_linear_layer
-from megatron.model.utils import init_method_normal
-from megatron.model.utils import scaled_init_method_normal
-from .module import MegatronModule
-
-def get_model_provider(only_query_model=False, only_context_model=False,
- biencoder_shared_query_context_model=False):
-
- def model_provider(pre_process=True, post_process=True):
- """Build the model."""
-
- print_rank_0('building Bienoder model ...')
- model = biencoder_model_provider(only_query_model=only_query_model,
- only_context_model = only_context_model,
- biencoder_shared_query_context_model = \
- biencoder_shared_query_context_model,
- pre_process=pre_process, post_process=post_process)
-
- return model
-
- return model_provider
-
-
-def biencoder_model_provider(only_query_model=False,
- only_context_model=False,
- biencoder_shared_query_context_model=False,
- pre_process=True,
- post_process=True):
- """Build the model."""
-
- assert mpu.get_tensor_model_parallel_world_size() == 1 and \
- mpu.get_pipeline_model_parallel_world_size() == 1, \
- "Model parallel size > 1 not supported for ICT"
-
- print_rank_0('building BiEncoderModel...')
-
- # simpler to just keep using 2 tokentypes since
- # the LM we initialize with has 2 tokentypes
- model = BiEncoderModel(
- num_tokentypes=2,
- parallel_output=False,
- only_query_model=only_query_model,
- only_context_model=only_context_model,
- biencoder_shared_query_context_model=\
- biencoder_shared_query_context_model,
- pre_process=pre_process,
- post_process=post_process)
-
- return model
-
-
-class BiEncoderModel(MegatronModule):
- """Bert-based module for Biencoder model."""
-
- def __init__(self,
- num_tokentypes=1,
- parallel_output=True,
- only_query_model=False,
- only_context_model=False,
- biencoder_shared_query_context_model=False,
- pre_process=True,
- post_process=True):
- super(BiEncoderModel, self).__init__()
- args = get_args()
-
- bert_kwargs = dict(
- num_tokentypes=num_tokentypes,
- parallel_output=parallel_output,
- pre_process=pre_process,
- post_process=post_process)
-
- self.biencoder_shared_query_context_model = \
- biencoder_shared_query_context_model
- assert not (only_context_model and only_query_model)
- self.use_context_model = not only_query_model
- self.use_query_model = not only_context_model
- self.biencoder_projection_dim = args.biencoder_projection_dim
-
- if self.biencoder_shared_query_context_model:
- self.model = PretrainedBertModel(**bert_kwargs)
- self._model_key = 'shared_model'
- self.query_model, self.context_model = self.model, self.model
- else:
- if self.use_query_model:
- # this model embeds (pseudo-)queries - Embed_input in the paper
- self.query_model = PretrainedBertModel(**bert_kwargs)
- self._query_key = 'query_model'
-
- if self.use_context_model:
- # this model embeds evidence blocks - Embed_doc in the paper
- self.context_model = PretrainedBertModel(**bert_kwargs)
- self._context_key = 'context_model'
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- # this is just a placeholder and will be needed when model
- # parallelism will be used
- # self.language_model.set_input_tensor(input_tensor)
- return
-
- def forward(self, query_tokens, query_attention_mask, query_types,
- context_tokens, context_attention_mask, context_types):
- """Run a forward pass for each of the models and
- return the respective embeddings."""
-
- if self.use_query_model:
- query_logits = self.embed_text(self.query_model,
- query_tokens,
- query_attention_mask,
- query_types)
- else:
- raise ValueError("Cannot embed query without the query model.")
- if self.use_context_model:
- context_logits = self.embed_text(self.context_model,
- context_tokens,
- context_attention_mask,
- context_types)
- else:
- raise ValueError("Cannot embed block without the block model.")
- return query_logits, context_logits
-
- @staticmethod
- def embed_text(model, tokens, attention_mask, token_types):
- """Embed a batch of tokens using the model"""
- logits = model(tokens,
- attention_mask,
- token_types)
- return logits
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """Save dict with state dicts of each of the models."""
- state_dict_ = {}
- if self.biencoder_shared_query_context_model:
- state_dict_[self._model_key] = \
- self.model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
- else:
- if self.use_query_model:
- state_dict_[self._query_key] = \
- self.query_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
-
- if self.use_context_model:
- state_dict_[self._context_key] = \
- self.context_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
-
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Load the state dicts of each of the models"""
- if self.biencoder_shared_query_context_model:
- print_rank_0("Loading shared query-context model")
- self.model.load_state_dict(state_dict[self._model_key], \
- strict=strict)
- else:
- if self.use_query_model:
- print_rank_0("Loading query model")
- self.query_model.load_state_dict( \
- state_dict[self._query_key], strict=strict)
-
- if self.use_context_model:
- print_rank_0("Loading context model")
- self.context_model.load_state_dict( \
- state_dict[self._context_key], strict=strict)
-
- def init_state_dict_from_bert(self):
- """Initialize the state from a pretrained BERT model
- on iteration zero of ICT pretraining"""
- args = get_args()
-
- if args.bert_load is None:
- print_rank_0("bert-load argument is None")
- return
-
- tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
- if not os.path.isfile(tracker_filename):
- raise FileNotFoundError("Could not find BERT checkpoint")
- with open(tracker_filename, 'r') as f:
- iteration = int(f.read().strip())
- assert iteration > 0
-
- checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading BERT checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
-
- # Load the checkpoint.
- try:
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- except ModuleNotFoundError:
- from megatron.fp16_deprecated import loss_scaler
- # For backward compatibility.
- print_rank_0(' > deserializing using the old code structure ...')
- sys.modules['fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- sys.modules.pop('fp16.loss_scaler', None)
- sys.modules.pop('megatron.fp16.loss_scaler', None)
- except BaseException:
- print_rank_0('could not load the BERT checkpoint')
- sys.exit()
-
- checkpoint_version = state_dict.get('checkpoint_version', 0)
-
- # load the LM state dict into each model
- model_dict = state_dict['model']['language_model']
-
- if self.biencoder_shared_query_context_model:
- self.model.language_model.load_state_dict(model_dict)
- fix_query_key_value_ordering(self.model, checkpoint_version)
- else:
- if self.use_query_model:
- self.query_model.language_model.load_state_dict(model_dict)
- # give each model the same ict_head to begin with as well
- if self.biencoder_projection_dim > 0:
- query_proj_state_dict = \
- self.state_dict_for_save_checkpoint()\
- [self._query_key]['projection_enc']
- fix_query_key_value_ordering(self.query_model, checkpoint_version)
-
- if self.use_context_model:
- self.context_model.language_model.load_state_dict(model_dict)
- if self.query_model is not None and \
- self.biencoder_projection_dim > 0:
- self.context_model.projection_enc.load_state_dict\
- (query_proj_state_dict)
- fix_query_key_value_ordering(self.context_model, checkpoint_version)
-
-
-class PretrainedBertModel(MegatronModule):
- """BERT-based encoder for queries or contexts used for
- learned information retrieval."""
-
- def __init__(self, num_tokentypes=2,
- parallel_output=True, pre_process=True, post_process=True):
- super(PretrainedBertModel, self).__init__()
-
- args = get_args()
- tokenizer = get_tokenizer()
- self.pad_id = tokenizer.pad
- self.biencoder_projection_dim = args.biencoder_projection_dim
- self.parallel_output = parallel_output
- self.pre_process = pre_process
- self.post_process = post_process
- init_method = init_method_normal(args.init_method_std)
- scaled_init_method = scaled_init_method_normal(
- args.init_method_std, args.num_layers)
-
- self.language_model, self._language_model_key = get_language_model(
- num_tokentypes=num_tokentypes,
- add_pooler=False,
- encoder_attn_mask_type=AttnMaskType.padding,
- init_method=init_method,
- scaled_init_method=scaled_init_method,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- if args.biencoder_projection_dim > 0:
- self.projection_enc = get_linear_layer(args.hidden_size,
- args.biencoder_projection_dim,
- init_method)
- self._projection_enc_key = 'projection_enc'
-
- def forward(self, input_ids, attention_mask, tokentype_ids=None):
- extended_attention_mask = attention_mask.unsqueeze(1)
- #extended_attention_mask = bert_extended_attention_mask(attention_mask)
- position_ids = bert_position_ids(input_ids)
-
- lm_output = self.language_model(input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids)
- # This mask will be used in average-pooling and max-pooling
- pool_mask = (input_ids == self.pad_id).unsqueeze(2)
-
- # Taking the representation of the [CLS] token of BERT
- pooled_output = lm_output[0, :, :]
-
- # Converting to float16 dtype
- pooled_output = pooled_output.to(lm_output.dtype)
-
- # Output.
- if self.biencoder_projection_dim:
- pooled_output = self.projection_enc(pooled_output)
-
- return pooled_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
-
- if self.biencoder_projection_dim > 0:
- state_dict_[self._projection_enc_key] = \
- self.projection_enc.state_dict(prefix=prefix,
- keep_vars=keep_vars)
-
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- print_rank_0("loading pretrained weights")
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
-
- if self.biencoder_projection_dim > 0:
- print_rank_0("loading projection head weights")
- self.projection_enc.load_state_dict(
- state_dict[self._projection_enc_key], strict=strict)
diff --git a/megatron/model/classification.py b/megatron/model/classification.py
deleted file mode 100644
index bac50c54cdf9981057a4844a3a8a07d365590b2a..0000000000000000000000000000000000000000
--- a/megatron/model/classification.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Classification model."""
-
-import torch
-
-from megatron import get_args, print_rank_last
-from megatron.model.enums import AttnMaskType
-from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
-from megatron.model.language_model import get_language_model
-from megatron.model.utils import get_linear_layer
-from megatron.model.utils import init_method_normal
-from megatron.model.utils import scaled_init_method_normal
-from .module import MegatronModule
-
-
-class Classification(MegatronModule):
-
- def __init__(self,
- config,
- num_classes,
- num_tokentypes=2,
- pre_process=True,
- post_process=True):
- super().__init__(config=config, share_embeddings_and_output_weights=False)
- args = get_args()
-
- self.num_classes = num_classes
- self.pre_process = pre_process
- self.post_process = post_process
-
- self.language_model, self._language_model_key = get_language_model(
- config=config,
- num_tokentypes=num_tokentypes,
- add_pooler=True,
- encoder_attn_mask_type=AttnMaskType.padding,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- # Multi-choice head.
- if self.post_process:
- self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
- self.classification_head = get_linear_layer(args.hidden_size,
- self.num_classes,
- init_method)
- self._classification_head_key = 'classification_head'
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.language_model.set_input_tensor(input_tensor)
-
- def forward(self, model_input, attention_mask, tokentype_ids=None):
-
- extended_attention_mask = bert_extended_attention_mask(attention_mask)
- input_ids = model_input
- position_ids = bert_position_ids(input_ids)
-
- lm_output = self.language_model(
- input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids
- )
-
- if self.post_process:
- _, pooled_output = lm_output
- classification_output = self.classification_dropout(pooled_output)
- classification_logits = self.classification_head(classification_output)
-
- # Reshape back to separate choices.
- classification_logits = classification_logits.view(-1, self.num_classes)
-
- return classification_logits
- return lm_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process:
- state_dict_[self._classification_head_key] \
- = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- if self.post_process:
- if self._classification_head_key in state_dict:
- self.classification_head.load_state_dict(
- state_dict[self._classification_head_key], strict=strict)
- else:
- print_rank_last('***WARNING*** could not find {} in the checkpoint, '
- 'initializing to random'.format(
- self._classification_head_key))
diff --git a/megatron/model/enums.py b/megatron/model/enums.py
deleted file mode 100644
index bc4e4aa29a05856bcef01d9e0fb6bfda216c247b..0000000000000000000000000000000000000000
--- a/megatron/model/enums.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import enum
-
-class LayerType(enum.Enum):
- encoder = 1
- decoder = 2
- retro_encoder = 3
- retro_decoder = 4
- retro_decoder_with_retriever = 5
-
-class AttnType(enum.Enum):
- self_attn = 1
- cross_attn = 2
-
-class AttnMaskType(enum.Enum):
- padding = 1
- causal = 2
-
-# For backward compatibility with old model checkpoints
-from megatron.core.enums import ModelType
diff --git a/megatron/model/fused_bias_gelu.py b/megatron/model/fused_bias_gelu.py
deleted file mode 100644
index 29222db024eb5c5e54c7f38f58be8edd45c49b39..0000000000000000000000000000000000000000
--- a/megatron/model/fused_bias_gelu.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-
-
-###### BIAS GELU FUSION/ NO AUTOGRAD ################
-# 1/sqrt(2*pi)-> 0.3989423
-# 1/sqrt(2) -> 0.70710678
-# sqrt(2/pi) -> 0.79788456
-# this function is tanh approximation of gelu
-# actual gelu is:
-# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
-
-@torch.jit.script
-def bias_gelu(bias, y):
- x = bias + y
- return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
-
-# gradient of tanh approximation of gelu
-# gradient of actual gelu is:
-# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
-@torch.jit.script
-def bias_gelu_back(g, bias, y):
- x = bias + y
- tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
- # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
- ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
- return ff*g
-
-class GeLUFunction(torch.autograd.Function):
- @staticmethod
- # bias is an optional argument
- def forward(ctx, input, bias):
- ctx.save_for_backward(input, bias)
- return bias_gelu(bias, input)
-
- @staticmethod
- def backward(ctx, grad_output):
- input, bias = ctx.saved_tensors
- tmp = bias_gelu_back(grad_output, bias, input)
- return tmp, tmp
-
-bias_gelu_impl = GeLUFunction.apply
diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py
deleted file mode 100644
index c91a674e8cafe548820f6c838e607fb30b1e087c..0000000000000000000000000000000000000000
--- a/megatron/model/fused_layer_norm.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""This code is copied fron NVIDIA apex:
- https://github.com/NVIDIA/apex
- with some changes. """
-
-import numbers
-import torch
-from torch.nn.parameter import Parameter
-from torch.nn import init
-import importlib
-
-from megatron.core.utils import make_viewless_tensor
-
-try:
- from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
- HAVE_PERSIST_LAYER_NORM = True
-except:
- HAVE_PERSIST_LAYER_NORM = False
-
-try:
- from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
-except:
- FusedLayerNormAffineFunction = None
-
-global fused_layer_norm_cuda
-fused_layer_norm_cuda = None
-
-
-class MixedFusedLayerNorm(torch.nn.Module):
-
- def __init__(self, normalized_shape, eps=1e-5,
- no_persist_layer_norm=True,
- sequence_parallel=False,
- apply_layernorm_1p=False):
- super(MixedFusedLayerNorm, self).__init__()
-
- self.apply_layernorm_1p = apply_layernorm_1p
-
- global fused_layer_norm_cuda
- fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
-
- # List of hiddens sizes supported in the persistent layer norm kernel
- # If the hidden size is not supported, fall back to the non-persistent
- # kernel.
- persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
- 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
- 24576, 25600, 30720, 32768, 40960, 49152, 65536]
- if normalized_shape not in persist_ln_hidden_sizes or \
- not HAVE_PERSIST_LAYER_NORM:
- no_persist_layer_norm = True
-
- if isinstance(normalized_shape, numbers.Integral):
- normalized_shape = (normalized_shape,)
- self.normalized_shape = torch.Size(normalized_shape)
- self.eps = eps
- self.weight = Parameter(torch.Tensor(*normalized_shape))
- self.bias = Parameter(torch.Tensor(*normalized_shape))
- self.reset_parameters()
- self.no_persist_layer_norm = no_persist_layer_norm
- self.sequence_parallel = sequence_parallel
-
- # set sequence parallelism flag on weight and bias parameters
- setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
- setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
-
-
- def reset_parameters(self):
-
- if self.apply_layernorm_1p:
- init.zeros_(self.weight)
- init.zeros_(self.bias)
- else:
- init.ones_(self.weight)
- init.zeros_(self.bias)
-
- def forward(self, input):
-
- weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
-
- if self.no_persist_layer_norm:
- assert FusedLayerNormAffineFunction is not None, \
- "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
- return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
- else:
- output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
-
- # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
- # a populated '_base' field). This will result in schedule.py's
- # deallocate_output_tensor() throwing an error, so a viewless tensor is
- # created to prevent this.
- output = make_viewless_tensor(inp = output,
- requires_grad = input.requires_grad,
- keep_graph = True)
-
- return output
diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py
deleted file mode 100644
index 9bacf337402c4b08f565e8ddfdfb8e4579f3e257..0000000000000000000000000000000000000000
--- a/megatron/model/fused_softmax.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-
-import torch
-import torch.nn as nn
-from megatron.model.enums import AttnMaskType
-
-
-class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following three operations in sequence
- 1. Scale the tensor.
- 2. Apply upper triangular mask (typically used in gpt models).
- 3. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, scale):
- import scaled_upper_triang_masked_softmax_cuda
-
- scale_t = torch.tensor([scale])
- softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
- inputs, scale_t[0]
- )
-
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_upper_triang_masked_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
- input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
- output_grads, softmax_results, scale_t[0]
- )
-
- return input_grads, None
-
-
-class ScaledMaskedSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following three operations in sequence
- 1. Scale the tensor.
- 2. Apply the mask.
- 3. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, mask, scale):
- import scaled_masked_softmax_cuda
-
- scale_t = torch.tensor([scale])
-
- softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_masked_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
-
- input_grads = scaled_masked_softmax_cuda.backward(
- output_grads, softmax_results, scale_t[0]
- )
- return input_grads, None, None
-
-
-class ScaledSoftmax(torch.autograd.Function):
- """
- Fused operation which performs following two operations in sequence
- 1. Scale the tensor.
- 2. Perform softmax.
- """
-
- @staticmethod
- def forward(ctx, inputs, scale):
- import scaled_softmax_cuda
-
- scale_t = torch.tensor([scale])
-
- softmax_results = scaled_softmax_cuda.forward(
- inputs, scale_t[0]
- )
- ctx.save_for_backward(softmax_results, scale_t)
- return softmax_results
-
- @staticmethod
- def backward(ctx, output_grads):
- import scaled_softmax_cuda
-
- softmax_results, scale_t = ctx.saved_tensors
-
- input_grads = scaled_softmax_cuda.backward(
- output_grads, softmax_results, scale_t[0]
- )
- return input_grads, None, None
-
-
-class FusedScaleMaskSoftmax(nn.Module):
- """
- fused operation: scaling + mask + softmax
-
- Arguments:
- input_in_fp16: flag to indicate if input in fp16 data format.
- input_in_bf16: flag to indicate if input in bf16 data format.
- attn_mask_type: attention mask type (pad or causal)
- scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
- mask_func: mask function to be applied.
- softmax_in_fp32: if true, softmax in performed at fp32 precision.
- scale: scaling factor used in input tensor scaling.
- """
-
- def __init__(
- self,
- input_in_fp16,
- input_in_bf16,
- attn_mask_type,
- scaled_masked_softmax_fusion,
- mask_func,
- softmax_in_fp32,
- scale,
- ):
- super(FusedScaleMaskSoftmax, self).__init__()
- self.input_in_fp16 = input_in_fp16
- self.input_in_bf16 = input_in_bf16
- assert not (
- self.input_in_fp16 and self.input_in_bf16
- ), "both fp16 and bf16 flags cannot be active at the same time."
- self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
- self.attn_mask_type = attn_mask_type
- self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
- self.mask_func = mask_func
- self.softmax_in_fp32 = softmax_in_fp32
- self.scale = scale
-
- assert (
- self.scale is None or softmax_in_fp32
- ), "softmax should be in fp32 when scaled"
-
- def forward(self, input, mask):
- # [b, np, sq, sk]
- assert input.dim() == 4
-
- if self.is_kernel_available(mask, *input.size()):
- return self.forward_fused_softmax(input, mask)
- else:
- return self.forward_torch_softmax(input, mask)
-
- def is_kernel_available(self, mask, b, np, sq, sk):
- attn_batches = b * np
-
- if (
- self.scaled_masked_softmax_fusion # user want to fuse
- and self.input_in_float16 # input must be fp16
- and 16 < sk <= 16384 # sk must be 16 ~ 16384
- and sq % 4 == 0 # sq must be divisor of 4
- and sk % 4 == 0 # sk must be divisor of 4
- and attn_batches % 4 == 0 # np * b must be divisor of 4
- ):
- if 0 <= sk <= 16384:
- batch_per_block = self.get_batch_per_block(sq, sk, b, np)
-
- if self.attn_mask_type == AttnMaskType.causal:
- if attn_batches % batch_per_block == 0:
- return True
- else:
- if sq % batch_per_block == 0:
- return True
- return False
-
- def forward_fused_softmax(self, input, mask):
- b, np, sq, sk = input.size()
- scale = self.scale if self.scale is not None else 1.0
-
- if self.attn_mask_type == AttnMaskType.causal:
- assert sq == sk, "causal mask is only for self attention"
-
- # input is 3D tensor (attn_batches, sq, sk)
- input = input.view(-1, sq, sk)
- probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
- return probs.view(b, np, sq, sk)
- else:
- # input is 4D tensor (b, np, sq, sk)
- if mask is not None:
- return ScaledMaskedSoftmax.apply(input, mask, scale)
- else:
- return ScaledSoftmax.apply(input, scale)
-
- def forward_torch_softmax(self, input, mask):
- if self.input_in_float16 and self.softmax_in_fp32:
- input = input.float()
-
- if self.scale is not None:
- input = input * self.scale
- mask_output = self.mask_func(input, mask) if mask is not None else input
- probs = torch.nn.Softmax(dim=-1)(mask_output)
-
- if self.input_in_float16 and self.softmax_in_fp32:
- if self.input_in_fp16:
- probs = probs.half()
- else:
- probs = probs.bfloat16()
-
- return probs
-
- @staticmethod
- def get_batch_per_block(sq, sk, b, np):
- import scaled_masked_softmax_cuda
-
- return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py
deleted file mode 100644
index dd47188da4a14515bf595766014baa933d6bfef4..0000000000000000000000000000000000000000
--- a/megatron/model/gpt_model.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""GPT-2 model."""
-
-import torch
-
-from megatron import get_args
-from megatron.core import tensor_parallel
-from .module import MegatronModule
-
-from .enums import AttnMaskType
-from .language_model import parallel_lm_logits
-from .language_model import get_language_model
-
-
-def post_language_model_processing(lm_output, labels, logit_weights,
- parallel_output,
- fp16_lm_cross_entropy):
-
- # Output. Format [s b h]
- output = parallel_lm_logits(
- lm_output,
- logit_weights,
- parallel_output)
-
- if labels is None:
- # [s b h] => [b s h]
- return output.transpose(0,1).contiguous()
- else:
- # [b s] => [s b]
- labels = labels.transpose(0,1).contiguous()
- if fp16_lm_cross_entropy:
- assert output.dtype == torch.half
- loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
- else:
- loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
-
- # [s b] => [b, s]
- loss = loss.transpose(0,1).contiguous()
- return loss
-
-
-class GPTModel(MegatronModule):
- """GPT-2 Language model."""
-
- def __init__(self,
- config,
- num_tokentypes=0,
- parallel_output=True,
- pre_process=True,
- post_process=True):
- args = get_args()
- super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
-
- self.parallel_output = parallel_output
- self.pre_process = pre_process
- self.post_process = post_process
- self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
- self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
-
- self.language_model, self._language_model_key = get_language_model(
- config=config,
- num_tokentypes=num_tokentypes,
- add_pooler=False,
- encoder_attn_mask_type=AttnMaskType.causal,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- if not args.untie_embeddings_and_output_weights:
- self.initialize_word_embeddings()
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.language_model.set_input_tensor(input_tensor)
-
- def forward(self, input_ids, position_ids, attention_mask,
- retriever_input_ids=None,
- retriever_position_ids=None,
- retriever_attn_mask=None,
- labels=None, tokentype_ids=None, inference_params=None):
-
- lm_output = self.language_model(
- input_ids,
- position_ids,
- attention_mask,
- retriever_input_ids=retriever_input_ids,
- retriever_position_ids=retriever_position_ids,
- retriever_attn_mask=retriever_attn_mask,
- inference_params=inference_params)
-
- if self.post_process:
- return post_language_model_processing(
- lm_output, labels,
- self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(),
- self.parallel_output,
- self.fp16_lm_cross_entropy)
- else:
- return lm_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
- # Save word_embeddings.
- if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
- state_dict_[self._word_embeddings_for_head_key] \
- = self.word_embeddings.state_dict(prefix=prefix,
- keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- # Load word_embeddings.
- if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
- self.word_embeddings.load_state_dict(
- state_dict[self._word_embeddings_for_head_key], strict=strict)
- if self._language_model_key in state_dict:
- state_dict = state_dict[self._language_model_key]
- self.language_model.load_state_dict(state_dict, strict=strict)
diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py
deleted file mode 100644
index 69bfa2e8018ed30411764bf357111ba5f4ed4d95..0000000000000000000000000000000000000000
--- a/megatron/model/language_model.py
+++ /dev/null
@@ -1,626 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Transformer based language model."""
-
-import torch
-import torch.nn.functional as F
-
-from megatron import get_args
-from megatron.core import mpu, tensor_parallel
-from megatron.core.enums import ModelType
-from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
-
-from .enums import AttnMaskType, LayerType
-from .module import MegatronModule
-from .transformer import ParallelTransformer
-from .utils import get_linear_layer
-from .utils import init_method_normal, scaled_init_method_normal
-
-
-def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
- bias=None):
- """LM logits using word embedding weights."""
- args = get_args()
- # Parallel logits.
- if args.async_tensor_model_parallel_allreduce or\
- args.sequence_parallel:
- input_parallel = input_
- model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
- async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
- model_parallel and not args.sequence_parallel
- else:
- input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
- async_grad_allreduce = False
-
- # Matrix multiply.
- logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
- input=input_parallel,
- weight=word_embeddings_weight,
- bias=bias,
- gradient_accumulation_fusion=args.gradient_accumulation_fusion,
- async_grad_allreduce=async_grad_allreduce,
- sequence_parallel=args.sequence_parallel)
- # Gather if needed.
-
- if parallel_output:
- return logits_parallel
-
- return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
-
-
-def get_language_model(config, num_tokentypes, add_pooler,
- encoder_attn_mask_type,
- add_encoder=True,
- add_decoder=False,
- decoder_attn_mask_type=AttnMaskType.causal,
- pre_process=True, post_process=True):
- """Build language model and return along with the key to save."""
- args = get_args()
- if config.init_method is None:
- config.init_method = init_method_normal(config.init_method_std)
-
- if config.output_layer_init_method is None:
- config.output_layer_init_method = scaled_init_method_normal(config.init_method_std,
- config.num_layers)
-
- # Language model.
- language_model = TransformerLanguageModel(
- config,
- encoder_attn_mask_type,
- num_tokentypes=num_tokentypes,
- add_encoder=add_encoder,
- add_decoder=add_decoder,
- decoder_attn_mask_type=decoder_attn_mask_type,
- add_pooler=add_pooler,
- pre_process=pre_process,
- post_process=post_process
- )
- # key used for checkpoints.
- language_model_key = 'language_model'
-
- return language_model, language_model_key
-
-
-class Pooler(MegatronModule):
- """Pooler layer.
-
- Pool hidden states of a specific token (for example start of the
- sequence) and add a linear transformation followed by a tanh.
-
- Arguments:
- hidden_size: hidden size
- init_method: weight initialization method for the linear layer.
- bias is set to zero.
- """
-
- def __init__(self, hidden_size, init_method):
- super(Pooler, self).__init__()
- args = get_args()
- self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
- self.sequence_parallel = args.sequence_parallel
-
-
- def forward(self, hidden_states, sequence_index=0):
- # hidden_states: [s, b, h]
- # sequence_index: index of the token to pool.
-
- # gather data along sequence dimensions
- # same pooler is run on all tensor parallel nodes
- if self.sequence_parallel:
- hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
- hidden_states,
- tensor_parallel_output_grad=False)
-
- pooled = hidden_states[sequence_index, :, :]
- pooled = self.dense(pooled)
- pooled = torch.tanh(pooled)
- return pooled
-
-
-class Embedding(MegatronModule):
- """Language model embeddings.
-
- Arguments:
- hidden_size: hidden size
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- embedding_dropout_prob: dropout probability for embeddings
- init_method: weight initialization method
- num_tokentypes: size of the token-type embeddings. 0 value
- will ignore this embedding
- """
-
- def __init__(self,
- hidden_size,
- vocab_size,
- max_sequence_length,
- embedding_dropout_prob,
- config,
- num_tokentypes=0):
- super(Embedding, self).__init__()
-
- self.hidden_size = hidden_size
- self.init_method = config.init_method
- self.num_tokentypes = num_tokentypes
-
- args = get_args()
-
- # Word embeddings (parallel).
- self.params_dtype = args.params_dtype
- self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
- vocab_size, self.hidden_size, config=config, init_method=config.init_method)
- self._word_embeddings_key = 'word_embeddings'
-
- # Position embedding (serial).
- self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
- if self.add_position_embedding:
- self.position_embeddings = torch.nn.Embedding(
- max_sequence_length, self.hidden_size)
- self._position_embeddings_key = 'position_embeddings'
- # Initialize the position embeddings.
- if args.perform_initialization:
- self.init_method(self.position_embeddings.weight)
-
- # Token type embedding.
- # Add this as an optional field that can be added through
- # method call so we can load a pretrain model without
- # token types and add them as needed.
- self._tokentype_embeddings_key = 'tokentype_embeddings'
- if self.num_tokentypes > 0:
- self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
- self.hidden_size)
- # Initialize the token-type embeddings.
- if args.perform_initialization:
- self.init_method(self.tokentype_embeddings.weight)
- else:
- self.tokentype_embeddings = None
-
- self.fp32_residual_connection = args.fp32_residual_connection
- self.sequence_parallel = args.sequence_parallel
- self.clone_scatter_output_in_embedding = args.clone_scatter_output_in_embedding
- # Embeddings dropout
- self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
-
- def zero_parameters(self):
- """Zero out all parameters in embedding."""
- self.word_embeddings.weight.data.fill_(0)
- self.word_embeddings.weight.shared = True
- if self.add_position_embedding:
- self.position_embeddings.weight.data.fill_(0)
- self.position_embeddings.weight.shared = True
- if self.num_tokentypes > 0:
- self.tokentype_embeddings.weight.data.fill_(0)
- self.tokentype_embeddings.weight.shared = True
-
- def add_tokentype_embeddings(self, num_tokentypes):
- """Add token-type embedding. This function is provided so we can add
- token-type embeddings in case the pretrained model does not have it.
- This allows us to load the model normally and then add this embedding.
- """
- if self.tokentype_embeddings is not None:
- raise Exception('tokentype embeddings is already initialized')
- if torch.distributed.get_rank() == 0:
- print('adding embedding for {} tokentypes'.format(num_tokentypes),
- flush=True)
- self.num_tokentypes = num_tokentypes
- self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
- self.hidden_size)
- # Initialize the token-type embeddings.
- args = get_args()
- self.init_method(self.tokentype_embeddings.weight)
-
- def forward(self, input_ids, position_ids, tokentype_ids=None):
- # Embeddings.
- words_embeddings = self.word_embeddings(input_ids)
- if self.add_position_embedding:
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = words_embeddings + position_embeddings
- else:
- embeddings = words_embeddings
-
- if tokentype_ids is not None:
- assert self.tokentype_embeddings is not None
- embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
- else:
- assert self.tokentype_embeddings is None
-
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
- embeddings = embeddings.transpose(0, 1).contiguous()
-
- # If the input flag for fp32 residual connection is set, convert for float.
- if self.fp32_residual_connection:
- embeddings = embeddings.float()
-
- # Dropout.
- if self.sequence_parallel:
- embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
- # `scatter_to_sequence_parallel_region` returns a view, which prevents
- # the original tensor from being garbage collected. Clone to facilitate GC.
- # Has a small runtime cost (~0.5%).
- if self.clone_scatter_output_in_embedding:
- embeddings = embeddings.clone()
- with tensor_parallel.get_cuda_rng_tracker().fork():
- embeddings = self.embedding_dropout(embeddings)
- else:
- embeddings = self.embedding_dropout(embeddings)
-
- return embeddings
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load."""
-
- state_dict_ = {}
- state_dict_[self._word_embeddings_key] \
- = self.word_embeddings.state_dict(prefix=prefix,
- keep_vars=keep_vars)
- if self.add_position_embedding:
- state_dict_[self._position_embeddings_key] \
- = self.position_embeddings.state_dict(prefix=prefix,
- keep_vars=keep_vars)
- if self.num_tokentypes > 0:
- state_dict_[self._tokentype_embeddings_key] \
- = self.tokentype_embeddings.state_dict(prefix=prefix,
- keep_vars=keep_vars)
-
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- # Word embedding.
- if self._word_embeddings_key in state_dict:
- state_dict_ = state_dict[self._word_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'word_embeddings' in key:
- state_dict_[key.split('word_embeddings.')[1]] \
- = state_dict[key]
- self.word_embeddings.load_state_dict(state_dict_, strict=strict)
-
- # Position embedding.
- if self.add_position_embedding:
- if self._position_embeddings_key in state_dict:
- state_dict_ = state_dict[self._position_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'position_embeddings' in key:
- state_dict_[key.split('position_embeddings.')[1]] \
- = state_dict[key]
- self.position_embeddings.load_state_dict(state_dict_, strict=strict)
-
- # Tokentype embedding.
- if self.num_tokentypes > 0:
- state_dict_ = {}
- if self._tokentype_embeddings_key in state_dict:
- state_dict_ = state_dict[self._tokentype_embeddings_key]
- else:
- # for backward compatibility.
- for key in state_dict.keys():
- if 'tokentype_embeddings' in key:
- state_dict_[key.split('tokentype_embeddings.')[1]] \
- = state_dict[key]
- if len(state_dict_.keys()) > 0:
- self.tokentype_embeddings.load_state_dict(state_dict_,
- strict=strict)
- else:
- print('***WARNING*** expected tokentype embeddings in the '
- 'checkpoint but could not find it', flush=True)
-
-
-class TransformerLanguageModel(MegatronModule):
- """Transformer language model.
-
- Arguments:
- transformer_hparams: transformer hyperparameters
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- embedding_dropout_prob: dropout probability for embeddings
- num_tokentypes: size of the token-type embeddings. 0 value
- will ignore this embedding
- """
-
- def __init__(self,
- config,
- encoder_attn_mask_type,
- num_tokentypes=0,
- add_encoder=True,
- add_decoder=False,
- decoder_attn_mask_type=AttnMaskType.causal,
- add_pooler=False,
- pre_process=True,
- post_process=True):
- args = get_args()
- # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
- if args.untie_embeddings_and_output_weights: assert not add_decoder
- super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
-
- self.pre_process = pre_process
- self.post_process = post_process
- self.hidden_size = config.hidden_size
- self.num_tokentypes = num_tokentypes
- self.init_method = config.init_method
- self.add_encoder = add_encoder
- self.encoder_attn_mask_type = encoder_attn_mask_type
- self.add_decoder = add_decoder
- self.decoder_attn_mask_type = decoder_attn_mask_type
- self.add_pooler = add_pooler
- self.encoder_hidden_state = None
- self.add_retriever = args.retro_add_retriever
- self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
-
- # Embeddings.
- if self.pre_process:
- self.embedding = Embedding(self.hidden_size,
- args.padded_vocab_size,
- args.max_position_embeddings,
- args.hidden_dropout,
- config,
- self.num_tokentypes)
- self._embedding_key = 'embedding'
-
- # Rotary positional embeddings
- self.use_rotary_position_embeddings = \
- args.position_embedding_type == 'rope'
- if self.use_rotary_position_embeddings:
- self.seq_length = args.seq_length
- rotary_dim = args.hidden_size // args.num_attention_heads \
- if args.kv_channels is None else args.kv_channels
-
- # partial rotary embeddings, which is better than full rotary
- # Wang and Komatsuzaki et al
- # https://github.com/kingoflolz/mesh-transformer-jax/
- self.rotary_pos_emb = RotaryEmbedding(
- rotary_dim,
- args.rotary_percent,
- seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
- )
-
- # Encoder (usually set to True, False if part of an encoder-decoder
- # architecture and in encoder-only stage).
- if self.add_encoder:
- self.encoder = ParallelTransformer(
- config,
- model_type=args.model_type if not args.retro_add_retriever \
- else ModelType.retro_decoder,
- self_attn_mask_type=self.encoder_attn_mask_type,
- pre_process=self.pre_process,
- post_process=self.post_process,
- )
- self._encoder_key = 'encoder'
- else:
- self.encoder = None
-
- # Decoder (usually set to False, True if part of an encoder-decoder
- # architecture and in decoder-only stage).
- if self.add_decoder:
- self.decoder = ParallelTransformer(
- config,
- model_type=args.model_type,
- layer_type=LayerType.decoder,
- self_attn_mask_type=self.decoder_attn_mask_type,
- pre_process=self.pre_process,
- post_process=self.post_process)
- self._decoder_key = 'decoder'
- else:
- self.decoder = None
-
- if self.post_process:
- # Pooler.
- if self.add_pooler:
- self.pooler = Pooler(self.hidden_size, self.init_method)
- self._pooler_key = 'pooler'
-
- if self.untie_embeddings_and_output_weights:
- self.output_layer = tensor_parallel.ColumnParallelLinear(
- args.hidden_size,
- args.padded_vocab_size,
- config=config,
- init_method=self.init_method,
- bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
- self._output_layer_key = 'output_layer'
-
- def set_input_tensor(self, input_tensor):
- """ See megatron.model.transformer.set_input_tensor()"""
-
- # This is usually handled in schedules.py but some inference code still
- # gives us non-lists or None
- if not isinstance(input_tensor, list):
- input_tensor = [input_tensor]
-
- if self.add_encoder and self.add_decoder:
- assert len(input_tensor) == 1, \
- 'input_tensor should only be length 1 for stage with both encoder and decoder'
- self.encoder.set_input_tensor(input_tensor[0])
- elif self.add_encoder:
- assert len(input_tensor) == 1, \
- 'input_tensor should only be length 1 for stage with only encoder'
- self.encoder.set_input_tensor(input_tensor[0])
- elif self.add_decoder:
- if len(input_tensor) == 2:
- self.decoder.set_input_tensor(input_tensor[0])
- self.encoder_hidden_state = input_tensor[1]
- elif len(input_tensor) == 1:
- self.decoder.set_input_tensor(None)
- self.encoder_hidden_state = input_tensor[0]
- else:
- raise Exception('input_tensor must have either length 1 or 2')
- else:
- raise Exception('Stage must have at least either encoder or decoder')
-
- def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
- dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
- retriever_input_ids=None,
- retriever_position_ids=None,
- retriever_attn_mask=None,
- enc_dec_attn_mask=None, tokentype_ids=None,
- inference_params=None,
- pooling_sequence_index=0,
- enc_hidden_states=None, output_enc_hidden=False):
-
- # Encoder embedding.
- if self.pre_process:
- encoder_input = self.embedding(enc_input_ids, enc_position_ids,
- tokentype_ids=tokentype_ids)
- else:
- encoder_input = None
-
- # Retriever embedding.
- if self.add_retriever and self.pre_process:
- retriever_input = self.embedding(retriever_input_ids,
- retriever_position_ids,
- tokentype_ids=tokentype_ids)
- else:
- retriever_input = None
-
- # Rotary positional embeddings
- rotary_pos_emb = None
- if self.use_rotary_position_embeddings:
- if inference_params is not None:
- rotary_pos_emb = \
- self.rotary_pos_emb(inference_params.max_sequence_length)
- else:
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
-
- # Run encoder.
- if enc_hidden_states is None:
- if self.encoder is not None:
- encoder_output = self.encoder(
- encoder_input,
- enc_attn_mask,
- retriever_input=retriever_input,
- retriever_attn_mask=retriever_attn_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb)
- else:
- encoder_output = self.encoder_hidden_state
- else:
- encoder_output = enc_hidden_states.to(encoder_input.dtype)
-
- if self.post_process:
- if self.add_pooler:
- pooled_output = self.pooler(encoder_output,
- pooling_sequence_index)
-
- # output_enc_hidden refers to when we just need the encoder's
- # output. For example, it is helpful to compute
- # similarity between two sequences by average pooling
- if not self.add_decoder or output_enc_hidden:
- if self.add_pooler and self.post_process:
- return encoder_output, pooled_output
- else:
- return encoder_output
-
- # Decoder embedding.
- if self.pre_process:
- decoder_input = self.embedding(dec_input_ids,
- dec_position_ids)
- else:
- decoder_input = None
-
- # Run decoder.
- decoder_output = self.decoder(
- decoder_input,
- dec_attn_mask,
- encoder_output=encoder_output,
- enc_dec_attn_mask=enc_dec_attn_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb)
-
- if self.add_pooler and self.post_process:
- return decoder_output, encoder_output, pooled_output
- else:
- return decoder_output, encoder_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load."""
-
- state_dict_ = {}
- if self.pre_process:
- state_dict_[self._embedding_key] \
- = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.add_encoder:
- state_dict_[self._encoder_key] \
- = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process:
- if self.add_pooler:
- state_dict_[self._pooler_key] \
- = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.untie_embeddings_and_output_weights:
- state_dict_[self._output_layer_key] \
- = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)
-
- if self.add_decoder:
- state_dict_[self._decoder_key] \
- = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
-
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- # Embedding.
- if self.pre_process:
- if self._embedding_key in state_dict:
- state_dict_ = state_dict[self._embedding_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if '_embeddings' in key:
- state_dict_[key] = state_dict[key]
- self.embedding.load_state_dict(state_dict_, strict=strict)
-
- # Encoder.
- if self.add_encoder:
- if self._encoder_key in state_dict:
- state_dict_ = state_dict[self._encoder_key]
- # For backward compatibility.
- elif 'transformer' in state_dict:
- state_dict_ = state_dict['transformer']
- else:
- # For backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'transformer.' in key:
- state_dict_[key.split('transformer.')[1]] = state_dict[key]
-
- # For backward compatibility.
- state_dict_self_attention = {}
- for key in state_dict_.keys():
- if '.attention.' in key:
- state_dict_self_attention[key.replace(".attention.",
- ".self_attention.")] = state_dict_[key]
- else:
- state_dict_self_attention[key] = state_dict_[key]
- state_dict_ = state_dict_self_attention
-
- self.encoder.load_state_dict(state_dict_, strict=strict)
-
- # Pooler.
- if self.post_process:
- if self.add_pooler:
- assert 'pooler' in state_dict, \
- 'could not find data for pooler in the checkpoint'
- self.pooler.load_state_dict(state_dict[self._pooler_key],
- strict=strict)
- if self.untie_embeddings_and_output_weights:
- assert 'output_layer' in state_dict, \
- 'could not find data for output_layer in the checkpoint'
- self.output_layer.load_state_dict(state_dict[self._output_layer_key],
- strict=strict)
- # Decoder.
- if self.add_decoder:
- assert 'decoder' in state_dict, \
- 'could not find data for pooler in the checkpoint'
- self.decoder.load_state_dict(state_dict[self._decoder_key],
- strict=strict)
diff --git a/megatron/model/module.py b/megatron/model/module.py
deleted file mode 100644
index c2887315a56e5ea575a6370de842d76f21f2937b..0000000000000000000000000000000000000000
--- a/megatron/model/module.py
+++ /dev/null
@@ -1,197 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron Module"""
-
-import torch
-from torch.autograd import Variable
-from torch.nn.parameter import Parameter
-
-from megatron import get_args
-from megatron.core import mpu, tensor_parallel
-
-
-_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
-_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
-_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
-
-
-
-def param_is_not_shared(param):
- return not hasattr(param, 'shared') or not param.shared
-
-
-
-class MegatronModule(torch.nn.Module):
- """Megatron specific extensions of torch Module with support
- for pipelining."""
-
- def __init__(self, config=None, share_embeddings_and_output_weights=True):
- super(MegatronModule, self).__init__()
- self.config = config
- self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
-
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """Use this function to override the state dict for
- saving checkpoints."""
- return self.state_dict(prefix=prefix, keep_vars=keep_vars)
-
-
- def shared_embedding_or_output_weight(self):
- if self.pre_process:
- return self.language_model.embedding.word_embeddings.weight
- else:
- if not self.share_embeddings_and_output_weights:
- raise Exception('shared_embedding_or_output_weight() called for last '
- 'stage, but share_embeddings_and_output_weights is false')
- return self.word_embeddings.weight
-
-
- def initialize_word_embeddings(self):
- args = get_args()
- if not self.share_embeddings_and_output_weights:
- raise Exception('initialize_word_embeddings() was called but '
- 'share_embeddings_and_output_weights is false')
-
- # This function just initializes the word embeddings in the final stage
- # when we are using pipeline parallelism. Nothing to do if we aren't
- # using pipeline parallelism.
- if args.pipeline_model_parallel_size == 1:
- return
-
- # Parameters are shared between the word embeddings layers, and the
- # heads at the end of the model. In a pipelined setup with more than
- # one stage, the initial embedding layer and the head are on different
- # workers, so we do the following:
- # 1. Create a second copy of word_embeddings on the last stage, with
- # initial parameters of 0.0.
- # 2. Do an all-reduce between the first and last stage to ensure that
- # the two copies of word_embeddings start off with the same
- # parameter values.
- # 3. In the training loop, before an all-reduce between the grads of
- # the two word_embeddings layers to ensure that every applied weight
- # update is the same on both stages.
- if mpu.is_pipeline_last_stage() and not self.pre_process:
- assert not mpu.is_pipeline_first_stage()
- self._word_embeddings_for_head_key = 'word_embeddings_for_head'
- # set word_embeddings weights to 0 here, then copy first
- # stage's weights using all_reduce below.
- self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
- args.padded_vocab_size, self.config.hidden_size,
- config=self.config, init_method=self.config.init_method)
- self.word_embeddings.weight.data.fill_(0)
- self.word_embeddings.weight.shared = True
-
- # Zero out initial weights for decoder embedding.
- # NOTE: We don't currently support T5 with the interleaved schedule.
- if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
- self.pre_process:
- self.language_model.embedding.zero_parameters()
-
- if not torch.distributed.is_initialized():
- if not getattr(MegatronModule, "embedding_warning_printed", False):
- print("WARNING! Distributed processes aren't initialized, so "
- "word embeddings in the last layer are not initialized. "
- "If you are just manipulating a model this is fine, but "
- "this needs to be handled manually. If you are training "
- "something is definitely wrong.")
- MegatronModule.embedding_warning_printed = True
- return
-
- # Ensure that first and last stages have the same initial parameter
- # values.
- if mpu.is_rank_in_embedding_group():
- torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data,
- group=mpu.get_embedding_group())
-
- # Ensure that encoder(first stage) and decoder(split stage) position
- # embeddings have the same initial parameter values
- # NOTE: We don't currently support T5 with the interleaved schedule.
- if mpu.is_rank_in_position_embedding_group() and \
- args.pipeline_model_parallel_split_rank is not None:
- # TODO: Support tokentype embedding.
- self.language_model.embedding.cuda()
- position_embeddings = self.language_model.embedding.position_embeddings
- torch.distributed.all_reduce(position_embeddings.weight.data,
- group=mpu.get_position_embedding_group())
-
-
-def conversion_helper(val, conversion):
- """Apply conversion to val. Recursively apply conversion if `val`
- #is a nested tuple/list structure."""
- if not isinstance(val, (tuple, list)):
- return conversion(val)
- rtn = [conversion_helper(v, conversion) for v in val]
- if isinstance(val, tuple):
- rtn = tuple(rtn)
- return rtn
-
-
-def fp32_to_float16(val, float16_convertor):
- """Convert fp32 `val` to fp16/bf16"""
- def half_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, _FLOAT_TYPES):
- val = float16_convertor(val)
- return val
- return conversion_helper(val, half_conversion)
-
-
-def float16_to_fp32(val):
- """Convert fp16/bf16 `val` to fp32"""
- def float_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
- val = val.float()
- return val
- return conversion_helper(val, float_conversion)
-
-
-
-class Float16Module(MegatronModule):
-
- def __init__(self, module, args):
- super(Float16Module, self).__init__()
-
- if args.fp16:
- self.add_module('module', module.half())
- def float16_convertor(val):
- return val.half()
- elif args.bf16:
- self.add_module('module', module.bfloat16())
- def float16_convertor(val):
- return val.bfloat16()
- else:
- raise Exception('should not be here')
-
- self.float16_convertor = float16_convertor
-
-
- def set_input_tensor(self, input_tensor):
- return self.module.set_input_tensor(input_tensor)
-
-
- def forward(self, *inputs, **kwargs):
- if mpu.is_pipeline_first_stage():
- inputs = fp32_to_float16(inputs, self.float16_convertor)
- outputs = self.module(*inputs, **kwargs)
- if mpu.is_pipeline_last_stage():
- outputs = float16_to_fp32(outputs)
- return outputs
-
-
- def state_dict(self, prefix='', keep_vars=False):
- return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
-
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- return self.module.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
-
-
- def load_state_dict(self, state_dict, strict=True):
- self.module.load_state_dict(state_dict, strict=strict)
diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py
deleted file mode 100644
index 41f8bb49f6a075a484d1a03654f44d326898056c..0000000000000000000000000000000000000000
--- a/megatron/model/multiple_choice.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Multiple choice model."""
-
-import torch
-
-from megatron import get_args, print_rank_last
-from megatron.model.enums import AttnMaskType
-from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
-from megatron.model.language_model import get_language_model
-from megatron.model.utils import get_linear_layer
-from megatron.model.utils import init_method_normal
-from megatron.model.utils import scaled_init_method_normal
-from .module import MegatronModule
-
-
-class MultipleChoice(MegatronModule):
-
- def __init__(self,
- config,
- num_tokentypes=2,
- pre_process=True,
- post_process=True):
- super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False)
- args = get_args()
-
- self.pre_process = pre_process
- self.post_process = post_process
-
- self.language_model, self._language_model_key = get_language_model(
- config=config,
- num_tokentypes=num_tokentypes,
- add_pooler=True,
- encoder_attn_mask_type=AttnMaskType.padding,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- # Multi-choice head.
- if self.post_process:
- self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
- self.multichoice_head = get_linear_layer(args.hidden_size, 1,
- init_method)
- self._multichoice_head_key = 'multichoice_head'
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.language_model.set_input_tensor(input_tensor)
-
- def forward(self, model_input, attention_mask, tokentype_ids=None):
-
- # [batch, choices, sequence] --> [batch * choices, sequence] -->
- # transformer --> [batch, choices] --> softmax
-
- # Ensure the shape is [batch-size, choices, sequence]
- assert len(attention_mask.shape) == 3
- num_choices = attention_mask.shape[1]
-
- # Reshape and treat choice dimension the same as batch.
- attention_mask = attention_mask.view(-1, attention_mask.size(-1))
- extended_attention_mask = bert_extended_attention_mask(attention_mask)
-
- input_ids = model_input
- # Do the same as attention_mask for input_ids, tokentype_ids
- assert len(input_ids.shape) == 3
- assert len(tokentype_ids.shape) == 3
- input_ids = input_ids.view(-1, input_ids.size(-1))
- tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
- position_ids = bert_position_ids(input_ids)
-
- lm_output = self.language_model(
- input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids
- )
- if self.post_process:
- _, pooled_output = lm_output
- multichoice_output = self.multichoice_dropout(pooled_output)
- multichoice_logits = self.multichoice_head(multichoice_output)
-
- # Reshape back to separate choices.
- multichoice_logits = multichoice_logits.view(-1, num_choices)
-
- return multichoice_logits
- return lm_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process:
- state_dict_[self._multichoice_head_key] \
- = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- if self.post_process:
- if self._multichoice_head_key in state_dict:
- self.multichoice_head.load_state_dict(
- state_dict[self._multichoice_head_key], strict=strict)
- else:
- print_rank_last('***WARNING*** could not find {} in the checkpoint, '
- 'initializing to random'.format(
- self._multichoice_head_key))
diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py
deleted file mode 100644
index 654f2992f62cdcb52c3ef7d40cc2dcb78fd776b6..0000000000000000000000000000000000000000
--- a/megatron/model/realm_model.py
+++ /dev/null
@@ -1,204 +0,0 @@
-import os
-import torch
-
-from megatron import get_args, print_rank_0
-from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
-from megatron.model import BertModel
-from .module import MegatronModule
-from megatron.core import mpu
-from megatron.model.enums import AttnMaskType
-from megatron.model.utils import get_linear_layer
-from megatron.model.utils import init_method_normal
-from megatron.model.language_model import get_language_model
-from megatron.model.utils import scaled_init_method_normal
-from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
-
-
-def general_ict_model_provider(only_query_model=False, only_block_model=False):
- """Build the model."""
- args = get_args()
- assert args.ict_head_size is not None, \
- "Need to specify --ict-head-size to provide an ICTBertModel"
- assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
- "Model parallel size > 1 not supported for ICT"
-
- print_rank_0('building ICTBertModel...')
-
- # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
- model = ICTBertModel(
- ict_head_size=args.ict_head_size,
- num_tokentypes=2,
- parallel_output=True,
- only_query_model=only_query_model,
- only_block_model=only_block_model)
-
- return model
-
-
-class ICTBertModel(MegatronModule):
- """Bert-based module for Inverse Cloze task."""
- def __init__(self,
- ict_head_size,
- num_tokentypes=1,
- parallel_output=True,
- only_query_model=False,
- only_block_model=False):
- super(ICTBertModel, self).__init__()
- bert_kwargs = dict(
- ict_head_size=ict_head_size,
- num_tokentypes=num_tokentypes,
- parallel_output=parallel_output
- )
- assert not (only_block_model and only_query_model)
- self.use_block_model = not only_query_model
- self.use_query_model = not only_block_model
-
- if self.use_query_model:
- # this model embeds (pseudo-)queries - Embed_input in the paper
- self.query_model = IREncoderBertModel(**bert_kwargs)
- self._query_key = 'question_model'
-
- if self.use_block_model:
- # this model embeds evidence blocks - Embed_doc in the paper
- self.block_model = IREncoderBertModel(**bert_kwargs)
- self._block_key = 'context_model'
-
- def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
- """Run a forward pass for each of the models and return the respective embeddings."""
- query_logits = self.embed_query(query_tokens, query_attention_mask)
- block_logits = self.embed_block(block_tokens, block_attention_mask)
- return query_logits, block_logits
-
- def embed_query(self, query_tokens, query_attention_mask):
- """Embed a batch of tokens using the query model"""
- if self.use_query_model:
- query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
- query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
- return query_ict_logits
- else:
- raise ValueError("Cannot embed query without query model.")
-
- def embed_block(self, block_tokens, block_attention_mask):
- """Embed a batch of tokens using the block model"""
- if self.use_block_model:
- block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
- block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
- return block_ict_logits
- else:
- raise ValueError("Cannot embed block without block model.")
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """Save dict with state dicts of each of the models."""
- state_dict_ = {}
- if self.use_query_model:
- state_dict_[self._query_key] \
- = self.query_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
-
- if self.use_block_model:
- state_dict_[self._block_key] \
- = self.block_model.state_dict_for_save_checkpoint(
- prefix=prefix, keep_vars=keep_vars)
-
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Load the state dicts of each of the models"""
- if self.use_query_model:
- print("Loading ICT query model", flush=True)
- self.query_model.load_state_dict(
- state_dict[self._query_key], strict=strict)
-
- if self.use_block_model:
- print("Loading ICT block model", flush=True)
- self.block_model.load_state_dict(
- state_dict[self._block_key], strict=strict)
-
- def init_state_dict_from_bert(self):
- """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
- args = get_args()
- tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
- if not os.path.isfile(tracker_filename):
- raise FileNotFoundError("Could not find BERT load for ICT")
- with open(tracker_filename, 'r') as f:
- iteration = int(f.read().strip())
- assert iteration > 0
-
- checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
-
- try:
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- except BaseException:
- raise ValueError("Could not load checkpoint")
-
- # load the LM state dict into each model
- model_dict = state_dict['model']['language_model']
- self.query_model.language_model.load_state_dict(model_dict)
- self.block_model.language_model.load_state_dict(model_dict)
-
- # give each model the same ict_head to begin with as well
- query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
- self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
-
-
-class IREncoderBertModel(MegatronModule):
- """BERT-based encoder for queries or blocks used for learned information retrieval."""
- def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
- super(IREncoderBertModel, self).__init__()
- args = get_args()
-
- self.ict_head_size = ict_head_size
- self.parallel_output = parallel_output
- init_method = init_method_normal(args.init_method_std)
- scaled_init_method = scaled_init_method_normal(args.init_method_std,
- args.num_layers)
-
- self.language_model, self._language_model_key = get_language_model(
- num_tokentypes=num_tokentypes,
- add_pooler=True,
- encoder_attn_mask_type=AttnMaskType.padding,
- init_method=init_method,
- scaled_init_method=scaled_init_method)
-
- self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
- self._ict_head_key = 'ict_head'
-
- def forward(self, input_ids, attention_mask, tokentype_ids=None):
- extended_attention_mask = bert_extended_attention_mask(
- attention_mask, next(self.language_model.parameters()).dtype)
- position_ids = bert_position_ids(input_ids)
-
- lm_output, pooled_output = self.language_model(
- input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids)
-
- # Output.
- ict_logits = self.ict_head(pooled_output)
- return ict_logits, None
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- state_dict_[self._ict_head_key] \
- = self.ict_head.state_dict(prefix=prefix,
- keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- self.ict_head.load_state_dict(
- state_dict[self._ict_head_key], strict=strict)
-
-
diff --git a/megatron/model/rms_norm.py b/megatron/model/rms_norm.py
deleted file mode 100644
index d42e7df9a812b4212516c48af61bfe1796a48244..0000000000000000000000000000000000000000
--- a/megatron/model/rms_norm.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-import torch
-from torch import nn
-
-class RMSNorm(torch.nn.Module):
-
- def __init__(self,
- dim: int,
- eps: float = 1e-6,
- sequence_parallel: bool = False):
- """RMS Normaliation module
-
- Arguments:
- dim (int): The width of input, i.e. hidden size
- eps (float): epsilon to use for the norm, default to 1e-6
- sequence_parallel (bool): Set to true if sequence parallelism is being used,
- this marks the weights as needing to be allreduced.
- """
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(dim))
-
- setattr(self.weight, 'sequence_parallel', sequence_parallel)
-
- def _norm(self, x):
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-
- def forward(self, x):
- output = self._norm(x.float()).type_as(x)
- return output * self.weight
diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py
deleted file mode 100644
index f9fabd34010e7467546b0c32a79039ef049aec86..0000000000000000000000000000000000000000
--- a/megatron/model/t5_model.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""T5 model."""
-
-import torch
-
-from megatron import get_args
-from megatron.core import tensor_parallel
-from megatron.model.enums import AttnMaskType
-from megatron.model.language_model import parallel_lm_logits, get_language_model
-from megatron.model import LayerNorm
-from megatron.model.utils import (
- openai_gelu,
- get_linear_layer
-)
-from .module import MegatronModule
-
-
-def t5_extended_attention_mask(attention_mask_list):
-
- def attn_mask_postprocess(attn_mask):
- # [b, 1, s, s]
- extended_attention_mask = attn_mask.unsqueeze(1)
- return extended_attention_mask
-
- return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
-
-
-def t5_position_ids(token_ids):
- # Create position ids
- seq_length = token_ids.size(1)
- position_ids = torch.arange(seq_length, dtype=torch.long,
- device=token_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
-
- return position_ids
-
-
-class T5LMHead(MegatronModule):
- """Masked LM head for T5
-
- Arguments:
- mpu_vocab_size: model parallel size of vocabulary.
- parallel_output: wether output logits being distributed or not.
- """
-
- def __init__(self, mpu_vocab_size, parallel_output):
- super(T5LMHead, self).__init__()
-
- self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
- self.bias.model_parallel = True
- self.bias.partition_dim = 0
- self.bias.stride = 1
- self.parallel_output = parallel_output
-
- def forward(self, hidden_states, word_embeddings_weight):
- output = parallel_lm_logits(hidden_states,
- word_embeddings_weight,
- self.parallel_output,
- bias=self.bias)
- return output
-
-
-class T5Model(MegatronModule):
- """T5 Language model."""
-
- def __init__(self,
- config,
- num_tokentypes=0,
- parallel_output=True,
- pre_process=True,
- post_process=True,
- add_encoder=True,
- add_decoder=True):
- super().__init__(config=config)
- args = get_args()
-
- self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
- self.parallel_output = parallel_output
- self.pre_process = pre_process
- self.post_process = post_process
- self.add_encoder = add_encoder
- self.add_decoder = add_decoder
-
- self.language_model, self._language_model_key = get_language_model(
- config=config,
- num_tokentypes=num_tokentypes,
- add_pooler=False,
- add_encoder=add_encoder,
- add_decoder=add_decoder,
- encoder_attn_mask_type=AttnMaskType.padding,
- pre_process=self.pre_process,
- post_process=self.post_process)
-
- self.initialize_word_embeddings()
-
- if self.post_process and self.add_decoder:
- self.lm_head = T5LMHead(
- self.shared_embedding_or_output_weight().size(0),
- parallel_output)
- self._lm_head_key = 'lm_head'
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.language_model.set_input_tensor(input_tensor)
-
- def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
- decoder_attn_mask, encoder_decoder_attn_mask,
- tokentype_ids=None, lm_labels=None, enc_hidden_states=None):
-
- # Converting the attention masks to proper parameter settings
- encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
- [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])
-
- encoder_position_ids = t5_position_ids(encoder_input_ids)
- decoder_position_ids = t5_position_ids(decoder_input_ids)
-
- lm_output = self.language_model(encoder_input_ids,
- encoder_position_ids,
- encoder_attn_mask,
- decoder_input_ids,
- decoder_position_ids,
- decoder_attn_mask,
- encoder_decoder_attn_mask,
- tokentype_ids=tokentype_ids,
- enc_hidden_states=enc_hidden_states)
-
- if self.post_process and self.add_decoder:
- decoder_output, encoder_output = lm_output
- # Output. [s, b, h]
- lm_logits = self.lm_head(decoder_output,
- self.shared_embedding_or_output_weight())
-
- if lm_labels is None:
- # [s b h] => [b s h]
- return lm_logits.transpose(0,1).contiguous()
- else:
- # [b s] => [s b]
- lm_labels = lm_labels.transpose(0,1).contiguous()
- if self.fp16_lm_cross_entropy:
- assert lm_logits.dtype == torch.half
- lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
- else:
- lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
- lm_labels)
- # [s b] => [b s]
- lm_loss = lm_loss.transpose(0,1).contiguous()
- return lm_loss
- elif self.add_decoder and not self.add_encoder:
- decoder_output, encoder_output = lm_output
- return decoder_output
- else:
- encoder_output = lm_output
- return encoder_output
-
- def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
-
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- if self.post_process and self.add_decoder:
- state_dict_[self._lm_head_key] \
- = self.lm_head.state_dict_for_save_checkpoint(prefix=prefix,
- keep_vars=keep_vars)
- # Save word_embeddings.
- if self.post_process and not self.pre_process and self.add_decoder:
- state_dict_[self._word_embeddings_for_head_key] \
- = self.word_embeddings.state_dict(prefix=prefix,
- keep_vars=keep_vars)
- return state_dict_
-
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
-
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- if self.post_process and self.add_decoder:
- self.lm_head.load_state_dict(state_dict[self._lm_head_key],
- strict=strict)
- # Load word embeddings.
- if self.post_process and not self.pre_process and self.add_decoder:
- self.word_embeddings.load_state_dict(
- state_dict[self._word_embeddings_for_head_key], strict=strict)
diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py
deleted file mode 100644
index 9f1144c02bb875f67ec556f739855b40eed8ea6f..0000000000000000000000000000000000000000
--- a/megatron/model/transformer.py
+++ /dev/null
@@ -1,1793 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Transformer."""
-from contextlib import nullcontext
-import math
-import numpy as np
-import torch
-import torch.nn.functional as F
-from typing import Optional
-
-from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
-from .module import MegatronModule
-from megatron.core import mpu, tensor_parallel
-from megatron.core.enums import ModelType
-from megatron.model.enums import AttnMaskType, LayerType, AttnType
-from megatron.model.fused_softmax import FusedScaleMaskSoftmax
-from megatron.model.fused_bias_gelu import bias_gelu_impl
-from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding, apply_rotary_pos_emb
-from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm
-from megatron.core.tensor_parallel import (
- gather_from_sequence_parallel_region_to_moe,
- reduce_scatter_to_sequence_parallel_region_from_moe,
- get_cuda_rng_tracker,
- get_data_parallel_rng_tracker_name
-)
-from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_expert_parallel_group
-
-try:
- from einops import rearrange
-except ImportError:
- rearrange = None
-
-try:
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func
-except ImportError:
- try:
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
- except ImportError:
- flash_attn_unpadded_func = None
-
-""" We use the following notation throughout this file:
- h: hidden size
- n: number of attention heads
- p: number of model parallel partitions
- np: n/p
- hp: h/p
- hn: h/n
- b: batch size
- s: sequence length
- l: number of layers
- Transformer takes input of size [s, b, h] and returns a
- tensor of the same size. We use the following arguments:
- hyperparameters: transformer hyperparameters
-"""
-
-class DropPath(MegatronModule):
- """Drop paths (Stochastic Depth) per sample
- (when applied in main path of residual blocks).
- """
-
- def __init__(self, drop_prob=0.):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, hidden_state):
- if self.drop_prob == 0. or not self.training:
- return hidden_state
- keep_prob = 1 - self.drop_prob
- # work with diff dim tensors, not just 2D ConvNets
- # hidden_state: [s, b, h]
- shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
- random_tensor = keep_prob + \
- torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
- random_tensor.floor_() # binarize
- output = hidden_state.div(keep_prob) * random_tensor
- return output
-
-class ParallelMLP(MegatronModule):
- """MLP.
-
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension.
- """
-
- def __init__(self, config, is_expert=False):
- super(ParallelMLP, self).__init__()
- args = get_args()
-
- self.add_bias = config.add_bias_linear
-
- ffn_hidden_size = config.ffn_hidden_size
- if config.gated_linear_unit:
- ffn_hidden_size *= 2
-
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
- self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- ffn_hidden_size,
- config=config,
- init_method=config.init_method,
- bias=self.add_bias,
- gather_output=False,
- skip_bias_add=True,
- is_expert=is_expert,
- )
-
- self.bias_gelu_fusion = False
- self.activation_func = None
- self.swiglu = args.swiglu
-
- if args.openai_gelu:
- self.activation_func = openai_gelu
- elif args.onnx_safe:
- self.activation_func = erf_gelu
- elif args.swiglu:
- def swiglu(x):
- x = torch.chunk(x, 2, dim=-1)
- return F.silu(x[0]) * x[1]
- self.activation_func = swiglu
- elif args.squared_relu:
- def squared_relu(x):
- return torch.pow(F.relu(x), 2)
- self.activation_func = squared_relu
- else:
- self.bias_gelu_fusion = args.bias_gelu_fusion
- self.activation_func = F.gelu
-
- # Project back to h.
- self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
- config.ffn_hidden_size,
- config.hidden_size,
- config=config,
- init_method=config.output_layer_init_method,
- bias=self.add_bias,
- skip_bias_add=True,
- input_is_parallel=True,
- is_expert=is_expert,
- )
-
- def forward(self, hidden_states):
-
- # [s, b, 4hp]
- intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
-
- if self.bias_gelu_fusion:
- assert self.add_bias is True
- assert self.activation_func == F.gelu
- intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
- else:
- if bias_parallel is not None:
- intermediate_parallel = intermediate_parallel + bias_parallel
- intermediate_parallel = self.activation_func(intermediate_parallel)
-
- # [s, b, h]
- output, output_bias = self.dense_4h_to_h(intermediate_parallel)
- return output, output_bias
-
-def sinkhorn(cost, tol=0.0001):
- cost = torch.exp(cost)
- d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
- d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
-
- eps = 0.00000001
- error = 1e9
- d1_old = d1
- while error > tol:
- d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps)
- d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps)
- error = torch.mean(torch.abs(d1_old-d1))
- d1_old = d1
- return d1*cost*d0.unsqueeze(1)
-
-
-def get_router_linear_layer(config):
- args = get_args()
- router = torch.nn.Linear(args.hidden_size, args.num_experts, bias=False)
- with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
- config.init_method(router.weight)
- setattr(router.weight, 'sequence_parallel',config.sequence_parallel)
- return router
-
-
-class SwitchMLP(MegatronModule):
- """
- Routes input to one of N MLP "experts"
- """
- def __init__(self, config):
- super(SwitchMLP, self).__init__()
- args = get_args()
- self.router = get_router_linear_layer(config)
- self.expert_parallel_size = mpu.get_expert_model_parallel_world_size()
- self.sequence_parallel = config.sequence_parallel
- self.add_bias = config.add_bias_linear
-
- assert args.num_experts % self.expert_parallel_size == 0
- self.num_local_experts = args.num_experts // self.expert_parallel_size
- local_expert_indices_offset = mpu.get_expert_model_parallel_rank() * self.num_local_experts
- self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]
-
- self.local_experts = torch.nn.ModuleList()
- for i in range(self.num_local_experts):
- self.local_experts.append(ParallelMLP(config, is_expert=True))
-
- def gather_indices(self, local_indices):
- """ Gather tensors and concatinate along the first dimension."""
- group = get_tensor_and_expert_parallel_group()
- world_size = torch.distributed.get_world_size(group=group)
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return local_indices
-
- dim_size = list(local_indices.size())
- dim_size[0] = dim_size[0] * world_size
-
- # TODO pre allocate memory
- output = torch.empty(dim_size, dtype=local_indices.dtype,
- device=torch.cuda.current_device())
- torch.distributed._all_gather_base(
- output, local_indices.contiguous(), group=group
- )
- return output
-
- def forward(self, hidden_states):
- # hidden_states: [b, s, h]
- args = get_args()
- s = hidden_states.size(0)
- b = hidden_states.size(1)
- h = hidden_states.size(2)
- route = self.router(hidden_states).view(-1, args.num_experts)
-
- # TODO (rprenger) Right now we're just using the sinkhorn algorithm
- # for load balancing. There should be an option to do no load balancing
- # and the algorithm and parametets should be further tested
- if self.training:
- with torch.no_grad():
- sinkroute = sinkhorn(route.detach().to(dtype=torch.float32))
- _, max_ind = torch.max(sinkroute, dim=1)
- route = torch.sigmoid(route)
- max_prob = route[torch.arange(route.size(0)), max_ind]
- else:
- route = torch.sigmoid(route)
- max_prob, max_ind = torch.max(route, dim=1)
-
- max_prob = torch.unsqueeze(max_prob, 1)
- hidden_states = hidden_states.view(-1, hidden_states.size(2))
-
- # TODO (rprenger) TODO this could be made easier to read
- # Converting [s, b, h] to [s*b, h].
- # Each vector could be routed differently
- if self.sequence_parallel or (self.expert_parallel_size > 1):
- global_hidden_states = \
- gather_from_sequence_parallel_region_to_moe(hidden_states)
- global_indices = self.gather_indices(max_ind)
- else:
- global_hidden_states = hidden_states
- global_indices = max_ind
-
- output_total = torch.zeros_like(global_hidden_states)
- if self.add_bias:
- output_bias_total = torch.zeros_like(global_hidden_states)
-
- for expert_num, expert in enumerate(self.local_experts):
- local_expert_index = self.local_expert_indices[expert_num]
- local_indices = (global_indices == local_expert_index).nonzero()
- hidden = global_hidden_states[local_indices, :]
- output, output_bias = expert(hidden)
- output_total[local_indices, :] = output
- if self.add_bias:
- output_bias = output_bias.expand_as(output)
- output_bias_total[local_indices, :] = output_bias
-
- if self.sequence_parallel or (self.expert_parallel_size > 1):
- output_total = \
- reduce_scatter_to_sequence_parallel_region_from_moe(output_total)
- if self.add_bias:
- output_bias_total = \
- reduce_scatter_to_sequence_parallel_region_from_moe(output_bias_total)
-
- # bias is duplicated across tensor parallelism ranks;
- # reduce scatter reduces bias across tensor parallel_ranks
- output_bias_total = \
- output_bias_total/mpu.get_tensor_model_parallel_world_size()
-
- output_total = output_total*max_prob
- output_total = output_total.view(s, b, h)
- if self.add_bias:
- output_bias_total = output_bias_total*max_prob
- output_bias_total = output_bias_total.view(s, b, h)
- else:
- output_bias_total = None
-
- return output_total, output_bias_total
-
-
-class CoreAttention(MegatronModule):
-
- def __init__(self, layer_number, config,
- attn_mask_type=AttnMaskType.padding):
- super(CoreAttention, self).__init__()
- self.fp16 = config.fp16
- self.bf16 = config.bf16
-
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
- if self.apply_query_key_layer_scaling:
- self.attention_softmax_in_fp32 = True
- self.layer_number = max(1, layer_number)
- self.attn_mask_type = attn_mask_type
- self.sequence_parallel = config.sequence_parallel
-
- projection_size = config.kv_channels * config.num_attention_heads
-
- # Per attention head and per partition values.
- world_size = mpu.get_tensor_model_parallel_world_size()
- self.hidden_size_per_partition = core.utils.divide(projection_size,
- world_size)
- self.hidden_size_per_attention_head = core.utils.divide(
- projection_size, config.num_attention_heads)
- self.num_attention_heads_per_partition = core.utils.divide(
- config.num_attention_heads, world_size)
-
- coeff = None
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- if self.apply_query_key_layer_scaling:
- coeff = self.layer_number
- self.norm_factor *= coeff
-
- self.scale_mask_softmax = FusedScaleMaskSoftmax(
- self.fp16, self.bf16,
- self.attn_mask_type,
- config.masked_softmax_fusion,
- attention_mask_func,
- self.attention_softmax_in_fp32,
- coeff)
-
- # Dropout. Note that for a single iteration, this layer will generate
- # different outputs on different number of parallel partitions but
- # on average it should not be partition dependent.
- self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
-
- def forward(self, query_layer, key_layer,
- value_layer, attention_mask):
-
- # ===================================
- # Raw attention scores. [b, np, s, s]
- # ===================================
-
- # [b, np, sq, sk]
- output_size = (query_layer.size(1),
- query_layer.size(2),
- query_layer.size(0),
- key_layer.size(0))
-
- # [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.reshape(output_size[2],
- output_size[0] * output_size[1], -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key_layer = key_layer.view(output_size[3],
- output_size[0] * output_size[1], -1)
-
- # preallocting input tensor: [b * np, sq, sk]
- matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
- (output_size[0]*output_size[1], output_size[2], output_size[3]),
- query_layer.dtype, "mpu")
-
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_input_buffer,
- query_layer.transpose(0, 1), # [b * np, sq, hn]
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0, alpha=(1.0/self.norm_factor))
-
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
-
- # ===========================
- # Attention probs and dropout
- # ===========================
-
- # attention scores and attention mask [b, np, sq, sk]
- attention_probs = self.scale_mask_softmax(attention_scores,
- attention_mask)
-
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- if not self.sequence_parallel:
- with tensor_parallel.get_cuda_rng_tracker().fork():
- attention_probs = self.attention_dropout(attention_probs)
- else:
- attention_probs = self.attention_dropout(attention_probs)
-
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
-
- # value_layer -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
-
- # context layer shape: [b, np, sq, hn]
- output_size = (value_layer.size(1),
- value_layer.size(2),
- query_layer.size(0),
- value_layer.size(3))
-
- # change view [sk, b * np, hn]
- value_layer = value_layer.view(value_layer.size(0),
- output_size[0] * output_size[1], -1)
-
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1],
- output_size[2], -1)
-
- # matmul: [b * np, sq, hn]
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
-
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
-
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
-
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + \
- (self.hidden_size_per_partition,)
- context_layer = context_layer.view(*new_context_layer_shape)
-
- return context_layer
-
-
-class FlashSelfAttention(torch.nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
- device=None, dtype=None):
- super().__init__()
- assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
- 'e.g., with pip install flash-attn')
- assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
- self.causal = causal
- self.softmax_scale = softmax_scale
- self.dropout_p = attention_dropout
-
- def forward(self, q, k, v):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
- """
-
- assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
- assert all((i.is_cuda for i in (q,k,v)))
-
- batch_size, seqlen_q = q.shape[0], q.shape[1]
- seqlen_k = k.shape[1]
-
- q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
- device=q.device)
-
- if self.training:
- # during training q,k,v always have same seqlen
- assert seqlen_k == seqlen_q
-
- is_causal = self.causal
- cu_seqlens_k = cu_seqlens_q
- dropout_p = self.dropout_p
- else:
- # turn off FA causal mask after first inference autoregressive iteration
- # only on first autoregressive step q,k,v have same seqlen
- is_causal = seqlen_q == seqlen_k
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
- device=q.device)
- dropout_p = 0
-
- output = flash_attn_unpadded_func(
- q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
- dropout_p,
- softmax_scale=self.softmax_scale, causal=is_causal
- )
-
- output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
- return output
-
-
-class ParallelAttention(MegatronModule):
- """Parallel self-attention layer abstract class.
-
- Self-attention layer takes input with size [s, b, h]
- and returns output of the same size.
- """
-
- def __init__(self, config, layer_number,
- attention_type=AttnType.self_attn,
- attn_mask_type=AttnMaskType.padding):
- super(ParallelAttention, self).__init__()
- args = get_args()
- self.layer_number = max(1, layer_number)
- self.attention_type = attention_type
- self.attn_mask_type = attn_mask_type
- self.params_dtype = config.params_dtype
- self.sequence_parallel = config.sequence_parallel
-
- self.group_query_attention = args.group_query_attention
- self.num_query_groups = args.num_query_groups
-
- query_projection_size = config.kv_channels * config.num_attention_heads
- if self.group_query_attention:
- kv_projection_size = args.kv_channels * args.num_query_groups
- else:
- kv_projection_size = args.kv_channels * args.num_attention_heads
-
- self.use_flash_attn = args.use_flash_attn \
- and attention_type == AttnType.self_attn \
- and self.attn_mask_type == AttnMaskType.causal
- if self.use_flash_attn:
- if flash_attn_unpadded_func is None:
- raise ImportError('FlashAttention is not installed, please install with '
- 'pip install flash-attn')
- assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
- 'self-attention for now')
- assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
- 'supports causal mask for now')
- if rearrange is None:
- raise ImportError('einops is not installed, please install with pip install einops')
-
- # Per attention head and per partition values.
- world_size = mpu.get_tensor_model_parallel_world_size()
- self.hidden_size_per_attention_head = core.utils.divide(
- query_projection_size, config.num_attention_heads)
- self.num_attention_heads_per_partition = core.utils.divide(
- config.num_attention_heads, world_size)
-
- if self.group_query_attention:
- if args.num_query_groups % world_size != 0:
- raise NotImplementedError('Currently the num_query_groups should be '
- 'a multiple of the tensor parallel size')
- self.num_query_groups_per_partition = core.utils.divide(
- args.num_query_groups, world_size)
- else:
- self.num_query_groups_per_partition = self.num_attention_heads_per_partition
-
- # Strided linear layer.
- if attention_type == AttnType.self_attn:
- self.query_key_value = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- query_projection_size + 2 * kv_projection_size,
- config=config,
- init_method=config.init_method,
- bias=args.add_bias_linear,
- gather_output=False)
- else:
- assert attention_type == AttnType.cross_attn
-
- if self.group_query_attention:
- raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
- assert query_projection_size == kv_projection_size
-
- self.query = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- query_projection_size,
- config=config,
- init_method=config.init_method,
- bias=config.add_bias_linear,
- gather_output=False)
-
- self.key_value = tensor_parallel.ColumnParallelLinear(
- config.hidden_size,
- 2 * kv_projection_size,
- config=config,
- init_method=config.init_method,
- bias=config.add_bias_linear,
- gather_output=False)
-
- self.core_attention = CoreAttention(self.layer_number, config,
- self.attn_mask_type)
- self.checkpoint_core_attention = config.recompute_granularity == 'selective'
-
- if self.use_flash_attn:
- self.core_attention_flash = FlashSelfAttention(
- causal=True, attention_dropout=config.attention_dropout
- )
-
- # Output.
- self.dense = tensor_parallel.RowParallelLinear(
- query_projection_size,
- config.hidden_size,
- config=config,
- init_method=config.output_layer_init_method,
- bias=args.add_bias_linear,
- input_is_parallel=True,
- skip_bias_add=True)
-
- def _checkpointed_attention_forward(self, query_layer, key_layer,
- value_layer, attention_mask,
- rotary_pos_emb=None):
- """Forward method with activation checkpointing."""
- def custom_forward(*inputs):
- query_layer = inputs[0]
- key_layer = inputs[1]
- value_layer = inputs[2]
- attention_mask = inputs[3]
- output_ = self.core_attention(query_layer, key_layer,
- value_layer, attention_mask)
- return output_
-
- q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
- else rotary_pos_emb
-
- hidden_states = tensor_parallel.checkpoint(
- custom_forward,
- False, query_layer, key_layer, value_layer, attention_mask,
- q_pos_emb, k_pos_emb)
-
- return hidden_states
-
- def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
- return torch.empty(
- inference_max_sequence_len,
- batch_size,
- num_attention_heads,
- self.hidden_size_per_attention_head,
- dtype=self.params_dtype,
- device=torch.cuda.current_device())
-
- def forward(self, hidden_states, attention_mask,
- encoder_output=None, inference_params=None,
- rotary_pos_emb=None):
- # hidden_states: [sq, b, h]
-
- # =================================================
- # Pre-allocate memory for key-values for inference.
- # =================================================
- is_first_step = False
- if inference_params:
- if self.layer_number not in inference_params.key_value_memory_dict:
- inf_max_seq_len = inference_params.max_sequence_length
- inf_max_batch_size = inference_params.max_batch_size
- inference_key_memory = self._allocate_memory(
- inf_max_seq_len, inf_max_batch_size,
- self.num_query_groups_per_partition)
- inference_value_memory = self._allocate_memory(
- inf_max_seq_len, inf_max_batch_size,
- self.num_query_groups_per_partition)
-
- inference_params.key_value_memory_dict[self.layer_number] = (
- inference_key_memory, inference_value_memory)
- is_first_step = True
- else:
- inference_key_memory, inference_value_memory = \
- inference_params.key_value_memory_dict[self.layer_number]
-
- # =====================
- # Query, Key, and Value
- # =====================
- if self.attention_type == AttnType.self_attn:
-
- # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
- mixed_x_layer, _ = self.query_key_value(hidden_states)
-
- # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
- new_tensor_shape = mixed_x_layer.size()[:-1] + (
- self.num_query_groups_per_partition,
- (
- (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
- * self.hidden_size_per_attention_head
- ),
- )
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
-
- # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
- (query_layer,
- key_layer,
- value_layer) = torch.split(
- mixed_x_layer,
- [
- (
- self.num_attention_heads_per_partition // self.num_query_groups_per_partition
- * self.hidden_size_per_attention_head
- ),
- self.hidden_size_per_attention_head,
- self.hidden_size_per_attention_head
- ],
- dim=3)
-
- # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
- query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
- else:
- # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
- mixed_kv_layer, _ = self.key_value(encoder_output)
-
- # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
- new_tensor_shape = mixed_kv_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- 2 * self.hidden_size_per_attention_head)
- mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
-
- # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
- (key_layer,
- value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
-
- # Attention head [sq, b, h] --> [sq, b, hp]
- query_layer, _ = self.query(hidden_states)
- # [sq, b, hp] --> [sq, b, np, hn]
- new_tensor_shape = query_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head)
- query_layer = query_layer.view(*new_tensor_shape)
-
- # ==================================
- # Adjust key and value for inference
- # ==================================
-
- # duplicate the pos_emb for self attention
- if rotary_pos_emb is not None:
- if isinstance(rotary_pos_emb, tuple):
- rotary_pos_emb = rotary_pos_emb
- else:
- rotary_pos_emb = ((rotary_pos_emb,) * 2)
-
- if inference_params:
- batch_start = inference_params.batch_size_offset
- batch_end = batch_start + key_layer.size(1)
- assert batch_end <= inference_key_memory.size(1)
- sequence_start = inference_params.sequence_len_offset
- sequence_end = sequence_start + key_layer.size(0)
- assert sequence_end <= inference_key_memory.size(0)
- # Copy key and values.
- inference_key_memory[sequence_start:sequence_end,
- batch_start:batch_end, ...] = key_layer
- inference_value_memory[sequence_start:sequence_end,
- batch_start:batch_end, ...] = value_layer
- key_layer = inference_key_memory[
- :sequence_end, batch_start:batch_end, ...]
- value_layer = inference_value_memory[
- :sequence_end, batch_start:batch_end, ...]
-
-
- # adjust the key rotary positional embedding
- if rotary_pos_emb is not None:
- q_pos_emb, k_pos_emb = rotary_pos_emb
- # need to cross check this condition during inference
- # if not set_inference_key_value_memory:
- if not is_first_step:
- # In inference, we compute one token at a time.
- # Select the correct positional embedding
- # (only the last token in the sequence)
- q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
- else:
- # In the first forward pass of inference,
- # we use the entire provided prefix.
- # q_pos_emb here has the rope embeddings of the entire
- # prefix + to-be-generated output so
- # we slice to just the prefix.
- q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
- k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
- rotary_pos_emb = (q_pos_emb, k_pos_emb)
-
- # ==================================
- # core attention computation
- # ==================================
-
- # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
- if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
- key_layer = key_layer.repeat_interleave(
- self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
- dim = 2
- )
- value_layer = value_layer.repeat_interleave(
- self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
- dim = 2
- )
-
- # apply relative positional encoding (rotary embedding)
- if rotary_pos_emb is not None:
- q_pos_emb, k_pos_emb = rotary_pos_emb
- query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
- key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
- # TODO, can apply positional embedding to value_layer so it has
- # absolute positional embedding.
- # otherwise, only relative positional embedding takes effect
- # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
-
- if not self.use_flash_attn:
- if self.checkpoint_core_attention:
- context_layer = self._checkpointed_attention_forward(
- query_layer, key_layer, value_layer, attention_mask)
- else:
- context_layer = self.core_attention(
- query_layer, key_layer, value_layer, attention_mask)
- else:
- q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
- for x in (query_layer, key_layer, value_layer)]
- if not self.sequence_parallel:
- with tensor_parallel.get_cuda_rng_tracker().fork():
- context_layer = self.core_attention_flash(q, k, v)
- else:
- context_layer = self.core_attention_flash(q, k, v)
- context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
-
- # =================
- # Output. [sq, b, h]
- # =================
-
- output, bias = self.dense(context_layer)
-
- return output, bias
-
-
-def bias_dropout_add(x, bias, residual, prob, training):
- # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
- if bias is not None:
- x = x + bias
- out = torch.nn.functional.dropout(x, p=prob, training=training)
- out = residual + out
- return out
-
-
-def get_bias_dropout_add(training):
- def _bias_dropout_add(x, bias, residual, prob):
- return bias_dropout_add(x, bias, residual, prob, training)
- return _bias_dropout_add
-
-
-@torch.jit.script
-def bias_dropout_add_fused_train(x: torch.Tensor,
- bias: Optional[torch.Tensor],
- residual: torch.Tensor,
- prob: float) -> torch.Tensor:
- return bias_dropout_add(x, bias, residual, prob, True)
-
-
-@torch.jit.script
-def bias_dropout_add_fused_inference(x: torch.Tensor,
- bias: Optional[torch.Tensor],
- residual: torch.Tensor,
- prob: float) -> torch.Tensor:
- return bias_dropout_add(x, bias, residual, prob, False)
-
-
-class ParallelTransformerLayer(MegatronModule):
- """A single transformer layer.
-
- Transformer layer takes input with size [s, b, h] and returns an
- output of the same size.
- """
-
- def __init__(self, config,
- layer_number, layer_type=LayerType.encoder,
- self_attn_mask_type=AttnMaskType.padding,
- drop_path_rate=0.):
- args = get_args()
-
- super(ParallelTransformerLayer, self).__init__()
- self.layer_number = layer_number
- self.layer_type = layer_type
-
- self.apply_residual_connection_post_norm \
- = config.apply_residual_connection_post_layernorm
-
- self.bf16 = config.bf16
- self.fp32_residual_connection = config.fp32_residual_connection
-
- # Normalize the input data.
- self.input_norm = get_norm(config)
-
- # Self attention.
- self.self_attention = ParallelAttention(
- config,
- layer_number,
- attention_type=AttnType.self_attn,
- attn_mask_type=self_attn_mask_type)
- self.hidden_dropout = config.hidden_dropout
- self.bias_dropout_fusion = config.bias_dropout_fusion
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
-
- # Normalize the attention output
- self.post_attention_norm = get_norm(config)
-
- # Cross attention.
- if self.layer_type in (LayerType.decoder,
- LayerType.retro_decoder,
- LayerType.retro_decoder_with_retriever,
- LayerType.retro_encoder):
- self.inter_attention = ParallelAttention(
- config,
- layer_number,
- attention_type=AttnType.cross_attn)
- # Normalize the attention output.
- self.post_inter_attention_norm = get_norm(config)
-
- # MLP
- if args.num_experts is not None:
- self.mlp = SwitchMLP(config)
- else:
- self.mlp = ParallelMLP(config)
-
- # Set bias+dropout+add fusion grad_enable execution handler.
- TORCH_MAJOR = int(torch.__version__.split('.')[0])
- TORCH_MINOR = int(torch.__version__.split('.')[1])
- use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
- self.bias_dropout_add_exec_handler = \
- nullcontext if use_nvfuser else torch.enable_grad
-
- if args.retro_add_retriever:
- retro_args = get_retro_args()
- self.retro_num_neighbors = args.retro_num_neighbors
- self.retro_chunk_length = retro_args.retro_gpt_chunk_length
- self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length
-
- # Retriever (bi-directional transformer with cross attention)
- if layer_type == LayerType.retro_decoder_with_retriever:
- self.retriever = ParallelTransformer(
- config=config,
- model_type=ModelType.retro_encoder,
- self_attn_mask_type=AttnMaskType.padding,
- pre_process=True,
- post_process=False,
- )
- self._retriever_key = 'retriever'
- else:
- self.retriever = None
-
- def default_decoder_cross_attention(self,
- encoder_output,
- enc_dec_attn_mask,
- norm_input,
- norm_output,
- bias_dropout_add_func):
- '''Cross attention for a standard encoder-decoder model.'''
-
- # Attention.
- attention_output, attention_bias = \
- self.inter_attention(norm_output,
- enc_dec_attn_mask,
- encoder_output=encoder_output)
-
- # Residual connection.
- if self.apply_residual_connection_post_norm:
- residual = norm_output
- else:
- residual = norm_input
-
- if attention_bias is not None:
- attention_bias = attention_bias.expand_as(residual)
-
- # Bias-dropout-add.
- with self.bias_dropout_add_exec_handler():
- norm_input = bias_dropout_add_func(
- attention_output,
- attention_bias,
- residual,
- self.hidden_dropout)
-
- # Normalize.
- norm_output = self.post_inter_attention_norm(norm_input)
-
- return norm_input, norm_output
-
- def retro_encoder_cross_attention(self,
- retriever_output,
- norm_input,
- norm_output,
- bias_dropout_add_func):
- """Cross attention for Retro encoder.
-
- Notation:
- ns : Sequence length.
- bs : Batch size.
- d : Hidden size.
- l : Number of chunks per sample (i.e., seq_length/chunk_length).
- k : Number of neighbors.
- r : Number of retrieved tokens (neighbors + continuation).
- """
-
- ns, bs, d = norm_output.shape # [r, bs * l * k, d]
-
- # Divide sequence dimension into chunks.
- chunked_outputs = norm_output.reshape(self.retro_retrieved_length,
- -1,
- self.retro_num_neighbors,
- d)
- chunked_outputs_before_norm = \
- norm_input.reshape(self.retro_retrieved_length, -1,
- self.retro_num_neighbors, d) # [r, bs*l, k, d]
-
- # Per-chunk attention.
- norm_inputs = []
- norm_outputs = []
- for k in range(self.retro_num_neighbors):
-
- # Attention.
- chunked_output = chunked_outputs[:,:,k].contiguous()
- attention_output, attention_bias = \
- self.inter_attention(
- chunked_output, # Q (neighbor embedding)
- None,
- encoder_output=retriever_output) # K, V (hidden act)
-
- # Residual connection.
- if self.apply_residual_connection_post_norm:
- residual = chunked_output
- else:
- residual = chunked_outputs_before_norm[:,:,k]
-
- # Re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
- norm_input = bias_dropout_add_func(
- attention_output,
- None if attention_bias is None else attention_bias.expand_as(residual),
- residual,
- self.hidden_dropout)
- norm_inputs.append(norm_input)
-
- # Layer norm.
- norm_output = self.post_inter_attention_norm(norm_input)
- norm_outputs.append(norm_output)
-
- # Concatenate layer norms.
- # norm_input : [r, k * bs * l, d]
- # norm_output : [r, k * bs * l, d]
- norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d)
- norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d)
-
- return norm_input, norm_output
-
- def retro_decoder_cross_attention(self,
- retriever_input,
- retriever_output,
- retriever_attn_mask,
- norm_input,
- norm_output,
- inference_params,
- bias_dropout_add_func):
- """Cross attention for Retro decoder.
-
- Notation:
- ns : Sequence length.
- bs : Batch size.
- d : Hidden size.
- l : Number of chunks per sample (i.e., seq_length/chunk_length).
- m : Number of tokens per chunk.
- k : Number of neighbors.
- r : Number of retrieved tokens (neighbors + continuation).
- """
-
- ns, bs, d = norm_output.shape
- l = int(np.ceil(ns / self.retro_chunk_length))
-
- # Retrieve neighbors.
- if self.layer_type == LayerType.retro_decoder_with_retriever:
- first_ns = ns % self.retro_chunk_length
- if first_ns > 0:
- raise Exception("test this case.")
- first_chunk, rest_chunk = \
- norm_output[:first_ns], norm_output[first_ns:]
- first_chunk = torch.nn.functional.pad(
- first_chunk,
- (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
- 'constant',
- 0)
- chunked_output = \
- torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
- else:
- chunked_output = norm_output # [l * m, bs, d]
- chunked_output = chunked_output \
- .reshape(l, self.retro_chunk_length, bs, d) \
- .permute(1, 2, 0, 3) \
- .reshape(self.retro_chunk_length, bs * l, d) \
- .contiguous()
-
- # Get Encoder Output
- retriever_output = self.retriever(
- hidden_states=retriever_input,
- attention_mask=retriever_attn_mask,
- retriever_output=chunked_output,
- retriever_attn_mask=retriever_attn_mask,
- inference_params=inference_params) # [r, k * bs * l , d]
- retriever_output = retriever_output.reshape(
- self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]
-
- # Chunks.
- pad = (ns - 1) % self.retro_chunk_length
- attending_chunks = norm_output[pad:]
- padded_chunks = torch.nn.functional.pad(
- attending_chunks,
- (0, 0, 0, 0, 0, self.retro_chunk_length - 1),
- 'constant', 0)
- padded_chunked_output = padded_chunks \
- .reshape(l, self.retro_chunk_length, bs, d) \
- .permute(1, 2, 0, 3)
- padded_chunked_output = padded_chunked_output.reshape(
- self.retro_chunk_length, bs * l, d).contiguous()
-
- # Encoder output.
- attention_output, attention_bias = \
- self.inter_attention(padded_chunked_output,
- None,
- encoder_output=retriever_output)
-
- # Residual connection.
- if self.apply_residual_connection_post_norm:
- residual = norm_output
- else:
- residual = norm_input
-
- # Re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
- norm_input = bias_dropout_add_func(
- attention_output,
- None if attention_bias is None else attention_bias.expand_as(attention_output),
- torch.zeros_like(attention_output),
- self.hidden_dropout)
- norm_input = norm_input \
- .reshape(self.retro_chunk_length, bs, l, d) \
- .permute(2, 0, 1, 3) # [l, m, bs, d]
- norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d)
- norm_input = torch.nn.functional.pad(
- norm_input,
- (0, 0, 0, 0, pad, 0),
- 'constant', 0)[:ns] # [ns, b, d]
- norm_input = norm_input + residual
-
- # Layer norm post the decoder attention
- norm_output = self.post_inter_attention_norm(norm_input)
-
- return retriever_output, norm_input, norm_output
-
- def forward(self, hidden_states, attention_mask,
- encoder_output=None, enc_dec_attn_mask=None,
- retriever_input=None,
- retriever_output=None,
- retriever_attn_mask=None,
- inference_params=None,
- rotary_pos_emb=None):
- # hidden_states: [s, b, h]
-
- # Layer norm at the beginning of the transformer layer.
- norm_output = self.input_norm(hidden_states)
-
- # Self attention.
- attention_output, attention_bias = \
- self.self_attention(
- norm_output,
- attention_mask,
- inference_params=inference_params,
- rotary_pos_emb=rotary_pos_emb)
-
- # Residual connection.
- if self.apply_residual_connection_post_norm:
- residual = norm_output
- else:
- residual = hidden_states
-
- if self.drop_path is None:
- # jit scripting for a nn.module (with dropout) is not
- # trigerring the fusion kernel. For now, we use two
- # different nn.functional routines to account for varying
- # dropout semantics during training and inference phases.
- if self.bias_dropout_fusion:
- if self.training:
- bias_dropout_add_func = bias_dropout_add_fused_train
- else:
- bias_dropout_add_func = bias_dropout_add_fused_inference
- else:
- bias_dropout_add_func = get_bias_dropout_add(self.training)
-
- if attention_bias is not None:
- attention_bias = attention_bias.expand_as(residual)
- with self.bias_dropout_add_exec_handler():
- norm_input = bias_dropout_add_func(
- attention_output,
- attention_bias,
- residual,
- self.hidden_dropout)
- else:
- out = torch.nn.functional.dropout(attention_output + attention_bias,
- p=self.hidden_dropout,
- training=self.training)
- norm_input = residual + self.drop_path(out)
-
- # Layer norm post the self attention.
- norm_output = self.post_attention_norm(norm_input)
-
- # Cross attention.
- if self.layer_type == LayerType.encoder:
- pass
- elif self.layer_type == LayerType.decoder:
- norm_input, norm_output = \
- self.default_decoder_cross_attention(
- encoder_output,
- enc_dec_attn_mask,
- norm_input,
- norm_output,
- bias_dropout_add_func)
- elif self.layer_type == LayerType.retro_encoder:
- norm_input, norm_output = \
- self.retro_encoder_cross_attention(
- retriever_output,
- norm_input,
- norm_output,
- bias_dropout_add_func)
- elif self.layer_type in (LayerType.retro_decoder,
- LayerType.retro_decoder_with_retriever):
- retriever_output, norm_input, norm_output = \
- self.retro_decoder_cross_attention(
- retriever_input,
- retriever_output,
- retriever_attn_mask,
- norm_input,
- norm_output,
- inference_params,
- bias_dropout_add_func)
- else:
- raise Exception("Unsupported layer type, '%s'." %
- self.layer_type.name)
-
- # MLP.
- mlp_output, mlp_bias = self.mlp(norm_output)
-
- # Second residual connection.
- if self.apply_residual_connection_post_norm:
- residual = norm_output
- else:
- residual = norm_input
-
- if self.drop_path is None:
- if mlp_bias is not None:
- mlp_bias = mlp_bias.expand_as(residual)
- with self.bias_dropout_add_exec_handler():
- output = bias_dropout_add_func(
- mlp_output,
- mlp_bias,
- residual,
- self.hidden_dropout)
-
- # Jit compiled function creates 'view' tensor. This tensor
- # potentially gets saved in the MPU checkpoint function context,
- # which rejects view tensors. While making a viewless tensor here
- # won't result in memory savings (like the data loader, or
- # p2p_communication), it serves to document the origin of this
- # 'view' tensor.
- output = core.utils.make_viewless_tensor(inp = output,
- requires_grad = output.requires_grad,
- keep_graph = True)
-
- else:
- if mlp_bias is not None:
- mlp_output = mlp_output + mlp_bias
- out = torch.nn.functional.dropout(mlp_output,
- p=self.hidden_dropout,
- training=self.training)
- output = residual + self.drop_path(out)
-
- if self.layer_type == LayerType.retro_decoder_with_retriever:
- return output, retriever_output
- else:
- return output
-
-
-class NoopTransformerLayer(MegatronModule):
- """A single 'no-op' transformer layer.
-
- The sole purpose of this layer is for when a standalone embedding layer
- is used (i.e., args.standalone_embedding_stage == True). In this case,
- zero transformer layers are assigned when pipeline rank == 0. Additionally,
- when virtual pipeline rank >= 1, zero total model parameters are created
- (virtual rank 0 contains the input embedding). This results in the model's
- input and output tensors being the same, which causes an error when
- performing certain memory optimiations on the output tensor (e.g.,
- deallocating it). Thus, this layer disconnects the input from the output
- via a clone. Since ranks containing a no-op layer are generally under-
- utilized (both compute and memory), there's no worry of any performance
- degredation.
- """
-
- def __init__(self, layer_number):
- super().__init__()
- self.layer_number = layer_number
-
- def forward(self, hidden_states, attention_mask,
- encoder_output=None, enc_dec_attn_mask=None,
- inference_params=None):
- return hidden_states.clone()
-
-
-def _get_num_layers(args, model_type, is_decoder=False):
- """Compute the number of transformer layers resident on the current rank."""
- is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
- if model_type == ModelType.retro_encoder:
- num_layers = args.retro_encoder_layers
- elif mpu.get_pipeline_model_parallel_world_size() > 1:
- if is_encoder_and_decoder_model:
- assert args.pipeline_model_parallel_split_rank is not None
-
- # When a standalone embedding stage is used, a rank is taken from
- # the encoder's ranks, to be used for the encoder's embedding
- # layer. This way, the rank referenced by the 'split rank' remains
- # the same whether or not a standalone embedding stage is used.
- num_ranks_in_encoder = (
- args.pipeline_model_parallel_split_rank - 1
- if args.standalone_embedding_stage else
- args.pipeline_model_parallel_split_rank
- )
- num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
- assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
- 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
- assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
- 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
- if mpu.is_pipeline_stage_before_split():
- num_layers = (
- 0
- if args.standalone_embedding_stage
- and mpu.get_pipeline_model_parallel_rank() == 0 else
- args.encoder_num_layers // num_ranks_in_encoder
- )
- else:
- num_layers = args.decoder_num_layers // num_ranks_in_decoder
- else:
- assert args.num_layers == args.encoder_num_layers
- assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
- 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
-
- # When a standalone embedding stage is used, all transformer layers
- # are divided among pipeline rank >= 1, while on pipeline rank 0,
- # ranks either contain the input embedding layer (virtual pp rank 0),
- # or no layers at all (virtual pp rank >= 1).
- num_layers = (
- 0
- if args.standalone_embedding_stage
- and mpu.get_pipeline_model_parallel_rank() == 0 else
- args.num_layers // args.transformer_pipeline_model_parallel_size
- )
- else:
- if not is_decoder:
- num_layers = args.encoder_num_layers
- else:
- num_layers = args.decoder_num_layers
- return num_layers
-
-
-def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
- layer_number):
- args = get_args()
- if args.retro_add_retriever and layer_number in retro_layer_numbers:
- if model_type == ModelType.retro_decoder:
- return LayerType.retro_decoder_with_retriever \
- if layer_number == retro_layer_numbers[0] \
- else LayerType.retro_decoder
- elif model_type == ModelType.retro_encoder:
- return LayerType.retro_encoder
- else:
- raise Exception("Unsupported model type, '%s'." % model_type)
- else:
- return default_layer_type
-
-
-class ParallelTransformer(MegatronModule):
- """Transformer class."""
-
- def __init__(self, config,
- model_type, layer_type=LayerType.encoder,
- self_attn_mask_type=AttnMaskType.padding,
- post_norm=True,
- pre_process=True,
- post_process=True,
- drop_path_rate=0.0):
- super(ParallelTransformer, self).__init__()
- args = get_args()
-
- self.layer_type = layer_type
- self.model_type = model_type
- self.bf16 = config.bf16
- self.fp32_residual_connection = config.fp32_residual_connection
- self.post_norm = post_norm
- self.pre_process = pre_process
- self.post_process = post_process
- self.input_tensor = None
- self.drop_path_rate = drop_path_rate
- self.transformer_impl = args.transformer_impl
- self.retro_add_retriever = args.retro_add_retriever
-
- # Store activation checkpoiting flag.
- self.recompute_granularity = config.recompute_granularity
- self.recompute_method = config.recompute_method
- self.recompute_num_layers = config.recompute_num_layers
- self.distribute_saved_activations = \
- config.distribute_saved_activations and not config.sequence_parallel
-
- self.sequence_parallel = config.sequence_parallel
-
- # Transformer Engine Init.
- self.transformer_engine_v_0_10 = False
- self.transformer_engine_v_0_11 = False
- self.transformer_engine_v_0_8 = False
- if self.transformer_impl == 'transformer_engine':
- global transformer_engine
- import transformer_engine
- from importlib.metadata import version
- from pkg_resources import packaging
-
- te_version = packaging.version.Version(version("transformer-engine"))
- if te_version >= packaging.version.Version("0.8.0"):
- self.transformer_engine_v_0_8 = True
- if te_version >= packaging.version.Version("0.10.0"):
- self.transformer_engine_v_0_10 = True
- if te_version >= packaging.version.Version("0.11.0"):
- self.transformer_engine_v_0_11 = True
-
- del version, packaging
-
- assert not args.squared_relu, "TransformerEngine does not support squared relu activation."
-
- self.use_fp8 = args.fp8 is not None
- self.fp8_recipe = None
- self.fp8_group = None
- if self.use_fp8:
- assert args.transformer_impl == 'transformer_engine', \
- 'transformer-engine required for fp8 training and inference'
- self.fp8_group = mpu.get_amax_reduction_group()
- if args.fp8 == "e4m3":
- fp8_format = transformer_engine.common.recipe.Format.E4M3
- elif args.fp8 == "hybrid":
- fp8_format = transformer_engine.common.recipe.Format.HYBRID
- else:
- raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
- self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
- margin=args.fp8_margin,
- interval=args.fp8_interval,
- fp8_format=fp8_format,
- amax_history_len=args.fp8_amax_history_len,
- amax_compute_algo=args.fp8_amax_compute_algo,
- override_linear_precision=(False, False, not args.fp8_wgrad),
- )
-
- self.num_microbatches_in_previous_step = -1
- self.microbatch_count = 0
- self.checkpoint_core_attention = config.recompute_granularity == 'selective'
-
- # Number of layers.
- self.num_layers = _get_num_layers(args, model_type,
- layer_type==LayerType.decoder)
-
- self.drop_path_rates = [
- rate.item() for rate in
- torch.linspace(0, self.drop_path_rate, config.num_layers)]
-
- self.retro_layer_numbers = None
- if model_type == ModelType.retro_decoder:
- retro_layer_start = 6 if config.num_layers <= 15 else 9
- self.retro_layer_numbers = \
- np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
- if model_type == ModelType.retro_encoder:
- self.retro_layer_numbers = [1]
-
- # Transformer layers.
- if args.retro_add_retriever:
- assert self.recompute_granularity != 'full', \
- "Full recompute not supported for Retro."
- assert args.transformer_impl == 'local', \
- "Transformer engine does not support Retro layers."
- def build_layer(layer_number):
- if args.transformer_impl == 'local':
- current_layer_type = _get_layer_type(
- model_type, layer_type, self.retro_layer_numbers,
- layer_number)
- return ParallelTransformerLayer(
- config,
- layer_number,
- layer_type=current_layer_type,
- self_attn_mask_type=self_attn_mask_type,
- drop_path_rate=self.drop_path_rates[layer_number - 1])
- else:
- # This argument is only available from TE v0.10 onwards.
- extra_transformer_engine_kwargs = {}
- if self.transformer_engine_v_0_8:
- extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
- if self.transformer_engine_v_0_10:
- extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
- if self.transformer_engine_v_0_11:
- extra_transformer_engine_kwargs["normalization"] = args.normalization
- return transformer_engine.pytorch.TransformerLayer(
- config.hidden_size,
- config.ffn_hidden_size,
- config.num_attention_heads,
- layernorm_epsilon=config.layernorm_epsilon,
- hidden_dropout=config.hidden_dropout,
- attention_dropout=config.attention_dropout,
- init_method=config.init_method,
- output_layer_init_method=config.output_layer_init_method,
- layer_number=layer_number,
- kv_channels=config.kv_channels,
- self_attn_mask_type=self_attn_mask_type.name,
- tp_group=mpu.get_tensor_model_parallel_group(),
- get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
- fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
- apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
- attention_softmax_in_fp32=config.attention_softmax_in_fp32,
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- sequence_parallel=config.sequence_parallel,
- params_dtype=config.params_dtype,
- apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
- output_layernorm=False,
- layer_type="encoder",
- drop_path_rate=self.drop_path_rates[layer_number - 1],
- set_parallel_mode=True,
- fuse_qkv_params=True,
- **extra_transformer_engine_kwargs)
-
- if config.virtual_pipeline_model_parallel_size is not None:
- assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
- 'num_layers_per_stage must be divisible by ' \
- 'virtual_pipeline_model_parallel_size'
- assert args.model_type != ModelType.encoder_and_decoder
- # Number of layers in each model chunk is the number of layers in the stage,
- # divided by the number of model chunks in a stage.
- self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
- # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0] [2] [4] [6]
- # Stage 1: [1] [3] [5] [7]
- # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0, 1] [4, 5]
- # Stage 1: [2, 3] [6, 7]
- offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
- config.num_layers // config.virtual_pipeline_model_parallel_size) + \
- (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
- else:
- # Each stage gets a contiguous set of layers.
- if args.model_type == ModelType.encoder_and_decoder and \
- mpu.get_pipeline_model_parallel_world_size() > 1:
- pipeline_rank = mpu.get_pipeline_model_parallel_rank()
- if layer_type == LayerType.encoder:
- offset = pipeline_rank * self.num_layers
- else:
- num_ranks_in_enc = args.pipeline_model_parallel_split_rank
- offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
- else:
- offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
-
- if self.num_layers == 0:
- # When a standalone embedding stage is used (e.g.,
- # args.standalone_embedding_stage == True), virtual pipeline ranks
- # on pipeline rank 0 will have zero transformer layers assigned to
- # them. This results in the model's input and output tensors to be
- # the same, which will cause failure for certain output tensor
- # optimizations (e.g., pipeline output deallocation). To remedy
- # this, we assign a 'no-op' layer on these ranks, which will
- # disconnect the input tensor from the output tensor.
- self.num_layers = 1
- self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
- else:
- self.layers = torch.nn.ModuleList(
- [build_layer(i + 1 + offset) for i in range(self.num_layers)])
-
- # Update dropout rate for Retro encoder.
- if model_type == ModelType.retro_encoder:
- for layer in self.layers:
- if layer.self_attention.use_flash_attn:
- layer.self_attention.core_attention_flash.dropout_p = \
- torch.nn.Dropout(args.retro_encoder_attention_dropout)
- else:
- layer.self_attention.core_attention.attention_dropout.p =\
- args.retro_encoder_attention_dropout
- layer.hidden_dropout = args.retro_encoder_hidden_dropout
-
- if self.post_process and self.post_norm:
- # Final layer norm before output.
- self.final_norm = get_norm(config)
-
- def _get_layer(self, layer_number):
- return self.layers[layer_number]
-
- def _checkpointed_forward(self, hidden_states, attention_mask,
- encoder_output, enc_dec_attn_mask,
- rotary_pos_emb, is_first_microbatch):
- """Forward method with activation checkpointing."""
- def custom(start, end):
- def custom_forward(*args, **kwargs):
- x_, *args = args
- for index in range(start, end):
- layer = self._get_layer(index)
- x_ = layer(x_, *args, **kwargs)
- return x_
- return custom_forward
-
- te_forward_kwargs = {}
- if self.transformer_impl == 'transformer_engine':
- te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
- if self.transformer_engine_v_0_10:
- te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
-
- if self.recompute_method == 'uniform':
- # Uniformly divide the total number of Transformer layers and
- # checkpoint the input activation of each divided chunk.
- # A method to further reduce memory usage reducing checkpoints.
- l = 0
- while l < self.num_layers:
- if self.transformer_impl == 'transformer_engine':
- hidden_states = transformer_engine.pytorch.checkpoint(
- custom(l, l + self.recompute_num_layers),
- self.distribute_saved_activations,
- tensor_parallel.get_cuda_rng_tracker,
- mpu.get_tensor_model_parallel_group(),
- hidden_states, attention_mask, encoder_output,
- enc_dec_attn_mask, **te_forward_kwargs)
- else:
- hidden_states = tensor_parallel.checkpoint(
- custom(l, l + self.recompute_num_layers),
- self.distribute_saved_activations,
- hidden_states, attention_mask,
- encoder_output, enc_dec_attn_mask,
- None, None, None, None, rotary_pos_emb)
-
- l += self.recompute_num_layers
-
- elif self.recompute_method == 'block':
- # Checkpoint the input activation of only a set number of individual
- # Transformer layers and skip the rest.
- # A method fully use the device memory removing redundant re-computation.
- for l in range(self.num_layers):
- if l < self.recompute_num_layers:
- if self.transformer_impl == 'transformer_engine':
- hidden_states = transformer_engine.pytorch.checkpoint(
- custom(l, l + 1),
- self.distribute_saved_activations,
- tensor_parallel.get_cuda_rng_tracker,
- mpu.get_tensor_model_parallel_group(),
- hidden_states, attention_mask, encoder_output,
- enc_dec_attn_mask, **te_forward_kwargs)
- else:
- hidden_states = tensor_parallel.checkpoint(
- custom(l, l + 1),
- self.distribute_saved_activations,
- hidden_states, attention_mask,
- encoder_output, enc_dec_attn_mask,
- None, None, None, None, rotary_pos_emb)
- else:
- if self.transformer_impl == 'transformer_engine':
- hidden_states = custom(l, l + 1)(
- hidden_states, attention_mask, encoder_output,
- enc_dec_attn_mask, **te_forward_kwargs)
- else:
- hidden_states = custom(l, l + 1)(
- hidden_states, attention_mask,
- encoder_output, enc_dec_attn_mask,
- None, None, None, None, rotary_pos_emb)
- else:
- raise ValueError("Invalid activation recompute method.")
-
- return hidden_states
-
- def set_input_tensor(self, input_tensor):
- """Set input tensor to be used instead of forward()'s input.
-
- When doing pipeline parallelism the input from the previous
- stage comes from communication, not from the input, so the
- model's forward_step_func won't have it. This function is thus
- used by internal code to bypass the input provided by the
- forward_step_func"""
- self.input_tensor = input_tensor
-
- def forward(self, hidden_states, attention_mask,
- encoder_output=None, enc_dec_attn_mask=None,
- retriever_input=None,
- retriever_output=None,
- retriever_attn_mask=None,
- inference_params=None,
- rotary_pos_emb=None):
- # hidden_states: [s, b, h]
-
- # Checks.
- if inference_params:
- assert self.recompute_granularity is None, \
- 'inference does not work with activation checkpointing'
-
- if not self.pre_process:
- # See set_input_tensor()
- hidden_states = self.input_tensor
-
- # Viewless tensor.
- # - We only need to create a viewless tensor in the case of micro batch
- # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
- # above creates a view tensor, and '.contiguous()' is a pass-through.
- # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
- # the need to make it viewless.
- #
- # However, we don't explicitly check mbs == 1 here because
- # make_viewless_tensor() has negligible overhead when its input
- # is already viewless.
- #
- # - For the 'else' case above, calling make_viewless_tensor() here is
- # likely redundant, since p2p_communication.py (likely originator)
- # already creates viewless tensors. That said, make_viewless_tensor()
- # is called here to be future-proof and corner-case-proof.
- hidden_states = core.utils.make_viewless_tensor(
- hidden_states,
- requires_grad=True,
- keep_graph=True,
- )
-
- # RNG context.
- if self.sequence_parallel:
- rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
- else:
- rng_context = nullcontext()
-
- # Forward layers.
- with rng_context:
- # The fp8_autocast context manager is a no-op when enabled=True
- # The if...else serves to short circuit name resolution for fp8_autocast
- with transformer_engine.pytorch.fp8_autocast(
- enabled=self.use_fp8,
- fp8_recipe=self.fp8_recipe,
- fp8_group=self.fp8_group
- ) if self.use_fp8 else nullcontext():
- # Determine if the current iteration is first microbatch
- if self.num_microbatches_in_previous_step != get_num_microbatches():
- self.microbatch_count = 0 # Reset count on new batch size rampup interval
- self.num_microbatches_in_previous_step = get_num_microbatches()
- is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
-
- # Forward pass.
- if self.recompute_granularity == 'full':
- hidden_states = self._checkpointed_forward(hidden_states,
- attention_mask,
- encoder_output,
- enc_dec_attn_mask,
- rotary_pos_emb,
- is_first_microbatch)
- else:
- forward_kwargs = {
- 'encoder_output': encoder_output,
- 'enc_dec_attn_mask': enc_dec_attn_mask,
- 'inference_params': inference_params,
- }
-
- if self.transformer_impl == 'transformer_engine':
- forward_kwargs['is_first_microbatch'] = is_first_microbatch
- forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
- if self.transformer_engine_v_0_10:
- forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
- else:
- forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
- forward_kwargs['retriever_input'] = retriever_input
- forward_kwargs['retriever_output'] = retriever_output
- forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
-
- for index in range(self.num_layers):
- layer = self._get_layer(index)
-
- hidden_states = layer(
- hidden_states,
- attention_mask,
- **forward_kwargs)
-
- # First Retro decoder layer returns both hidden_states
- # and retriever_output. Make retriever_output available
- # to subsequence Retro layers.
- if isinstance(hidden_states, tuple):
- assert len(hidden_states) == 2
- hidden_states, retriever_output = hidden_states
- forward_kwargs["retriever_output"] = retriever_output
-
- # Skip counter update for eval and activation checkpointing
- if torch.is_grad_enabled() and self.training:
- self.microbatch_count += 1
-
- # Final layer norm.
- if self.post_process and self.post_norm:
- hidden_states = self.final_norm(hidden_states)
-
- return hidden_states
-
- def load_state_dict(self, state_dict, strict=True):
- """Customize load."""
-
- # Handle renaming layernorm -> norm in component names
- state_dict_ = {}
- for key in state_dict.keys():
- newkey = key.replace("layernorm", "norm")
- state_dict_[newkey] = state_dict[key]
-
- super().load_state_dict(state_dict_, strict)
diff --git a/megatron/model/utils.py b/megatron/model/utils.py
deleted file mode 100644
index 15fbe9ad9e215e5a33d03c8fcc58cbc281b5913f..0000000000000000000000000000000000000000
--- a/megatron/model/utils.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Utilities for models."""
-
-import math
-
-import torch
-
-from megatron import get_args
-from megatron.model import LayerNorm, RMSNorm
-
-def init_method_normal(sigma):
- """Init method based on N(0, sigma)."""
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
-
- return init_
-
-
-def scaled_init_method_normal(sigma, num_layers):
- """Init method based on N(0, sigma/sqrt(2*num_layers)."""
- std = sigma / math.sqrt(2.0 * num_layers)
-
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=std)
-
- return init_
-
-
-def attention_mask_func(attention_scores, attention_mask):
- attention_scores.masked_fill_(attention_mask, -10000.0)
- return attention_scores
-
-
-def get_linear_layer(rows, columns, init_method):
- """Simple linear layer with weight initialization."""
- layer = torch.nn.Linear(rows, columns)
- if get_args().perform_initialization:
- init_method(layer.weight)
- with torch.no_grad():
- layer.bias.zero_()
- return layer
-
-
-@torch.jit.script
-def gelu_impl(x):
- """OpenAI's gelu implementation."""
- return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
-
- (1.0 + 0.044715 * x * x)))
-def openai_gelu(x):
- return gelu_impl(x)
-
-
-#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
-@torch.jit.script
-def erf_gelu(x):
- return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
-
-
-def get_norm(config):
- args = get_args()
- if args.normalization == "LayerNorm":
- return LayerNorm(
- config.hidden_size,
- eps=config.layernorm_epsilon,
- no_persist_layer_norm=not config.persist_layer_norm,
- sequence_parallel=config.sequence_parallel,
- apply_layernorm_1p=args.apply_layernorm_1p)
- elif args.normalization == "RMSNorm":
- if args.apply_layernorm_1p:
- raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')
-
- return RMSNorm(dim=config.hidden_size,
- eps=config.layernorm_epsilon,
- sequence_parallel=config.sequence_parallel)
- else:
- raise Exception(f"unsupported norm type '{args.normalization}'.")
diff --git a/megatron/model/vision/__init__.py b/megatron/model/vision/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/model/vision/classification.py b/megatron/model/vision/classification.py
deleted file mode 100644
index 3d5c823df47934f59319c4e8e1e869a62d425dc2..0000000000000000000000000000000000000000
--- a/megatron/model/vision/classification.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Vision Transformer(VIT) model."""
-
-import torch
-from torch.nn.init import trunc_normal_
-from megatron import get_args
-from megatron.model.utils import get_linear_layer
-from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
-from megatron.model.vision.mit_backbone import mit_b3_avg
-from megatron.model.module import MegatronModule
-
-class VitClassificationModel(MegatronModule):
- """Vision Transformer Model."""
-
- def __init__(self, config, num_classes, finetune=False,
- pre_process=True, post_process=True):
- super(VitClassificationModel, self).__init__()
- args = get_args()
- self.config = config
-
- self.hidden_size = args.hidden_size
- self.num_classes = num_classes
- self.finetune = finetune
- self.pre_process = pre_process
- self.post_process = post_process
- self.backbone = VitBackbone(
- config=config,
- pre_process=self.pre_process,
- post_process=self.post_process,
- single_token_output=True
- )
-
- if self.post_process:
- if not self.finetune:
- self.head = VitMlpHead(config, self.hidden_size, self.num_classes)
- else:
- self.head = get_linear_layer(
- self.hidden_size,
- self.num_classes,
- torch.nn.init.zeros_
- )
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.backbone.set_input_tensor(input_tensor)
-
- def forward(self, input):
- hidden_states = self.backbone(input)
-
- if self.post_process:
- hidden_states = self.head(hidden_states)
-
- return hidden_states
-
-
-class MitClassificationModel(MegatronModule):
- """Mix vision Transformer Model."""
-
- def __init__(self, num_classes,
- pre_process=True, post_process=True):
- super(MitClassificationModel, self).__init__()
- args = get_args()
-
- self.hidden_size = args.hidden_size
- self.num_classes = num_classes
-
- self.backbone = mit_b3_avg()
- self.head = torch.nn.Linear(512, num_classes)
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, torch.nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, torch.nn.Linear) and m.bias is not None:
- torch.nn.init.constant_(m.bias, 0)
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- pass
-
- def forward(self, input):
- hidden_states = self.backbone(input)
- hidden_states = self.head(hidden_states)
-
- return hidden_states
diff --git a/megatron/model/vision/dino.py b/megatron/model/vision/dino.py
deleted file mode 100644
index 151ec26647453cf139616ed2eeaf899c47cd280c..0000000000000000000000000000000000000000
--- a/megatron/model/vision/dino.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the Apache license found in the
-# LICENSE file in the root directory of this source tree.
-
-# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
-# reworked/refactored some parts to make it run in Megatron.
-import math
-import apex
-import einops
-import torch
-import numpy as np
-import torch.nn.functional as F
-from torch.nn.init import trunc_normal_
-from megatron import get_args, print_rank_0
-from megatron.model.utils import get_linear_layer
-from megatron.model.vision.vit_backbone import VitBackbone
-from megatron.model.module import MegatronModule
-from megatron.model.vision.mit_backbone import mit_b5_avg
-from megatron.model.vision.esvit_swin_backbone import get_swin
-
-
-class DINOLoss(torch.nn.Module):
- def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
- warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
- center_momentum=0.9):
- super().__init__()
- self.student_temp = student_temp
- self.center_momentum = center_momentum
- self.ncrops = ncrops
- self.register_buffer("center", torch.zeros(1, out_dim))
- # we apply a warm up for the teacher temperature because
- # a too high temperature makes the training instable at the beginning
- self.teacher_temp_schedule = np.concatenate((
- np.linspace(warmup_teacher_temp,
- teacher_temp, warmup_teacher_temp_epochs),
- np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
- ))
- self.teacher_temp = teacher_temp
-
- def forward(self, student_output, teacher_output, iteration):
- """
- Cross-entropy between softmax outputs of the teacher
- and student network.
- """
- args = get_args()
- student_out = student_output / self.student_temp
- student_out = student_out.chunk(self.ncrops)
-
- epoch = iteration // args.iter_per_epoch
-
- # teacher centering and sharpening
- temp = self.teacher_temp_schedule[epoch]
- teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
-
- teacher_out = teacher_out.detach().chunk(2)
-
- total_loss = 0
- n_loss_terms = 0
- for iq, q in enumerate(teacher_out):
- for v in range(len(student_out)):
- if v == iq:
- # we skip cases where student and teacher operate on the same view
- continue
- loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
- total_loss += loss.mean()
- n_loss_terms += 1
- total_loss /= n_loss_terms
- self.update_center(teacher_output)
- return total_loss
-
- @torch.no_grad()
- def update_center(self, teacher_output):
- """
- Update center used for teacher output.
- """
- batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
- torch.distributed.all_reduce(batch_center)
- batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
- self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
-
-class DINOHead(torch.nn.Module):
- def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
- super().__init__()
- args = get_args()
- hidden_dim = args.dino_head_hidden_size
- bottleneck_dim = args.dino_bottleneck_size
- nlayers = max(nlayers, 1)
- if nlayers == 1:
- self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
- else:
- layers = [torch.nn.Linear(in_dim, hidden_dim)]
- layers.append(torch.nn.GELU())
- for _ in range(nlayers - 2):
- layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
- layers.append(torch.nn.GELU())
- layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
- self.mlp = torch.nn.Sequential(*layers)
- self.apply(self._init_weights)
- self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
- self.last_layer.weight_g.data.fill_(1)
- if norm_last_layer:
- self.last_layer.weight_g.requires_grad = False
-
- def _init_weights(self, m):
- if isinstance(m, torch.nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, torch.nn.Linear) and m.bias is not None:
- torch.nn.init.constant_(m.bias, 0)
-
- def forward(self, x):
- x = self.mlp(x)
- x = torch.nn.functional.normalize(x, dim=-1, p=2)
- x = self.last_layer(x)
- return x
-
-
-class MultiCropWrapper(MegatronModule):
-
- """
- Perform forward pass separately on each resolution input.
- The inputs corresponding to a single resolution are clubbed and single
- forward is run on the same resolution inputs. Hence we do several
- forward passes = number of different resolutions used. We then
- concatenate all the output features and run the head forward on these
- concatenated features.
- """
- def __init__(self, backbone, head):
- super(MultiCropWrapper, self).__init__()
- # disable layers dedicated to ImageNet labels classification
- #backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
- self.backbone = backbone
- self.head = head
-
- def forward(self, x):
- # convert to list
- if not isinstance(x, list):
- x = [x]
- idx_crops = torch.cumsum(torch.unique_consecutive(
- torch.tensor([inp.shape[-1] for inp in x]),
- return_counts=True,
- )[1], 0)
-
- start_idx = 0
- for end_idx in idx_crops:
- _out = self.backbone(torch.cat(x[start_idx: end_idx]))
- if start_idx == 0:
- output = _out
- else:
- output = torch.cat((output, _out))
- start_idx = end_idx
- # Run the head forward on the concatenated features.
- if self.training:
- return self.head(output)
- else:
- return output
-
-
-def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
- warmup_epochs=0, start_warmup_value=0):
- warmup_schedule = np.array([])
- warmup_iters = warmup_epochs * niter_per_ep
- if warmup_epochs > 0:
- warmup_schedule = \
- np.linspace(start_warmup_value, base_value, warmup_iters)
-
- iters = np.arange(epochs * niter_per_ep - warmup_iters)
- schedule = final_value + 0.5 * (base_value - final_value) \
- * (1 + np.cos(np.pi * iters / len(iters)))
-
- schedule = np.concatenate((warmup_schedule, schedule))
- assert len(schedule) == epochs * niter_per_ep
- return schedule
-
-
-def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
- args = get_args()
-
- if args.vision_backbone_type == 'vit':
- student = VitBackbone(config,
- pre_process=pre_process,
- post_process=post_process,
- drop_path_rate=0.1,
- single_token_output=True)
- num_features = args.hidden_size
- elif args.vision_backbone_type == 'mit':
- student = mit_b5_avg(drop_path_rate=0.1)
- num_features = 512
- elif args.vision_backbone_type == 'swin':
- student = get_swin()
- num_features = student.num_features
- else:
- raise Exception('{} vision backbone is not supported.'.format(
- args.vision_backbone_type))
-
- return student, num_features
-
-def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
- args = get_args()
-
- if args.vision_backbone_type == 'vit':
- teacher = VitBackbone(config,
- pre_process=pre_process,
- post_process=post_process,
- single_token_output=True)
- num_features = args.hidden_size
- elif args.vision_backbone_type == 'mit':
- teacher = mit_b5_avg(drop_path_rate=0.0)
- num_features = 512
- elif args.vision_backbone_type == 'swin':
- teacher = get_swin(is_teacher=True)
- num_features = teacher.num_features
- else:
- raise Exception('{} vision backbone is not supported.'.format(
- args.vision_backbone_type))
- return teacher, num_features
-
-
-class DINOPretrainModel(MegatronModule):
- def __init__(self, config, pre_process=True, post_process=True):
- super(DINOPretrainModel, self).__init__()
- args = get_args()
- self.config = config
- self.out_dim = 65536
-
- self.dino_loss = DINOLoss(
- self.out_dim,
- args.dino_local_crops_number + 2,
- args.dino_warmup_teacher_temp,
- args.dino_teacher_temp,
- args.dino_warmup_teacher_temp_epochs,
- 300,
- )
-
- self.pre_process = pre_process
- self.post_process = post_process
- self.momentum_teacher = 0.996
-
- student_backbone, num_features = \
- get_student_backbone_and_num_features(config, pre_process, post_process)
-
- self.student = MultiCropWrapper(
- student_backbone,
- DINOHead(num_features, self.out_dim,
- norm_last_layer=args.dino_norm_last_layer)
- )
-
- self.momentum_schedule = cosine_scheduler(
- self.momentum_teacher, 1,
- args.train_iters // args.iter_per_epoch,
- args.iter_per_epoch
- )
-
- teacher_backbone, num_features = \
- get_teacher_backbone_and_num_features(config, pre_process, post_process)
- self.teacher = MultiCropWrapper(
- teacher_backbone,
- DINOHead(num_features, self.out_dim)
- )
- self.teacher.load_state_dict(self.student.state_dict())
-
- for p in self.teacher.parameters():
- if hasattr(p, "requires_grad") and p.requires_grad is not None:
- p.requires_grad = False
-
- def set_input_tensor(self, tensor):
- pass
-
- def forward(self, input):
- student_output = None
- if self.training:
- student_output = self.student(input)
- teacher_output = self.teacher(input[:2])
- else:
- teacher_output = self.teacher(input)
- return student_output, teacher_output
-
- def cancel_gradients_last_layer(self, iteration):
- args = get_args()
- epoch = iteration // args.iter_per_epoch
- if epoch < args.dino_freeze_last_layer:
- for n, p in self.student.named_parameters():
- if "last_layer" in n:
- p.grad = None
-
- def update_momentum(self, iteration):
- with torch.no_grad():
- m = self.momentum_schedule[iteration]
- for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
- param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
-
diff --git a/megatron/model/vision/esvit_swin_backbone.py b/megatron/model/vision/esvit_swin_backbone.py
deleted file mode 100644
index 70aee3db429bf63680b7c19b4d5acfe25ee2edf5..0000000000000000000000000000000000000000
--- a/megatron/model/vision/esvit_swin_backbone.py
+++ /dev/null
@@ -1,849 +0,0 @@
-# Copyright (c) 2021 Microsoft
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# Modified by Chunyuan Li (chunyl@microsoft.com)
-# Swin Transformer
-# --------------------------------------------------------
-
-import os
-import logging
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from functools import partial
-import torch.distributed as dist
-from torch.nn.init import trunc_normal_
-from megatron.model.transformer import DropPath
-from megatron import get_args
-from megatron.model import LayerNorm
-import numpy as np
-from math import sqrt
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None,
- out_features=None, act_layer=nn.GELU, drop=0.):
- super(Mlp, self).__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super(WindowAttention, self).__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type())
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn_out = attn
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x, attn_out
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
- @staticmethod
- def compute_macs(module, input, output):
- B, N, C = input[0].shape
-
- module.__flops__ += module.flops(N) * B
-
-
-class SwinTransformerBlock(nn.Module):
- r"""Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- self.H = input_resolution[0]
- self.W = input_resolution[1]
-
- self.attn_mask_dict = {}
-
-
- def create_attn_mask(self, H, W):
- # calculate attention mask for SW-MSA
-
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
- img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
-
- def forward(self, x):
- B, L, C = x.shape
- H = int(sqrt(L))
- W = H
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # pad feature maps to multiples of window size
- pad_l = pad_t = 0
- pad_r = (self.window_size - W % self.window_size) % self.window_size
- pad_b = (self.window_size - H % self.window_size) % self.window_size
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-
- if H in self.attn_mask_dict.keys():
- attn_mask = self.attn_mask_dict[H]
- else:
- self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device)
- attn_mask = self.attn_mask_dict[H]
-
- else:
- shifted_x = x
- attn_mask = None
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA
- attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
-
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
-
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x, attn
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r"""Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """ Forward function.
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- """
- B, L, C = x.shape
- H = int(sqrt(L))
- W = H
-
- x = x.view(B, H, W, C)
-
- # padding
- pad_input = (H % 2 == 1) or (W % 2 == 1)
- if pad_input:
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
-
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x):
- for blk in self.blocks:
- x, _ = blk(x)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def forward_with_features(self, x):
- fea = []
- for blk in self.blocks:
- x, _ = blk(x)
- fea.append(x)
- if self.downsample is not None:
- x = self.downsample(x)
- return x, fea
-
- def forward_with_attention(self, x):
- attns = []
- for blk in self.blocks:
- x, attn = blk(x)
- attns.append(attn)
- if self.downsample is not None:
- x = self.downsample(x)
- return x, attns
-
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class PatchEmbed(nn.Module):
- """ Image to Patch Embedding
- """
-
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
- super().__init__()
- img_size = (img_size, img_size)
- patch_size = (patch_size, patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
-
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class SwinTransformer(nn.Module):
- r""" Swin Transformer
- A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
- https://arxiv.org/pdf/2103.14030
- Args:
- img_size (int | tuple(int)): Input image size.
- patch_size (int | tuple(int)): Patch size.
- in_chans (int): Number of input channels.
- num_classes (int): Number of classes for classification head.
- embed_dim (int): Embedding dimension.
- depths (tuple(int)): Depth of Swin Transformer layers.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
- drop_rate (float): Dropout rate.
- attn_drop_rate (float): Attention dropout rate.
- drop_path_rate (float): Stochastic depth rate.
- norm_layer (nn.Module): normalization layer.
- ape (bool): If True, add absolute position embedding to the patch embedding.
- patch_norm (bool): If True, add normalization after patch embedding.
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
- embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs):
- super().__init__()
-
- self.num_classes = num_classes
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
- self.mlp_ratio = mlp_ratio
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
- input_resolution=(patches_resolution[0] // (2 ** i_layer),
- patches_resolution[1] // (2 ** i_layer)),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
- self.layers.append(layer)
-
- self.norm = norm_layer(self.num_features)
- self.avgpool = nn.AdaptiveAvgPool1d(1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- # todo: to be implemented
- return {'relative_position_bias_table'}
-
- def forward(self, x):
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x)
-
- x_region = self.norm(x) # B L C
- x = self.avgpool(x_region.transpose(1, 2)) # B C 1
- x = torch.flatten(x, 1)
-
- return x
-
-
- def forward_feature_maps(self, x):
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x)
-
- x_grid = self.norm(x) # B L C
- x = self.avgpool(x_grid.transpose(1, 2)) # B C 1
- x = torch.flatten(x, 1)
-
- return x, x_grid
-
-
- def forward_selfattention(self, x, n=1):
- # n=1 return the last layer attn map; otherwise return attn maps in all layers
-
-
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- if n==1:
- return self.forward_last_selfattention(x)
- else:
- return self.forward_all_selfattention(x)
-
- def forward_last_selfattention(self, x):
-
- for i, layer in enumerate(self.layers):
- if i < len(self.layers) - 1:
- x = layer(x)
- else:
- x, attns = layer.forward_with_attention(x)
- return attns[-1]
-
- def forward_all_selfattention(self, x):
- attn_out = []
-
- for layer in self.layers:
- x, attns = layer.forward_with_attention(x)
- attn_out += attns
-
- return attn_out
-
-
- def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]):
-
- num_blks = sum(depth)
- start_idx = num_blks - n
-
- sum_cur = 0
- for i, d in enumerate(depth):
- sum_cur_new = sum_cur + d
- if start_idx >= sum_cur and start_idx < sum_cur_new:
- start_stage = i
- start_blk = start_idx - sum_cur
- sum_cur = sum_cur_new
-
-
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- # we will return the averaged token features from the `n` last blocks
- # note: there is no [CLS] token in Swin Transformer
- output = []
- s = 0
- for i, layer in enumerate(self.layers):
- x, fea = layer.forward_with_features(x)
-
- if i >= start_stage:
- for x_ in fea[start_blk:]:
-
- if i == len(self.layers)-1: # use the norm in the last stage
- x_ = self.norm(x_)
-
- x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C
- # print(f'Stage {i}, x_avg {x_avg.shape}')
- output.append(x_avg)
-
- start_blk = 0
-
- return torch.cat(output, dim=-1)
-
-
-
- def flops(self):
- flops = 0
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- if dist.get_rank() == 0:
- print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}")
- flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
- flops += self.num_features * self.num_classes
- return flops
-
- def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
- if os.path.isfile(pretrained):
- pretrained_dict = torch.load(pretrained, map_location='cpu')
- logging.info(f'=> loading pretrained model {pretrained}')
- model_dict = self.state_dict()
- pretrained_dict = {
- k: v for k, v in pretrained_dict.items()
- if k in model_dict.keys()
- }
- need_init_state_dict = {}
- for k, v in pretrained_dict.items():
- need_init = (
- k.split('.')[0] in pretrained_layers
- or pretrained_layers[0] is '*'
- or 'relative_position_index' not in k
- or 'attn_mask' not in k
- )
-
- if need_init:
- if verbose:
- logging.info(f'=> init {k} from {pretrained}')
-
- if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
- relative_position_bias_table_pretrained = v
- relative_position_bias_table_current = model_dict[k]
- L1, nH1 = relative_position_bias_table_pretrained.size()
- L2, nH2 = relative_position_bias_table_current.size()
- if nH1 != nH2:
- logging.info(f"Error in loading {k}, passing")
- else:
- if L1 != L2:
- logging.info(
- '=> load_pretrained: resized variant: {} to {}'
- .format((L1, nH1), (L2, nH2))
- )
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
- relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
- size=(S2, S2),
- mode='bicubic')
- v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
-
- if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
- absolute_pos_embed_pretrained = v
- absolute_pos_embed_current = model_dict[k]
- _, L1, C1 = absolute_pos_embed_pretrained.size()
- _, L2, C2 = absolute_pos_embed_current.size()
- if C1 != C1:
- logging.info(f"Error in loading {k}, passing")
- else:
- if L1 != L2:
- logging.info(
- '=> load_pretrained: resized variant: {} to {}'
- .format((1, L1, C1), (1, L2, C2))
- )
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
- absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
- absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
- absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
- v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
-
- need_init_state_dict[k] = v
- self.load_state_dict(need_init_state_dict, strict=False)
-
- def freeze_pretrained_layers(self, frozen_layers=[]):
- for name, module in self.named_modules():
- if (
- name.split('.')[0] in frozen_layers
- or '.'.join(name.split('.')[0:2]) in frozen_layers
- or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
- ):
- for _name, param in module.named_parameters():
- param.requires_grad = False
- logging.info(
- '=> set param {} requires grad to False'
- .format(name)
- )
- for name, param in self.named_parameters():
- if (
- name.split('.')[0] in frozen_layers
- or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
- and param.requires_grad is True
- ):
- param.requires_grad = False
- logging.info(
- '=> set param {} requires grad to False'
- .format(name)
- )
- return self
-
-
-def get_swin(is_teacher=False):
- args = get_args()
-
- if args.swin_backbone_type == "tiny":
- embed_dim = 96
- depths = [2, 2, 6, 2]
- num_heads = [3, 6, 12, 24]
- drop_path_rate = 0.1
- elif args.swin_backbone_type == 'h3':
- embed_dim = 384
- depths = [2, 2, 18, 2]
- num_heads = [6, 12, 24, 48]
- drop_path_rate = 0.2
- else:
- embed_dim = 128
- depths = [2, 2, 18, 2]
- num_heads = [4, 8, 16, 32]
- drop_path_rate = 0.2
-
- swin = SwinTransformer(
- img_size=224,
- in_chans=3,
- num_classes=1000,
- patch_size=4,
- embed_dim=embed_dim,
- depths=depths,
- num_heads=num_heads,
- window_size=7,
- mlp_ratio=4,
- qkv_bias=True,
- drop_rate=0,
- attn_drop_rate=0,
- drop_path_rate=(0.0 if is_teacher else drop_path_rate),
- norm_layer=partial(LayerNorm, eps=1e-6),
- ape=False,
- patch_norm=True,
- )
-
- return swin
-
diff --git a/megatron/model/vision/inpainting.py b/megatron/model/vision/inpainting.py
deleted file mode 100644
index 6aae9658bc86a110d20feb6f9026d0fb8c2b8f1e..0000000000000000000000000000000000000000
--- a/megatron/model/vision/inpainting.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# This source code is licensed under the BSD license found in the
-# LICENSE file in the root directory of this source tree.
-
-import math
-import apex
-import einops
-import torch
-import torch.nn.functional as F
-from megatron import get_args, print_rank_0
-from megatron.model.utils import get_linear_layer
-from megatron.model.vision.vit_backbone import VitBackbone
-from megatron.model.module import MegatronModule
-from megatron.model.vision.mit_backbone import mit_b3
-from megatron.model.vision.utils import resize
-
-
-class VitInpaintingModel(MegatronModule):
-
- def __init__(self, config, pre_process=True, post_process=True):
- super(VitInpaintingModel, self).__init__()
- args = get_args()
-
- self.config = config
- self.pre_process = pre_process
- self.post_process = post_process
- self.hidden_size = config.hidden_size
- self.backbone = VitBackbone(
- config=config,
- pre_process=self.pre_process,
- post_process=self.post_process,
- class_token=False,
- )
- self.patch_dim = args.patch_dim
- self.img_h = args.img_h
- self.img_w = args.img_w
- self.seq_length = args.seq_length
- # full mask
-
- if self.post_process:
- self.linear_decoder = get_linear_layer(
- self.hidden_size,
- self.backbone.flatten_dim,
- torch.nn.init.zeros_
- )
-
- def set_input_tensor(self, input_tensor):
- self.backbone.set_input_tensor(input_tensor)
-
- def forward(self, input):
-
- hidden_states = self.backbone(input)
-
- if not self.post_process:
- return hidden_states
- decoded_output = self.linear_decoder(hidden_states)
- output = einops.rearrange(
- decoded_output,
- "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
- p1=self.patch_dim,
- p2=self.patch_dim,
- h=self.img_h//self.patch_dim,
- w=self.img_w//self.patch_dim,
- )
-
- return output
-
-
-class MLP(torch.nn.Module):
- """
- Linear Embedding
- """
- def __init__(self, input_dim=2048, embed_dim=768):
- super().__init__()
- self.proj = torch.nn.Linear(input_dim, embed_dim)
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2)
- x = self.proj(x)
- return x
-
-
-class MitInpaintingModel(MegatronModule):
- """Mix vision Transformer Model."""
-
- def __init__(self, pre_process=True, post_process=True):
- super(MitInpaintingModel, self).__init__()
- self.pre_process = pre_process
- self.post_process = post_process
-
- args = get_args()
- self.patch_dim = args.patch_dim
- self.img_h = args.img_h
- self.img_w = args.img_w
- self.flatten_dim = self.patch_dim * self.patch_dim * 3
- self.backbone = mit_b3()
-
- self.in_channels = [64, 128, 320, 512]
- self.embedding_dim = 768
-
- c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
-
- self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
- self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
- self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
- self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
-
- self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
- self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
- self.dropout = torch.nn.Dropout2d(0.1)
-
- self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- pass
-
- def forward(self, input):
- c1, c2, c3, c4 = self.backbone(input)
-
- n, _, h, w = c4.shape
- _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
- _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
-
- _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
- _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
-
- _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
- _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
-
- _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
-
- _c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
- _c = self.conv_fuse(_c)
-
- x = self.norm(_c)
- x = F.relu(x, inplace=True)
- x = self.dropout(x)
-
- x = self.linear_pred(x)
-
- output = einops.rearrange(
- x,
- "b (c p1 p2) h w -> b c (h p1) (w p2)",
- p1=self.patch_dim,
- p2=self.patch_dim,
- h=self.img_h//self.patch_dim,
- w=self.img_w//self.patch_dim,
- )
-
- return output
diff --git a/megatron/model/vision/knn_monitor.py b/megatron/model/vision/knn_monitor.py
deleted file mode 100644
index a7d79854eb51f24443fcdce9e029df5512503bf9..0000000000000000000000000000000000000000
--- a/megatron/model/vision/knn_monitor.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import torch.nn.functional as F
-import torch
-from megatron import print_rank_0, get_args
-from megatron.core import mpu
-from megatron.data.vit_dataset import ClassificationTransform
-from megatron.data.image_folder import ImageFolder
-
-_FEATURE_BANK = None
-
-
-def build_data_loader(dataset, drop_last=True, shuffle=False):
- """Data loader. Note that batch-size is the local (per GPU) batch-size."""
- # Sampler.
- args = get_args()
- micro_batch_size = 16
- num_workers = args.num_workers
- world_size = mpu.get_data_parallel_world_size()
- rank = mpu.get_data_parallel_rank()
- sampler = torch.utils.data.distributed.DistributedSampler(
- dataset, num_replicas=world_size, rank=rank,
- drop_last=drop_last, shuffle=shuffle
- )
-
- # Data loader. Note that batch size is the per GPU batch size.
- data_loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=micro_batch_size,
- sampler=sampler,
- shuffle=False,
- num_workers=num_workers,
- drop_last=not drop_last,
- pin_memory=True,
- )
- return data_loader
-
-
-def compute_feature_bank(model):
- args = get_args()
- global _FEATURE_BANK
- feature_bank = []
- feature_label = []
-
- train_ds = ImageFolder(
- root=args.data_path[0],
- transform=ClassificationTransform((args.img_h, args.img_w), train=False),
- data_per_class_fraction=1.0
- )
- classes = len(train_ds.classes)
- dataloader = build_data_loader(train_ds)
-
- for m in model:
- m.eval()
-
- with torch.no_grad():
- for i, batch in enumerate(dataloader):
- images = batch[0].cuda().contiguous()
- labels = batch[1].cuda().contiguous()
- student_feature, teacher_feature = model[0](images)
- feature = F.normalize(teacher_feature.float(), dim=1)
- feature_bank.append(feature)
- feature_label.append(labels)
-
- for m in model:
- m.train()
-
- # [N', D]
- feature_bank = torch.cat(feature_bank, dim=0).contiguous()
- feature_label = torch.cat(feature_label, dim=0).contiguous()
-
- feature_banks = [torch.zeros_like(feature_bank)
- for i in range(mpu.get_data_parallel_world_size())]
- torch.distributed.all_gather(feature_banks,
- feature_bank,
- group=mpu.get_data_parallel_group())
-
- assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
- feature_bank))
-
- feature_labels = [torch.zeros_like(feature_label)
- for i in range(mpu.get_data_parallel_world_size())]
- torch.distributed.all_gather(feature_labels,
- feature_label,
- group=mpu.get_data_parallel_group())
-
- # [D, N]
- feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
- # [N]
- feature_labels = torch.cat(feature_labels, dim=0).contiguous()
- print_rank_0("feature_banks size is {}".format(feature_banks.size()))
- print_rank_0("feature labels size is {}".format(feature_labels.size()))
-
- _FEATURE_BANK = (feature_banks, feature_labels, classes)
-
-
-def get_feature_bank():
- global _FEATURE_BANK
- assert _FEATURE_BANK is not None
- return _FEATURE_BANK
-
-
-# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
-# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
-# https://github.com/leftthomas/SimCLR
-def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
- # compute cos similarity between each feature vector and feature bank ---> [B, N]
- sim_matrix = torch.mm(feature, feature_bank)
- # [B, K]
- sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
- # [B, K]
- sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
- dim=-1,
- index=sim_indices)
- sim_weight = (sim_weight / knn_t).exp()
-
- # counts for each class
- one_hot_label = torch.zeros(feature.size(0) * knn_k,
- classes,
- device=sim_labels.device)
- # [B*K, C]
- one_hot_label = one_hot_label.scatter(dim=-1,
- index=sim_labels.view(-1, 1),
- value=1.0)
- # weighted score ---> [B, C]
- pred_scores = torch.sum(
- one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
- dim=1)
-
- pred_labels = pred_scores.argsort(dim=-1, descending=True)
- return pred_labels
diff --git a/megatron/model/vision/mit_backbone.py b/megatron/model/vision/mit_backbone.py
deleted file mode 100644
index 6640b105dfce36169b51dd93351469000af4fdbf..0000000000000000000000000000000000000000
--- a/megatron/model/vision/mit_backbone.py
+++ /dev/null
@@ -1,415 +0,0 @@
-# Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from functools import partial
-from torch.nn.init import trunc_normal_
-from megatron.model.transformer import DropPath
-from megatron.model import LayerNorm
-
-
-class Mlp(nn.Module):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.dwconv = DWConv(hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x, H, W):
- x = self.fc1(x)
- x = self.dwconv(x, H, W)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-class Attention(nn.Module):
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.,
- sr_ratio=1):
- super().__init__()
- assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
-
- self.dim = dim
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
- self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- self.sr_ratio = sr_ratio
- if sr_ratio > 1:
- self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
- self.norm = LayerNorm(dim)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x, H, W):
- B, N, C = x.shape
- q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
-
- if self.sr_ratio > 1:
- x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
- x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
- x_ = self.norm(x_)
- kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- else:
- kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- k, v = kv[0], kv[1]
-
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
-
- return x
-
-
-class Block(nn.Module):
-
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
- attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x, H, W):
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
-
- return x
-
-
-class OverlapPatchEmbed(nn.Module):
- """ Image to Patch Embedding
- """
-
- def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
- super().__init__()
- img_size = (img_size, img_size)
- patch_size = (patch_size, patch_size)
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
- padding=(patch_size[0] // 2, patch_size[1] // 2))
- self.norm = LayerNorm(embed_dim)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x):
- x = self.proj(x)
- _, _, H, W = x.shape
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
-
- return x, H, W
-
-
-class MixVisionTransformer(nn.Module):
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
- num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
- attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm,
- depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False):
- super().__init__()
- self.num_classes = num_classes
- self.depths = depths
- self.output_avg = output_avg
-
- # patch_embed
- self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
- embed_dim=embed_dims[0])
- self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
- embed_dim=embed_dims[1])
- self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
- embed_dim=embed_dims[2])
- self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
- embed_dim=embed_dims[3])
-
- # transformer encoder
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
- cur = 0
- self.block1 = nn.ModuleList([Block(
- dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[0])
- for i in range(depths[0])])
- self.norm1 = norm_layer(embed_dims[0])
-
- cur += depths[0]
- self.block2 = nn.ModuleList([Block(
- dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[1])
- for i in range(depths[1])])
- self.norm2 = norm_layer(embed_dims[1])
-
- cur += depths[1]
- self.block3 = nn.ModuleList([Block(
- dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[2])
- for i in range(depths[2])])
- self.norm3 = norm_layer(embed_dims[2])
-
- cur += depths[2]
- self.block4 = nn.ModuleList([Block(
- dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[3])
- for i in range(depths[3])])
- self.norm4 = norm_layer(embed_dims[3])
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def reset_drop_path(self, drop_path_rate):
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
- cur = 0
- for i in range(self.depths[0]):
- self.block1[i].drop_path.drop_prob = dpr[cur + i]
-
- cur += self.depths[0]
- for i in range(self.depths[1]):
- self.block2[i].drop_path.drop_prob = dpr[cur + i]
-
- cur += self.depths[1]
- for i in range(self.depths[2]):
- self.block3[i].drop_path.drop_prob = dpr[cur + i]
-
- cur += self.depths[2]
- for i in range(self.depths[3]):
- self.block4[i].drop_path.drop_prob = dpr[cur + i]
-
- def freeze_patch_emb(self):
- self.patch_embed1.requires_grad = False
-
- def forward_features(self, x):
- B = x.shape[0]
- outs = []
-
- # stage 1
- x, H, W = self.patch_embed1(x)
- for i, blk in enumerate(self.block1):
- x = blk(x, H, W)
- x = self.norm1(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 2
- x, H, W = self.patch_embed2(x)
- for i, blk in enumerate(self.block2):
- x = blk(x, H, W)
- x = self.norm2(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 3
- x, H, W = self.patch_embed3(x)
- for i, blk in enumerate(self.block3):
- x = blk(x, H, W)
- x = self.norm3(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 4
- x, H, W = self.patch_embed4(x)
- for i, blk in enumerate(self.block4):
- x = blk(x, H, W)
- x = self.norm4(x)
- if not self.output_avg:
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- return outs
-
- def forward(self, x):
- x = self.forward_features(x)
-
- if self.output_avg:
- x = x[3].mean(dim=1)
-
- return x
-
-
-class DWConv(nn.Module):
- def __init__(self, dim=768):
- super(DWConv, self).__init__()
- self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
-
- def forward(self, x, H, W):
- B, N, C = x.shape
- x = x.transpose(1, 2).view(B, C, H, W)
- x = self.dwconv(x)
- x = x.flatten(2).transpose(1, 2)
-
- return x
-
-class mit_b0(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b0, self).__init__(
- patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
-class mit_b1(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b1, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
-class mit_b2(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b2, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
-class mit_b3(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b3, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-class mit_b3_avg(MixVisionTransformer):
- def __init__(self, drop_path_rate=0.1, **kwargs):
- super(mit_b3_avg, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
-
-class mit_b4(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b4, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-class mit_b5(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b5, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-class mit_b5_avg(MixVisionTransformer):
- def __init__(self, drop_path_rate=0.1, **kwargs):
- super(mit_b5_avg, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
-
diff --git a/megatron/model/vision/swin_backbone.py b/megatron/model/vision/swin_backbone.py
deleted file mode 100644
index 9a622c7070f5a8335b56ce01398e3e464e3fef6a..0000000000000000000000000000000000000000
--- a/megatron/model/vision/swin_backbone.py
+++ /dev/null
@@ -1,625 +0,0 @@
-# Copyright (c) 2021 Microsoft
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# Swin Transformer
-# --------------------------------------------------------
-
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-from math import sqrt
-
-from megatron import get_args
-from functools import partial
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None,
- out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- self.H = input_resolution[0]
- self.W = input_resolution[1]
-
- self.attn_mask_dict = {}
-
- def create_attn_mask(self, H, W):
- # calculate attention mask for SW-MSA
-
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
- img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
-
- def forward(self, x):
- B, L, C = x.shape
- H = int(sqrt(L))
- W = H
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x)
- else:
- x = blk(x)
- x_b4_ds = x
- if self.downsample is not None:
- x = self.downsample(x)
- return x_b4_ds, x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- assert H == self.img_size[0] and W == self.img_size[1], \
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-
-class SwinTransformer(nn.Module):
- r""" Swin Transformer
- A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
- https://arxiv.org/pdf/2103.14030
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 224
- patch_size (int | tuple(int)): Patch size. Default: 4
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3,
- embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True,
- use_checkpoint=False, output_avg=False, **kwargs):
- super().__init__()
-
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
- self.mlp_ratio = mlp_ratio
- self.img_size = to_2tuple(img_size)
- self.patch_size = to_2tuple(patch_size)
- self.output_avg = output_avg
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build layers
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
- input_resolution=(patches_resolution[0] // (2 ** i_layer),
- patches_resolution[1] // (2 ** i_layer)),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
- use_checkpoint=use_checkpoint)
- self.layers.append(layer)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def forward(self, x):
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- h = self.img_size[0] // self.patch_size[0]
- w = self.img_size[1] // self.patch_size[1]
- outs = []
-
- for i, layer in enumerate(self.layers):
- px, x = layer(x)
- b, n, c = px.shape
-
- if i != len(self.layers) - 1 or not self.output_avg:
- px = px.permute(0, 2, 1).contiguous()
- px = px.reshape(b, c, h, w)
- # is this a fair assumption ?? i think it's baked into the architecture
- h, w = h//2, w//2
- outs.append(px)
-
- if self.output_avg:
- return outs[-1].mean(dim=1)
-
- return outs
-
- def flops(self):
- flops = 0
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
- flops += self.num_features * self.num_classes
- return flops
-
-
-def get_swin(drop_path_rate=0.3, output_avg=False):
- args = get_args()
-
- window_size = 7
- embed_dim = 128
- depths = [2, 2, 18, 2]
- num_heads = [4, 8, 16, 32]
- swin = SwinTransformer(
- img_size=(args.img_h, args.img_w,),
- in_chans=3,
- patch_size=args.patch_dim,
- embed_dim=embed_dim,
- depths=depths,
- num_heads=num_heads,
- window_size=window_size,
- drop_path_rate=drop_path_rate,
- output_avg=output_avg,
- )
-
- return swin
-
diff --git a/megatron/model/vision/utils.py b/megatron/model/vision/utils.py
deleted file mode 100644
index b4068912c8bb234eff54d6b4feae499f7e8ab30c..0000000000000000000000000000000000000000
--- a/megatron/model/vision/utils.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import warnings
-import torch
-import torch.nn.functional as F
-
-
-def resize(input,
- size=None,
- scale_factor=None,
- mode='nearest',
- align_corners=None,
- warning=True):
- if warning:
- if size is not None and align_corners:
- input_h, input_w = tuple(int(x) for x in input.shape[2:])
- output_h, output_w = tuple(int(x) for x in size)
- if output_h > input_h or output_w > output_h:
- if ((output_h > 1 and output_w > 1 and input_h > 1
- and input_w > 1) and (output_h - 1) % (input_h - 1)
- and (output_w - 1) % (input_w - 1)):
- warnings.warn(
- f'When align_corners={align_corners}, '
- 'the output would more aligned if '
- f'input size {(input_h, input_w)} is `x+1` and '
- f'out size {(output_h, output_w)} is `nx+1`')
- if isinstance(size, torch.Size):
- size = tuple(int(x) for x in size)
- return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/megatron/model/vision/vit_backbone.py b/megatron/model/vision/vit_backbone.py
deleted file mode 100644
index 15cf75affcd7cd111d5f2663d707a0026ff708a3..0000000000000000000000000000000000000000
--- a/megatron/model/vision/vit_backbone.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Vision Transformer(VIT) model."""
-
-import math
-import einops
-import torch
-import apex
-import torch.nn.functional as F
-from megatron import get_args
-from megatron.model.transformer import ParallelTransformer
-from megatron.model.utils import (
- get_linear_layer,
- init_method_normal,
- scaled_init_method_normal,
-)
-from megatron.model.module import MegatronModule
-
-CLASS_TOKEN_LENGTH = 8
-
-class VitMlpHead(MegatronModule):
- """Pooler layer.
-
- Pool hidden states of a specific token (for example start of the
- sequence) and add a linear transformation followed by a tanh.
-
- Arguments:
- hidden_size: hidden size
- init_method: weight initialization method for the linear layer.
- bias is set to zero.
- """
-
- def __init__(self, config, hidden_size, num_classes):
- super(VitMlpHead, self).__init__()
- self.config = config
- self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
- self.relu = torch.nn.ReLU()
- self.dense_out = torch.nn.Linear(hidden_size, num_classes)
- torch.nn.init.constant_(self.dense_out.bias, -10)
-
- def forward(self, hidden_states):
- # hidden_states: [b, 1, h]
- # sequence_index: index of the token to pool.
- dense_in_result = self.dense_in(hidden_states)
- tanh_result = torch.tanh(dense_in_result)
- dense_out_result = self.dense_out(tanh_result)
- return dense_out_result
-
-
-def isPerfectSquare(x):
- if(x >= 0):
- sr = math.sqrt(x)
- return (int(sr) * int(sr) == x)
- return False
-
-
-def twod_interpolate_position_embeddings_hook(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
-):
-
- args = get_args()
- num_patches_per_dim_h = args.img_h // args.patch_dim
- num_patches_per_dim_w = args.img_w // args.patch_dim
- num_patches = num_patches_per_dim_h * num_patches_per_dim_w
- hidden_size = args.hidden_size
-
- key = prefix + "weight"
-
- assert key in state_dict
- if key in state_dict:
- input_param = state_dict[key]
-
- input_seq_len = input_param.shape[0]
- assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
- input_has_class_token = not isPerfectSquare(input_seq_len)
- num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
- num_tok_output = num_patches
- output_has_class_token = args.class_token_present
-
- # update input_param and load it to state_dict[key]
- if input_has_class_token:
- input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
- input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
- else:
- input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
- input_param_grid = input_param
-
- assert input_param.shape[1] == hidden_size
-
- if num_tok_input != num_tok_output:
-
- gs_input = int(math.sqrt(num_tok_input))
- gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
-
- input_param_grid = input_param_grid.transpose(0, 1).contiguous()
- input_param_grid = input_param_grid.reshape(
- (1, -1, gs_input, gs_input)
- )
- input_param_grid = input_param_grid.float()
- scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
-
- input_param_grid = F.interpolate(
- input_param_grid, scale_factor=scale_factor, mode="bilinear"
- )
-
- input_param_grid = input_param_grid.half()
- input_param_grid = input_param_grid.reshape((-1, num_tok_output))
- input_param_grid = input_param_grid.transpose(0, 1).contiguous()
-
- assert input_param_grid.shape[1] == hidden_size
-
- input_param = input_param_grid
- assert (
- input_param.shape[0] == num_tok_output
- and input_param.shape[1] == hidden_size
- )
-
- if output_has_class_token:
- input_param = torch.cat((input_param_tok, input_param), dim=0)
-
- state_dict[key] = input_param
-
-
-class VitBackbone(MegatronModule):
- """Vision Transformer Model."""
-
- def __init__(self,
- config,
- pre_process=True,
- post_process=True,
- class_token=True,
- single_token_output=False,
- post_layer_norm=True,
- drop_path_rate=0.0):
- super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False)
- args = get_args()
- self.config = config
-
- self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
-
- self.pre_process = pre_process
- self.post_process = post_process
- self.class_token = class_token
- self.post_layer_norm = post_layer_norm
- self.hidden_size = args.hidden_size
- self.patch_dim = args.patch_dim
- self.img_h = args.img_h
- self.img_w = args.img_w
- self.micro_batch_size = args.micro_batch_size
- self.single_token_output = single_token_output
- self.drop_path_rate = drop_path_rate
-
- assert self.img_h % self.patch_dim == 0
- assert self.img_w % self.patch_dim == 0
- self.num_patches_per_dim_h = self.img_h // self.patch_dim
- self.num_patches_per_dim_w = self.img_w // self.patch_dim
- self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
- self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
- self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
- self.input_tensor = None
- self.position_ids = None
-
- if self.pre_process:
- # cls_token
- if self.class_token:
- self.cls_token = torch.nn.Parameter(
- torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
- )
- torch.nn.init.zeros_(self.cls_token)
- self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
-
- # Linear encoder
- self.linear_encoder = torch.nn.Linear(
- self.flatten_dim, self.hidden_size
- )
-
- # embedding
- self.position_embeddings = torch.nn.Embedding(
- self.seq_length, self.hidden_size
- )
- init_method_normal(args.init_method_std)(
- self.position_embeddings.weight
- )
-
- args.class_token_present = self.class_token
- self.position_embeddings._register_load_state_dict_pre_hook(
- twod_interpolate_position_embeddings_hook
- )
-
- self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
-
- # Transformer
- self.transformer = ParallelTransformer(
- config,
- model_type=args.model_type,
- pre_process=self.pre_process,
- post_process=self.post_process,
- post_layer_norm=self.post_layer_norm,
- drop_path_rate=self.drop_path_rate
- )
-
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- self.transformer.set_input_tensor(input_tensor)
-
- def forward(self, input):
-
- if self.pre_process:
- rearranged_input = einops.rearrange(
- input,
- "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
- p1=self.patch_dim,
- p2=self.patch_dim,
- )
-
- assert rearranged_input.dtype == torch.half
- encoder_output = self.linear_encoder(rearranged_input)
-
- concatenated_tokens = encoder_output
- if self.class_token:
- cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
- concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
-
- token_embeddings = concatenated_tokens + \
- self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
- # [b, s, h] => [s, b, h]
- token_embeddings = token_embeddings.transpose(0, 1).contiguous()
- hidden_states = self.embedding_dropout(token_embeddings)
- else:
- hidden_states = input
-
- hidden_states = self.transformer(hidden_states, None)
-
- if self.post_process:
- # [s b h] => [b s h]
- if self.single_token_output:
- hidden_states = hidden_states[0]
- else:
- hidden_states = hidden_states.transpose(0, 1).contiguous()
-
- return hidden_states
-
diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/mpu/tests/__init__.py b/megatron/mpu/tests/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/megatron/mpu/tests/commons.py b/megatron/mpu/tests/commons.py
deleted file mode 100644
index 611daf0f66692426ee5ad59824f3c421d7b94a90..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/commons.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-import argparse
-import os
-import random
-import numpy
-import torch
-
-import mpu
-
-
-class IdentityLayer(torch.nn.Module):
- def __init__(self, size, scale=1.0):
- super(IdentityLayer, self).__init__()
- self.weight = torch.nn.Parameter(scale * torch.randn(size))
-
- def forward(self):
- return self.weight
-
-
-def set_random_seed(seed):
- """Set random seed for reproducability."""
- random.seed(seed)
- numpy.random.seed(seed)
- torch.manual_seed(seed)
- mpu.model_parallel_cuda_manual_seed(seed)
-
-
-def initialize_distributed(backend='nccl'):
- """Initialize torch.distributed."""
- # Get local rank in case it is provided.
- parser = argparse.ArgumentParser()
- parser.add_argument('--local_rank', type=int, default=None,
- help='local rank passed from distributed launcher')
- args = parser.parse_args()
- local_rank = args.local_rank
-
- # Get rank and world size.
- rank = int(os.getenv('RANK', '0'))
- world_size = int(os.getenv("WORLD_SIZE", '1'))
-
- print('> initializing torch.distributed with local rank: {}, '
- 'rank: {}, world size: {}'.format(local_rank, rank, world_size))
-
- # Set the device id.
- device = rank % torch.cuda.device_count()
- if local_rank is not None:
- device = local_rank
- torch.cuda.set_device(device)
-
- # Call the init process.
- init_method = 'tcp://'
- master_ip = os.getenv('MASTER_ADDR', 'localhost')
- master_port = os.getenv('MASTER_PORT', '6000')
- init_method += master_ip + ':' + master_port
- torch.distributed.init_process_group(
- backend=backend,
- world_size=world_size,
- rank=rank,
- init_method=init_method)
-
-
-def print_separator(message):
- torch.distributed.barrier()
- filler_len = (78 - len(message)) // 2
- filler = '-' * filler_len
- string = '\n' + filler + ' {} '.format(message) + filler
- if torch.distributed.get_rank() == 0:
- print(string, flush=True)
- torch.distributed.barrier()
diff --git a/megatron/mpu/tests/test_cross_entropy.py b/megatron/mpu/tests/test_cross_entropy.py
deleted file mode 100644
index 00ae42228a9259e12640034a911899b6386882bc..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/test_cross_entropy.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from commons import set_random_seed
-from commons import IdentityLayer
-from commons import print_separator
-from commons import initialize_distributed
-from mpu.cross_entropy import vocab_parallel_cross_entropy
-import mpu
-import torch.nn.functional as F
-import torch
-import random
-import sys
-sys.path.append("../..")
-
-
-def torch_cross_entropy(batch_size, seq_length, vocab_size,
- logits_scale, seed):
- set_random_seed(seed)
- identity = IdentityLayer((batch_size, seq_length, vocab_size),
- scale=logits_scale).cuda()
- logits = identity()
- target = torch.cuda.LongTensor(
- size=(batch_size, seq_length)).random_(0, vocab_size)
- loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
- target.view(-1),
- reduction='none').view_as(target).mean()
- loss.backward()
- return loss, identity.weight.grad
-
-
-def mpu_cross_entropy(batch_size, seq_length, vocab_size,
- logits_scale, seed):
- set_random_seed(seed)
- identity = IdentityLayer((batch_size, seq_length, vocab_size),
- scale=logits_scale).cuda()
- logits = identity()
- logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
- target = torch.cuda.LongTensor(
- size=(batch_size, seq_length)).random_(0, vocab_size)
- loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
- loss.backward()
- return loss, identity.weight.grad
-
-
-def test_cross_entropy(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing cross entropy with model parallel size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- batch_size = 13
- seq_length = 17
- vocab_size_per_partition = 11
- logits_scale = 1000.0
- vocab_size = vocab_size_per_partition * tensor_model_parallel_size
- seed = 1234
-
- loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
- vocab_size, logits_scale,
- seed)
- loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
- vocab_size, logits_scale,
- seed)
-
- error = loss_torch.sub_(loss_mpu).abs().max()
- print(' max error in loss on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- error = grad_torch.sub_(grad_mpu).abs().max()
- print(' max error in grad on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Reset groups
- mpu.destroy_tensor_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-if __name__ == '__main__':
-
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test cross entropy')
- test_cross_entropy(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
diff --git a/megatron/mpu/tests/test_data.py b/megatron/mpu/tests/test_data.py
deleted file mode 100644
index c30bf4bb8d4dbb0c2d576d20b18b4ae640d00d2c..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/test_data.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from commons import print_separator
-from commons import initialize_distributed
-from mpu import data as data_utils
-import mpu
-import torch
-import functools
-import operator
-import sys
-sys.path.append("../..")
-
-
-def test_broadcast_data(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing broadcast_data with model parallel size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- torch.manual_seed(1234 + mpu.get_data_parallel_rank())
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- key_size_t = {'key1': [7, 11],
- 'key2': [8, 2, 1],
- 'key3': [13],
- 'key4': [5, 1, 2],
- 'key5': [5, 12]}
- keys = list(key_size_t.keys())
-
- data = {}
- data_t = {}
- for key in key_size_t:
- data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
- data_t[key] = data[key].clone()
- data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
- data_t['keyX'] = data['keyX'].clone()
- if mpu.get_tensor_model_parallel_rank() != 0:
- data = None
-
- data_utils._check_data_types(keys, data_t, torch.int64)
- key_size, key_numel, \
- total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
- for key in keys:
- assert key_size[key] == key_size_t[key]
- total_numel_t = 0
- for key in keys:
- target_size = functools.reduce(operator.mul, key_size_t[key], 1)
- assert key_numel[key] == target_size
- total_numel_t += target_size
- assert total_numel == total_numel_t
-
- data_b = data_utils.broadcast_data(keys, data, torch.int64)
- for key in keys:
- tensor = data_t[key].cuda()
- assert data_b[key].sub(tensor).abs().max() == 0
-
- # Reset groups
- mpu.destroy_tensor_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-if __name__ == '__main__':
-
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test test broadcast data')
- test_broadcast_data(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
diff --git a/megatron/mpu/tests/test_initialize.py b/megatron/mpu/tests/test_initialize.py
deleted file mode 100644
index e5d2be37e269d8176a987b8a6ef5d7f47de98394..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/test_initialize.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from commons import print_separator
-from commons import initialize_distributed
-import mpu
-import torch
-import sys
-sys.path.append("../..")
-
-
-def test_initialize_model_parallel(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing initialize_model_parallel with size {} ...'.format(
- tensor_model_parallel_size))
- tensor_model_parallel_size_ = min(tensor_model_parallel_size,
- torch.distributed.get_world_size())
- assert not mpu.model_parallel_is_initialized()
- mpu.initialize_model_parallel(tensor_model_parallel_size_)
- assert mpu.model_parallel_is_initialized()
-
- # Checks.
- def check(group, world_size, rank):
- assert world_size == torch.distributed.get_world_size(group=group)
- assert rank == torch.distributed.get_rank(group=group)
-
- # Model parallel.
- world_size = tensor_model_parallel_size_
- rank = torch.distributed.get_rank() % tensor_model_parallel_size_
- assert world_size == mpu.get_tensor_model_parallel_world_size()
- assert rank == mpu.get_tensor_model_parallel_rank()
- check(mpu.get_tensor_model_parallel_group(), world_size, rank)
-
- # Data parallel.
- world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
- rank = torch.distributed.get_rank() // tensor_model_parallel_size
- assert world_size == mpu.get_data_parallel_world_size()
- assert rank == mpu.get_data_parallel_rank()
- check(mpu.get_data_parallel_group(), world_size, rank)
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
-
- if torch.distributed.get_rank() == 0:
- print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
- tensor_model_parallel_size_))
- tensor_model_parallel_size = min(tensor_model_parallel_size_,
- torch.distributed.get_world_size())
- assert not mpu.model_parallel_is_initialized()
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- assert mpu.model_parallel_is_initialized()
-
- # Checks
- src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
- assert mpu.get_tensor_model_parallel_src_rank() == src_rank
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-if __name__ == '__main__':
-
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test initialize model parallel')
- test_initialize_model_parallel(tensor_model_parallel_size)
- print_separator('test model parallel source rank')
- test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
diff --git a/megatron/mpu/tests/test_layers.py b/megatron/mpu/tests/test_layers.py
deleted file mode 100644
index 73ad4b9459502dc2f68a8e3d0cb66157895eda1d..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/test_layers.py
+++ /dev/null
@@ -1,517 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from mpu import layers
-from commons import set_random_seed
-from commons import print_separator
-from commons import initialize_distributed
-import mpu
-from torch.nn.parameter import Parameter
-import torch.nn.init as init
-import torch
-import random
-import sys
-sys.path.append("../..")
-
-
-def test_parallel_embedding(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing parallel embedding with model parallel size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- batch_size = 17
- seq_length = 23
- vocab_size = 48
- hidden_size = 16
- seed = 1236
-
- set_random_seed(123)
- input_data = torch.LongTensor(
- size=(batch_size, seq_length)).random_(0, vocab_size).cuda()
- loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
-
- set_random_seed(seed)
- embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
-
- output = embedding_original(input_data)
- loss_original = torch.mul(output, loss_weight).sum()
- loss_original.backward()
-
- set_random_seed(seed)
- embedding_parallel = layers.ParallelEmbedding(
- vocab_size, hidden_size, init_method=init.normal_).cuda()
- output = embedding_parallel(input_data)
- loss_parallel = torch.mul(output, loss_weight).sum()
- loss_parallel.backward()
-
- set_random_seed(seed)
- embedding_vocab_parallel = layers.VocabParallelEmbedding(
- vocab_size, hidden_size, init_method=init.normal_).cuda()
- output = embedding_vocab_parallel(input_data)
- loss_vocab_parallel = torch.mul(output, loss_weight).sum()
- loss_vocab_parallel.backward()
-
- torch.distributed.barrier()
- error = loss_parallel.sub(loss_original).abs()
- print(' error in loss (parallel) on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-12, 'error: {}'.format(error)
-
- torch.distributed.barrier()
- error = loss_vocab_parallel.sub(loss_original).abs()
- print(' error in loss (vocab parallel) on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-12, 'error: {}'.format(error)
-
- weight_grad_orig = torch.split(embedding_original.weight.grad,
- hidden_size // tensor_model_parallel_size,
- 1)[mpu.get_tensor_model_parallel_rank()]
- error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
- print(' error in grad (parallel) on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-12, 'error: {}'.format(error)
-
- weight_grad_orig = torch.split(embedding_original.weight.grad,
- vocab_size // tensor_model_parallel_size,
- 0)[mpu.get_tensor_model_parallel_rank()]
- error = embedding_vocab_parallel.weight.grad.sub(
- weight_grad_orig).abs().max()
- print(' error in grad (vocab parallel) on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-12, 'error: {}'.format(error)
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-def test_initialize_affine_weight(tensor_model_parallel_size):
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- if torch.distributed.get_rank() == 0:
- print('> testing initialize_affine_weight with model parallel '
- 'size: {}'.format(tensor_model_parallel_size))
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed = 12345
- input_size_coeff = 13
- input_size = input_size_coeff * tensor_model_parallel_size
- output_size_coeff = 17
- output_size = output_size_coeff * tensor_model_parallel_size
-
- # ---------------
- # Column parallel
- # ---------------
- weight = torch.empty(output_size_coeff, input_size)
- set_random_seed(seed)
- layers._initialize_affine_weight(weight, output_size, input_size,
-
- output_size_coeff, 0,
- torch.nn.init.normal_)
- # Target.
- set_random_seed(seed)
- master_weight = torch.empty(output_size, input_size)
- torch.nn.init.normal_(master_weight)
- rank = mpu.get_tensor_model_parallel_rank()
- my_weight = torch.split(master_weight, output_size_coeff,
- dim=0)[rank].contiguous().clone()
-
- # Compare.
- error = weight.sub(my_weight).abs().max()
- torch.distributed.barrier()
- print(' column parallel max error (should be zero) on global rank '
- '{}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # ------------
- # Row parallel
- # ------------
- weight = torch.empty(output_size, input_size_coeff)
- set_random_seed(seed)
- mpu.layers._initialize_affine_weight(weight, output_size, input_size,
- input_size_coeff, 1,
- torch.nn.init.normal_)
- # Target.
- set_random_seed(seed)
- master_weight = torch.empty(output_size, input_size)
- torch.nn.init.normal_(master_weight)
- rank = mpu.get_tensor_model_parallel_rank()
- my_weight = torch.split(master_weight, input_size_coeff,
- dim=1)[rank].contiguous().clone()
-
- # Compare.
- error = weight.sub(my_weight).abs().max()
- torch.distributed.barrier()
- print(' row parallel max error (should be zero) on global rank '
- '{}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(' >> passed the test :-)')
-
-
-class IdentityLayer2D(torch.nn.Module):
- def __init__(self, m, n):
- super(IdentityLayer2D, self).__init__()
- self.weight = Parameter(torch.Tensor(m, n))
- torch.nn.init.xavier_normal_(self.weight)
-
- def forward(self):
- return self.weight
-
-
-def test_column_parallel_linear(tensor_model_parallel_size):
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- if torch.distributed.get_rank() == 0:
- print('> testing ColumnParallelLinear with model parallel '
- 'size: {}'.format(tensor_model_parallel_size))
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed = 12345
- set_random_seed(seed)
- input_size_coeff = 13
- input_size = input_size_coeff * tensor_model_parallel_size
- output_size_coeff = 17
- output_size = output_size_coeff * tensor_model_parallel_size
- batch_size = 7
-
- # Network
- identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
- linear_layer = mpu.ColumnParallelLinear(
- input_size, output_size, keep_master_weight_for_test=True).cuda()
- loss_weight = torch.randn([batch_size, output_size]).cuda()
- # Forward
- input_ = identity_layer()
- output = linear_layer(input_)
- loss = torch.mul(output, loss_weight).sum()
- # Backward
- loss.backward()
-
- # Values.
- dLdY = loss_weight
- X = identity_layer.weight
- A = linear_layer.master_weight.cuda()
- dLdA = torch.matmul(dLdY.t(), X)
- dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
- dLdX = torch.matmul(dLdY, A)
-
- rank = mpu.get_tensor_model_parallel_rank()
- my_dLdA = torch.split(dLdA, output_size_coeff,
- dim=0)[rank].contiguous().clone()
- error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdA on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- my_dLdb = torch.split(dLdb, output_size_coeff,
- dim=0)[rank].contiguous().clone()
- error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdb on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- error = dLdX.sub(identity_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdX on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(' >> passed the test :-)')
-
-
-def test_row_parallel_linear(tensor_model_parallel_size):
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- if torch.distributed.get_rank() == 0:
- print('> testing RowParallelLinear with model parallel '
- 'size: {}'.format(tensor_model_parallel_size))
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed = 12345
- set_random_seed(seed)
- input_size_coeff = 13
- input_size = input_size_coeff * tensor_model_parallel_size
- output_size_coeff = 17
- output_size = output_size_coeff * tensor_model_parallel_size
- batch_size = 7
-
- # Network
- identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
- linear_layer = mpu.RowParallelLinear(
- input_size, output_size, keep_master_weight_for_test=True).cuda()
- loss_weight = torch.randn([batch_size, output_size]).cuda()
- # Forward
- input_ = identity_layer()
- output = linear_layer(input_)
- loss = torch.mul(output, loss_weight).sum()
- # Backward
- loss.backward()
-
- # Values.
- dLdY = loss_weight
- X = identity_layer.weight
- A = linear_layer.master_weight.cuda()
- dLdA = torch.matmul(dLdY.t(), X)
- dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
- dLdX = torch.matmul(dLdY, A)
-
- rank = mpu.get_tensor_model_parallel_rank()
- my_dLdA = torch.split(dLdA, input_size_coeff,
- dim=1)[rank].contiguous().clone()
- error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdA on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- error = dLdb.sub(linear_layer.bias.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdb on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- error = dLdX.sub(identity_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' error in dLdX on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(' >> passed the test :-)')
-
-
-class IdentityLayer3D(torch.nn.Module):
- def __init__(self, m, n, k):
- super(IdentityLayer3D, self).__init__()
- self.weight = Parameter(torch.Tensor(m, n, k))
- torch.nn.init.xavier_normal_(self.weight)
-
- def forward(self):
- return self.weight
-
-
-def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
- hidden_size_per_att_head, dropout_prob, batch_size,
- sequence_length):
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed = 12345
- set_random_seed(seed)
-
- num_att_heads = num_att_heads_per_partition * \
- torch.distributed.get_world_size()
- hidden_size = hidden_size_per_att_head * num_att_heads
-
- # Network
- identity_layer = IdentityLayer3D(batch_size, sequence_length,
- hidden_size).cuda()
- attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
- dropout_prob).cuda()
- loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
- attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
- # Forward
- input_ = identity_layer()
- output = attention_layer(input_, attention_mask)
- loss = torch.mul(output, loss_weight).sum()
- # Backward
- loss.backward()
-
- rank = mpu.get_tensor_model_parallel_rank()
- mpu.destroy_model_parallel()
- return rank, hidden_size, tensor_model_parallel_size, loss, \
- attention_layer, identity_layer
-
-
-def test_parallel_self_attention(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing ParallelSelfAttention with model parallel '
- 'size: {}'.format(tensor_model_parallel_size))
-
- num_att_heads_per_partition = 3
- hidden_size_per_att_head = 7
- dropout_prob = 0.0 # has to be zero
- batch_size = 5
- sequence_length = 13
-
- rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
- attention_layer_1, identity_layer_1 = parallel_self_attention(
- 1, num_att_heads_per_partition,
- hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
-
- rank, hidden_size, tensor_model_parallel_size, loss, \
- attention_layer, identity_layer = parallel_self_attention(
- tensor_model_parallel_size, num_att_heads_per_partition,
- hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
- assert hideen_size_1 == hidden_size
-
- error = loss_1.sub(loss).abs().max()
- torch.distributed.barrier()
- print(' loss error on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 5.0e-6
-
- my_lin_grad_list = torch.split(
- attention_layer_1.query_key_value.weight.grad,
- hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
- my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
- error = my_lin_grad.sub(
- attention_layer.query_key_value.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' weight gradient error on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 5.0e-6
-
- error = identity_layer_1.weight.grad.sub(
- identity_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' input gradient error on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 5.0e-6
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(' >> passed the test :-)')
-
-
-def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
- hidden_size_per_att_head, batch_size, sequence_length):
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed = 12345
- set_random_seed(seed)
-
- num_att_heads = num_att_heads_per_partition * \
- torch.distributed.get_world_size()
- hidden_size = hidden_size_per_att_head * num_att_heads
- intermediate_size = 4 * hidden_size
-
- # Network
- identity_layer = IdentityLayer3D(batch_size, sequence_length,
- hidden_size).cuda()
- transformer_layer = mpu.BertParallelTransformerLayer(
- hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
- torch.nn.functional.relu, 1.0e-5).cuda()
-
- loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
- attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
- # Forward
- input_ = identity_layer()
- output = transformer_layer(input_, attention_mask)
- loss = torch.mul(output, loss_weight).sum()
- # Backward
- loss.backward()
-
- rank = mpu.get_tensor_model_parallel_rank()
- mpu.destroy_model_parallel()
- return rank, hidden_size, tensor_model_parallel_size, loss, \
- transformer_layer, identity_layer
-
-
-def test_parallel_transformer_layer(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing ParallelTransformerLayer with model parallel '
- 'size: {}'.format(tensor_model_parallel_size))
-
- num_att_heads_per_partition = 3
- hidden_size_per_att_head = 7
- batch_size = 5
- sequence_length = 13
-
- rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
- transformer_layer_1, identity_layer_1 = parallel_transformer(
- 1, num_att_heads_per_partition,
- hidden_size_per_att_head, batch_size, sequence_length)
-
- rank, hidden_size, tensor_model_parallel_size, loss, \
- transformer_layer, identity_layer = parallel_transformer(
- tensor_model_parallel_size, num_att_heads_per_partition,
- hidden_size_per_att_head, batch_size, sequence_length)
-
- error = loss_1.sub(loss).abs().max()
- torch.distributed.barrier()
- print(' loss error on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 5.0e-5, 'error: {}'.format(error)
-
- error = identity_layer_1.weight.grad.sub(
- identity_layer.weight.grad).abs().max()
- torch.distributed.barrier()
- print(' input gradient error on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 5.0e-5, 'error: {}'.format(error)
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(' >> passed the test :-)')
-
-
-if __name__ == '__main__':
-
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
-
- print_separator('test initialize affine weight')
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- test_initialize_affine_weight(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test parallel embedding')
- test_parallel_embedding(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- print_separator('test column-parallel linear')
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- test_column_parallel_linear(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- print_separator('test row-parallel linear')
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- test_row_parallel_linear(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- print_separator('test parallel self-attention')
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- test_parallel_self_attention(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- print_separator('test parallel transformer')
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- test_parallel_transformer_layer(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
diff --git a/megatron/mpu/tests/test_random.py b/megatron/mpu/tests/test_random.py
deleted file mode 100644
index 8ee6942cf01fd7d9c93012c37f7b5e4b351f3c15..0000000000000000000000000000000000000000
--- a/megatron/mpu/tests/test_random.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from commons import print_separator
-from commons import initialize_distributed
-import mpu
-import torch
-import sys
-sys.path.append("../..")
-
-
-def test_set_cuda_rng_state(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing set_rng_state with size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- size = 123
- seed = 1234
- torch.cuda.manual_seed(1234)
- tensor = torch.cuda.FloatTensor(size)
-
- # Get the state
- rng_state = torch.cuda.get_rng_state()
- rng_state_copy = rng_state.clone()
-
- # Do some stuff.
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_1 = tensor.clone()
-
- assert rng_state.sub(rng_state_copy).max() == 0
- assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
-
- # State should be different.
- new_rng_state = torch.cuda.get_rng_state()
- max_diff = new_rng_state.sub(rng_state).max()
- print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
- format(torch.distributed.get_rank(), max_diff))
- assert max_diff > 0
-
- # Reset the rng state and do the same stuff.
- mpu.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- mpu.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_2 = tensor.clone()
-
- # Results should be the same
- error = result_2.sub(result_1).abs().max()
- print(' max error in generated tensors (should be zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Input state should have remained intact.
- error = rng_state.sub(rng_state_copy).max()
- print(' max error in rng state (should be zero) on global rank {}: {}'.
- format(torch.distributed.get_rank(), error))
- assert error == 0
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-def test_cuda_rng_tracker(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing cuda rng tracker with size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- seed_1 = 1234
- seed_2 = 4321
- size = [12, 21]
- tensor = torch.cuda.FloatTensor(size)
-
- # Set to seed_1 and generate two tensors.
- torch.cuda.manual_seed(seed_1)
- torch.randn(size, out=tensor)
- target_11 = tensor.clone()
- torch.randn(size, out=tensor)
- target_12 = tensor.clone()
-
- # Set to seed_2 and generate two tensors.
- torch.cuda.manual_seed(seed_2)
- torch.randn(size, out=tensor)
- target_21 = tensor.clone()
- torch.randn(size, out=tensor)
- target_22 = tensor.clone()
-
- # Now if we interleave seed_1 and seed_2,
- # we should still get the same tensors
- torch.cuda.manual_seed(seed_1)
- mpu.get_cuda_rng_tracker().add('test', seed_2)
-
- torch.randn(size, out=tensor)
- result_11 = tensor.clone()
-
- with mpu.get_cuda_rng_tracker().fork('test'):
- torch.randn(size, out=tensor)
- result_21 = tensor.clone()
-
- torch.randn(size, out=tensor)
- result_12 = tensor.clone()
-
- with mpu.get_cuda_rng_tracker().fork('test'):
- torch.randn(size, out=tensor)
- result_22 = tensor.clone()
-
- diff = result_11.sub(result_21).abs().max()
- diff = min(diff, result_12.sub(result_22).abs().max())
- print(' max diff in generated tensors (should be non-zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
- assert diff > 1.0e-6
- error = max(result_11.sub(target_11).abs().max(),
- result_12.sub(target_12).abs().max())
- error = max(error, result_21.sub(target_21).abs().max())
- error = max(error, result_22.sub(target_22).abs().max())
- print(' max error in generated tensors (should be zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
-
- # Reset the tracker
- mpu.get_cuda_rng_tracker().reset()
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
-
- if torch.distributed.get_rank() == 0:
- print('> testing model parallel cuda manual seed with size {} ...'.
- format(tensor_model_parallel_size))
-
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
-
- mpu.model_parallel_cuda_manual_seed(12345)
- assert torch.cuda.initial_seed() == 12345
- with mpu.get_cuda_rng_tracker().fork():
- assert torch.cuda.initial_seed() == (12345 + 2718 +
- mpu.get_tensor_model_parallel_rank())
-
- # Reset the tracker
- mpu.get_cuda_rng_tracker().reset()
-
- # Reset groups
- mpu.destroy_model_parallel()
-
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
-
-
-if __name__ == '__main__':
-
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test set rng state')
- test_set_cuda_rng_state(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test cuda rng tracker')
- test_cuda_rng_tracker(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
-
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test model parallel cuda manual seed')
- test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
deleted file mode 100644
index 33744a2f3aded055c21d18700c32e3c240d440c0..0000000000000000000000000000000000000000
--- a/megatron/optimizer/__init__.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-from apex.optimizers import FusedAdam as Adam
-from apex.optimizers import FusedSGD as SGD
-
-from megatron import get_args
-
-from .distrib_optimizer import DistributedOptimizer
-from .grad_scaler import ConstantGradScaler, DynamicGradScaler
-from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
-
-def get_param_groups(modules,
- no_weight_decay_cond,
- scale_lr_cond,
- lr_mult):
- """creates param groups based on weight decay condition (regularized vs non regularized)
- and learning rate scale condition (args.lr vs lr_mult * args.lr)
- scale_lr_cond is used during finetuning where head of the network requires a scaled
- version of the base learning rate.
- """
- wd_no_scale_lr = []
- wd_scale_lr = []
- no_wd_no_scale_lr = []
- no_wd_scale_lr = []
- for module in modules:
- for name, param in module.named_parameters():
- if not param.requires_grad:
- continue
-
- if no_weight_decay_cond is not None:
- no_wd = no_weight_decay_cond(name, param)
- else:
- # do not regularize biases nor Norm parameters
- no_wd = name.endswith(".bias") or len(param.shape) == 1
-
- if scale_lr_cond is not None:
- scale_lr = scale_lr_cond(name, param)
- else:
- scale_lr = False
-
- if not no_wd and not scale_lr:
- wd_no_scale_lr.append(param)
- elif not no_wd and scale_lr:
- wd_scale_lr.append(param)
- elif no_wd and not scale_lr:
- no_wd_no_scale_lr.append(param)
- else:
- no_wd_scale_lr.append(param)
-
- param_groups = []
- if len(wd_no_scale_lr):
- param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
- if len(wd_scale_lr):
- param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
- if len(no_wd_no_scale_lr):
- param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
- if len(no_wd_scale_lr):
- param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
-
- return param_groups
-
-def get_megatron_optimizer(model,
- no_weight_decay_cond=None,
- scale_lr_cond=None,
- lr_mult=1.0):
- args = get_args()
-
- # Base optimizer.
- param_groups = get_param_groups(model,
- no_weight_decay_cond,
- scale_lr_cond,
- lr_mult)
-
- if args.optimizer == 'adam':
- optimizer = Adam(param_groups,
- lr=args.lr,
- weight_decay=args.weight_decay,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_eps)
- elif args.optimizer == 'sgd':
- optimizer = SGD(param_groups,
- lr=args.lr,
- weight_decay=args.weight_decay,
- momentum=args.sgd_momentum)
- else:
- raise Exception('{} optimizer is not supported.'.format(
- args.optimizer))
-
- # Determine whether the params have main-grad field.
- params_have_main_grad = True
-
- # Mixed precision optimizer.
- # - Note: both the Float16Optimizer and the DistributedOptimizer inherit
- # from the MixedPrecisionOptimizer, which manages any optimizer where
- # the model params and main params are distinct.
- if args.fp16 or args.bf16 or args.use_distributed_optimizer:
-
- # Grad scaler:
- # if loss-scale is provided, instantiate the constant scaler.
- # if we are using fp16 and loss-scale is not present, use a
- # dynamic scaler.
- # otherwise we are running in bf16 with no loss-scale so
- # leave it as None.
- grad_scaler = None
-
- # Constant loss scale.
- if args.loss_scale:
- grad_scaler = ConstantGradScaler(args.loss_scale)
-
- # Dynamic loss scale.
- else:
- if args.fp16:
- grad_scaler = DynamicGradScaler(
- initial_scale=args.initial_loss_scale,
- min_scale=args.min_loss_scale,
- growth_factor=2.0,
- backoff_factor=0.5,
- growth_interval=args.loss_scale_window,
- hysteresis=args.hysteresis)
-
- # Megatron optimizer.
- opt_ty = DistributedOptimizer \
- if args.use_distributed_optimizer else \
- Float16OptimizerWithFloat16Params
- return opt_ty(optimizer,
- args.clip_grad,
- args.log_num_zeros_in_grad,
- args.check_for_nan_in_loss_and_grad,
- params_have_main_grad,
- args.fp16,
- args.bf16,
- args.params_dtype,
- grad_scaler,
- model)
-
- # FP32.
- return FP32Optimizer(optimizer, args.clip_grad,
- args.log_num_zeros_in_grad,
- args.check_for_nan_in_loss_and_grad,
- params_have_main_grad,
- model)
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
deleted file mode 100644
index d6e38afb58a01ce546e6939bdc6a381c943942e2..0000000000000000000000000000000000000000
--- a/megatron/optimizer/clip_grads.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Gradient clipping."""
-
-import os
-
-import torch
-from torch import inf
-
-from apex.multi_tensor_apply import multi_tensor_applier
-import amp_C
-
-from megatron.model.module import param_is_not_shared
-from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
-
-
-def clip_grad_norm_fp32(parameters, grads_for_norm,
- max_norm, check_for_nan_in_grad,
- norm_type=2, model_parallel_group=None):
- """Clips gradient norm of an iterable of parameters whose gradients
- are in fp32.
-
- This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
- added functionality to handle model parallel parameters. Note that
- the gradients are modified in place.
-
- Arguments:
- parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
- single Tensor that will have gradients normalized
- grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
- Tensor that will be used for calculating the grad norm.
- max_norm (float or int): max norm of the gradients.
- check_for_nan_in_grad (bool): check if gradients have a NaN.
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
- infinity norm.
- model_parallel_group (group): given the nature of the distributed
- optimizer, this is passed as an argument.
-
- Returns:
- Total norm of the parameters (viewed as a single vector).
- """
-
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- if isinstance(grads_for_norm, torch.Tensor):
- grads_for_norm = [grads_for_norm]
-
- # Grads.
- grads = []
- for param in parameters:
- if param.grad is not None:
- assert param.grad.type() == 'torch.cuda.FloatTensor'
- grads.append(param.grad.detach())
-
- # Norm parameters.
- max_norm = float(max_norm)
- norm_type = float(norm_type)
- total_norm = 0.0
-
- # Calculate norm.
- if norm_type == inf:
- total_norm = max(grad.abs().max() for grad in grads_for_norm)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- # Take max across all model-parallel GPUs.
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=model_parallel_group)
- total_norm = total_norm_cuda[0].item()
-
- else:
- if norm_type == 2.0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- # Use apex's multi-tensor applier for efficiency reasons.
- # Multi-tensor applier takes a function and a list of list
- # and performs the operation on that list all in one kernel.
- if grads_for_norm:
- grad_norm, _ = multi_tensor_applier(
- amp_C.multi_tensor_l2norm,
- dummy_overflow_buf,
- [grads_for_norm],
- False # no per-parameter norm
- )
- else:
- grad_norm = torch.cuda.FloatTensor([0])
- # Since we will be summing across data parallel groups,
- # we need the pow(norm-type).
- total_norm = grad_norm ** norm_type
-
- else:
- for grad in grads_for_norm:
- grad_norm = torch.norm(grad, norm_type)
- total_norm += grad_norm ** norm_type
-
- # Check individual rank grad norms are not NaN
- # prior to model-parallel all-reduce.
- if check_for_nan_in_grad:
- global_rank = torch.distributed.get_rank()
- assert not total_norm.isnan(), (
- f'Rank {global_rank}: found NaN in local grad norm in '
- f'backwards pass. Device: {torch.cuda.current_device()}, '
- f'node: {os.uname()[1]}'
- )
-
- # Sum across all model-parallel GPUs.
- torch.distributed.all_reduce(total_norm,
- op=torch.distributed.ReduceOp.SUM,
- group=model_parallel_group)
- total_norm = total_norm.item() ** (1.0 / norm_type)
-
- # Scale.
- clip_coeff = max_norm / (total_norm + 1.0e-6)
- if clip_coeff < 1.0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- multi_tensor_applier(amp_C.multi_tensor_scale,
- dummy_overflow_buf,
- [grads, grads],
- clip_coeff)
-
- return total_norm
-
-
-def count_zeros_fp32(parameters, model_parallel_group):
-
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- total_num_zeros = torch.cuda.FloatTensor([0.0])
- for param in parameters:
- grad_not_none = param.grad is not None
- is_not_shared = param_is_not_shared(param)
- is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grad = param.grad.detach()
- num_zeros = grad.numel() - torch.count_nonzero(grad)
- total_num_zeros = num_zeros + total_num_zeros
-
- # Sum across all model-parallel GPUs.
- torch.distributed.all_reduce(total_num_zeros,
- op=torch.distributed.ReduceOp.SUM,
- group=model_parallel_group)
-
- total_num_zeros = total_num_zeros.item()
-
- return total_num_zeros
diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py
deleted file mode 100644
index d58b1b08fcdfe81d5603cc5090d6362a3f48d2ac..0000000000000000000000000000000000000000
--- a/megatron/optimizer/distrib_optimizer.py
+++ /dev/null
@@ -1,1161 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron distributed optimizer."""
-
-
-from apex.optimizers import FusedAdam as Adam
-import math
-import torch
-
-from megatron import get_args
-from megatron import get_timers
-from megatron import print_rank_0
-from megatron.core import mpu, tensor_parallel
-
-from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
-from .utils import shard_buffer
-
-
-
-class Range:
- """
- A range represents a start and end points for indexing a shard
- from a full tensor.
- """
- def __init__(self, start, end):
- self.start = start
- self.end = end
- self.size = end - start
- def normalize(self, start = 0):
- return Range(start, start + self.size)
- def __str__(self):
- return "%d,%d [%d]" % (self.start, self.end, self.size)
- def __len__(self):
- return self.end - self.start
-
-
-class DistributedOptimizer(MixedPrecisionOptimizer):
- """Distributed optimizer, for all data types (fp16, bf16, and fp32).
-
- Arguments:
- optimizer: base optimizer such as Adam or SGD
- clip_grad: clip gradeints with this global L2 norm. Note
- that clipping is ignored if clip_grad == 0
- log_num_zeros_in_grad: return number of zeros in the gradients.
- check_for_nan_in_grad: check if gradients have a NaN.
- params_have_main_grad: flag indicating if parameters have
- a `main_grad` field. If this is set, we are assuming
- that the model parameters are store in the `main_grad`
- field instead of the typical `grad` field. This happens
- for the DDP cases where there is a continuous buffer
- holding the gradients. For example for bfloat16, we want
- to do gradient accumulation and all-reduces in float32
- and as a result we store those gradients in the main_grad.
- Note that main grad is not necessarily in float32.
- fp16: if true, the model is running in fp16.
- bf16: if true, the model is running in bfloat16.
- grad_scaler: used for scaling gradients. Note that this can be
- None. This case happens when `bf16 = True` and we don't
- use any loss scale. Note that for `bf16 = True`, we can have
- a constnat gradient scaler. Also for `bf16 = False`, we
- always require a grad scaler.
- models: list of models (i.e., the virtual pipelining models). This
- is used by the distributed optimizer for mapping parameters.
- """
-
- @classmethod
- def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range, bucket_offset):
- """
- Build mapping from param reference to grad buffer shard ranges.
-
- This method builds a mapping from parameter references to grad
- buffer shard ranges, specific to each data-parallel (DP) rank's
- set of 'owned' parameters. Each grad buffer (padded to be an even
- multiple of DP-world-size) is conceptually divided into DP-world-size
- contiguous regions, where each DP rank 'owns' a contiguous regions.
- Ownership in this sense means DP rank is responsible for reducing
- the relevant subset of grads, and updating the relevant subset of
- params.
-
- This conceptual partitioning of the grad buffer does NOT respect
- parameter boundaries, and as such it is assumed that each created
- range references a shard (or subset) of the full parameter. It is
- easiest to think of each DP rank as operating (i.e., reducing,
- gathering) purely on views into the grad buffer, for all model-to-
- main & main-to-model operations.
-
- This method creates four ranges:
- - The param's range within the entire grad buffer (i.e., world index).
- - The param's range within the relevant grad bucket's buffer.
- - The param's range within the DP rank's local view of the grad buffer.
- - The param's range within itself (i.e., its shard).
- """
-
- # Param range map.
- param_world_index_map = model.grad_buffer_param_index_map[dtype]
- param_range_map = {}
- for param, param_world_indexes in param_world_index_map.items():
-
- # Param range.
- param_world_start, param_world_end, _ = param_world_indexes
- param_local_start = max(
- 0,
- param_world_start - gbuf_world_range.start)
- param_local_end = min(
- gbuf_world_range.size,
- param_world_end - gbuf_world_range.start)
-
- # Add param, if within local gbuf range.
- if param_local_end > param_local_start:
- param_local_range = Range(param_local_start, param_local_end)
- param_world_range = param_local_range.normalize(
- param_local_start + gbuf_world_range.start)
- param_world_range_in_bucket = Range(param_world_range.start-bucket_offset,
- param_world_range.end-bucket_offset)
- sub_param_start = max(0, gbuf_world_range.start-param_world_start)
- sub_param_range = param_local_range.normalize(sub_param_start)
- param_range_map[param] = {
- "gbuf_world" : param_world_range,
- "gbuf_world_in_bucket": param_world_range_in_bucket,
- "gbuf_local" : param_local_range,
- "param" : sub_param_range,
- }
-
- return param_range_map
-
-
- @classmethod
- def build_model_gbuf_range(cls, model, dtype, bucket_index):
- """
- Build mapping between params and their grad buffers.
-
- This method does the initial setup for the method above. This setup
- includes determining the shard ranges into the DDP's grad buffer for
- each data-parallel (DP) rank. Each DP rank keeps range info for
- all other DP ranks, for the purpose of creating args for
- reduce-scatter and all-gather.
- """
-
- data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
- data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
-
- bucket = model.grad_buffers[dtype].buckets[bucket_index]
- bucket_buffer = bucket.data
- gbuf_size = bucket_buffer.numel()
- assert gbuf_size % data_parallel_world_size == 0, \
- f"Each bucket's buffer size should be divisible by {data_parallel_world_size}"
- max_gbuf_range_size = gbuf_size // data_parallel_world_size
-
- # All world ranges (i.e., across all data parallel ranks).
- gbuf_world_all_ranges = []
- for r in range(data_parallel_world_size):
- # Compute start of chunk in this bucket.
- gbuf_world_start = r * max_gbuf_range_size
- gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
- # Add bucket's offset in grad buffer.
- gbuf_world_range = Range(gbuf_world_start + bucket.offset,
- gbuf_world_end + bucket.offset)
- gbuf_world_all_ranges.append(gbuf_world_range)
-
- # Local DP's ranges.
- gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
-
- # Get each param's ranges.
- param_range_map = cls.build_model_gbuf_param_range_map(model,
- dtype,
- gbuf_world_range,
- bucket.offset)
-
- # Group into dict.
- data = {
- "param_map" : param_range_map,
- }
-
- return data
-
-
- @classmethod
- def build_model_gbuf_range_map(cls, model):
- """
- Create param-to-grad-buffer mappings, for grad buffer data types
- within a specific virtual model.
- """
- # Iterate through all buckets to construct param ranges that this rank "owns"
- # (the dp_rank'th shard of each bucket, where each shard is 1/dp_world_size
- # of the bucket).
- return {
- dtype : [cls.build_model_gbuf_range(model, dtype, bucket_index)
- for bucket_index in range(len(model.grad_buffers[dtype].buckets))]
- for dtype in model.grad_buffers
- }
-
-
- @classmethod
- def build_model_param_gbuf_map(cls, model_gbuf_ranges):
- """
- Create a reverse of the model_gbuf_ranges, for referencing in
- opposite direction.
- """
- param_gbuf_map = {}
- for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
- for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items():
- for bucket_index, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
- for param, _ in gbuf_range_map["param_map"].items():
- assert param not in param_gbuf_map, \
- "Param should not be in param_gbuf_map; each param only belongs to a single bucket"
- param_gbuf_map[param] = (model_index, dtype, bucket_index)
- return param_gbuf_map
-
-
- @classmethod
- def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
- """
- Create optimizer groups.
-
- Given the set of parameter shard ranges that are owned by the current
- data-parallel (DP) rank, gather the set of parameters that will be
- used (in the method below) to create the current DP's optimizer
- groups.
- """
-
- num_groups = len(param_groups)
-
- # Param group map.
- # World param group map.
- # - Store a mapping of for all parameters
- # across all DP ranks. This is necessary because it is our first
- # cross reference between the DDP mappings and the optimizer group
- # parameters. This mapping only for use in the next step of building
- # the local mapping over this DP rank's parameters.
- world_param_group_map = {}
- for group_index, group in enumerate(param_groups):
- for param in group["params"]:
- assert param.requires_grad
- world_param_group_map[param] = group_index
-
- # Optimizer group ranges & param-group mapping.
- # - Build a mapping from groups to their contained parameters, and also
- # from parameters to their containing group index and order within
- # the group. The group index and order are particularly important for
- # saving and loading checkpoints.
- local_param_group_map = {}
- group_ranges = [ {"params": []} for _ in param_groups ]
- for model_gbuf_range_map in model_gbuf_ranges:
- for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items():
- for gbuf_range_map in gbuf_range_map_for_all_buckets:
- for param in gbuf_range_map["param_map"]:
- group_index = world_param_group_map[param]
- group_range = group_ranges[group_index]
- group_range["params"].append(param)
- local_param_group_map[param] = \
- (group_index, len(group_range["params"]) - 1)
-
- # Squeeze zero-size group ranges.
- for group_index, group_range in enumerate(group_ranges):
- group_range["orig_group"] = param_groups[group_index]
- group_range["orig_group_idx"] = param_groups[group_index]
-
- return local_param_group_map, group_ranges
-
-
- @classmethod
- def build_model_and_main_param_groups(cls,
- model_gbuf_ranges,
- param_gbuf_map,
- opt_group_ranges):
- """
- Create main parameter groups needed for the optimizer step.
-
- These groups encompass both: 1) groups used by this class, for
- reducing/gather, and 2) groups used by the inner optimizer for the
- parameter update. Given that the conceptual grad buffer partitioning
- (created in earlier method) doesn't respect parameter boundaries,
- the optimizer operates on shards of the model parameters, rather than
- the full parameters.
- """
-
- # Parameter groups:
- # model_float16_groups: original float16 parameters
- # model_fp32_groups: original fp32 parameters
- # shard_float16_groups: shards of original float16 parameters
- # shard_fp32_groups: shards of original fp32 parameters
- # shard_fp32_from_float16_groups: fp32 copy of float16 parameters
- model_float16_groups = []
- model_fp32_groups = []
- shard_float16_groups = []
- shard_fp32_groups = []
- shard_fp32_from_float16_groups = []
-
- # Allocate (or slice) each group's param shard.
- for group_index, group_range in enumerate(opt_group_ranges):
-
- # Params of this group.
- model_float16_params_this_group = []
- model_fp32_params_this_group = []
- shard_float16_params_this_group = []
- shard_fp32_params_this_group = []
- shard_fp32_from_float16_params_this_group = []
- model_float16_groups.append(model_float16_params_this_group)
- model_fp32_groups.append(model_fp32_params_this_group)
- shard_float16_groups.append(shard_float16_params_this_group)
- shard_fp32_groups.append(shard_fp32_params_this_group)
- shard_fp32_from_float16_groups.append(
- shard_fp32_from_float16_params_this_group)
-
- for model_param in group_range["params"]:
-
- assert model_param.requires_grad
-
- model_index, dtype, bucket_index = param_gbuf_map[model_param]
- gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index]
- param_range = gbuf_range["param_map"][model_param]["param"]
-
- # fp16, bf16 params.
- if model_param.type() in ['torch.cuda.HalfTensor',
- 'torch.cuda.BFloat16Tensor']:
-
- # Clone model -> main.
- shard_model_param = model_param.detach().view(-1) \
- [param_range.start:param_range.end]
- shard_main_param = shard_model_param.clone().float()
- tensor_parallel.copy_tensor_model_parallel_attributes(
- shard_model_param, model_param)
- tensor_parallel.copy_tensor_model_parallel_attributes(
- shard_main_param, model_param)
- if hasattr(model_param, 'shared'):
- shard_model_param.shared = model_param.shared
- shard_main_param.shared = model_param.shared
-
- # Add to group.
- model_float16_params_this_group.append(model_param)
- shard_float16_params_this_group.append(shard_model_param)
- shard_fp32_from_float16_params_this_group.append(shard_main_param)
-
- # fp32 params.
- elif model_param.type() == 'torch.cuda.FloatTensor':
- shard_model_param = model_param.view(-1) \
- [param_range.start:param_range.end]
- model_fp32_params_this_group.append(model_param)
- shard_fp32_params_this_group.append(shard_model_param)
- tensor_parallel.copy_tensor_model_parallel_attributes(
- shard_model_param, model_param)
- if hasattr(model_param, 'shared'):
- shard_model_param.shared = model_param.shared
-
- else:
- raise TypeError('Wrapped parameters must be one of '
- 'torch.cuda.FloatTensor, '
- 'torch.cuda.HalfTensor, or '
- 'torch.cuda.BFloat16Tensor. '
- 'Received {}'.format(model_param.type()))
-
- # Update optimizer's params.
- group_range["orig_group"]["params"] = [
- *shard_fp32_params_this_group,
- *shard_fp32_from_float16_params_this_group,
- ]
-
- return (
- model_float16_groups,
- model_fp32_groups,
- shard_float16_groups,
- shard_fp32_groups,
- shard_fp32_from_float16_groups,
- )
-
-
- def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad, fp16,
- bf16, params_dtype, grad_scaler, models):
- """
- See top of class definition for argument descriptions.
-
- The steps in this method create the core mapping between DDP grad
- buffers, parameters, and parameter shard ranges, that is needed for
- converting between model param indexes and main parameter shard
- indexes. This method also updates the optimizer parameter groups
- with the newly created shards.
- """
-
- super().__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad,
- fp16, bf16, params_dtype, grad_scaler, models)
-
- assert isinstance(optimizer, Adam), \
- "Only Adam currently supported, due to checkpointing requirements."
-
- # Model grad buffer ranges.
- self.model_gbuf_ranges = []
- self.per_bucket_numel = []
- for _, model_chunk in enumerate(self.models):
- self.per_bucket_numel.append(
- {dtype: [bucket.data.numel() for bucket in model_chunk.grad_buffers[dtype].buckets]
- for dtype in model_chunk.grad_buffers})
- self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model_chunk))
- self.model_param_gbuf_map = \
- self.build_model_param_gbuf_map(self.model_gbuf_ranges)
-
- # Optimizer ranges.
- self.model_param_group_index_map, self.opt_group_ranges = \
- self.build_optimizer_group_ranges(self.optimizer.param_groups,
- self.model_gbuf_ranges)
-
- # Allocate main param shards.
- (
- self.model_float16_groups,
- self.model_fp32_groups,
- self.shard_float16_groups,
- self.shard_fp32_groups,
- self.shard_fp32_from_float16_groups,
- ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
- self.model_param_gbuf_map,
- self.opt_group_ranges)
-
- # Initialize param buffers.
- # - These are views on the DDP model's grad buffers, that share
- # storage & have their own dtype. This is safe because the param
- # dtype size is always <= grad dtype size.
- self.param_buffers = []
- for model_index, model in enumerate(self.models):
- current_param_buffers = {}
- for dtype, grad_buffer in model.grad_buffers.items():
- size_ratio = torch.finfo(dtype).bits // torch.finfo(params_dtype).bits
- current_param_buffers[dtype] = []
- for bucket in grad_buffer.buckets:
-
- # Handle older/newer method for getting untyped storage.
- try:
- storage = bucket.data.storage()._untyped()
- except:
- storage = bucket.data.storage().untyped()
-
- # Typed param buffer.
- param_buffer = torch.tensor(
- storage,
- dtype = params_dtype,
- device = bucket.data.device)
-
- # .storage() ignores views / slices, so param_buffer now points to the start
- # of the grad_buffer instead of to the start of each bucket. As a result,
- # add bucket.offset to make sure param_buffers point to the right region of
- # memory.
- # Since we want the start of each bucket's param_buffer to coincide with the
- # start of the same bucket's grad_buffer (this ensures that zeroing the grad
- # buffer does not zero out params in the param_buffer before they are copied
- # into the model_params), multiply the offset by the size ratio of grads and
- # params.
- offset = bucket.offset * size_ratio
- param_buffer = param_buffer[offset:offset+bucket.data.numel()]
- assert param_buffer.data_ptr() == bucket.data.data_ptr(), \
- "param_buffer and grad_buffer for same bucket should start at the same byte address"
- assert param_buffer.numel() == bucket.data.numel(), \
- "param_buffer and grad_buffer for same bucket should have the same number of elements"
- current_param_buffers[dtype].append(param_buffer)
- self.param_buffers.append(current_param_buffers)
-
- # Now construct data structures to manage all-gather handles.
- self.all_gather_handles = []
- self.all_gather_handle_index_to_bucket_index_map = []
- self.model_index_to_all_gather_handle_index_map = {}
- self.param_to_all_gather_handle_index_map = {}
- self.param_buffer_copied = []
-
- self.pbuf_view_items = self.get_model_param_buffer_dp_views()
- for (model_index, dtype, bucket_index, _, _) in self.pbuf_view_items:
- self.all_gather_handle_index_to_bucket_index_map.append((model_index, dtype, bucket_index))
- all_gather_handle_index = len(self.all_gather_handle_index_to_bucket_index_map) - 1
-
- # Store all all_gather_handle_indices relevant to a particular model chunk.
- if model_index not in self.model_index_to_all_gather_handle_index_map:
- self.model_index_to_all_gather_handle_index_map[model_index] = []
- self.model_index_to_all_gather_handle_index_map[model_index].append(all_gather_handle_index)
-
- for param in self.models[model_index].grad_buffers[dtype].buckets[bucket_index].params_list:
- self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index
- self.param_buffer_copied.append(False)
- self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map)
-
- self.overlap_param_gather = get_args().overlap_param_gather
- if self.overlap_param_gather:
- self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(
- self._make_forward_pre_hook())
- else:
- self.remove_pre_hook_handle = None
-
- self.update_successful = False
-
- # Update optimizer groups.
- # - Also, leverage state_dict() and load_state_dict() to
- # recast preexisting per-param state tensors.
- self.optimizer.param_groups = \
- [ g["orig_group"] for g in self.opt_group_ranges ]
- self.optimizer.load_state_dict(self.optimizer.state_dict())
-
-
- def get_model_param_range_map(self, param):
- """
- Given a model param, get the index sub-range of the param that this
- data-parallel rank owns.
- """
- model_index, dtype, bucket_index = self.model_param_gbuf_map[param]
- gbuf_range_map = self.model_gbuf_ranges[model_index][dtype][bucket_index]
- param_range_map = gbuf_range_map["param_map"][param]
- return param_range_map
-
-
- def get_model_parallel_group(self):
- """
- With the distributed optimizer, the model parallel group is the
- entire world.
- """
- return None
-
-
- def state_dict(self):
- """
- The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
- related) optimizer variables. The returned state dict can be stored in
- the standard model/RNG checkpoint file. The parameter and dependent
- optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
- checkpoint file by calling 'save_parameter_state()'.
- """
-
- state_dict = {}
-
- # Optimizer state (do not store parameter state here).
- state_dict['optimizer'] = {
- k : v
- for k, v in self.optimizer.state_dict().items()
- if k != "state"
- }
- for param_group in state_dict["optimizer"]["param_groups"]:
- del param_group["params"]
-
- # Grad scaler state.
- if self.grad_scaler:
- state_dict['grad_scaler'] = self.grad_scaler.state_dict()
-
- return state_dict
-
-
- def load_state_dict(self, state_dict):
- """Load the state dict.
-
- As detailed in state_dict(), the state dict contains all non-
- parameter-related variables. This method is notably longer than
- state_dict(), because the Torch optimizers state has yet to be
- allocated at this point, and so we must do a cross referencing between
- the optimizers state (and the ordering it expects for parameter state)
- and this DP rank's shards. The optimizer at this point does not contain
- any tensor dimension information, so we must get these dimensions from
- the DP shards mapped during DistributedOptimizer.__init__().
-
- The tensor parameter state is loaded via load_parameter_state(), and
- so this method also must populate the loaded state dict with dummy
- tensor data (i.e., via torch.empty() below). This will be overwritten
- during load_parameter_state().
-
- ** Note: Torch optimizer's state structure. **
- The Torch optimizer stores its state in two levels. The top level is a
- list of groups, where each group contains a list of integer indexes
- (corresponding to parameters) that index into a master parameter list
- that is shared by all groups. As such, three values are necessary for
- maintaining this ordering:
-
- - group_index : The group to which a parameter belongs.
- - group_order : The index of a parameter within its group.
- - state_order : The index of a parameter within the shared parameter
- list.
- """
-
- # Get the Torch optimizer's state dict.
- # - This 'inner' optimizer at this point is unallocated, and only
- # contains an integer odering of parameters within each group, and
- # the ordering of parameters within its flattened parameter state
- # list.
- inner_state_dict = self.optimizer.state_dict()
- state_dict_param_groups = [{
- **group,
- "params" : list(inner_state_dict["param_groups"][idx]["params"]),
- } for idx, group in enumerate(state_dict["optimizer"]["param_groups"])]
-
- # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
- # - Real data is overwritten during load_parameter_state().
- state_dict_state = []
- for gbuf_range_maps in self.model_gbuf_ranges:
- for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
- for gbuf_range_map in gbuf_range_map_for_all_buckets:
- for model_param, param_range_map in \
- gbuf_range_map["param_map"].items():
-
- # Get parameter ordering information (see method docstring
- # for details).
- group_index, group_order = \
- self.model_param_group_index_map[model_param]
- state_order = inner_state_dict["param_groups"] \
- [group_index]["params"][group_order]
-
- # Allocate dummy tensors.
- numel = len(param_range_map["gbuf_world"])
- init_shard = lambda : torch.empty(
- (numel,),
- dtype=torch.float32,
- device=torch.cuda.current_device())
-
- state_dict_state.append((state_order, {
- "exp_avg" : init_shard(),
- "exp_avg_sq" : init_shard(),
- }))
-
- # Sort by state order (see method docstring for details).
- state_dict_state.sort(key = lambda s : s[0])
- state_dict_state = {s[0]:s[1] for s in state_dict_state}
-
- # Optimizer.
- self.optimizer.load_state_dict({
- "state" : state_dict_state,
- "param_groups" : state_dict_param_groups,
- })
-
- # Grad scaler.
- if 'grad_scaler' not in state_dict:
- if self.fp16:
- print_rank_0('***WARNING*** found an old checkpoint, will not '
- 'load grad scaler ...')
- else:
- if self.grad_scaler:
- self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
- else:
- print_rank_0('***WARNING*** fould the grad scaler in the '
- 'checkpoint but it is None in the class. '
- 'Skipping loading grad scaler ...')
-
-
- def save_parameter_state(self, filename):
- """Save parameter state (i.e., parameter & optimizer tensors).
-
- This method performs three steps:
- - For each DP rank, copy param & optimizer shards to contiguous CPU
- buffers. (e.g., one buffer each for main_param, exp_avg, and
- exp_avg_sq).
- - Gather contiguous buffers on DP rank 0 and concatenate to world
- buffers.
- - Save world buffers to disk (i.e., distrib_opt.pt).
- """
-
- # Data parallelism variables.
- data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
- data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
- data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True)
- data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP)
-
- # Collect param states.
- state = {"per_bucket_numel": self.per_bucket_numel}
- for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):
-
- # Iterate grad buffers (by data type).
- dtype_state = {}
- assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
- for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
- world_tensors = {}
- for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
-
- # Compute local DP contiguous shard's size.
- model = self.models[model_idx]
- gbuf_world_numel = model.grad_buffers[dtype].buckets[bucket_idx].data.numel()
- assert gbuf_world_numel % data_parallel_world_size == 0
- gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
- local_shards = {key: torch.empty((gbuf_local_numel,),
- dtype=torch.float32,
- device="cpu")
- for key in ("param", "exp_avg", "exp_avg_sq")}
-
- # Build contiguous DP rank shards (for param + optim states).
- for model_param, param_range_map in \
- gbuf_range_map["param_map"].items():
-
- # Main param & optimizer states.
- group_index, group_order = \
- self.model_param_group_index_map[model_param]
- main_param = self.optimizer.param_groups \
- [group_index]["params"][group_order]
- optim_state = self.optimizer.state[main_param]
-
- tensors = {
- "param" : main_param,
- **optim_state,
- }
-
- # Copy states into contiguous shard.
- gbuf_local_start = param_range_map["gbuf_local"].start
- gbuf_local_end = param_range_map["gbuf_local"].end
- for key in local_shards:
- local_shards[key][gbuf_local_start:gbuf_local_end] \
- .data.copy_(tensors[key].detach().cpu())
-
- # Gather contiguous shards on DP rank 0.
- for key, send_tensor in local_shards.items():
-
- # Gather tensor list.
- if data_parallel_rank == 0:
- recv_tensors = [torch.empty((gbuf_local_numel,),
- dtype=torch.float32,
- device="cpu")
- for _ in range(data_parallel_world_size)]
- else:
- recv_tensors = None
-
- # Gather.
- torch.distributed.gather(
- send_tensor,
- recv_tensors,
- data_parallel_global_ranks[0],
- data_parallel_group_gloo,
- )
-
- # Concatenate.
- if data_parallel_rank == 0:
- if key not in world_tensors:
- world_tensors[key] = []
- world_tensors[key].append(torch.cat(recv_tensors))
-
- # Collect world state.
- dtype_state[dtype] = world_tensors
- state[model_idx] = dtype_state
-
- # Save param state.
- if data_parallel_rank == 0:
- torch.save(state, filename)
-
-
- def load_parameter_state(self, filename):
- """Load parameter state (i.e., parameter & optimizer tensors).
-
- This method performs the reverse of save_parameter_state():
- - Load world buffers from disk (i.e., distrib_opt.pt).
- - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
- rank receives its relevant subset of the world buffers).
- - For each DP rank, copy param & optimizer shards from contiguous CPU
- buffers. (e.g., one buffer each for main_param, exp_avg, and
- exp_avg_sq).
- """
-
- # Data parallelism variables.
- data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
- data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
- data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True)
- data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP)
-
- # Load on DP rank 0.
- if data_parallel_rank == 0:
- loaded_state = torch.load(filename)
- if "per_bucket_numel" in loaded_state:
- per_bucket_numel_in_checkpoint = loaded_state["per_bucket_numel"]
- assert self.per_bucket_numel == per_bucket_numel_in_checkpoint, \
- (f"Number of elements in each bucket need to be the same in current run "
- f"({self.per_bucket_numel}) and checkpoint ({per_bucket_numel_in_checkpoint})")
-
- # Scatter tensors to all DP ranks.
- for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):
- for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
- for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
-
- # Compute local DP contiguous shard's size.
- model = self.models[model_idx]
- gbuf_world_numel = model.grad_buffers[dtype].buckets[bucket_idx].data.numel()
- assert gbuf_world_numel % data_parallel_world_size == 0
- gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
-
- # Contiguous local shards (received from DP rank 0).
- local_shards = {key: torch.empty((gbuf_local_numel,),
- dtype=torch.float32,
- device="cpu")
- for key in ("param", "exp_avg", "exp_avg_sq")}
-
- # Scatter local shards from DP rank 0.
- for key, recv_tensor in local_shards.items():
-
- # Scatter tensor list.
- if data_parallel_rank == 0:
- world_tensor_for_all_buckets = loaded_state[model_idx][dtype][key]
- if not isinstance(world_tensor_for_all_buckets, list):
- world_tensor_for_all_buckets = [world_tensor_for_all_buckets]
- assert bucket_idx < len(world_tensor_for_all_buckets), \
- (f"Trying to load state for bucket_id {bucket_idx} (out of "
- f"{len(gbuf_range_map_for_all_buckets)} buckets) from checkpoint; "
- f"checkpoint only has {len(world_tensor_for_all_buckets)} bucket(s)")
- world_tensor = world_tensor_for_all_buckets[bucket_idx]
- gbuf_start_idxs = \
- list(range(0, gbuf_world_numel, gbuf_local_numel))
- send_tensors = [world_tensor[i:(i+gbuf_local_numel)]
- for i in gbuf_start_idxs]
- else:
- send_tensors = None
-
- # Scatter.
- torch.distributed.scatter(
- recv_tensor,
- send_tensors,
- data_parallel_global_ranks[0],
- data_parallel_group_gloo,
- )
-
- # Copy local contiguous shards to param/optim shards.
- for model_param, param_range_map in \
- gbuf_range_map["param_map"].items():
-
- # Main param & optimizer states.
- group_index, group_order = \
- self.model_param_group_index_map[model_param]
- main_param = self.optimizer.param_groups \
- [group_index]["params"][group_order]
- optim_state = self.optimizer.state[main_param]
-
- tensors = {
- "param" : main_param,
- **optim_state,
- }
-
- # Copy states into contiguous shard.
- gbuf_local_start = param_range_map["gbuf_local"].start
- gbuf_local_end = param_range_map["gbuf_local"].end
- for key in local_shards:
- tensors[key].data.copy_(
- local_shards[key][gbuf_local_start:gbuf_local_end])
-
-
- def zero_grad(self, set_to_none=True):
- """
- Zero grads.
-
- We only need to zero the model related parameters, i.e.,
- model_float16_groups & model_fp32_groups. We additionally zero
- the remaining groups as a memory optimization to reduce
- fragmentation; in the case of set_to_none==True, the space
- used by this field can be safely deallocated at this point.
- """
- for groups in (
- self.model_float16_groups,
- self.model_fp32_groups,
- self.shard_float16_groups, # grad empty/unused here?
- self.shard_fp32_groups, # throws grad-access warning
- self.shard_fp32_from_float16_groups):
- for group in groups:
- _zero_grad_group_helper(group, set_to_none)
-
- # If overlapping param all-gather with forward compute, launch all-gather
- # for first accessed bucket here before forward compute is initiated.
- # The all-gather for the next bucket will be launched in the forward
- # pre-hook when this all-gather finishes (to ensure that the communication
- # kernels don't head-of-line block the compute kernels since we run with
- # CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism).
- if self.overlap_param_gather:
- self._dispatch_gather_model_params(all_gather_handle_index=0)
-
-
- def get_model_param_buffer_dp_views(self):
- """
- Get shard views of each of the param buffers.
-
- In this nested list, the top level is grouped by the virtual model
- index and the buffer's data type. The sub-level is a list of
- shards of that buffer, where each shard in the list represents
- a contiguous view of the buffer, that is owned by a data-parallel
- rank. The shard boundary does not respect parameter boundaries, and
- so the elements of some parameters are split across data parallel
- ranks.
-
- Additionally, return references to the entire buffers, for use
- in _all_gather_base.
- """
-
- # Buffer views.
- # Add in reverse order in each model chunk since buckets start from the end of the model but we want
- # all-gathers to run first for the start of the model (same order as forward pass).
- # We keep the view_items in model chunk order since we want to still first run all_gather and
- # all_gather_handle.wait() for the first model chunk.
- # In all cases, we want all_gather and all_gather_handle.wait() to be called in the same order,
- # and all_gather_handle.wait() needs to be called just before the corresponding forward pass.
- view_items = []
- for model_index, buffers in enumerate(self.param_buffers):
- view_items_per_model_chunk = []
- for dtype, buf_for_all_buckets in buffers.items():
- for bucket_index, buf in enumerate(buf_for_all_buckets):
- buf_views = shard_buffer(buf)
- view_items_per_model_chunk.insert(0, (model_index, dtype, bucket_index, buf, buf_views))
- view_items.extend(view_items_per_model_chunk)
-
- return view_items
-
-
- def _dispatch_gather_model_params(self, all_gather_handle_index):
- """
- All-gather updated model params.
-
- The DDP's param buffer is used for the all-gather, and thus no
- tensors are dynamically allocated. After the all-gather, the params
- can be copied from the param buffer to the param.
- """
- if self.update_successful:
- data_parallel_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
- data_parallel_group = mpu.get_data_parallel_group(with_context_parallel=True)
-
- # All-gather updated main params.
- # All param_buf views are guaranteed to have the same number of elements
- # across all data-parallel ranks, due to padding (done in grad_buffer.py),
- # and extended to the param_bufs. Thus, all sub-views will have consistent
- # start / end indexes across data-parallel ranks.
- (model_index, dtype, bucket_index, pbuf, pbuf_views) = self.pbuf_view_items[all_gather_handle_index]
- assert all_gather_handle_index == len(self.all_gather_handles)
- all_gather_handle = torch.distributed._all_gather_base(
- pbuf,
- pbuf_views[data_parallel_rank],
- group = data_parallel_group,
- async_op = self.overlap_param_gather
- )
- self.all_gather_handles.append(all_gather_handle)
- assert self.all_gather_handle_index_to_bucket_index_map[all_gather_handle_index] == \
- (model_index, dtype, bucket_index)
- self.param_buffer_copied.append(False)
-
- if not self.overlap_param_gather:
- self._copy_params_from_param_buffer(all_gather_handle_index)
-
-
-
- def _make_forward_pre_hook(self):
- """
- Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
- when a module uses a parameter in a bucket with a still incomplete all-gather)
- and then copy the results from the param_buffer into model_params.
- """
-
- def hook(module, *unused):
- assert self.overlap_param_gather, "Should use pre-hook only when overlap_param_gather is True"
-
- # Make sure all parameters in this module have been all-gathered as necessary.
- for param in module.parameters(recurse=False):
- # Skip parameters that don't require grad.
- if not param.requires_grad:
- continue
-
- assert param in self.param_to_all_gather_handle_index_map
- all_gather_handle_index = self.param_to_all_gather_handle_index_map[param]
- self._finish_param_sync_helper(all_gather_handle_index)
-
- return hook
-
-
- def finish_param_sync(self, model_index, *unused):
- """
- Finishes all necessary param syncs for the model_index'th model chunk.
- """
- all_gather_handle_indices = self.model_index_to_all_gather_handle_index_map[model_index]
- for all_gather_handle_index in all_gather_handle_indices:
- self._finish_param_sync_helper(all_gather_handle_index)
-
-
- def _finish_param_sync_helper(self, all_gather_handle_index):
- """
- Waits on all_gather_handle if necessary, then copies params from param_buffer
- into model_params if necessary.
- """
-
- # First check if there is an outstanding all-gather handle for this param.
- # If so, wait on the handle to ensure the communication is finished.
- if all_gather_handle_index >= len(self.all_gather_handles):
- return
-
- all_gather_handle = self.all_gather_handles[all_gather_handle_index]
- if all_gather_handle is not None:
- all_gather_handle.wait()
- self.all_gather_handles[all_gather_handle_index] = None
-
- # Launch the all-gather for the next bucket now.
- # We can't pre-launch all-gathers for all buckets at once since we don't
- # want to head-of-line block the compute kernels with communication kernels
- # (since we run with CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence
- # parallelism).
- next_all_gather_handle_index = all_gather_handle_index + 1
- if next_all_gather_handle_index < self.num_all_gather_handles:
- self._dispatch_gather_model_params(next_all_gather_handle_index)
-
- # Also check if we have already copied from the param buffer for this
- # handle; if not, complete the copy and mark as such.
- if not self.param_buffer_copied[all_gather_handle_index]:
- self._copy_params_from_param_buffer(all_gather_handle_index)
- self.param_buffer_copied[all_gather_handle_index] = True
-
-
- def _copy_params_from_param_buffer(self, all_gather_handle_index):
- """
- Copy params from param_buffer to model_params.
- """
- (model_index, dtype, bucket_index) = self.all_gather_handle_index_to_bucket_index_map[
- all_gather_handle_index]
- model = self.models[model_index]
- if self.update_successful:
- # Copy from param buffer to each param.
- param_map = model.grad_buffer_param_index_map[dtype]
- for param, (buf_start, buf_end, bucket_index_in_param_map) in param_map.items():
- if bucket_index == bucket_index_in_param_map:
- bucket_offset = model.grad_buffers[dtype].buckets[bucket_index].offset
- param_buf = self.param_buffers[model_index][dtype][bucket_index]
- # buf_start and buf_end store position of this parameter in the full grad_buffer,
- # so need to adjust these indices (by subtracting out bucket_offset) since we
- # have independent param_bufs for each bucket.
- param_buf_shard = param_buf[buf_start-bucket_offset:buf_end-bucket_offset]
- assert param.data.nelement() == param_buf_shard.nelement()
- param.view(-1).detach().copy_(param_buf_shard)
-
- # Zero out the grad buffer in preparation for next set of fwd / bwd passes after copy
- # completes (since param_buffer and grad_buffer are shared for each bucket).
- param_buf = self.param_buffers[model_index][dtype][bucket_index]
- grad_buf = model.grad_buffers[dtype].buckets[bucket_index].data
- assert param_buf.data_ptr() == grad_buf.data_ptr()
- grad_buf.zero_()
-
-
- def _collect_main_grad_data_for_unscaling(self):
- """
- Note: this should be equivalent to the float-16 optimizer's method,
- but writtent differently, so the two should be combined.
- """
- return [
- param.grad.data
- for group in self.optimizer.param_groups
- for param in group["params"]
- ]
-
-
- def _get_model_and_main_params_data_float16(self):
- """
- Get aligned list of model and main params.
- """
- model_data = []
- main_data = []
- for model_group, main_group in zip(self.shard_float16_groups,
- self.shard_fp32_from_float16_groups):
- for model_param, main_param in zip(model_group, main_group):
- model_data.append(model_param.data)
- main_data.append(main_param.data)
- return model_data, main_data
-
-
- def _copy_model_grads_to_main_grads(self):
- """
- Copy model grads to main grads.
-
- Since this step follows a reduce-scatter through the DDP's grad
- buffer, this method is responsible for copying the updated grads
- from the grad buffer to the main shard's grad field.
- """
-
- # Utility method for copying group grads.
- def copy_group_grads(model_groups, shard_main_groups):
- for model_group, shard_main_group in zip(model_groups,
- shard_main_groups):
- for model_param, shard_main_param in zip(model_group,
- shard_main_group):
-
- param_range_map = self.get_model_param_range_map(model_param)
- param_range = param_range_map["param"]
- assert param_range.size == shard_main_param.nelement()
-
- model_grad = model_param.main_grad
- shard_model_grad = model_grad.view(-1) \
- [param_range.start:param_range.end]
- shard_main_param.grad = shard_model_grad.float()
-
- # Copy model groups to shard groups.
- copy_group_grads(self.model_float16_groups,
- self.shard_fp32_from_float16_groups)
- copy_group_grads(self.model_fp32_groups,
- self.shard_fp32_groups)
-
-
- def _copy_main_params_to_model_params(self):
- """
- Copy main params to model params.
-
- Since this step is followed by an all-gather through the DDP's grad
- buffer, this method is responsible for copying the updated params
- from the main shards into the correct position in the grad buffer.
- """
-
- # Utility method for copying group params.
- def copy_group_params(shard_main_groups, model_groups):
- for shard_main_group, model_group in zip(shard_main_groups,
- model_groups):
- for shard_main_param, model_param in zip(shard_main_group,
- model_group):
-
- param_range_map = self.get_model_param_range_map(model_param)
- world_range = param_range_map["gbuf_world_in_bucket"]
-
- assert world_range.size == shard_main_param.nelement()
-
- model_id, dtype, bucket_id = self.model_param_gbuf_map[model_param]
- model_param_buffer = self.param_buffers[model_id][dtype][bucket_id]
-
- shard_model_param = model_param_buffer.view(-1) \
- [world_range.start:world_range.end]
-
- shard_model_param.data.copy_(shard_main_param)
-
- # Copy shard groups to model groups.
- copy_group_params(self.shard_fp32_from_float16_groups,
- self.model_float16_groups)
- copy_group_params(self.shard_fp32_groups,
- self.model_fp32_groups)
-
-
- def _copy_model_params_to_main_params(self):
- """
- Copy model params to main params.
-
- During finetuning, this method is used to reload the main params from
- the model params. This copy does not make use of the grad buffer as
- an intermediary.
- """
-
- # Utility method for copying group params.
- def copy_group_params(model_groups, shard_main_groups):
- for model_group, shard_main_group in zip(model_groups,
- shard_main_groups):
- for model_param, shard_main_param in zip(model_group,
- shard_main_group):
-
- param_range_map = self.get_model_param_range_map(model_param)
- param_range = param_range_map["param"]
- assert param_range.size == shard_main_param.nelement()
-
- shard_model_param = model_param.view(-1) \
- [param_range.start:param_range.end]
- shard_main_param.data.copy_(shard_model_param)
-
- # Copy model groups to shard groups.
- copy_group_params(self.model_float16_groups,
- self.shard_fp32_from_float16_groups)
- copy_group_params(self.model_fp32_groups,
- self.shard_fp32_groups)
-
-
- @torch.no_grad()
- def step(self, args, timers):
- self.update_successful, grad_norm, num_zeros_in_grad = super().step(args, timers)
-
- # Reset metadata needed to track results of all-gathers.
- self.all_gather_handles = []
- self.param_buffer_copied = []
-
- # If not overlapping all-gather for parameters, launch synchronous all-gather
- # communication calls here.
- if not self.overlap_param_gather:
- timers('params-all-gather', log_level=1).start(barrier=args.barrier_with_L1_time)
- for all_gather_handle_index in range(self.num_all_gather_handles):
- self._dispatch_gather_model_params(all_gather_handle_index)
- timers('params-all-gather').stop()
-
- return self.update_successful, grad_norm, num_zeros_in_grad
diff --git a/megatron/optimizer/grad_scaler.py b/megatron/optimizer/grad_scaler.py
deleted file mode 100644
index 66f7c907a41816de17e0d6a7bd5cb626722b33a6..0000000000000000000000000000000000000000
--- a/megatron/optimizer/grad_scaler.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron grad scaler."""
-
-from abc import ABC
-from abc import abstractmethod
-
-import torch
-
-
-class MegatronGradScaler(ABC):
-
- def __init__(self, initial_scale):
- """Initialize scale value with the input initial scale."""
- assert initial_scale > 0.0
- self._scale = torch.cuda.FloatTensor([initial_scale])
-
- @property
- def scale(self):
- return self._scale
-
- @property
- def inv_scale(self):
- return self._scale.double().reciprocal().float()
-
- @abstractmethod
- def update(self, found_inf):
- pass
-
- @abstractmethod
- def state_dict(self):
- pass
-
- @abstractmethod
- def load_state_dict(self, state_dict):
- pass
-
-
-
-class ConstantGradScaler(MegatronGradScaler):
-
- def update(self, found_inf):
- pass
-
- def state_dict(self):
- return dict()
-
- def load_state_dict(self, state_dict):
- pass
-
-
-
-class DynamicGradScaler(MegatronGradScaler):
-
- def __init__(self, initial_scale, min_scale,
- growth_factor, backoff_factor,
- growth_interval, hysteresis):
- """"Grad scaler with dynamic scale that gets adjusted
- during training."""
- super(DynamicGradScaler, self).__init__(initial_scale)
-
- # Lower bound on the scale.
- assert min_scale > 0.0
- assert min_scale <= initial_scale
- self.min_scale = torch.cuda.FloatTensor([min_scale])
- # Growth and backoff factors for the scale.
- assert growth_factor > 1.0
- self.growth_factor = torch.cuda.FloatTensor([growth_factor])
- assert backoff_factor < 1.0
- assert backoff_factor > 0.0
- self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
- # Interval over which if we don't see any inf/nan,
- # we will scale the grad scale by the growth factor.
- assert growth_interval > 0
- self.growth_interval = growth_interval
- # Number of inf/nans we should see before scaling down
- # the grad scale by the backoff factor.
- assert hysteresis > 0
- self.hysteresis = hysteresis
-
- # Trackers.
- self._growth_tracker = 0
- self._hysteresis_tracker = self.hysteresis
-
-
- def update(self, found_inf):
-
- # If we have an inf/nan, growth tracker is set to 0
- # and hysterisis tracker is reduced by 1.
- if found_inf:
- self._growth_tracker = 0
- self._hysteresis_tracker -= 1
- # Now if we are out of hysteresis count, scale down the loss.
- if self._hysteresis_tracker <= 0:
- self._scale = torch.max(self._scale * self.backoff_factor,
- self.min_scale)
- else:
- # If there is no nan/inf, increment the growth tracker.
- self._growth_tracker += 1
- # If we have had enough consequitive intervals with no nan/inf:
- if self._growth_tracker == self.growth_interval:
- # Reset the tracker and hysteresis trackers,
- self._growth_tracker = 0
- self._hysteresis_tracker = self.hysteresis
- # and scale up the loss scale.
- self._scale = self._scale * self.growth_factor
-
-
- def state_dict(self):
- state_dict = {}
- state_dict['scale'] = self._scale
- state_dict['growth_tracker'] = self._growth_tracker
- state_dict['hysteresis_tracker'] = self._hysteresis_tracker
- return state_dict
-
-
- def load_state_dict(self, state_dict):
- self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
- self._growth_tracker = state_dict['growth_tracker']
- self._hysteresis_tracker = state_dict['hysteresis_tracker']
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
deleted file mode 100644
index 23749959b90086c7904962c221199e0cd8cb699d..0000000000000000000000000000000000000000
--- a/megatron/optimizer/optimizer.py
+++ /dev/null
@@ -1,644 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron optimizer."""
-
-from abc import ABC
-from abc import abstractmethod
-from apex.multi_tensor_apply import multi_tensor_applier
-import amp_C
-import torch
-
-from megatron import get_timers
-from megatron import print_rank_0
-from megatron.core import mpu, tensor_parallel
-from megatron.model import Float16Module
-from megatron.model.module import param_is_not_shared
-
-from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
-
-
-def _zero_grad_group_helper(group, set_to_none):
- """Zero out the gradient for a group of parameters.
- Note: copied from torch.optim.optimizer."""
- for param in group:
- if param.grad is not None:
- if set_to_none:
- param.grad = None
- else:
- if param.grad.grad_fn is not None:
- param.grad.detach_()
- else:
- param.grad.requires_grad_(False)
- param.grad.zero_()
-
-
-def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
- """Use multi-tensor-applier to copy values from one list to another.
- We don't have a blfoat16 implementation so for now if the overflow_buf
- is not provided, we default back to simple loop copy to be compatible
- with bfloat16."""
- if overflow_buf:
- overflow_buf.fill_(0)
- # Scaling with factor `1.0` is equivalent to copy.
- multi_tensor_applier(amp_C.multi_tensor_scale,
- overflow_buf,
- [this, that],
- 1.0)
- else:
- for this_, that_ in zip(this, that):
- that_.copy_(this_)
-
-
-
-class MegatronOptimizer(ABC):
-
-
- def __init__(self, optimizer, clip_grad,
- log_num_zeros_in_grad,
- check_for_nan_in_grad,
- params_have_main_grad,
- models):
-
- """Input optimizer is the base optimizer for example Adam."""
- self.optimizer = optimizer
- assert self.optimizer, 'no optimizer is provided.'
- # Set gradient clipping and logging params.
- self.clip_grad = clip_grad
- self.log_num_zeros_in_grad = log_num_zeros_in_grad
- self.check_for_nan_in_grad = check_for_nan_in_grad
- self.params_have_main_grad = params_have_main_grad
-
- # 'models' are retained for access to the contiguous grad buffers.
- # (see distributed optimizer)
- self.models = models
-
-
- def get_parameters(self):
- params = []
- for param_group in self.optimizer.param_groups:
- for param in param_group['params']:
- params.append(param)
- return params
-
-
- def get_main_grads_for_grad_norm(self):
-
- # Filter parameters based on:
- # - grad should not be none
- # - parameter should not be shared
- # - should not be a replica due to tensor model parallelism
- params = self.get_parameters()
- grads_for_norm = []
- for param in params:
- grad = param.grad
- grad_not_none = grad is not None
- is_not_shared = param_is_not_shared(param)
- is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
- if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
-
- return grads_for_norm
-
-
- def get_model_parallel_group(self):
- """Default returned here, but the distributed optimizer overrides this."""
- return mpu.get_model_parallel_group()
-
-
- def clip_grad_norm(self, clip_grad, check_for_nan_in_grad):
- params = self.get_parameters()
- grads_for_norm = self.get_main_grads_for_grad_norm()
- return clip_grad_norm_fp32(
- params, grads_for_norm, clip_grad,
- check_for_nan_in_grad,
- model_parallel_group=self.get_model_parallel_group())
-
-
- def count_zeros(self):
- params = self.get_parameters()
- return count_zeros_fp32(params,
- model_parallel_group=self.get_model_parallel_group())
-
-
- @abstractmethod
- def zero_grad(self, set_to_none=True):
- pass
-
-
- @abstractmethod
- def get_loss_scale(self):
- """The output should be a cuda tensor of size 1."""
- pass
-
-
- def scale_loss(self, loss):
- """Simple scaling."""
- return self.get_loss_scale() * loss
-
-
- @abstractmethod
- def reload_model_params(self):
- """Refreshes any internal state from the current model parameters.
- Call whenever the parameters are changed outside of the optimizer.
- For example, when we load a model from a checkpoint without loading
- the optimizer, the model parameters are updated but for fp16 optimizer
- with main parameters, the main parameters need to also be updated."""
- pass
-
-
- @abstractmethod
- def state_dict(self):
- pass
-
-
- @abstractmethod
- def load_state_dict(self, state_dict):
- pass
-
-
- # Promote state so it can be retrieved or set via
- # "optimizer_instance.state"
- def _get_state(self):
- return self.optimizer.state
-
- def _set_state(self, value):
- self.optimizer.state = value
-
- state = property(_get_state, _set_state)
-
-
- # Promote param_groups so it can be retrieved or set via
- # "optimizer_instance.param_groups"
- # (for example, to adjust the learning rate)
- def _get_param_groups(self):
- return self.optimizer.param_groups
-
- def _set_param_groups(self, value):
- self.optimizer.param_groups = value
-
- param_groups = property(_get_param_groups, _set_param_groups)
-
-
- @abstractmethod
- def step(self, args, timers):
- pass
-
-
-
-class MixedPrecisionOptimizer(MegatronOptimizer):
- """Base class for both the float-16 and the distributed optimizer.
-
- Arguments:
- optimizer: base optimizer such as Adam or SGD
- clip_grad: clip gradeints with this global L2 norm. Note
- that clipping is ignored if clip_grad == 0
- log_num_zeros_in_grad: return number of zeros in the gradients.
- check_for_nan_in_grad: check if gradients have a NaN.
- params_have_main_grad: flag indicating if parameters have
- a `main_grad` field. If this is set, we are assuming
- that the model parameters are store in the `main_grad`
- field instead of the typical `grad` field. This happens
- for the DDP cases where there is a continuous buffer
- holding the gradients. For example for bfloat16, we want
- to do gradient accumulation and all-reduces in float32
- and as a result we store those gradients in the main_grad.
- Note that main grad is not necessarily in float32.
- fp16: if true, the model is running in fp16.
- bf16: if true, the model is running in bfloat16.
- params_dtype: used by distributed optimizer.
- grad_scaler: used for scaling gradients. Note that this can be
- None. This case happens when `bf16 = True` and we don't
- use any loss scale. Note that for `bf16 = True`, we can have
- a constnat gradient scaler. Also for `bf16 = False`, we
- always require a grad scaler.
- models: list of models (i.e., the virtual pipelining models). This
- is used by the distributed optimizer for mapping parameters.
- """
-
- def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad,
- fp16, bf16, params_dtype, grad_scaler, models):
-
- super().__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad,
- models)
-
- self.fp16 = fp16
- self.bf16 = bf16
- self.params_dtype = params_dtype
- self.grad_scaler = grad_scaler
-
- # None grad scaler is only supported for bf16.
- if self.grad_scaler is None:
- assert not self.fp16, 'fp16 expects a grad scaler.'
-
- # Tensor used to determine if a nan/if has happend.
- # Any non-zero value indicates inf/nan.
- # Note that we keep this for the cases that grad scaler is none.
- # We still record nan/inf if we have a bfloat16 with a grad scaler.
- if self.grad_scaler:
- self.found_inf = torch.cuda.FloatTensor([0.0])
-
- # Dummy tensor needed for apex multi-apply tensor.
- # For bfloat, we don't have multi-tensor apply and for now
- # we set it to none so the multi-tensor apply gets ignored.
- if bf16:
- self._dummy_overflow_buf = None
- else:
- self._dummy_overflow_buf = torch.cuda.IntTensor([0])
-
- # In case grad scaler is not passed, define the unity scale.
- if self.grad_scaler is None:
- self._scale_one = torch.cuda.FloatTensor([1.0])
-
-
- def get_loss_scale(self):
- if self.grad_scaler is None:
- return self._scale_one
- return self.grad_scaler.scale
-
-
- def reload_model_params(self):
- self._copy_model_params_to_main_params()
-
-
- def _unscale_main_grads_and_check_for_nan(self):
-
- # Collect main grads.
- main_grads = self._collect_main_grad_data_for_unscaling()
-
- # Reset found inf.
- self.found_inf.fill_(0.0)
-
- # Unscale and set found inf/nan
- torch._amp_foreach_non_finite_check_and_unscale_(
- main_grads, self.found_inf, self.grad_scaler.inv_scale)
-
- # Update across all model parallel instances.
- torch.distributed.all_reduce(self.found_inf,
- op=torch.distributed.ReduceOp.MAX,
- group=self.get_model_parallel_group())
-
- # Check for nan.
- found_inf_flag = (self.found_inf.item() > 0)
-
- return found_inf_flag
-
-
- @torch.no_grad()
- def step(self, args, timers):
-
- # Copy gradients from model params to main params.
- timers('optimizer-copy-to-main-grad', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- self._copy_model_grads_to_main_grads()
- timers('optimizer-copy-to-main-grad').stop()
-
- # Do unscale, check for inf, and update grad scaler only for
- # the case that grad scaler is provided.
- if self.grad_scaler:
-
- # Unscale and check for inf/nan.
- timers('optimizer-unscale-and-check-inf', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- found_inf_flag = self._unscale_main_grads_and_check_for_nan()
- timers('optimizer-unscale-and-check-inf').stop()
-
- # We are done with scaling gradients
- # so we can update the loss scale.
- self.grad_scaler.update(found_inf_flag)
-
- # If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
-
- # Clip the main gradients.
- timers('optimizer-clip-main-grad', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad,
- self.check_for_nan_in_grad)
- timers('optimizer-clip-main-grad').stop()
-
- # Count the zeros in the grads.
- timers('optimizer-count-zeros', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- num_zeros_in_grad = self.count_zeros() if \
- self.log_num_zeros_in_grad else None
- timers('optimizer-count-zeros').stop()
-
- # Step the optimizer.
- timers('optimizer-inner-step', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- self.optimizer.step()
- timers('optimizer-inner-step').stop()
-
- # Update params from main params.
- timers('optimizer-copy-main-to-model-params', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- self._copy_main_params_to_model_params()
- timers('optimizer-copy-main-to-model-params').stop()
-
- # Successful update.
- return True, grad_norm, num_zeros_in_grad
-
-
-class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
- """Float16 optimizer for fp16 and bf16 data types.
-
- Arguments:
- optimizer: base optimizer such as Adam or SGD
- clip_grad: clip gradeints with this global L2 norm. Note
- that clipping is ignored if clip_grad == 0
- log_num_zeros_in_grad: return number of zeros in the gradients.
- check_for_nan_in_grad: check if gradients have a NaN.
- params_have_main_grad: flag indicating if parameters have
- a `main_grad` field. If this is set, we are assuming
- that the model parameters are store in the `main_grad`
- field instead of the typical `grad` field. This happens
- for the DDP cases where there is a continuous buffer
- holding the gradients. For example for bfloat16, we want
- to do gradient accumulation and all-reduces in float32
- and as a result we store those gradients in the main_grad.
- Note that main grad is not necessarily in float32.
- fp16: if true, the model is running in fp16.
- bf16: if true, the model is running in bfloat16.
- grad_scaler: used for scaling gradients. Note that this can be
- None. This case happens when `bf16 = True` and we don't
- use any loss scale. Note that for `bf16 = True`, we can have
- a constnat gradient scaler. Also for `bf16 = False`, we
- always require a grad scaler.
- models: list of models (i.e., the virtual pipelining models). This
- is used by the distributed optimizer for mapping parameters.
- """
-
- def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad, fp16, bf16,
- params_dtype, grad_scaler, models):
-
- super().__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad,
- fp16, bf16, params_dtype, grad_scaler, models)
-
- # ======================
- # main parameter stuff
- # ======================
-
- # Three groups of parameters:
- # float16_groups: original float16 parameters
- # fp32_from_float16_groups: fp32 copy of float16 parameters
- # fp32_from_fp32_groups: original fp32 parameters
- self.float16_groups = []
- self.fp32_from_float16_groups = []
- self.fp32_from_fp32_groups = []
-
- # For all the groups in the original optimizer:
- for param_group in self.optimizer.param_groups:
- float16_params_this_group = []
- fp32_params_this_group = []
- fp32_from_float16_params_this_group = []
- # For all the parameters in this group:
- for i, param in enumerate(param_group['params']):
- if param.requires_grad:
-
- # float16 params:
- if param.type() in ['torch.cuda.HalfTensor',
- 'torch.cuda.BFloat16Tensor']:
- float16_params_this_group.append(param)
- # Create a copy
- main_param = param.detach().clone().float()
- # Copy tensor model parallel attributes.
- tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
- param)
- if hasattr(param, 'shared'):
- main_param.shared = param.shared
- # Replace the optimizer params with the new fp32 copy.
- param_group['params'][i] = main_param
-
- fp32_from_float16_params_this_group.append(main_param)
- # Reset existing state dict key to the new main param.
- if param in self.optimizer.state:
- self.optimizer.state[main_param] \
- = self.optimizer.state.pop(param)
- # fp32 params.
- elif param.type() == 'torch.cuda.FloatTensor':
- fp32_params_this_group.append(param)
- param_group['params'][i] = param
-
- else:
- raise TypeError('Wrapped parameters must be one of '
- 'torch.cuda.FloatTensor, '
- 'torch.cuda.HalfTensor, or '
- 'torch.cuda.BFloat16Tensor. '
- 'Received {}'.format(param.type()))
-
- self.float16_groups.append(float16_params_this_group)
- self.fp32_from_float16_groups.append(
- fp32_from_float16_params_this_group)
- self.fp32_from_fp32_groups.append(fp32_params_this_group)
-
-
- def zero_grad(self, set_to_none=True):
- """We only need to zero the model related parameters, i.e.,
- float16_groups & fp32_from_fp32_groups. We additionally zero
- fp32_from_float16_groups as a memory optimization to reduce
- fragmentation; in the case of set_to_none==True, the space
- used by this field can be safely deallocated at this point."""
- for group in self.float16_groups:
- _zero_grad_group_helper(group, set_to_none)
- for group in self.fp32_from_float16_groups:
- _zero_grad_group_helper(group, set_to_none)
- for group in self.fp32_from_fp32_groups:
- _zero_grad_group_helper(group, set_to_none)
-
-
- def _collect_main_grad_data_for_unscaling(self):
-
- main_grads = []
-
- # fp32 params from float16 ones.
- for main_group in self.fp32_from_float16_groups:
- for main_param in main_group:
- if main_param.grad is not None:
- main_grads.append(main_param.grad.data)
-
- # Append fp32 parameters.
- for main_group in self.fp32_from_fp32_groups:
- for main_param in main_group:
- if main_param.grad is not None:
- main_grads.append(main_param.grad.data)
-
- return main_grads
-
-
- def _get_model_and_main_params_data_float16(self):
- model_data = []
- main_data = []
- for model_group, main_group in zip(self.float16_groups,
- self.fp32_from_float16_groups):
- for model_param, main_param in zip(model_group, main_group):
- model_data.append(model_param.data)
- main_data.append(main_param.data)
- return model_data, main_data
-
-
- def _copy_model_grads_to_main_grads(self):
- # This only needs to be done for the float16 group.
- for model_group, main_group in zip(self.float16_groups,
- self.fp32_from_float16_groups):
- for model_param, main_param in zip(model_group, main_group):
- if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
- main_param.grad = model_param.main_grad.float()
- else:
- if model_param.grad is not None:
- main_param.grad = model_param.grad.float()
-
- # Safe to deallocate model's grad/main_grad after copying.
- # (If using contiguous buffers, main_grad's memory should
- # persist and therefore should not be deallocated.)
- model_param.grad = None
-
- # For fp32 grads, we need to reset the grads to main grad.
- if self.params_have_main_grad:
- for model_group in self.fp32_from_fp32_groups:
- for model_param in model_group:
- model_param.grad = model_param.main_grad
-
-
- def _copy_main_params_to_model_params(self):
- # Only needed for the float16 params.
- model_data, main_data = self._get_model_and_main_params_data_float16()
- _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
- overflow_buf=self._dummy_overflow_buf)
-
-
- def _copy_model_params_to_main_params(self):
- # Only needed for the float16 params.
- model_data, main_data = self._get_model_and_main_params_data_float16()
- _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
- overflow_buf=self._dummy_overflow_buf)
-
-
- def state_dict(self):
- state_dict = {}
- state_dict['optimizer'] = self.optimizer.state_dict()
- if self.grad_scaler:
- state_dict['grad_scaler'] = self.grad_scaler.state_dict()
- state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
- return state_dict
-
-
- def load_state_dict(self, state_dict):
- # Optimizer.
- optimizer_key = 'optimizer'
- if optimizer_key not in state_dict:
- optimizer_key = 'optimizer_state_dict'
- print_rank_0('***WARNING*** loading optimizer from '
- 'an old checkpoint ...')
- self.optimizer.load_state_dict(state_dict[optimizer_key])
-
- # Grad scaler.
- if 'grad_scaler' not in state_dict:
- if self.fp16:
- print_rank_0('***WARNING*** found an old checkpoint, will not '
- 'load grad scaler ...')
- else:
- if self.grad_scaler:
- self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
- else:
- print_rank_0('***WARNING*** fould the grad scaler in the '
- 'checkpoint but it is None in the class. '
- 'Skipping loading grad scaler ...')
-
- # Copy data for the main params.
- fp32_from_float16_params_key = 'fp32_from_fp16_params'
- if fp32_from_float16_params_key not in state_dict:
- fp32_from_float16_params_key = 'fp32_from_fp16'
- for current_group, saved_group in zip(
- self.fp32_from_float16_groups,
- state_dict[fp32_from_float16_params_key]):
- for current_param, saved_param in zip(current_group, saved_group):
- current_param.data.copy_(saved_param.data)
-
-
-class FP32Optimizer(MegatronOptimizer):
-
- def __init__(self, optimizer, clip_grad,
- log_num_zeros_in_grad,
- check_for_nan_in_grad,
- params_have_main_grad,
- models):
-
- super(FP32Optimizer, self).__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- check_for_nan_in_grad, params_have_main_grad,
- models)
-
- self._scale = torch.cuda.FloatTensor([1.0])
-
-
- def zero_grad(self, set_to_none=True):
- """Copied from torch.optim.optimizer"""
- for group in self.optimizer.param_groups:
- _zero_grad_group_helper(group['params'], set_to_none)
-
-
- def get_loss_scale(self):
- """FP32 optimizer does not do any scaling."""
- return self._scale
-
-
- @torch.no_grad()
- def step(self, args, timers):
- """Clip gradients (if needed) and step the base optimizer.
- Always return successful since there is no overflow."""
-
- # Copy main_grads to grads.
- timers('optimizer-copy-to-main-grad', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- if self.params_have_main_grad:
- for param_group in self.optimizer.param_groups:
- for param in param_group['params']:
- param.grad = param.main_grad
-
- timers('optimizer-copy-to-main-grad').stop()
-
- # Clip gradients.
- timers('optimizer-clip-main-grad', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad,
- self.check_for_nan_in_grad)
- timers('optimizer-clip-main-grad').stop()
-
- # count the zeros in the grads
- timers('optimizer-count-zeros', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- num_zeros_in_grad = self.count_zeros() if \
- self.log_num_zeros_in_grad else None
- timers('optimizer-count-zeros').stop()
-
- # Update parameters.
- timers('optimizer-inner-step', log_level=1).start(
- barrier=args.barrier_with_L1_time)
- self.optimizer.step()
- timers('optimizer-inner-step').stop()
-
- # No overflow for FP32 optimizer.
- return True, grad_norm, num_zeros_in_grad
-
-
- def reload_model_params(self):
- pass
-
-
- def state_dict(self):
- return self.optimizer.state_dict()
-
-
- def load_state_dict(self, state_dict):
- self.optimizer.load_state_dict(state_dict)
diff --git a/megatron/optimizer/utils.py b/megatron/optimizer/utils.py
deleted file mode 100644
index f4b7cbd634b4f67d3a1fbf5a950cf1cc95664b26..0000000000000000000000000000000000000000
--- a/megatron/optimizer/utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Utility functions for Megatron optimizer."""
-
-
-from megatron.core import mpu
-
-
-def shard_buffer(buffer):
- """
- Shard buffer into dp_size chunks of equal size.
- """
- data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
- assert buffer.numel() % data_parallel_world_size == 0
- shard_size = buffer.numel() // data_parallel_world_size
- sharded_buffer = [buffer[(r*shard_size):((r+1)*shard_size)]
- for r in range(data_parallel_world_size)]
- return sharded_buffer
-
diff --git a/megatron/optimizer_param_scheduler.py b/megatron/optimizer_param_scheduler.py
deleted file mode 100644
index 0cf5fb1d8fcdb9bf212828464d1785e1f0284e61..0000000000000000000000000000000000000000
--- a/megatron/optimizer_param_scheduler.py
+++ /dev/null
@@ -1,235 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Learning rate decay and weight decay incr functions."""
-
-import math
-
-from megatron import print_rank_0
-
-class OptimizerParamScheduler(object):
- """Anneals learning rate and weight decay"""
-
- def __init__(self, optimizer, init_lr, max_lr, min_lr,
- lr_warmup_steps, lr_decay_steps, lr_decay_style,
- start_wd, end_wd, wd_incr_steps, wd_incr_style,
- use_checkpoint_opt_param_scheduler=True,
- override_opt_param_scheduler=False):
-
- # Class values.
- self.optimizer = optimizer
-
- self.init_lr = init_lr
- self.max_lr = float(max_lr)
- self.min_lr = min_lr
- assert self.min_lr >= 0.0
- assert self.max_lr >= self.min_lr
- assert self.init_lr <= self.max_lr
-
- self.lr_warmup_steps = lr_warmup_steps
- self.num_steps = 0
- self.lr_decay_steps = lr_decay_steps
- assert self.lr_decay_steps > 0
- assert self.lr_warmup_steps < self.lr_decay_steps
-
- self.lr_decay_style = lr_decay_style
-
- self.start_wd = start_wd
- self.end_wd = end_wd
- assert self.start_wd >= 0.0
- assert self.end_wd >= self.start_wd
- self.wd_incr_steps = wd_incr_steps
- self.wd_incr_style = wd_incr_style
-
- self.override_opt_param_scheduler = override_opt_param_scheduler
- self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
- if self.override_opt_param_scheduler:
- assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
- 'use-checkpoint are set.'
-
- # Set the learning rate
- self.step(0)
- print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
-
-
- def get_wd(self):
- """ Weight decay incr functions"""
- if self.num_steps > self.wd_incr_steps:
- return self.end_wd
-
- if self.wd_incr_style == 'constant':
- assert self.start_wd == self.end_wd
- return self.end_wd
-
- incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
- assert incr_ratio >= 0.0
- assert incr_ratio <= 1.0
- delta_wd = self.end_wd - self.start_wd
-
- if self.wd_incr_style == 'linear':
- coeff = incr_ratio
- elif self.wd_incr_style == 'cosine':
- coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
- else:
- raise Exception('{} weight decay increment style is not supported.'.format(
- self.wd_incr_style))
-
- return self.start_wd + coeff * delta_wd
-
-
- def get_lr(self):
- """Learning rate decay functions from:
- https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
-
- # Use linear warmup for the initial part.
- if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
- return (
- self.init_lr
- + (
- (self.max_lr - self.init_lr)
- * float(self.num_steps)
- / float(self.lr_warmup_steps)
- )
- )
-
- # If the learning rate is constant, just return the initial value.
- if self.lr_decay_style == 'constant':
- return self.max_lr
-
- # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
- if self.num_steps > self.lr_decay_steps:
- return self.min_lr
-
- # If we are done with the warmup period, use the decay style.
- if self.lr_decay_style == 'inverse-square-root':
- warmup_steps = max(self.lr_warmup_steps, 1)
- num_steps = max(self.num_steps, 1)
- lr = self.max_lr * warmup_steps ** 0.5 / (num_steps ** 0.5)
- return max(self.min_lr, lr)
-
- num_steps_ = self.num_steps - self.lr_warmup_steps
- decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
- decay_ratio = float(num_steps_) / float(decay_steps_)
- assert decay_ratio >= 0.0
- assert decay_ratio <= 1.0
- delta_lr = self.max_lr - self.min_lr
-
- if self.lr_decay_style == 'linear':
- coeff = (1.0 - decay_ratio)
- elif self.lr_decay_style == 'cosine':
- coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
- else:
- raise Exception('{} decay style is not supported.'.format(
- self.lr_decay_style))
-
- return self.min_lr + coeff * delta_lr
-
-
- def step(self, increment):
- """Set lr for all parameters groups."""
- self.num_steps += increment
- new_lr = self.get_lr()
- new_wd = self.get_wd()
- for group in self.optimizer.param_groups:
- group['lr'] = new_lr * group.get('lr_mult', 1.0)
- group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
-
-
- def state_dict(self):
- state_dict = {
- 'max_lr': self.max_lr,
- 'lr_warmup_steps': self.lr_warmup_steps,
- 'num_steps': self.num_steps,
- 'lr_decay_style': self.lr_decay_style,
- 'lr_decay_steps': self.lr_decay_steps,
- 'min_lr': self.min_lr,
- 'start_wd': self.start_wd,
- 'end_wd': self.end_wd,
- 'wd_incr_style': self.wd_incr_style,
- 'wd_incr_steps': self.wd_incr_steps
- }
- return state_dict
-
-
- def _check_and_set(self, cls_value, sd_value, name):
- """Auxiliary function for checking the values in the checkpoint and
- setting them."""
- if self.override_opt_param_scheduler:
- print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
- return cls_value
-
- if not self.use_checkpoint_opt_param_scheduler:
- assert cls_value == sd_value, \
- f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
- f'value {sd_value} for {name} do not match'
- print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
- name))
- return sd_value
-
-
- def load_state_dict(self, sd):
-
- if 'start_lr' in sd:
- max_lr_ = sd['start_lr']
- else:
- max_lr_ = sd['max_lr']
- self.max_lr = self._check_and_set(self.max_lr, max_lr_,
- 'learning rate')
-
- self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
- 'minimum learning rate')
-
- if 'warmup_iter' in sd:
- lr_warmup_steps_ = sd['warmup_iter']
- elif 'warmup_steps' in sd:
- lr_warmup_steps_ = sd['warmup_steps']
- else:
- lr_warmup_steps_ = sd['lr_warmup_steps']
- self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
- lr_warmup_steps_,
- 'warmup iterations')
-
- if 'end_iter' in sd:
- lr_decay_steps_ = sd['end_iter']
- elif 'decay_steps' in sd:
- lr_decay_steps_ = sd['decay_steps']
- else:
- lr_decay_steps_ = sd['lr_decay_steps']
- self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
- 'total number of iterations')
-
- if 'decay_style' in sd:
- lr_decay_style_ = sd['decay_style']
- else:
- lr_decay_style_ = sd['lr_decay_style']
- self.lr_decay_style = self._check_and_set(self.lr_decay_style,
- lr_decay_style_,
- 'learning rate decay style')
-
- if 'num_iters' in sd:
- num_steps = sd['num_iters']
- else:
- num_steps = sd['num_steps']
- self.step(increment=num_steps)
-
-
- if 'start_wd' in sd:
- self.start_wd = self._check_and_set(self.start_wd,
- sd['start_wd'],
- "start weight decay")
- self.end_wd = self._check_and_set(self.end_wd,
- sd['end_wd'],
- "end weight decay")
- self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
- sd['wd_incr_steps'],
- "total number of weight decay iterations")
- self.wd_incr_style = self._check_and_set(self.wd_incr_style,
- sd['wd_incr_style'],
- "weight decay incr style")
-
-
-
-
-
-
-
-
diff --git a/megatron/static/index.html b/megatron/static/index.html
deleted file mode 100644
index 806287955bcc02e2d4148855af5ddb36ba94ae72..0000000000000000000000000000000000000000
--- a/megatron/static/index.html
+++ /dev/null
@@ -1,124 +0,0 @@
-
-
-
-
-
-
-
-Megatron
-
-
-
-
-
Prompt Megatron
-
-
-
-
-
-
-0
-/ 1000
-
-
-
-
-
-
-
-
diff --git a/megatron/text_generation/__init__.py b/megatron/text_generation/__init__.py
deleted file mode 100644
index 77da7be30ae4d02bd7ab1e4bae86afc8923d4e23..0000000000000000000000000000000000000000
--- a/megatron/text_generation/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-
-from .api import (
- generate,
- generate_and_post_process,
- beam_search_and_post_process)
diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py
deleted file mode 100644
index 4557ff3c12e219daf3a8092390c5a74a7e56b4e0..0000000000000000000000000000000000000000
--- a/megatron/text_generation/api.py
+++ /dev/null
@@ -1,207 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Inference API."""
-
-
-import torch
-
-from megatron.core import mpu
-from .communication import broadcast_float_list
-from .generation import (
- generate_tokens_probs_and_return_on_first_stage,
- score_and_return_on_first_stage,
- beam_search_and_return_on_first_stage)
-from .tokenization import (
- tokenize_prompts,
- detokenize_generations)
-
-def generate_and_post_process(model,
- prompts=None,
- tokens_to_generate=0,
- return_output_log_probs=False,
- top_k_sampling=0,
- top_p_sampling=0.0,
- top_p_decay=0.0,
- top_p_bound=0.0,
- temperature=1.0,
- add_BOS=False,
- use_eod_token_for_early_termination=True,
- stop_on_double_eol=False,
- stop_on_eol=False,
- prevent_newline_after_colon=False,
- random_seed=-1,
- return_logits=False):
- """Run inference and post-process outputs, i.e., detokenize,
- move to cpu and convert to list."""
-
- # Main inference.
- tokens, lengths, output_log_probs, logits = generate(
- model,
- prompts=prompts,
- tokens_to_generate=tokens_to_generate,
- return_output_log_probs=return_output_log_probs,
- top_k_sampling=top_k_sampling,
- top_p_sampling=top_p_sampling,
- top_p_decay=top_p_decay,
- top_p_bound=top_p_bound,
- temperature=temperature,
- add_BOS=add_BOS,
- use_eod_token_for_early_termination=use_eod_token_for_early_termination,
- stop_on_double_eol=stop_on_double_eol,
- stop_on_eol=stop_on_eol,
- prevent_newline_after_colon=prevent_newline_after_colon,
- random_seed=random_seed)
-
- # Only post-process on first stage.
- if mpu.is_pipeline_first_stage():
- tokens, prompts_plus_generations, prompts_plus_generations_segments = \
- detokenize_generations(tokens, lengths, True)
-
- if return_output_log_probs:
- output_log_probs = output_log_probs.cpu().numpy().tolist()
- for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
- output_log_probs[i] = prob[:len(seg)-1]
-
- if return_logits:
- assert(tokens_to_generate == 0)
- assert(mpu.get_pipeline_model_parallel_world_size() == 1)
- return prompts_plus_generations, prompts_plus_generations_segments, \
- output_log_probs, tokens, logits
- else:
- return prompts_plus_generations, prompts_plus_generations_segments, \
- output_log_probs, tokens
-
- return None
-
-def generate(model,
- prompts=None,
- tokens_to_generate=0,
- return_output_log_probs=False,
- top_k_sampling=0,
- top_p_sampling=0.0,
- top_p_decay=0.0,
- top_p_bound=0.0,
- temperature=1.0,
- add_BOS=False,
- use_eod_token_for_early_termination=True,
- stop_on_double_eol=False,
- stop_on_eol=False,
- prevent_newline_after_colon=False,
- random_seed=-1):
- """Given prompts and input parameters, run inference and return:
- tokens: prompts plus the generated tokens.
- lengths: length of the prompt + generations. Note that we can
- discard tokens in the tokens tensor that are after the
- corresponding length.
- output_log_probs: log probs of the tokens.
- """
-
- # Make sure input params are avaialble to all ranks.
- values = [tokens_to_generate,
- return_output_log_probs,
- top_k_sampling, top_p_sampling, top_p_decay, top_p_bound,
- temperature, add_BOS, use_eod_token_for_early_termination,
- stop_on_double_eol,
- stop_on_eol,
- prevent_newline_after_colon,
- random_seed]
- values_float_tensor = broadcast_float_list(len(values), float_list=values)
- tokens_to_generate = int(values_float_tensor[0].item())
- return_output_log_probs = bool(values_float_tensor[1].item())
- top_k_sampling = int(values_float_tensor[2].item())
- top_p_sampling = values_float_tensor[3].item()
- top_p_decay = values_float_tensor[4].item()
- top_p_bound = values_float_tensor[5].item()
- temperature = values_float_tensor[6].item()
- add_BOS = bool(values_float_tensor[7].item())
- use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
- stop_on_double_eol = bool(values_float_tensor[9].item())
- stop_on_eol = bool(values_float_tensor[10].item())
- prevent_newline_after_colon = bool(values_float_tensor[11].item())
- random_seed = int(values_float_tensor[12].item())
-
- if random_seed != -1:
- torch.random.manual_seed(random_seed)
-
- # Tokenize prompts and get the batch.
- # Note that these tensors are broadcaseted to all ranks.
- if torch.distributed.get_rank() == 0:
- assert prompts is not None
-
- context_tokens_tensor, context_length_tensor = tokenize_prompts(
- prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
-
- if tokens_to_generate == 0:
- return score_and_return_on_first_stage(
- model, context_tokens_tensor, context_length_tensor)
-
- # Main inference function.
- # Note that the outputs are available on the first stage.
- return generate_tokens_probs_and_return_on_first_stage(
- model, context_tokens_tensor, context_length_tensor,
- return_output_log_probs=return_output_log_probs,
- top_k=top_k_sampling,
- top_p=top_p_sampling,
- top_p_decay=top_p_decay,
- top_p_bound=top_p_bound,
- temperature=temperature,
- use_eod_token_for_early_termination=use_eod_token_for_early_termination,
- stop_on_double_eol=stop_on_double_eol,
- stop_on_eol=stop_on_eol,
- prevent_newline_after_colon=prevent_newline_after_colon)
-
-def beam_search_and_post_process(model,
- prompts=None,
- tokens_to_generate=0,
- beam_size=0,
- add_BOS=False,
- stop_token=50256,
- num_return_gen=1,
- length_penalty=1,
- prevent_newline_after_colon=False):
- """Run beam search and post-process outputs, i.e., detokenize,
- move to cpu and convert to list."""
-
- # Main inference.
- tokens, scores = beam_search(model,
- prompts=prompts,
- tokens_to_generate=tokens_to_generate,
- beam_size=beam_size,
- add_BOS=add_BOS,
- stop_token=stop_token,
- num_return_gen=num_return_gen,
- length_penalty=length_penalty,
- prevent_newline_after_colon=prevent_newline_after_colon)
- # Only post-process on first stage.
- if mpu.is_pipeline_first_stage():
- lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
- tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
- scores = scores.cpu().numpy().tolist()
- return prompts_plus_generations, prompts_plus_generations_segments, scores
-
- return None
-
-def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
- # Make sure input params are avaialble to all ranks.
- values = [tokens_to_generate,
- beam_size,
- add_BOS,
- stop_token,
- num_return_gen,
- length_penalty,
- prevent_newline_after_colon]
- values_float_tensor = broadcast_float_list(len(values), float_list=values)
- tokens_to_generate = int(values_float_tensor[0].item())
- beam_size = int(values_float_tensor[1].item())
- add_BOS = bool(values_float_tensor[2].item())
- stop_token = int(values_float_tensor[3].item())
- num_return_gen = int(values_float_tensor[4].item())
- length_penalty = values_float_tensor[5].item()
- prevent_newline_after_colon = values_float_tensor[6].item()
-
- context_tokens_tensor, context_length_tensor = tokenize_prompts(
- prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
-
- return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor,
- beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty,
- prevent_newline_after_colon=prevent_newline_after_colon)
diff --git a/megatron/text_generation/beam_utils.py b/megatron/text_generation/beam_utils.py
deleted file mode 100644
index 911a64143a86c8521abd9741df22de528a82f692..0000000000000000000000000000000000000000
--- a/megatron/text_generation/beam_utils.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
-# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-## from huggingface beam search
-class BeamHypotheses(object):
- def __init__(self, num_beams, length_penalty=1.0, early_stopping=False):
- """
- Initialize n-best list of hypotheses.
- """
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- self.num_beams = num_beams
- self.beams = []
- self.worst_score = 1e9
-
- def __len__(self):
- """
- Number of hypotheses in the list.
- """
- return len(self.beams)
-
- def add(self, hyp, sum_logprobs, length):
- """
- Add a new hypothesis to the list.
- """
- score = sum_logprobs / length ** self.length_penalty
- if len(self) < self.num_beams or score > self.worst_score:
- self.beams.append((score, hyp))
- if len(self) > self.num_beams:
- sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
- del self.beams[sorted_scores[0][1]]
- self.worst_score = sorted_scores[1][0]
- else:
- self.worst_score = min(score, self.worst_score)
-
- def is_done(self, best_sum_logprobs, cur_len):
- """
- If there are enough hypotheses and that none of the hypotheses being generated
- can become better than the worst one in the heap, then we are done with this sentence.
- """
-
- if len(self) < self.num_beams:
- return False
- elif self.early_stopping:
- return True
- else:
- cur_score = best_sum_logprobs / cur_len ** self.length_penalty
- ret = self.worst_score >= cur_score
- return ret
-
diff --git a/megatron/text_generation/communication.py b/megatron/text_generation/communication.py
deleted file mode 100644
index dee32077f34904f7585fab0f5180a5d014f7829f..0000000000000000000000000000000000000000
--- a/megatron/text_generation/communication.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Communications utilities."""
-
-
-import torch
-
-from megatron.core import mpu
-
-
-
-# TODO: use functions from megatron/p2p
-def recv_from_prev_pipeline_rank_(recv_buffer=None):
- """Receive from previous pipeline stage and update the
- input buffer inplace."""
- if not mpu.is_pipeline_first_stage():
- assert recv_buffer is not None
- recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv, recv_buffer,
- mpu.get_pipeline_model_parallel_prev_rank())
- reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
- for req in reqs:
- req.wait()
- # To protect against race condition when using batch_isend_irecv().
- torch.cuda.synchronize()
-
-
-
-# TODO: use functions from megatron/p2p
-def send_to_next_pipeline_rank(tensor=None):
- """Send output to the next pipeline stage."""
- if not mpu.is_pipeline_last_stage():
- assert tensor is not None
- send_next_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor,
- mpu.get_pipeline_model_parallel_next_rank())
- reqs = torch.distributed.batch_isend_irecv([send_next_op])
- for req in reqs:
- req.wait()
- # To protect against race condition when using batch_isend_irecv().
- torch.cuda.synchronize()
-
-
-
-def _is_cuda(tensor):
- """Check if a tensor is not none and is cuda."""
- assert tensor is not None
- assert tensor.is_cuda
-
-
-
-def _is_cuda_contiguous(tensor):
- """Check if a tensor is not none, is cuda, and is contiguous."""
- _is_cuda(tensor)
- assert tensor.is_contiguous()
-
-
-
-def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
- """Broadcast a tensor from last pipeline stage to all ranks."""
-
- is_last_stage = mpu.is_pipeline_last_stage()
- # If first stage and last state are the same, then there is no
- # pipeline parallelism and no need to communicate.
- if mpu.is_pipeline_first_stage() and is_last_stage:
- return tensor
-
- if is_last_stage:
- _is_cuda_contiguous(tensor)
- else:
- tensor = torch.empty(size,
- dtype=dtype,
- device=torch.cuda.current_device())
- # Get the group and corresponding source rank.
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_pipeline_model_parallel_group()
- torch.distributed.broadcast(tensor, src, group)
-
- return tensor
-
-
-
-def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
- """Broadcast tensor values from last stage into the first stage."""
-
- is_last_stage = mpu.is_pipeline_last_stage()
- is_first_stage = mpu.is_pipeline_first_stage()
- # If first stage and last state are the same, then there is no
- # pipeline parallelism and no need to communicate.
- if is_first_stage and is_last_stage:
- return tensor
- # Only first and last stage pipeline stages need to be involved.
- if is_last_stage or is_first_stage:
- if is_last_stage:
- _is_cuda_contiguous(tensor)
- else:
- tensor = torch.empty(size,
- dtype=dtype,
- device=torch.cuda.current_device())
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_embedding_group()
- # Broadcast from last stage into the first stage.
- torch.distributed.broadcast(tensor, src, group)
- else:
- tensor = None
-
- return tensor
-
-
-
-def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
- """Copy tensor values from last stage into the first stage.
- Note that the input tensor is updated in place."""
-
- is_last_stage = mpu.is_pipeline_last_stage()
- is_first_stage = mpu.is_pipeline_first_stage()
- # If first stage and last state are the same, then there is no
- # pipeline parallelism and no need to communicate.
- if is_first_stage and is_last_stage:
- return
- # Only first and last stage pipeline stages need to be involved.
- if is_last_stage or is_first_stage:
- _is_cuda(tensor)
- is_contiguous = tensor.is_contiguous()
- src = mpu.get_pipeline_model_parallel_last_rank()
- group = mpu.get_embedding_group()
- if is_contiguous:
- tensor_ = tensor
- else:
- if is_last_stage:
- tensor_ = tensor.contiguous()
- else:
- tensor_ = torch.empty(size,
- dtype=dtype,
- device=torch.cuda.current_device())
- # Broadcast from last stage into the first stage.
- torch.distributed.broadcast(tensor_, src, group)
- # Update the first stage tensor
- if is_first_stage and not is_contiguous:
- tensor[...] = tensor_
-
-
-
-def broadcast_tensor(size, dtype, tensor=None, rank=0):
- """ Given size and type of a tensor on all ranks and the tensor value
- only on a specific rank, broadcast from that rank to all other ranks.
- """
-
- if torch.distributed.get_rank() == rank:
- _is_cuda_contiguous(tensor)
- else:
- tensor = torch.empty(size,
- dtype=dtype,
- device=torch.cuda.current_device())
-
- torch.distributed.broadcast(tensor, rank)
-
- return tensor
-
-
-
-def broadcast_list(size, dtype, list_values=None, rank=0):
- """Broadcast a list of values with a given type."""
-
- tensor = None
- if torch.distributed.get_rank() == rank:
- tensor = torch.tensor(list_values, dtype=dtype,
- device=torch.cuda.current_device())
-
- return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
-
-
-
-def broadcast_int_list(size, int_list=None, rank=0):
- """Broadcast a list of interger values."""
-
- return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
-
-
-
-def broadcast_float_list(size, float_list=None, rank=0):
- """Broadcast a list of float values."""
-
- return broadcast_list(size, torch.float32, list_values=float_list,
- rank=rank)
diff --git a/megatron/text_generation/forward_step.py b/megatron/text_generation/forward_step.py
deleted file mode 100644
index 6a88709a521b29d40cf9f7760abe0f95756900f6..0000000000000000000000000000000000000000
--- a/megatron/text_generation/forward_step.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Forward step utilities."""
-
-from collections.abc import Iterable
-
-import torch
-
-from megatron import get_args
-from megatron.core import mpu, InferenceParams
-from .communication import (
- send_to_next_pipeline_rank,
- recv_from_prev_pipeline_rank_)
-
-
-class ForwardStep:
- """Forward step function with all the communications.
- We use a class here to hide the inference parameters
- from the outside caller."""
-
- def __init__(self, model, max_batch_size, max_sequence_length):
- """Set values so we don't need to do it multiple times."""
- # Make sure model is in eval mode.
- assert not isinstance(model, Iterable), \
- 'interleaving schedule is not supported for inference'
- model.eval()
- self.model = model
- # Initialize inference parameters.
- self.inference_params = InferenceParams(max_batch_size,
- max_sequence_length)
- # Pipelining arguments.
- args = get_args()
- self.pipeline_size_larger_than_one = (
- args.pipeline_model_parallel_size > 1)
- # Threshold of pipelining.
- self.pipelining_batch_x_seqlen = \
- args.inference_batch_times_seqlen_threshold
-
-
- def __call__(self, tokens, position_ids, attention_mask):
- """Invocation of the forward methods. Note that self.inference_params
- is being modified by the forward step."""
- # Pipelining case.
- if self.pipeline_size_larger_than_one:
- current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
- if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
- micro_batch_size = \
- max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
- return _with_pipelining_forward_step(self.model,
- tokens,
- position_ids,
- attention_mask,
- self.inference_params,
- micro_batch_size)
-
- return _no_pipelining_forward_step(self.model,
- tokens,
- position_ids,
- attention_mask,
- self.inference_params)
-
-
-
-def _get_recv_buffer_dtype(args):
- """Receive happens between the layers."""
- if args.fp32_residual_connection:
- return torch.float
- return args.params_dtype
-
-
-
-def _allocate_recv_buffer(batch_size, sequence_length):
- """Receive happens between the layers with size [s, b, h]."""
- if mpu.is_pipeline_first_stage():
- return None
- args = get_args()
- recv_size = (sequence_length, batch_size, args.hidden_size)
- return torch.empty(recv_size,
- dtype=_get_recv_buffer_dtype(args),
- device=torch.cuda.current_device())
-
-
-
-def _forward_step_helper(model, tokens, position_ids, attention_mask,
- inference_params, recv_buffer=None):
- """Single forward step. Update the allocate memory flag so
- only the first time the memory is allocated."""
- batch_size = tokens.size(0)
- sequence_length = tokens.size(1)
- if recv_buffer is None:
- recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
-
- # Receive from previous stage.
- recv_from_prev_pipeline_rank_(recv_buffer)
-
- # Forward pass through the model.
- model.set_input_tensor(recv_buffer)
- output_tensor = model(tokens, position_ids, attention_mask,
- inference_params=inference_params)
-
- # Send output to the next stage.
- send_to_next_pipeline_rank(output_tensor)
-
- return output_tensor
-
-
-
-def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
- inference_params, recv_buffer=None):
- """If recv_buffer is none, we will allocate one on the fly."""
- # Run a simple forward pass.
- output_tensor = _forward_step_helper(model, tokens, position_ids,
- attention_mask, inference_params,
- recv_buffer=recv_buffer)
- # Update the sequence length offset.
- inference_params.sequence_len_offset += tokens.size(1)
-
- logits = None
- if mpu.is_pipeline_last_stage():
- logits = output_tensor
-
- return logits
-
-
-
-def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
- inference_params, micro_batch_size):
- """No interleaving is supported."""
- sequence_length = tokens.size(1)
- batch_size = tokens.size(0)
-
- # Divide the batch dimension into micro batches.
- num_micro_batches, last_chunk = divmod(batch_size,
- micro_batch_size)
- if last_chunk > 0:
- num_micro_batches += 1
-
- # Preallocate memory for output logits.
- logits = None
- if mpu.is_pipeline_last_stage():
- args = get_args()
- logits = torch.empty(
- (batch_size, sequence_length, args.padded_vocab_size),
- dtype=torch.float32, device=torch.cuda.current_device())
-
- # Preallocate recv buffer.
- recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
-
- for micro_batch_index in range(num_micro_batches):
- # Slice among the batch dimenion.
- start = micro_batch_index * micro_batch_size
- end = min(start + micro_batch_size, batch_size)
- this_micro_batch_size = end - start
- tokens2use = tokens[start:end, ...]
- position_ids2use = position_ids[start:end, ...]
-
- # Run a simple forward pass.
- if this_micro_batch_size != micro_batch_size:
- recv_buffer = None
- output = _forward_step_helper(model, tokens2use, position_ids2use,
- attention_mask, inference_params,
- recv_buffer=recv_buffer)
-
- # Adjust the batch size offset to account for the micro-batch.
- inference_params.batch_size_offset += this_micro_batch_size
-
- # Copy logits.
- if mpu.is_pipeline_last_stage():
- logits[start:end, ...] = output
-
- # Once we are done with all the micro-batches, we can
- # adjust the sequence length offset.
- inference_params.sequence_len_offset += sequence_length
- # and reset the batch size offset
- inference_params.batch_size_offset = 0
-
- return logits
diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py
deleted file mode 100644
index 11dd9f436b5992739f4d0543b2108317edfd46b1..0000000000000000000000000000000000000000
--- a/megatron/text_generation/generation.py
+++ /dev/null
@@ -1,428 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Generation utilities."""
-
-import torch
-import torch.nn.functional as F
-
-from megatron import get_args, get_tokenizer
-from megatron.core import mpu
-from megatron.utils import get_ltor_masks_and_position_ids
-from .communication import (
- copy_from_last_to_first_pipeline_stage,
- broadcast_from_last_pipeline_stage,
- broadcast_from_last_to_first_pipeline_stage)
-from .forward_step import ForwardStep
-from .sampling import sample
-from .beam_utils import BeamHypotheses
-
-def score_and_return_on_first_stage(model, tokens, lengths):
- """Function for just scoring.
- Arguments:
- model: no interleaving is supported.
- tokens: prompt tokens extended to be of size [b, max_prompt_length]
- lengths: original prompt length, size: [b]
- Note: Outside of model, other parameters only need to be available on
- rank 0.
- Outputs:
- output_log_probs: log probability of the selected tokens. size: [b, s]
- """
-
- args = get_args()
-
- batch_size = tokens.size(0)
- max_prompt_length = lengths.max().item()
- assert max_prompt_length == tokens.size(1)
-
- if max_prompt_length > args.max_position_embeddings:
- raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
-
- if max_prompt_length * batch_size > args.max_tokens_to_oom:
- raise ValueError("Too many tokens. " + str(max_prompt_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
-
- # forward step.
- forward_step = ForwardStep(model, batch_size, max_prompt_length)
-
- # ===================
- # Pre-allocate memory
- # ===================
-
- # Log probability of the sequence (prompt + generated tokens).
- output_log_probs = None
- output_log_probs_size = (batch_size, max_prompt_length - 1)
-
- if mpu.is_pipeline_last_stage():
- output_log_probs = torch.empty(output_log_probs_size,
- dtype=torch.float32,
- device=torch.cuda.current_device())
-
- # =============
- # Run infernece
- # =============
- with torch.no_grad():
- attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
-
- # logits will be meanigful only in the last pipeline stage.
- logits = forward_step(tokens, position_ids, attention_mask)
-
- if mpu.is_pipeline_last_stage():
- # Always the last stage should have an output.
- assert logits is not None
- log_probs = F.log_softmax(logits, dim=2)
-
- # Pick the tokens that we need to get the log
- # probabilities for. Note that next input token is
- # the token which we selected in the current logits,
- # so shift by 1.
- indices = torch.unsqueeze(tokens[:, 1:], 2)
- output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
-
- # ======================================
- # Broadcast to the first pipeline stage.
- # ======================================
- output_log_probs = broadcast_from_last_to_first_pipeline_stage(
- output_log_probs_size, torch.float32, output_log_probs)
-
- return tokens, lengths, output_log_probs, logits
-
-def generate_tokens_probs_and_return_on_first_stage(
- model, tokens, lengths,
- return_output_log_probs=False,
- top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
- temperature=1.0,
- use_eod_token_for_early_termination=True,
- stop_on_double_eol=False,
- stop_on_eol=False,
- prevent_newline_after_colon=True
- ):
- """Main token generation function.
- Arguments:
- model: no interleaving is supported.
- tokens: prompt tokens extended to be of size [b, max-sequence-length]
- lengths: original prompt length, size: [b]
- return_output_log_probs: flag to calculate the log probability of
- the generated tokens. Note that the log probability is the one
- from the original logit.
- top_k, top_p: top-k and top-p sampling parameters.
- Note that top-k = 1 is gready. Also, these paramters are
- exclusive meaning that:
- if top-k > 0 then we expect top-p=0.
- if top-p > 0 then we check for top-k=0.
- temperature: sampling temperature.
- use_eod_token_for_early_termination: if True, do early termination if
- all the sequences have reached this token.
- prevent_newline_after_colon: if True, it will disable generating new line \n after :
- Note: Outside of model, other parameters only need to be available on
- rank 0.
- Outputs: Note that is size is adjusted to a lower value than
- max-sequence-length if generation is terminated early.
- tokens: prompt and generated tokens. size: [b, :]
- generated_sequence_lengths: total length (including prompt) of
- the generated sequence. size: [b]
- output_log_probs: log probability of the selected tokens. size: [b, s]
- """
-
- args = get_args()
- tokenizer = get_tokenizer()
-
- batch_size = tokens.size(0)
- min_prompt_length = lengths.min().item()
- max_sequence_length = tokens.size(1)
-
- if max_sequence_length > args.max_position_embeddings:
- raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
-
- if max_sequence_length * batch_size > args.max_tokens_to_oom:
- raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
-
- # forward step.
- forward_step = ForwardStep(model, batch_size, max_sequence_length)
-
- # Added termination_id to support the case that we want to terminate the
- # generation once that id is generated.
- if hasattr(args, 'eos_id'):
- termination_id = args.eos_id
- else:
- termination_id = tokenizer.eod
-
- # ===================
- # Pre-allocate memory
- # ===================
-
- # Log probability of the sequence (prompt + generated tokens).
- output_log_probs = None
- output_log_probs_size = (batch_size, max_sequence_length - 1)
- # Lengths of generated seuquence including including prompts.
- generated_sequence_lengths = None
- if mpu.is_pipeline_last_stage():
- if return_output_log_probs:
- output_log_probs = torch.empty(output_log_probs_size,
- dtype=torch.float32,
- device=torch.cuda.current_device())
- generated_sequence_lengths = torch.ones(
- batch_size, dtype=torch.int64,
- device=torch.cuda.current_device()) * max_sequence_length
-
- # Whether we have reached a termination id.
- is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
- device=torch.cuda.current_device())
-
- # =============
- # Run infernece
- # =============
-
- with torch.no_grad():
- attention_mask, position_ids = _build_attention_mask_and_position_ids(
- tokens)
- prev_context_length = 0
- for context_length in range(min_prompt_length, max_sequence_length):
-
- # Pick the slice that we need to pass through the network.
- tokens2use = tokens[:, prev_context_length:context_length]
- positions2use = position_ids[:, prev_context_length:context_length]
- attention_mask2use = attention_mask[
- ..., prev_context_length:context_length, :context_length]
-
- # logits will be meanigful only in the last pipeline stage.
- logits = forward_step(tokens2use, positions2use, attention_mask2use)
-
- if mpu.is_pipeline_last_stage():
- if prevent_newline_after_colon:
- logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
- # Always the last stage should have an output.
- assert logits is not None
-
- # Sample.
- last_token_logits = logits[:, -1, :]
- new_sample = sample(last_token_logits,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- vocab_size=tokenizer.vocab_size)
- if top_p > 0.0 and top_p_decay > 0.0:
- top_p = top_p * top_p_decay
- if top_p_bound > 0.0:
- top_p = max(top_p, top_p_bound)
-
- # If a prompt length is smaller or equal th current context
- # length, it means we have started generating tokens
- started = lengths <= context_length
- # Update the tokens.
- tokens[started, context_length] = new_sample[started]
-
- # Calculate the log probabilities.
- if return_output_log_probs:
- log_probs = F.log_softmax(logits, dim=2)
- if return_output_log_probs:
- # Pick the tokens that we need to get the log
- # probabilities for. Note that next input token is
- # the token which we selected in the current logits,
- # so shift by 1.
- indices = torch.unsqueeze(
- tokens[
- :,
- (prev_context_length + 1):(context_length + 1)],
- 2)
- output_log_probs[:,
- prev_context_length:context_length] = \
- torch.gather(log_probs, 2, indices).squeeze(2)
-
- # Update the tokens on the first stage so the next input to
- # the network is correct.
- copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
- tokens[:, context_length])
-
- # Update the context length for the next token generation.
- prev_context_length = context_length
-
- # Check if all the sequences have hit the termination_id.
- done = None
- if mpu.is_pipeline_last_stage():
- # TODO(rprenger) These stopping methods are tokenizer dependent
- # instead tokenization should be in the inference loop so stop sequences can be used
- if stop_on_double_eol:
- hit_double_eol = (new_sample == 628).byte() & started.byte()
- hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
- done_token = hit_double_eol | hit_two_eols
- elif stop_on_eol:
- hit_double_eol = (new_sample == 628).byte() & started.byte()
- hit_eol = (new_sample == 198).byte() & started.byte()
- done_token = hit_double_eol | hit_eol
- else:
- done_token = (new_sample == termination_id).byte() & \
- started.byte()
-
- just_finished = (done_token & ~is_generation_done).bool()
- generated_sequence_lengths[just_finished.view(-1)] = \
- context_length + 1
- is_generation_done = is_generation_done | done_token
- done = torch.all(is_generation_done)
- done = broadcast_from_last_pipeline_stage(1, torch.uint8,
- tensor=done)
- if use_eod_token_for_early_termination and done:
- break
-
- # ===================================================
- # Update the length of based on max generated length.
- # ===================================================
-
- tokens = tokens[:, :(context_length + 1)]
- if mpu.is_pipeline_last_stage():
- if return_output_log_probs:
- output_log_probs = output_log_probs[:, :context_length]
-
- # ======================================
- # Broadcast to the first pipeline stage.
- # ======================================
-
- generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
- batch_size, torch.int64, generated_sequence_lengths)
- if return_output_log_probs:
- output_log_probs_size = (batch_size, context_length)
- output_log_probs = broadcast_from_last_to_first_pipeline_stage(
- output_log_probs_size, torch.float32, output_log_probs)
-
- return tokens, generated_sequence_lengths, output_log_probs, None
-
-def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
- args = get_args()
- tokenizer = get_tokenizer()
-
- batch_size = tokens.size(0)
- assert(batch_size == 1)
- prompt_length = lengths.item()
- final_sequence_length = tokens.size(1)
- final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
-
- # If the context is too big, this happens
- if prompt_length >= final_sequence_length:
- raise ValueError("context length + tokens_to_generate too large")
-
- # forward step.
- forward_step = ForwardStep(model, beam_size, final_sequence_length)
-
- beam_hyp = BeamHypotheses(beam_size, length_penalty)
- best_batches = None
- done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
- scores = torch.zeros(beam_size,
- dtype=torch.float32,
- device=torch.cuda.current_device()).unsqueeze(1)
- scores_size_tensor, tokens_size_tensor = None, None
- # =============
- # Run infernece
- # =============
- with torch.no_grad():
- tokens = tokens.repeat(beam_size, 1)
- attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
- prev_context_length = 0
- for context_length in range(prompt_length, final_sequence_length):
-
- # Pick the slice that we need to pass through the network.
- tokens2use = tokens[:, prev_context_length:context_length]
- positions2use = position_ids[:, prev_context_length:context_length]
- attention_mask2use = attention_mask[
- ..., prev_context_length:context_length, :context_length]
-
- # logits will be meanigful only in the last pipeline stage.
- logits = forward_step(tokens2use, positions2use, attention_mask2use)
-
- if mpu.is_pipeline_last_stage():
- if prevent_newline_after_colon:
- logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
- vocab_size = logits.size(2)
- log_probs = F.log_softmax(logits, dim=2)
- new_scores = log_probs[:, -1, :] + scores
-
- if context_length == prompt_length: # if this is the first one
- sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
- else:
- sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
-
- best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
- best_words = indices[:2 * beam_size] % vocab_size
- best_scores = sorted_scores[: 2 * beam_size]
-
- next_beams = []
- for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
- zip(best_words, best_scores, best_beam_ids)
- ):
- if token_id.item() == stop_token:
- # if beam_token does not belong to top num_beams tokens, it should not be added
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
- if is_beam_token_worse_than_top_num_beams:
- continue
- beam_hyp.add(
- tokens[beam_id].clone(),
- beam_score,
- context_length + 1 - prompt_length
- )
- else:
- # add next predicted token since it is not eos_token
- next_beams.append((token_id, beam_score, beam_id))
-
- if len(next_beams) == beam_size:
- break
-
- if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
- done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
-
- best_batches = tokens.new([item[2] for item in next_beams])
- tokens = tokens[best_batches,:]
- tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
- scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
-
- # torch.distributed.barrier()
- done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
- if done:
- break
-
- # Update the tokens on the first stage so the next input to
- # the network is correct.
- copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
- tokens)
-
- # set inference key values to make it consistent with best beam index
- best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
- forward_step.inference_params.swap_key_value_dict(best_batches)
-
- # Update the context length for the next token generation.
- prev_context_length = context_length
-
- if mpu.is_pipeline_last_stage():
- # if cannot find stop token, add open beams to hyps
- if not done:
- for beam_id in range(beam_size):
- beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
-
- # rank based on scores
- sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
- num_return_gen = min(num_return_gen, len(sorted_hyps))
- scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
- tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
- scores = torch.stack(scores, dim=0)
- tokens = torch.stack(tokens, dim=0)
- scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
- tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())
-
- scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
- tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)
-
- scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
- tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
-
- return tokens, scores
-
-
-def _build_attention_mask_and_position_ids(tokens):
- """Build the attention mask and postition ids for the input tokens."""
-
- # Since we are not interested in loss-mask and reset attention/position
- # is also False, eod_token is not used so it is safe to set it to None.
- attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
- data=tokens,
- eod_token=None,
- reset_position_ids=False,
- reset_attention_mask=False,
- eod_mask_loss=False)
-
- return attention_mask, position_ids
diff --git a/megatron/text_generation/sampling.py b/megatron/text_generation/sampling.py
deleted file mode 100644
index 370773a36c087d01e75731e38724cfb35d4acd74..0000000000000000000000000000000000000000
--- a/megatron/text_generation/sampling.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Sampling utilities.
-Part of this code is inspired by:
- - https://github.com/ari-holtzman/degen/blob/master/gen.py
- - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
-"""
-
-
-import torch
-
-
-
-def modify_logits_for_top_k_filtering(logits, top_k):
- """Set the logits for none top-k values to -inf."""
-
- filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
- logits.masked_fill_(filter_, float('-Inf'))
-
-
-
-def modify_logits_for_top_p_filtering(logits, top_p):
- """Set the logits for none top-p values to -inf."""
-
- # First sort and calculate cumulative sum of probabilities.
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
-
- # Filteration based on the cumulative sum.
- filter_ = cumulative_probs > top_p
- # This shift by 1 is weird and I cannot justify it. This existed
- # in the original implementation:
- # https://github.com/ari-holtzman/degen/blob/master/gen.py
- # and I guess it is needed so keeping it for now.
- filter_[:, 1:] = filter_[:, :-1].clone()
- # Make sure we at least have one token to select from.
- filter_[..., 0] = 0
-
- # Fill in the filtered part
- filter_ = filter_.scatter(1, sorted_indices, filter_)
- logits.masked_fill_(filter_, float('-Inf'))
-
-
-
-def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
- """ Sample and generate a token.
- Note: logits has the dimension [b, v] where b is the batch size
- and v is the vocabulary size.
- If vocab_size is provided, we will make sure the sample that is
- generated is in [0, vocab-size). This will avoid out of vocabulary
- generations due to padding.
- """
-
- # Check logits for consistency.
- assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
- assert logits.type() == 'torch.cuda.FloatTensor', \
- 'input logits should be floats.'
-
-
- # Greedy is just simple argmax.
- if top_k == 1:
- assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
- samples = torch.argmax(logits, dim=-1)
-
- # Top-k or top-p sampling.
- else:
- # Clone so we do not modify the inputs,
- logits = logits.clone()
- # Apply temperature in place.
- if temperature != 1.0:
- logits.div_(temperature)
-
- if top_k > 1:
- assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
- assert top_k <= logits.size(1), 'top-k is larger than logit size.'
- if vocab_size:
- assert top_k < vocab_size, 'top-k is larger than vocab size.'
- modify_logits_for_top_k_filtering(logits, top_k)
-
- elif top_p > 0.0:
- assert top_p <= 1.0, 'top-p should be in (0, 1].'
- modify_logits_for_top_p_filtering(logits, top_p)
-
- # After filtering, we need to recalculate the distribution.
- probs = logits.softmax(dim=-1)
- samples = torch.multinomial(probs, num_samples=1).view(-1)
-
- # If vocab size is provided, make sure the samples are in
- # in the range [0, vocab-size).
- if vocab_size:
- samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
-
- return samples
diff --git a/megatron/text_generation/tokenization.py b/megatron/text_generation/tokenization.py
deleted file mode 100644
index 4d4eb82e8049e5bdcd5aa84272f21941bef79e5a..0000000000000000000000000000000000000000
--- a/megatron/text_generation/tokenization.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Tokenization utilities."""
-
-
-import torch
-
-
-from megatron import get_tokenizer, get_args
-from .communication import broadcast_int_list, broadcast_tensor
-
-
-def detokenize_generations(tokens_gpu_tensor,
- lengths_gpu_tensor,
- return_segments):
- """Detokenize the generated tokens."""
-
- tokenizer = get_tokenizer()
- args = get_args()
- prompts_plus_generations = []
- if return_segments:
- prompts_plus_generations_segments = []
-
- tokens = tokens_gpu_tensor.cpu().numpy().tolist()
- lengths = lengths_gpu_tensor.cpu().numpy().tolist()
- for sequence_tokens, length in zip(tokens, lengths):
- sequence_tokens = sequence_tokens[:length]
- prompts_plus_generations.append(
- tokenizer.detokenize(sequence_tokens))
- if return_segments:
- words = []
- for token in sequence_tokens:
- if args.tokenizer_type in ['SentencePieceTokenizer',
- 'GPTSentencePieceTokenizer',
- 'Llama2Tokenizer']:
- word = tokenizer.decoder[token]
- elif args.tokenizer_type == 'NullTokenizer':
- word = str(token)
- else:
- word = tokenizer.tokenizer.decoder[token]
- word = bytearray(
- [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
- 'utf-8', errors='replace')
- words.append(word)
- prompts_plus_generations_segments.append(words)
-
- if return_segments:
- return tokens, prompts_plus_generations, \
- prompts_plus_generations_segments
-
- return tokens, prompts_plus_generations
-
-
-def tokenize_prompts(prompts=None, tokens_to_generate=None,
- add_BOS=None, rank=0):
- """Tokenize prompts and make them avaiable on all ranks."""
-
- # On all ranks set to None so we can pass them to functions
- sizes_list = None
- prompts_tokens_cuda_long_tensor = None
- prompts_length_cuda_long_tensor = None
-
- # On the specified rank, build the above.
- if torch.distributed.get_rank() == rank:
- assert prompts is not None
- assert tokens_to_generate is not None
- # Tensor of tokens padded and their unpadded length.
- prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
- _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
- # We need the sizes of these tensors for the boradcast
- sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
- prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
-
- # First, broadcast the sizes.
- sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
-
- # Now that we have the sizes, we can boradcast the tokens
- # and length tensors.
- sizes = sizes_tensor.tolist()
- prompts_tokens_cuda_long_tensor = broadcast_tensor(
- sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
- prompts_length_cuda_long_tensor = broadcast_tensor(
- sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
- rank=rank)
-
- return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
-
-
-def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
- """Given a set of prompts and number of tokens to generate:
- - tokenize prompts
- - set the sequence length to be the max of length of prompts
- plus the number of tokens we would like to generate
- - pad all the sequences to this length so we can convert them
- into a 2D tensor.
- """
-
- # Tokenize all the prompts.
- tokenizer = get_tokenizer()
- if add_BOS:
- prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
- for prompt in prompts]
- else:
- prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
-
- # Now we have a list of list of tokens which each list has a different
- # size. We want to extend this list to:
- # - incorporate the tokens that need to be generated
- # - make all the sequences equal length.
- # Get the prompts length.
- prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
- # Get the max prompts length.
- max_prompt_len = max(prompts_length)
- # Number of tokens in the each sample of the batch.
- samples_length = max_prompt_len + tokens_to_generate
- # Now update the list of list to be of the same size: samples_length.
- for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
- padding_size = samples_length - prompt_length
- prompt_tokens.extend([tokenizer.eod] * padding_size)
-
- # Now we are in a structured format, we can convert to tensors.
- prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
- prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
-
- return prompts_tokens_tensor, prompts_length_tensor
diff --git a/megatron/text_generation_server.py b/megatron/text_generation_server.py
deleted file mode 100644
index 8bd6c26fcc5bc0441bc50680d96e550030f0e964..0000000000000000000000000000000000000000
--- a/megatron/text_generation_server.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-import datetime
-import torch
-import json
-import threading
-from flask import Flask, request, jsonify, current_app
-from flask_restful import Resource, Api
-from megatron import get_args
-from megatron.text_generation import generate_and_post_process
-from megatron.text_generation import beam_search_and_post_process
-
-
-GENERATE_NUM = 0
-BEAM_NUM = 1
-lock = threading.Lock()
-
-class MegatronGenerate(Resource):
- def __init__(self, model):
- self.model = model
-
- @staticmethod
- def send_do_generate():
- choice = torch.cuda.LongTensor([GENERATE_NUM])
- torch.distributed.broadcast(choice, 0)
-
- @staticmethod
- def send_do_beam_search():
- choice = torch.cuda.LongTensor([BEAM_NUM])
- torch.distributed.broadcast(choice, 0)
-
- def put(self):
- args = get_args()
-
- if not "prompts" in request.get_json():
- return "prompts argument required", 400
-
- if "max_len" in request.get_json():
- return "max_len is no longer used. Replace with tokens_to_generate", 400
-
- if "sentences" in request.get_json():
- return "sentences is no longer used. Replace with prompts", 400
-
- prompts = request.get_json()["prompts"]
- if not isinstance(prompts, list):
- return "prompts is not a list of strings", 400
-
- if len(prompts) == 0:
- return "prompts is empty", 400
-
- if len(prompts) > 128:
- return "Maximum number of prompts is 128", 400
-
- tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
- if "tokens_to_generate" in request.get_json():
- tokens_to_generate = request.get_json()["tokens_to_generate"]
- if not isinstance(tokens_to_generate, int):
- return "tokens_to_generate must be an integer greater than 0"
- if tokens_to_generate < 0:
- return "tokens_to_generate must be an integer greater than or equal to 0"
-
- logprobs = False
- if "logprobs" in request.get_json():
- logprobs = request.get_json()["logprobs"]
- if not isinstance(logprobs, bool):
- return "logprobs must be a boolean value"
-
- if tokens_to_generate == 0 and not logprobs:
- return "tokens_to_generate=0 implies logprobs should be True"
-
- temperature = 1.0
- if "temperature" in request.get_json():
- temperature = request.get_json()["temperature"]
- if not (type(temperature) == int or type(temperature) == float):
- return "temperature must be a positive number less than or equal to 100.0"
- if not (0.0 < temperature <= 100.0):
- return "temperature must be a positive number less than or equal to 100.0"
-
- top_k = 0.0
- if "top_k" in request.get_json():
- top_k = request.get_json()["top_k"]
- if not (type(top_k) == int):
- return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
- if not (0 <= top_k <= 1000):
- return "top_k must be equal to or greater than 0 and less than or equal to 1000"
-
- top_p = 0.0
- if "top_p" in request.get_json():
- top_p = request.get_json()["top_p"]
- if not (type(top_p) == float):
- return "top_p must be a positive float less than or equal to 1.0"
- if top_p > 0.0 and top_k > 0.0:
- return "cannot set both top-k and top-p samplings."
- if not (0 <= top_p <= 1.0):
- return "top_p must be less than or equal to 1.0"
-
- top_p_decay = 0.0
- if "top_p_decay" in request.get_json():
- top_p_decay = request.get_json()["top_p_decay"]
- if not (type(top_p_decay) == float):
- return "top_p_decay must be a positive float less than or equal to 1.0"
- if top_p == 0.0:
- return "top_p_decay cannot be set without top_p"
- if not (0 <= top_p_decay <= 1.0):
- return "top_p_decay must be less than or equal to 1.0"
-
- top_p_bound = 0.0
- if "top_p_bound" in request.get_json():
- top_p_bound = request.get_json()["top_p_bound"]
- if not (type(top_p_bound) == float):
- return "top_p_bound must be a positive float less than or equal to top_p"
- if top_p == 0.0:
- return "top_p_bound cannot be set without top_p"
- if not (0.0 < top_p_bound <= top_p):
- return "top_p_bound must be greater than 0 and less than top_p"
-
- add_BOS = False
- if "add_BOS" in request.get_json():
- add_BOS = request.get_json()["add_BOS"]
- if not isinstance(add_BOS, bool):
- return "add_BOS must be a boolean value"
-
- if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
- return "Empty prompts require add_BOS=true"
-
- stop_on_double_eol = False
- if "stop_on_double_eol" in request.get_json():
- stop_on_double_eol = request.get_json()["stop_on_double_eol"]
- if not isinstance(stop_on_double_eol, bool):
- return "stop_on_double_eol must be a boolean value"
-
- stop_on_eol = False
- if "stop_on_eol" in request.get_json():
- stop_on_eol = request.get_json()["stop_on_eol"]
- if not isinstance(stop_on_eol, bool):
- return "stop_on_eol must be a boolean value"
-
- prevent_newline_after_colon = False
- if "prevent_newline_after_colon" in request.get_json():
- prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
- if not isinstance(prevent_newline_after_colon, bool):
- return "prevent_newline_after_colon must be a boolean value"
-
- random_seed = -1
- if "random_seed" in request.get_json():
- random_seed = request.get_json()["random_seed"]
- if not isinstance(random_seed, int):
- return "random_seed must be integer"
- if random_seed < 0:
- return "random_seed must be a positive integer"
-
- no_log = False
- if "no_log" in request.get_json():
- no_log = request.get_json()["no_log"]
- if not isinstance(no_log, bool):
- return "no_log must be a boolean value"
-
- beam_width = None
- if "beam_width" in request.get_json():
- beam_width = request.get_json()["beam_width"]
- if not isinstance(beam_width, int):
- return "beam_width must be integer"
- if beam_width < 1:
- return "beam_width must be an integer > 1"
- if len(prompts) > 1:
- return "When doing beam_search, batch size must be 1"
-
- stop_token=50256
- if "stop_token" in request.get_json():
- stop_token = request.get_json()["stop_token"]
- if not isinstance(stop_token, int):
- return "stop_token must be an integer"
-
- length_penalty = 1
- if "length_penalty" in request.get_json():
- length_penalty = request.get_json()["length_penalty"]
- if not isinstance(length_penalty, float):
- return "length_penalty must be a float"
-
- with lock: # Need to get lock to keep multiple threads from hitting code
-
- if not no_log:
- print("request IP: " + str(request.remote_addr))
- print(json.dumps(request.get_json()),flush=True)
- print("start time: ", datetime.datetime.now())
-
- try:
- if beam_width is not None:
- MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
- response, response_seg, response_scores = \
- beam_search_and_post_process(
- self.model,
- prompts=prompts,
- tokens_to_generate=tokens_to_generate,
- beam_size = beam_width,
- add_BOS=add_BOS,
- stop_token=stop_token,
- num_return_gen=beam_width, # Returning whole beam
- length_penalty=length_penalty,
- prevent_newline_after_colon=prevent_newline_after_colon
- )
-
- return jsonify({"text": response,
- "segments": response_seg,
- "scores": response_scores})
- else:
- MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
- response, response_seg, response_logprobs, _ = \
- generate_and_post_process(
- self.model,
- prompts=prompts,
- tokens_to_generate=tokens_to_generate,
- return_output_log_probs=logprobs,
- top_k_sampling=top_k,
- top_p_sampling=top_p,
- top_p_decay=top_p_decay,
- top_p_bound=top_p_bound,
- temperature=temperature,
- add_BOS=add_BOS,
- use_eod_token_for_early_termination=True,
- stop_on_double_eol=stop_on_double_eol,
- stop_on_eol=stop_on_eol,
- prevent_newline_after_colon=prevent_newline_after_colon,
- random_seed=random_seed)
-
- return jsonify({"text": response,
- "segments": response_seg,
- "logprobs": response_logprobs})
-
- except ValueError as ve:
- return ve.args[0]
- print("end time: ", datetime.datetime.now())
-
-
-class MegatronServer(object):
- def __init__(self, model):
- self.app = Flask(__name__, static_url_path='')
- api = Api(self.app)
- api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
-
- def run(self, url, port):
- self.app.run(url, threaded=True, debug=False, port=port)
diff --git a/megatron/theoretical_memory_usage.py b/megatron/theoretical_memory_usage.py
deleted file mode 100644
index 1a6fb6b5b313dc572ed241cfa6db157bc6784d54..0000000000000000000000000000000000000000
--- a/megatron/theoretical_memory_usage.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Computes theoretical memory footprint for model training."""
-
-
-import math
-
-
-NUM_BYTES_IN_MEGABYTE = 1024 * 1024
-
-
-def compute_weight_and_optimizer_memory(args, verbose=False):
- if not args.group_query_attention:
- args.num_query_groups = args.num_attention_heads
- num_parameters_in_transformer_layers = (
- 10
- * args.num_layers
- * args.hidden_size
- * args.hidden_size
- * (
- 1
- + (args.num_query_groups / (5.0 * args.num_attention_heads))
- + (2 / (5 * args.hidden_size))
- + (1 / (5 * args.num_layers * args.hidden_size))
- )
- )
- embedding_size = args.hidden_size * args.padded_vocab_size
- if args.untie_embeddings_and_output_weights:
- num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + (
- 2 * embedding_size
- )
- else:
- num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + embedding_size
- if verbose:
- print(
- f"Number of parameters in billions: {num_total_parameters_with_embeddings / 10**9:.2f}"
- )
-
- # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
- num_parameters_on_most_loaded_model_shard = (
- (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
- ) / args.tensor_model_parallel_size
- if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
- num_parameters_on_most_loaded_model_shard += (
- embedding_size / args.tensor_model_parallel_size
- )
- if verbose:
- print(
- f"Number of parameters in most loaded shard in billions: {num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
- )
-
- if args.pipeline_model_parallel_size > 1:
- # Other shards just have (1/pp_size transformer layers) / tp_size.
- num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / (
- args.pipeline_model_parallel_size * args.tensor_model_parallel_size
- )
- if verbose:
- print(
- f"Number of parameters in other shards in billions: {num_parameters_on_other_model_shards / 10**9:.4f}"
- )
-
- num_bytes_per_parameter = (
- 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
- )
- weight_and_optimizer_memory = (
- num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
- )
-
- return weight_and_optimizer_memory
-
-
-def compute_activation_memory(args, num_microbatches, verbose=False):
- # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
- # We are trying to compute the maximum activation footprint, so all calculations in this function
- # are for the first pipeline stage.
-
- # Memory footprint from transformer layer (self-attention and MLP).
- activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * 34
- if verbose:
- print(
- f"Activation memory footprint per transformer layer: "
- f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB"
- )
- activation_memory *= args.num_layers
-
- # Now add activation memory required for input embeddings, last LayerNorm and output layer.
-
- # Input to embedding (pp_size microbatches in flight).
- activation_memory += (
- 8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size
- )
- # Dropout in embedding layer (pp_size microbatches in flight).
- activation_memory += (
- args.seq_length
- * args.micro_batch_size
- * args.hidden_size
- * args.pipeline_model_parallel_size
- )
-
- # Multiply by interleaved PP memory factor.
- if args.virtual_pipeline_model_parallel_size is not None:
- interleaved_schedule_memory_penalty = 1 + (
- (args.pipeline_model_parallel_size - 1)
- / (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size)
- )
- in_flight_microbatches = math.ceil(
- interleaved_schedule_memory_penalty * args.pipeline_model_parallel_size
- )
- if verbose:
- print(
- f"Memory penalty from interleaved schedule: {interleaved_schedule_memory_penalty:.2f}"
- )
- print(f"Number of in-flight microbatches: {in_flight_microbatches}")
- activation_memory *= interleaved_schedule_memory_penalty
-
- # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
- # so discount accordingly.
- if args.virtual_pipeline_model_parallel_size is None and args.pipeline_model_parallel_size > 1:
- if num_microbatches is not None:
- activation_memory *= min(1, num_microbatches / args.pipeline_model_parallel_size)
- in_flight_microbatches = min(num_microbatches, args.pipeline_model_parallel_size)
- else:
- in_flight_microbatches = args.pipeline_model_parallel_size
- if verbose:
- print(f"Number of in-flight microbatches: {in_flight_microbatches}")
-
- if args.pipeline_model_parallel_size == 1:
- # Inputs to output layer and CE loss.
- activation_memory += (
- args.seq_length
- * args.micro_batch_size
- * args.hidden_size
- * 4
- * (1 + (args.padded_vocab_size / args.hidden_size))
- )
-
- # Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
- return activation_memory / args.tensor_model_parallel_size
-
-
-def report_theoretical_memory(args, num_microbatches=None, verbose=False):
- # Formulae here assume sequence parallelism and selective activation recomputation.
- if not args.sequence_parallel or args.recompute_granularity != 'selective':
- return
-
- weight_and_optimizer_memory = (
- compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE
- )
- activation_memory = (
- compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose)
- / NUM_BYTES_IN_MEGABYTE
- )
- total_memory = weight_and_optimizer_memory + activation_memory
-
- print(
- f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, "
- f"activation={activation_memory:.2f} MB, "
- f"total={total_memory:.2f} MB\n"
- )
diff --git a/megatron/timers.py b/megatron/timers.py
deleted file mode 100644
index a9478fa014b3a01dd514f74005a4b86294328dc2..0000000000000000000000000000000000000000
--- a/megatron/timers.py
+++ /dev/null
@@ -1,304 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron timers."""
-
-from abc import ABC
-from abc import abstractmethod
-import time
-
-import torch
-
-
-
-class TimerBase(ABC):
-
- def __init__(self, name):
- self.name = name
-
- @abstractmethod
- def start(self, barrier=False):
- pass
-
- @abstractmethod
- def stop(self, barrier=False):
- pass
-
- @abstractmethod
- def reset(self):
- pass
-
- @abstractmethod
- def elapsed(self, reset=True, barrier=False):
- pass
-
-
-
-class DummyTimer(TimerBase):
-
- def __init__(self):
- super().__init__('dummy timer')
-
- def start(self, barrier=False):
- return
-
- def stop(self, barrier=False):
- return
-
- def reset(self):
- return
-
- def elapsed(self, reset=True, barrier=False):
- raise Exception('dummy timer should not be used to '
- 'calculate elapsed time')
-
-
-
-class Timer(TimerBase):
- """
- Comment on using `barrier`: If this flag is passed, then all
- the caller processes will wait till all reach the timing routine.
- It is up to the user to make sure all the ranks in `barrier_group`
- call it otherwise, it will result in a hang.
- Comment on `barrier_group`: By default it is set to None which
- in torch distributed land, it will result in the global communicator.
- """
-
- def __init__(self, name):
- super().__init__(name)
- self._elapsed = 0.0
- self._started = False
- # Note that None will default to the global process group
- self._barrier_group = None
- self._start_time = time.time()
-
-
- def set_barrier_group(self, barrier_group):
- self._barrier_group = barrier_group
-
-
- def start(self, barrier=False):
- """Start the timer."""
- assert not self._started, 'timer has already been started'
- if barrier:
- torch.distributed.barrier(group=self._barrier_group)
- torch.cuda.synchronize()
- self._start_time = time.time()
- self._started = True
-
-
- def stop(self, barrier=False):
- """Stop the timer."""
- assert self._started, 'timer is not started'
- if barrier:
- torch.distributed.barrier(group=self._barrier_group)
- torch.cuda.synchronize()
- self._elapsed += (time.time() - self._start_time)
- self._started = False
-
-
- def reset(self):
- """Reset timer."""
- self._elapsed = 0.0
- self._started = False
-
-
- def elapsed(self, reset=True, barrier=False):
- """Calculate the elapsed time."""
- _started = self._started
- # If the timing in progress, end it first.
- if self._started:
- self.stop(barrier=barrier)
- # Get the elapsed time.
- _elapsed = self._elapsed
- # Reset the elapsed time
- if reset:
- self.reset()
- # If timing was in progress, set it back.
- if _started:
- self.start(barrier=barrier)
- return _elapsed
-
-
-
-class Timers:
- """Group of timers."""
-
- def __init__(self, log_level, log_option):
- self._log_level = log_level
- self._log_option = log_option
- self._timers = {}
- self._log_levels = {}
- self._dummy_timer = DummyTimer()
- self._max_log_level = 2
-
-
- def __call__(self, name, log_level=None):
- # If the timer has already been set, then check if the log-level
- # is provided, it matches the one that the timer was created with.
- if name in self._timers:
- if log_level is not None:
- assert log_level == self._log_levels[name], \
- 'input log level {} does not match already existing '\
- 'log level {} for {} timer'.format(
- log_level, self._log_levels[name], name)
- return self._timers[name]
- # If timer does not exist and no log level is provided,
- # set it to the max log level which is 2.
- if log_level is None:
- log_level = self._max_log_level
- assert log_level <= self._max_log_level, \
- 'log level {} is larger than max supported log level {}'.format(
- log_level, self._max_log_level)
- # Now if the input log level is larger than the one set for
- # the timers class, just ignore it and return a dummy timer.
- if log_level > self._log_level:
- return self._dummy_timer
- # Otherwise, initalize the timer and set the level.
- self._timers[name] = Timer(name)
- self._log_levels[name] = log_level
- return self._timers[name]
-
-
- def _get_elapsed_time_all_ranks(self, names, reset, barrier):
- """
- Assumptions:
- - All the ranks call this function.
- - `names` are identical on all ranks.
- If the above assumptions are not met, calling this function will
- result in hang.
- Arguments:
- - names: list of timer names
- - reset: reset the timer after recording the elapsed time
- - barrier: if set, do a global barrier before time measurments
- """
-
- # First make sure all the callers are in sync.
- if barrier:
- torch.distributed.barrier()
-
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
-
- # Here we can use gather on the rank we want to print the
- # timing, however, there is no gather_base support in
- # pytorch yet. It is simpler to deal with a single tensor
- # and since we are only gathering a small amount of data,
- # it should be ok to use all-gather instead of gather.
- rank_name_to_time = torch.zeros((world_size, len(names)),
- dtype=torch.float,
- device=torch.cuda.current_device())
- for i, name in enumerate(names):
- if name in self._timers:
- # Here we don't need to pass the barrier flag as all
- # the processes are already in sync. This avoids the
- # issue of different timers having different barrier
- # groups inside their class.
- rank_name_to_time[rank, i] = self._timers[name].elapsed(
- reset=reset)
-
- # See the note above for why we are not using gather.
- torch.distributed._all_gather_base(rank_name_to_time.view(-1),
- rank_name_to_time[rank, :].view(-1))
-
- return rank_name_to_time
-
-
- def _get_global_min_max_time(self, names, reset, barrier, normalizer):
- """Report only min and max times across all ranks."""
-
- rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
- barrier)
- name_to_min_max_time = {}
- for i, name in enumerate(names):
- rank_to_time = rank_name_to_time[:, i]
- # filter out the ones we did not have any timings for
- rank_to_time = rank_to_time[rank_to_time > 0.0]
- # If the timer exists:
- if rank_to_time.numel() > 0:
- name_to_min_max_time[name] = (
- rank_to_time.min().item() / normalizer,
- rank_to_time.max().item() / normalizer)
- return name_to_min_max_time
-
-
- def _get_global_min_max_time_string(self, names, reset, barrier,
- normalizer, max_only):
- name_to_min_max_time = self._get_global_min_max_time(
- names, reset, barrier, normalizer)
- if not name_to_min_max_time:
- return None
- output_string = '(min, max) time across ranks (ms):'
- for name in name_to_min_max_time:
- min_time, max_time = name_to_min_max_time[name]
- if max_only:
- output_string += '\n {}: {:.2f}'.format(
- (name+' ').ljust(48, '.'), max_time)
- else:
- output_string += '\n {}: ({:.2f}, {:.2f})'.format(
- (name+' ').ljust(48, '.'), min_time, max_time)
- return output_string
-
-
- def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
- """Report times across all ranks."""
- rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
- barrier)
-
- output_string = 'times across ranks (ms):'
- no_reported_timing = True
- for i, name in enumerate(names):
- not_yet_found = True
- for rank in range(torch.distributed.get_world_size()):
- if rank_name_to_time[rank, i] > 0:
- no_reported_timing = False
- if not_yet_found:
- not_yet_found = False
- output_string += '\n {}:'.format(name)
- output_string += '\n rank {:2d}: {:.2f}'.format(
- rank, rank_name_to_time[rank, i] / normalizer)
- if no_reported_timing:
- return None
- return output_string
-
-
- def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False):
- """Log a group of timers."""
-
- # Print.
- assert normalizer > 0.0
- if self._log_option in ['max', 'minmax']:
- max_only = False
- if self._log_option == 'max':
- max_only = True
- output_string = self._get_global_min_max_time_string(
- names, reset, barrier, normalizer/1000.0, max_only)
- elif self._log_option == 'all':
- output_string = self._get_all_ranks_time_string(names,
- reset, barrier,
- normalizer/1000.0)
- else:
- raise Exception('unknown timing log option {}'.format(
- self._log_option))
-
- # If no input rank is provided, log on last rank.
- if rank is None:
- rank = torch.distributed.get_world_size() - 1
- if rank == torch.distributed.get_rank() and output_string is not None:
- print(output_string, flush=True)
-
-
- def write(self, names, writer, iteration, normalizer=1.0,
- reset=False, barrier=False):
- """Write timers to a tensorboard writer
- Note that we only report maximum time across ranks to tensorboard.
- """
- # currently when using add_scalars,
- # torch.utils.add_scalars makes each timer its own run, which
- # polutes the runs list, so we just add each as a scalar
- assert normalizer > 0.0
- name_to_min_max_time = self._get_global_min_max_time(
- names, reset, barrier, normalizer)
- if writer is not None:
- for name in name_to_min_max_time:
- _, max_time = name_to_min_max_time[name]
- writer.add_scalar(name + '-time', max_time, iteration)
diff --git a/megatron/tokenizer/__init__.py b/megatron/tokenizer/__init__.py
deleted file mode 100644
index 59ceb33865a384b4d5f12efee2d31944c79ff109..0000000000000000000000000000000000000000
--- a/megatron/tokenizer/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-
-from .tokenizer import build_tokenizer
diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py
deleted file mode 100644
index 642041e778e81a0ddb8bba755ce93116b296a9dd..0000000000000000000000000000000000000000
--- a/megatron/tokenizer/bert_tokenization.py
+++ /dev/null
@@ -1,431 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The Google AI Language Team Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tokenization classes."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import re
-import unicodedata
-import six
-
-
-def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
- """Checks whether the casing config is consistent with the checkpoint name."""
-
- # The casing has to be passed in by the user and there is no explicit check
- # as to whether it matches the checkpoint. The casing information probably
- # should have been stored in the bert_config.json file, but it's not, so
- # we have to heuristically detect it to validate.
-
- if not init_checkpoint:
- return
-
- m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
- if m is None:
- return
-
- model_name = m.group(1)
-
- lower_models = [
- "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
- "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
- ]
-
- cased_models = [
- "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
- "multi_cased_L-12_H-768_A-12"
- ]
-
- is_bad_config = False
- if model_name in lower_models and not do_lower_case:
- is_bad_config = True
- actual_flag = "False"
- case_name = "lowercased"
- opposite_flag = "True"
-
- if model_name in cased_models and do_lower_case:
- is_bad_config = True
- actual_flag = "True"
- case_name = "cased"
- opposite_flag = "False"
-
- if is_bad_config:
- raise ValueError(
- "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
- "However, `%s` seems to be a %s model, so you "
- "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
- "how the model was pre-training. If this error is wrong, please "
- "just comment out this check." % (actual_flag, init_checkpoint,
- model_name, case_name, opposite_flag))
-
-
-def convert_to_unicode(text):
- """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
- if six.PY3:
- if isinstance(text, str):
- return text
- elif isinstance(text, bytes):
- return text.decode("utf-8", "ignore")
- else:
- raise ValueError("Unsupported string type: %s" % (type(text)))
- elif six.PY2:
- if isinstance(text, str):
- return text.decode("utf-8", "ignore")
- elif isinstance(text, unicode):
- return text
- else:
- raise ValueError("Unsupported string type: %s" % (type(text)))
- else:
- raise ValueError("Not running on Python2 or Python 3?")
-
-
-def printable_text(text):
- """Returns text encoded in a way suitable for print or `tf.logging`."""
-
- # These functions want `str` for both Python2 and Python3, but in one case
- # it's a Unicode string and in the other it's a byte string.
- if six.PY3:
- if isinstance(text, str):
- return text
- elif isinstance(text, bytes):
- return text.decode("utf-8", "ignore")
- else:
- raise ValueError("Unsupported string type: %s" % (type(text)))
- elif six.PY2:
- if isinstance(text, str):
- return text
- elif isinstance(text, unicode):
- return text.encode("utf-8")
- else:
- raise ValueError("Unsupported string type: %s" % (type(text)))
- else:
- raise ValueError("Not running on Python2 or Python 3?")
-
-
-def load_vocab(vocab_file):
- """Loads a vocabulary file into a dictionary."""
- vocab = collections.OrderedDict()
- index = 0
- with open(vocab_file, "r", encoding = "utf-8") as reader:
- while True:
- token = convert_to_unicode(reader.readline())
- if not token:
- break
- token = token.strip()
- vocab[token] = index
- index += 1
- return vocab
-
-
-def convert_by_vocab(vocab, items):
- """Converts a sequence of [tokens|ids] using the vocab."""
- output = []
- for item in items:
- output.append(vocab[item])
- return output
-
-
-def convert_tokens_to_ids(vocab, tokens):
- return convert_by_vocab(vocab, tokens)
-
-
-def convert_ids_to_tokens(inv_vocab, ids):
- return convert_by_vocab(inv_vocab, ids)
-
-
-def whitespace_tokenize(text):
- """Runs basic whitespace cleaning and splitting on a piece of text."""
- text = text.strip()
- if not text:
- return []
- tokens = text.split()
- return tokens
-
-
-class FullTokenizer(object):
- """Runs end-to-end tokenziation."""
-
- def __init__(self, vocab_file, do_lower_case=True):
- self.vocab = load_vocab(vocab_file)
- self.inv_vocab = {v: k for k, v in self.vocab.items()}
- self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
-
- def tokenize(self, text):
- split_tokens = []
- for token in self.basic_tokenizer.tokenize(text):
- for sub_token in self.wordpiece_tokenizer.tokenize(token):
- split_tokens.append(sub_token)
-
- return split_tokens
-
- def convert_tokens_to_ids(self, tokens):
- return convert_by_vocab(self.vocab, tokens)
-
- def convert_ids_to_tokens(self, ids):
- return convert_by_vocab(self.inv_vocab, ids)
-
- @staticmethod
- def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
- """ Converts a sequence of tokens (string) in a single string. """
-
- def clean_up_tokenization(out_string):
- """ Clean up a list of simple English tokenization artifacts
- like spaces before punctuations and abreviated forms.
- """
- out_string = (
- out_string.replace(" .", ".")
- .replace(" ?", "?")
- .replace(" !", "!")
- .replace(" ,", ",")
- .replace(" ' ", "'")
- .replace(" n't", "n't")
- .replace(" 'm", "'m")
- .replace(" 's", "'s")
- .replace(" 've", "'ve")
- .replace(" 're", "'re")
- )
- return out_string
-
- text = ' '.join(tokens).replace(' ##', '').strip()
- if clean_up_tokenization_spaces:
- clean_text = clean_up_tokenization(text)
- return clean_text
- else:
- return text
-
- def vocab_size(self):
- return len(self.vocab)
-
-
-class BasicTokenizer(object):
- """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
-
- def __init__(self, do_lower_case=True):
- """Constructs a BasicTokenizer.
-
- Args:
- do_lower_case: Whether to lower case the input.
- """
- self.do_lower_case = do_lower_case
-
- def tokenize(self, text):
- """Tokenizes a piece of text."""
- text = convert_to_unicode(text)
- text = self._clean_text(text)
-
- # This was added on November 1st, 2018 for the multilingual and Chinese
- # models. This is also applied to the English models now, but it doesn't
- # matter since the English models were not trained on any Chinese data
- # and generally don't have any Chinese data in them (there are Chinese
- # characters in the vocabulary because Wikipedia does have some Chinese
- # words in the English Wikipedia.).
- text = self._tokenize_chinese_chars(text)
-
- orig_tokens = whitespace_tokenize(text)
- split_tokens = []
- for token in orig_tokens:
- if self.do_lower_case:
- token = token.lower()
- token = self._run_strip_accents(token)
- split_tokens.extend(self._run_split_on_punc(token))
-
- output_tokens = whitespace_tokenize(" ".join(split_tokens))
- return output_tokens
-
- def _run_strip_accents(self, text):
- """Strips accents from a piece of text."""
- text = unicodedata.normalize("NFD", text)
- output = []
- for char in text:
- cat = unicodedata.category(char)
- if cat == "Mn":
- continue
- output.append(char)
- return "".join(output)
-
- def _run_split_on_punc(self, text):
- """Splits punctuation on a piece of text."""
- chars = list(text)
- i = 0
- start_new_word = True
- output = []
- while i < len(chars):
- char = chars[i]
- if _is_punctuation(char):
- output.append([char])
- start_new_word = True
- else:
- if start_new_word:
- output.append([])
- start_new_word = False
- output[-1].append(char)
- i += 1
-
- return ["".join(x) for x in output]
-
- def _tokenize_chinese_chars(self, text):
- """Adds whitespace around any CJK character."""
- output = []
- for char in text:
- cp = ord(char)
- if self._is_chinese_char(cp):
- output.append(" ")
- output.append(char)
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
- def _is_chinese_char(self, cp):
- """Checks whether CP is the codepoint of a CJK character."""
- # This defines a "chinese character" as anything in the CJK Unicode block:
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- #
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
- # despite its name. The modern Korean Hangul alphabet is a different block,
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
- # space-separated words, so they are not treated specially and handled
- # like the all of the other languages.
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
- (cp >= 0x3400 and cp <= 0x4DBF) or #
- (cp >= 0x20000 and cp <= 0x2A6DF) or #
- (cp >= 0x2A700 and cp <= 0x2B73F) or #
- (cp >= 0x2B740 and cp <= 0x2B81F) or #
- (cp >= 0x2B820 and cp <= 0x2CEAF) or
- (cp >= 0xF900 and cp <= 0xFAFF) or #
- (cp >= 0x2F800 and cp <= 0x2FA1F)): #
- return True
-
- return False
-
- def _clean_text(self, text):
- """Performs invalid character removal and whitespace cleanup on text."""
- output = []
- for char in text:
- cp = ord(char)
- if cp == 0 or cp == 0xfffd or _is_control(char):
- continue
- if _is_whitespace(char):
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
-
-class WordpieceTokenizer(object):
- """Runs WordPiece tokenziation."""
-
- def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
- self.vocab = vocab
- self.unk_token = unk_token
- self.max_input_chars_per_word = max_input_chars_per_word
-
- def tokenize(self, text):
- """Tokenizes a piece of text into its word pieces.
-
- This uses a greedy longest-match-first algorithm to perform tokenization
- using the given vocabulary.
-
- For example:
- input = "unaffable"
- output = ["un", "##aff", "##able"]
-
- Args:
- text: A single token or whitespace separated tokens. This should have
- already been passed through `BasicTokenizer.
-
- Returns:
- A list of wordpiece tokens.
- """
-
- text = convert_to_unicode(text)
-
- output_tokens = []
- for token in whitespace_tokenize(text):
- chars = list(token)
- if len(chars) > self.max_input_chars_per_word:
- output_tokens.append(self.unk_token)
- continue
-
- is_bad = False
- start = 0
- sub_tokens = []
- while start < len(chars):
- end = len(chars)
- cur_substr = None
- while start < end:
- substr = "".join(chars[start:end])
- if start > 0:
- substr = "##" + substr
- if substr in self.vocab:
- cur_substr = substr
- break
- end -= 1
- if cur_substr is None:
- is_bad = True
- break
- sub_tokens.append(cur_substr)
- start = end
-
- if is_bad:
- output_tokens.append(self.unk_token)
- else:
- output_tokens.extend(sub_tokens)
- return output_tokens
-
-
-def _is_whitespace(char):
- """Checks whether `chars` is a whitespace character."""
- # \t, \n, and \r are technically contorl characters but we treat them
- # as whitespace since they are generally considered as such.
- if char == " " or char == "\t" or char == "\n" or char == "\r":
- return True
- cat = unicodedata.category(char)
- if cat == "Zs":
- return True
- return False
-
-
-def _is_control(char):
- """Checks whether `chars` is a control character."""
- # These are technically control characters but we count them as whitespace
- # characters.
- if char == "\t" or char == "\n" or char == "\r":
- return False
- cat = unicodedata.category(char)
- if cat in ("Cc", "Cf"):
- return True
- return False
-
-
-def _is_punctuation(char):
- """Checks whether `chars` is a punctuation character."""
- cp = ord(char)
- # We treat all non-letter/number ASCII as punctuation.
- # Characters such as "^", "$", and "`" are not in the Unicode
- # Punctuation class but we treat them as punctuation anyways, for
- # consistency.
- if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
- (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
- return True
- cat = unicodedata.category(char)
- if cat.startswith("P"):
- return True
- return False
diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py
deleted file mode 100644
index 3f37e449089b8f779c93ac7f034085fd7607bfb0..0000000000000000000000000000000000000000
--- a/megatron/tokenizer/gpt2_tokenization.py
+++ /dev/null
@@ -1,321 +0,0 @@
-# coding=utf-8
-# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tokenization classes for OpenAI GPT."""
-
-from __future__ import (absolute_import, division, print_function,
- unicode_literals)
-
-import sys
-import json
-import logging
-import os
-import regex as re
-from io import open
-
-try:
- from functools import lru_cache
-except ImportError:
- # Just a dummy decorator to get the checks to run on python2
- # because honestly I don't want to support a byte-level unicode BPE
- # tokenizer on python 2 right now.
- def lru_cache():
- return lambda func: func
-
-
-logger = logging.getLogger(__name__)
-
-PRETRAINED_VOCAB_ARCHIVE_MAP = {
- 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
-}
-PRETRAINED_MERGES_ARCHIVE_MAP = {
- 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
-}
-PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
- 'gpt2': 1024,
-}
-VOCAB_NAME = 'vocab.json'
-MERGES_NAME = 'merges.txt'
-SPECIAL_TOKENS_NAME = 'special_tokens.txt'
-
-
-@lru_cache()
-def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a signficant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- _chr = unichr if sys.version_info[0] == 2 else chr
- bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
- list(range(ord("®"), ord("ÿ") + 1))
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8 + n)
- n += 1
- cs = [_chr(n) for n in cs]
- return dict(zip(bs, cs))
-
-
-def get_pairs(word):
- """Return set of symbol pairs in a word.
-
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
-
-
-class GPT2Tokenizer(object):
- """
- GPT-2 BPE tokenizer. Peculiarities:
- - Byte-level BPE
- """
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
- """
- Instantiate a PreTrainedBertModel from a pre-trained model file.
- Download and cache the pre-trained model file if needed.
- """
- if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
- vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
- merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
- special_tokens_file = None
- else:
- vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
- merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
- special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
- if not os.path.exists(special_tokens_file):
- special_tokens_file = None
- else:
- logger.info("loading special tokens file {}".format(special_tokens_file))
- # redirect to the cache, if necessary
- try:
- from .file_utils import cached_path
- resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
- resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
- except EnvironmentError:
- logger.error(
- "Model name '{}' was not found in model name list ({}). "
- "We assumed '{}' was a path or url but couldn't find files {} and {} "
- "at this path or url.".format(
- pretrained_model_name_or_path,
- ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
- pretrained_model_name_or_path,
- vocab_file, merges_file))
- return None
- if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
- logger.info("loading vocabulary file {}".format(vocab_file))
- logger.info("loading merges file {}".format(merges_file))
- else:
- logger.info("loading vocabulary file {} from cache at {}".format(
- vocab_file, resolved_vocab_file))
- logger.info("loading merges file {} from cache at {}".format(
- merges_file, resolved_merges_file))
- if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
- # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
- # than the number of positional embeddings
- max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
- kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
- # Instantiate tokenizer.
- if special_tokens_file and 'special_tokens' not in kwargs:
- special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
- else:
- special_tokens = kwargs.pop('special_tokens', [])
- tokenizer = cls(
- resolved_vocab_file,
- resolved_merges_file,
- special_tokens=special_tokens,
- *inputs,
- **kwargs)
- return tokenizer
-
- def __init__(self, vocab_file, merges_file, errors='replace',
- special_tokens=None, max_len=None):
- self.max_len = max_len if max_len is not None else int(1e12)
- self.encoder = json.load(open(vocab_file))
- self.decoder = {v: k for k, v in self.encoder.items()}
- self.errors = errors # how to handle errors in decoding
- self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
- bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
- bpe_merges = [tuple(merge.split()) for merge in bpe_data]
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
- self.cache = {}
-
- # Should haved added re.IGNORECASE so BPE merges can happen for
- # capitalized versions of contractions
- self.pat = re.compile(
- r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
-
- self.special_tokens = {}
- self.special_tokens_decoder = {}
- self.set_special_tokens(special_tokens)
-
- def __len__(self):
- return len(self.encoder) + len(self.special_tokens)
-
- def set_special_tokens(self, special_tokens):
- """ Add a list of additional tokens to the encoder.
- The additional tokens are indexed starting from the last index of the
- current vocabulary in the order of the `special_tokens` list.
- """
- if not special_tokens:
- self.special_tokens = {}
- self.special_tokens_decoder = {}
- return
- self.special_tokens = dict((tok, len(self.encoder) + i)
- for i, tok in enumerate(special_tokens))
- self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
- logger.info("Special tokens {}".format(self.special_tokens))
-
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token)
- pairs = get_pairs(word)
-
- if not pairs:
- return token
-
- while True:
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- new_word.extend(word[i:j])
- i = j
- except BaseException:
- new_word.extend(word[i:])
- break
-
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
- new_word.append(first + second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- word = ' '.join(word)
- self.cache[token] = word
- return word
-
- def tokenize(self, text):
- """ Tokenize a string. """
- bpe_tokens = []
- for token in re.findall(self.pat, text):
- if sys.version_info[0] == 2:
- token = ''.join(self.byte_encoder[ord(b)] for b in token)
- else:
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
- return bpe_tokens
-
- def convert_tokens_to_ids(self, tokens):
- """ Converts a sequence of tokens into ids using the vocab. """
- ids = []
- if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
- if tokens in self.special_tokens:
- return self.special_tokens[tokens]
- else:
- return self.encoder.get(tokens, 0)
- for token in tokens:
- if token in self.special_tokens:
- ids.append(self.special_tokens[token])
- else:
- ids.append(self.encoder.get(token, 0))
- if len(ids) > self.max_len:
- logger.warning(
- "Token indices sequence length is longer than the specified maximum "
- " sequence length for this OpenAI GPT model ({} > {}). Running this"
- " sequence through the model will result in indexing errors".format(
- len(ids), self.max_len)
- )
- return ids
-
- def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
- """Converts a sequence of ids in BPE tokens using the vocab."""
- tokens = []
- for i in ids:
- if i in self.special_tokens_decoder:
- if not skip_special_tokens:
- tokens.append(self.special_tokens_decoder[i])
- else:
- tokens.append(self.decoder[i])
- return tokens
-
- def encode(self, text):
- return self.convert_tokens_to_ids(self.tokenize(text))
-
- def decode(self, tokens):
- text = ''.join([self.decoder[token] for token in tokens])
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
- return text
-
- def save_vocabulary(self, vocab_path):
- """Save the tokenizer vocabulary and merge files to a directory."""
- if not os.path.isdir(vocab_path):
- logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
- return
- vocab_file = os.path.join(vocab_path, VOCAB_NAME)
- merge_file = os.path.join(vocab_path, MERGES_NAME)
- special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
-
- with open(vocab_file, 'w', encoding='utf-8') as f:
- f.write(json.dumps(self.encoder, ensure_ascii=False))
-
- index = 0
- with open(merge_file, "w", encoding="utf-8") as writer:
- writer.write(u'#version: 0.2\n')
- for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
- if index != token_index:
- logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
- " Please check that the tokenizer is not corrupted!".format(merge_file))
- index = token_index
- writer.write(' '.join(bpe_tokens) + u'\n')
- index += 1
-
- index = len(self.encoder)
- with open(special_tokens_file, 'w', encoding='utf-8') as writer:
- for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
- if index != token_index:
- logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
- " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
- index = token_index
- writer.write(token + u'\n')
- index += 1
-
- return vocab_file, merge_file, special_tokens_file
diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py
deleted file mode 100644
index 98643343c5c7621f8351b75e22ea9203ae61ad7b..0000000000000000000000000000000000000000
--- a/megatron/tokenizer/tokenizer.py
+++ /dev/null
@@ -1,588 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Megatron tokenizers."""
-
-from abc import ABC
-from abc import abstractmethod
-
-from .bert_tokenization import FullTokenizer as FullBertTokenizer
-from .gpt2_tokenization import GPT2Tokenizer
-
-def build_tokenizer(args):
- """Initialize tokenizer."""
- if args.rank == 0:
- print('> building {} tokenizer ...'.format(args.tokenizer_type),
- flush=True)
-
- # Select and instantiate the tokenizer.
- if args.tokenizer_type == 'BertWordPieceLowerCase':
- assert args.vocab_file is not None
- tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
- lower_case=True,
- vocab_extra_ids=args.vocab_extra_ids)
- elif args.tokenizer_type == 'BertWordPieceCase':
- assert args.vocab_file is not None
- tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
- lower_case=False,
- vocab_extra_ids=args.vocab_extra_ids)
- elif args.tokenizer_type == 'GPT2BPETokenizer':
- assert args.vocab_file is not None
- assert args.merge_file is not None
- tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
- elif args.tokenizer_type == 'SentencePieceTokenizer':
- assert args.tokenizer_model is not None
- tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
- elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
- assert args.tokenizer_model is not None
- tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
- elif args.tokenizer_type == 'Llama2Tokenizer':
- assert args.tokenizer_model is not None
- tokenizer = _Llama2Tokenizer(args.tokenizer_model)
- elif args.tokenizer_type == 'NullTokenizer':
- assert args.vocab_size is not None
- tokenizer = _NullTokenizer(args.vocab_size)
- else:
- raise NotImplementedError('{} tokenizer is not '
- 'implemented.'.format(args.tokenizer_type))
-
- # Add vocab size (if not already set from a checkpoint).
- if getattr(args, "padded_vocab_size", None) is None:
- args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
- args)
-
- return tokenizer
-
-
-def _vocab_size_with_padding(orig_vocab_size, args):
- """Pad vocab size so it is divisible by model parallel size and
- still having GPU friendly size."""
-
- after = orig_vocab_size
- multiple = args.make_vocab_size_divisible_by * \
- args.tensor_model_parallel_size
- while (after % multiple) != 0:
- after += 1
- if args.rank == 0:
- print(' > padded vocab (size: {}) with {} dummy tokens '
- '(new size: {})'.format(
- orig_vocab_size, after - orig_vocab_size, after), flush=True)
- return after
-
-
-class AbstractTokenizer(ABC):
- """Abstract class for tokenizer."""
-
- def __init__(self, name):
- self.name = name
- super().__init__()
-
- @property
- @abstractmethod
- def vocab_size(self):
- pass
-
- @property
- @abstractmethod
- def vocab(self):
- """Dictionary from vocab text token to id token."""
- pass
-
- @property
- @abstractmethod
- def inv_vocab(self):
- """Dictionary from vocab id token to text token."""
- pass
-
- @abstractmethod
- def tokenize(self, text):
- pass
-
- def detokenize(self, token_ids):
- raise NotImplementedError('detokenizer is not implemented for {} '
- 'tokenizer'.format(self.name))
-
- @property
- def cls(self):
- raise NotImplementedError('CLS is not provided for {} '
- 'tokenizer'.format(self.name))
-
- @property
- def sep(self):
- raise NotImplementedError('SEP is not provided for {} '
- 'tokenizer'.format(self.name))
-
- @property
- def pad(self):
- raise NotImplementedError('PAD is not provided for {} '
- 'tokenizer'.format(self.name))
-
- @property
- def eod(self):
- raise NotImplementedError('EOD is not provided for {} '
- 'tokenizer'.format(self.name))
-
- @property
- def mask(self):
- raise NotImplementedError('MASK is not provided for {} '
- 'tokenizer'.format(self.name))
-
-
-class _BertWordPieceTokenizer(AbstractTokenizer):
- """Original BERT wordpiece tokenizer."""
-
- def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
- if lower_case:
- name = 'BERT Lower Case'
- else:
- name = 'BERT Upper Case'
- super().__init__(name)
- self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
- self.cls_id = self.tokenizer.vocab['[CLS]']
- self.sep_id = self.tokenizer.vocab['[SEP]']
- self.pad_id = self.tokenizer.vocab['[PAD]']
- self.mask_id = self.tokenizer.vocab['[MASK]']
- self._additional_special_tokens = []
-
- # (dsachan) Add BOS and EOS tokens
- SPECIAL_TOKENS = {'eos_token': '[EOS]',
- 'bos_token': '[BOS]'}
- self._bos_token = '[BOS]'
- self.add_token(self._bos_token)
- self._bos_token_id = self.vocab.get(self._bos_token)
-
- self._eos_token = '[EOS]'
- self.add_token(self._eos_token)
- self._eos_token_id = self.vocab.get(self._eos_token)
-
- # (dsachan) Add additional special tokens
- # These can be used as sentinel tokens in T5 model inputs
- additional_special_tokens = []
- additional_special_tokens.extend(
- ["".format(i) for i in range(vocab_extra_ids)])
- self.add_additional_special_tokens(additional_special_tokens)
-
- def add_token(self, token):
- if token not in self.vocab:
- self.inv_vocab[self.vocab_size] = token
- # self.vocab_size comes from len(vocab)
- # and it will increase as we add elements
- self.vocab[token] = self.vocab_size
-
- def add_additional_special_tokens(self, tokens_list):
- setattr(self, "additional_special_tokens", tokens_list)
- for value in tokens_list:
- self.add_token(value)
-
- @property
- def vocab_size(self):
- return self.tokenizer.vocab_size()
-
- @property
- def vocab(self):
- return self.tokenizer.vocab
-
- @property
- def inv_vocab(self):
- return self.tokenizer.inv_vocab
-
- def tokenize(self, text):
- text_tokens = self.tokenizer.tokenize(text)
- return self.tokenizer.convert_tokens_to_ids(text_tokens)
-
- def decode(self, ids):
- tokens = self.tokenizer.convert_ids_to_tokens(ids)
- return self.tokenizer.convert_tokens_to_string(tokens)
-
- def decode_token_ids(self, token_ids):
- tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
- exclude_list = ['[PAD]', '[CLS]']
- non_pads = [t for t in tokens if t not in exclude_list]
-
- result = ""
- for s in non_pads:
- if s.startswith("##"):
- result += s[2:]
- else:
- result += " " + s
-
- return result
-
- @property
- def cls(self):
- return self.cls_id
-
- @property
- def sep(self):
- return self.sep_id
-
- @property
- def pad(self):
- return self.pad_id
-
- @property
- def mask(self):
- return self.mask_id
-
- @property
- def bos_token(self):
- """ Beginning of sentence token id """
- return self._bos_token
-
- @property
- def eos_token(self):
- """ End of sentence token id """
- return self._eos_token
-
- @property
- def additional_special_tokens(self):
- """ All the additional special tokens you may want to use (list of strings)."""
- return self._additional_special_tokens
-
- @property
- def bos_token_id(self):
- """ Id of the beginning of sentence token in the vocabulary."""
- return self._bos_token_id
-
- @property
- def eos_token_id(self):
- """ Id of the end of sentence token in the vocabulary."""
- return self._eos_token_id
-
- @property
- def additional_special_tokens_ids(self):
- """ Ids of all the additional special tokens in the vocabulary (list of integers)."""
- return [self.vocab.get(token) for token in self._additional_special_tokens]
-
- @additional_special_tokens.setter
- def additional_special_tokens(self, value):
- self._additional_special_tokens = value
-
-
-class _GPT2BPETokenizer(AbstractTokenizer):
- """Original GPT2 BPE tokenizer."""
-
- def __init__(self, vocab_file, merge_file):
- name = 'GPT2 BPE'
- super().__init__(name)
-
- self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
- special_tokens=[], max_len=None)
- self.eod_id = self.tokenizer.encoder['<|endoftext|>']
-
- @property
- def vocab_size(self):
- return len(self.tokenizer.encoder)
-
- @property
- def vocab(self):
- return self.tokenizer.encoder
-
- @property
- def inv_vocab(self):
- return self.tokenizer.decoder
-
- def tokenize(self, text):
- return self.tokenizer.encode(text)
-
- def detokenize(self, token_ids):
- return self.tokenizer.decode(token_ids)
-
- @property
- def eod(self):
- return self.eod_id
-
-
-class _SentencePieceTokenizer(AbstractTokenizer):
- """SentencePieceTokenizer-Megatron wrapper"""
-
- def __init__(self, model_file, vocab_extra_ids=0):
- name = 'SentencePieceTokenizer'
- super().__init__(name)
-
- import sentencepiece
- self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
- self._initalize(vocab_extra_ids)
-
- def _populate_vocab(self):
- self._vocab = {}
- self._inv_vocab = {}
-
- for i in range(len(self.tokenizer)):
- t = self.tokenizer.id_to_piece(i)
- self._inv_vocab[i] = t
- self._vocab[t] = i
-
- def _initalize(self, vocab_extra_ids):
- self._populate_vocab()
- self._special_tokens = {}
- self._inv_special_tokens = {}
-
- self._t5_tokens = []
-
- def _add_special_token(t):
- if t not in self._vocab:
- next_id = len(self._vocab)
- self._vocab[t] = next_id
- self._inv_vocab[next_id] = t
- self._special_tokens[t] = self._vocab[t]
- self._inv_special_tokens[self._vocab[t]] = t
-
- _add_special_token('')
- self._cls_id = self._vocab['']
- _add_special_token('')
- self._sep_id = self._vocab['']
- _add_special_token('')
- self._eod_id = self._vocab['']
- _add_special_token('')
- self._mask_id = self._vocab['']
-
- pad_id = self.tokenizer.pad_id()
- try:
- pad_token = self.tokenizer.id_to_piece(pad_id)
- except IndexError:
- pad_token = ''
- _add_special_token(pad_token)
- self._pad_id = self._vocab[pad_token]
-
- bos_id = self.tokenizer.bos_id()
- try:
- bos_token = self.tokenizer.id_to_piece(bos_id)
- except IndexError:
- bos_token = ''
- _add_special_token(bos_token)
- self._bos_id = self._vocab[bos_token]
-
- eos_id = self.tokenizer.eos_id()
- try:
- eos_token = self.tokenizer.id_to_piece(eos_id)
- except IndexError:
- eos_token = ''
- _add_special_token(eos_token)
- self._eos_id = self._vocab[eos_token]
-
- for i in range(vocab_extra_ids):
- t = "".format(i)
- _add_special_token(t)
- self._t5_tokens += [t]
-
- @property
- def vocab_size(self):
- return len(self._vocab)
-
- @property
- def vocab(self):
- return self._vocab
-
- @property
- def inv_vocab(self):
- return self._inv_vocab
-
- @property
- def decoder(self):
- return self._inv_vocab
-
- @property
- def encoder(self):
- return self._vocab
-
- # From:
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
- def tokenize(self, text):
- ids = []
- idx = 0
-
- while 1:
- indices = {}
- for token in self._special_tokens:
- try:
- indices[token] = text[idx:].index(token)
- except ValueError:
- continue
- if len(indices) == 0:
- break
-
- next_token = min(indices, key=indices.get)
- next_idx = idx + indices[next_token]
-
- ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
- ids.append(self._special_tokens[next_token])
- idx = next_idx + len(next_token)
-
- ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
- return ids
-
- # From:
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
- def detokenize(self, ids):
- text = ""
- last_i = 0
-
- for i, id in enumerate(ids):
- if id in self._inv_special_tokens:
- text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
- text += self._inv_special_tokens[id] + " "
- last_i = i + 1
-
- text += self.tokenizer.decode_ids(ids[last_i:])
- return text
-
- @property
- def cls(self):
- return self._cls_id
-
- @property
- def sep(self):
- return self._sep_id
-
- @property
- def pad(self):
- return self._pad_id
-
- @property
- def bos_token_id(self):
- return self._bos_id
-
- @property
- def bos(self):
- return self._bos_id
-
- @property
- def eod(self):
- return self._eod_id
-
- @property
- def eos_token_id(self):
- return self._eos_id
-
- @property
- def eos(self):
- return self._eos_id
-
- @property
- def mask(self):
- return self._mask_id
-
- @property
- def additional_special_tokens_ids(self):
- return [self.vocab[k] for k in self._t5_tokens]
-
-class _GPTSentencePieceTokenizer(_SentencePieceTokenizer):
- """SentencePieceTokenizer-Megatron wrapper"""
-
- def __init__(self, model_file,):
- super().__init__(model_file, vocab_extra_ids=0)
-
- def _initalize(self, vocab_extra_ids):
- self._populate_vocab()
-
- self._pad_id = self.tokenizer.pad_id()
- self._bos_id = self.tokenizer.bos_id()
- self._eos_id = self.tokenizer.eos_id()
-
- def tokenize(self, text):
- return self.tokenizer.encode_as_ids(text)
-
- def detokenize(self, ids):
- return self.tokenizer.decode_ids(ids)
-
- @property
- def cls(self):
- return -1
-
- @property
- def sep(self):
- return -1
-
- @property
- def mask(self):
- return -1
-
- @property
- def eod(self):
- return self._eos_id
-
- @property
- def additional_special_tokens_ids(self):
- return None
-
-class _Llama2Tokenizer(_SentencePieceTokenizer):
- """SentencePieceTokenizer-Megatron wrapper"""
-
- def __init__(self, model_file,):
- super().__init__(model_file, vocab_extra_ids=0)
-
- def _initalize(self, vocab_extra_ids):
- self._populate_vocab()
-
- # BOS / EOS token IDs
- self.n_words: int = self.tokenizer.vocab_size()
- self.bos_id: int = self.tokenizer.bos_id()
- self.eos_id: int = self.tokenizer.eos_id()
- self.pad_id: int = self.tokenizer.pad_id()
- assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size()
-
- def tokenize(self, s: str, bos=True, eos=False):
- '''Default args for text completion, not chat/dialog.'''
- assert type(s) is str
- t = self.tokenizer.encode(s)
- if bos:
- t = [self.bos_id] + t
- if eos:
- t = t + [self.eos_id]
- return t
-
- def detokenize(self, ids):
- return self.tokenizer.decode_ids(ids)
-
- @property
- def cls(self):
- return -1
-
- @property
- def sep(self):
- return -1
-
- @property
- def mask(self):
- return -1
-
- @property
- def eod(self):
- return self.eos_id
-
- @property
- def additional_special_tokens_ids(self):
- return None
-
-class _NullTokenizer:
- def __init__(self, vocab_size):
- vocab_size = int(vocab_size)
- self._eos_id = vocab_size
- self.vocab_size = vocab_size+1
-
- def tokenize(self, text):
- return [int(x) for x in text.split(' ')]
-
- def detokenize(self, ids):
- text = [str(x) for x in ids]
- return ' '.join(text)
-
- @property
- def cls(self):
- return -1
-
- @property
- def sep(self):
- return -1
-
- @property
- def mask(self):
- return -1
-
- @property
- def eod(self):
- return self._eos_id
-
- @property
- def additional_special_tokens_ids(self):
- return None
diff --git a/megatron/training.py b/megatron/training.py
deleted file mode 100644
index 823402b6ef6beabe0dab507da9349bfafda20451..0000000000000000000000000000000000000000
--- a/megatron/training.py
+++ /dev/null
@@ -1,1167 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-
-"""Pretrain utilities."""
-
-import gc
-from datetime import datetime
-import math
-import logging
-import sys
-from .log_handler import CustomHandler
-# Make default logging level INFO, but filter out all log messages not from MCore.
-logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
-from .theoretical_memory_usage import report_theoretical_memory
-import time
-# The earliest we can measure the start time.
-_TRAIN_START_TIME = time.time()
-import torch
-
-from megatron import get_args
-from megatron import get_signal_handler
-from megatron import get_timers
-from megatron import get_tensorboard_writer
-from megatron import get_wandb_writer
-from megatron import get_current_global_batch_size
-from megatron import get_num_microbatches
-from megatron import is_last_rank
-from megatron import update_num_microbatches
-from megatron.core import mpu, tensor_parallel
-from megatron.core.utils import get_model_config
-from megatron import print_rank_0
-from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
-from megatron.model import Float16Module
-from megatron.model import GPTModel
-from megatron.core.distributed import DistributedDataParallel as DDP
-from megatron.core.distributed import finalize_model_grads
-from megatron.core.enums import ModelType
-from megatron.optimizer import get_megatron_optimizer
-from megatron.initialize import initialize_megatron
-from megatron.initialize import write_args_to_tensorboard
-from megatron.initialize import set_jit_fusion_options
-from megatron.optimizer_param_scheduler import OptimizerParamScheduler
-from megatron.utils import check_adlr_autoresume_termination
-from megatron.utils import unwrap_model
-from megatron.data.data_samplers import build_pretraining_data_loader
-from megatron.utils import calc_params_l2_norm
-from megatron.core.pipeline_parallel import get_forward_backward_func
-from megatron.utils import report_memory
-from megatron.model.vision.knn_monitor import compute_feature_bank
-
-
-def print_datetime(string):
- """Note that this call will sync across all ranks."""
- torch.distributed.barrier()
- time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- print_rank_0('[' + string + '] datetime: {} '.format(time_str))
-
-
-def num_floating_point_operations(args, batch_size):
- if not args.group_query_attention:
- args.num_query_groups = args.num_attention_heads
- return (
- 60
- * batch_size
- * args.seq_length
- * args.num_layers
- * args.hidden_size
- * args.hidden_size
- * (
- 1
- + (args.num_query_groups / (5 * args.num_attention_heads))
- + (args.seq_length / (5 * args.hidden_size))
- + (args.padded_vocab_size / (10 * args.num_layers * args.hidden_size))
- )
- )
-
-
-def pretrain(train_valid_test_dataset_provider,
- model_provider,
- model_type,
- forward_step_func,
- process_non_loss_data_func=None,
- extra_args_provider=None,
- args_defaults={}):
- """Main training program.
-
- This function will run the followings in the order provided:
- 1) initialize Megatron.
- 2) setup model, optimizer and lr schedule using the model_provider.
- 3) call train_val_test_data_provider to get train/val/test datasets.
- 4) train the modle using the forward_step_func.
-
- Arguments:
- train_valid_test_dataset_provider: a function that takes the size of
- train/valid/test dataset and returns `train, valid, test` datasets.
- model_provider: a function that returns a vanilla version of the
- model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
- model_type: an enum that specifies the type of model being trained.
- forward_step_func: a function that takes a `data iterator` and `model`,
- and returns a `loss` scalar with a dictionary with key:values being
- the info we would like to monitor during training, for example
- `lm-loss: value`. We also require that this function add
- `batch generator` to the timers class.
- process_non_loss_data_func: a function to post process outputs of the
- network. It can be used for dumping output tensors (e.g images) to
- tensorboard. It takes `collected data`(list of tensors),
- `current iteration index` and `tensorboard writer` as arguments.
- extra_args_provider: a function that takes a parser and adds arguments
- to it. It is used for programs to add their own arguments.
- args_defaults: a dictionary from argument-name to argument-value. It
- to set already parse arguments.
- """
-
- # Initalize and get arguments, timers, and Tensorboard writer.
- initialize_megatron(extra_args_provider=extra_args_provider,
- args_defaults=args_defaults)
- # Set pytorch JIT layer fusion options and warmup JIT functions.
- set_jit_fusion_options()
-
- # Adjust the startup time so it reflects the largest value.
- # This will be closer to what scheduler will see (outside of
- # image ... launches.
- global _TRAIN_START_TIME
- start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
- torch.distributed.all_reduce(start_time_tensor,
- op=torch.distributed.ReduceOp.MIN)
- _TRAIN_START_TIME = start_time_tensor.item()
- print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
- time.time() - _TRAIN_START_TIME))
- print_datetime('after megatron is initialized')
-
- args = get_args()
- timers = get_timers()
-
- # Model, optimizer, and learning rate.
- timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
- model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
- model_provider, model_type)
- timers('model-and-optimizer-setup').stop()
- print_datetime('after model, optimizer, and learning rate '
- 'scheduler are built')
- config = get_model_config(model[0])
-
- # Data stuff.
- timers('train/valid/test-data-iterators-setup', log_level=0).start(
- barrier=True)
- if args.virtual_pipeline_model_parallel_size is not None:
- train_data_iterator = []
- valid_data_iterator = []
- test_data_iterator = []
- for i in range(len(model)):
- mpu.set_virtual_pipeline_model_parallel_rank(i)
- iterators = build_train_valid_test_data_iterators(
- train_valid_test_dataset_provider)
- train_data_iterator.append(iterators[0])
- valid_data_iterator.append(iterators[1])
- test_data_iterator.append(iterators[2])
- else:
- train_data_iterator, valid_data_iterator, test_data_iterator \
- = build_train_valid_test_data_iterators(
- train_valid_test_dataset_provider)
- timers('train/valid/test-data-iterators-setup').stop()
- print_datetime('after dataloaders are built')
-
- # Print setup timing.
- print_rank_0('done with setup ...')
- timers.log(['model-and-optimizer-setup',
- 'train/valid/test-data-iterators-setup'], barrier=True)
-
- if not args.skip_train:
- print_rank_0('training ...')
-
- if args.dataloader_type == 'cyclic' and args.retro_add_retriever:
- args.train_iters = args.retro_cyclic_train_iters
- print_rank_0("retro cyclic train iters : %d" % args.train_iters)
-
- iteration = 0
- if args.do_train and args.train_iters > 0:
- iteration = train(forward_step_func,
- model, optimizer, opt_param_scheduler,
- train_data_iterator, valid_data_iterator,
- process_non_loss_data_func, config)
-
- print_datetime('after training is done')
-
- if args.save and iteration != 0:
- save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
- else:
- print_rank_0('skipping training (--skip-train is on) ...')
-
- iteration = args.iteration
-
- if args.do_valid:
- prefix = f'iteration {iteration} on validation set'
- evaluate_and_print_results(prefix, forward_step_func,
- valid_data_iterator, model,
- iteration, process_non_loss_data_func, config,
- verbose=True, write_to_tensorboard=not args.skip_train)
-
- if args.do_test:
- prefix = f'iteration {iteration} on test set'
- evaluate_and_print_results(prefix, forward_step_func,
- test_data_iterator, model,
- iteration, process_non_loss_data_func, config,
- verbose=True, write_to_tensorboard=not args.skip_train)
-
-
-def update_train_iters(args):
-
- # For iteration-based training, we don't need to do anything
- if args.train_iters:
- return
-
- # Constant batch size with sample-based training.
- if args.rampup_batch_size is None:
- args.train_iters = args.train_samples // args.global_batch_size
-
- else:
- # Sample based training with rampup batch size.
- iterations = 0
- consumed_samples = 0
- # Rampup phase.
- while consumed_samples <= int(args.rampup_batch_size[2]):
- update_num_microbatches(consumed_samples, consistency_check=False)
- consumed_samples += get_current_global_batch_size()
- iterations += 1
- # Reset
- update_num_microbatches(0, consistency_check=False)
- # Constant phase
- # Note that we throw away any partial last batch.
- iterations += (args.train_samples - consumed_samples) // \
- args.global_batch_size
- args.train_iters = iterations
-
- print_rank_0('setting training iterations to {}'.format(args.train_iters))
-
-
-def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
- """Build the model."""
- args = get_args()
- args.model_type = model_type
-
- # Build model.
- if mpu.get_pipeline_model_parallel_world_size() > 1 and \
- args.virtual_pipeline_model_parallel_size is not None:
- assert model_type != ModelType.encoder_and_decoder, \
- "Interleaved schedule not supported for model with both encoder and decoder"
- model = []
- for i in range(args.virtual_pipeline_model_parallel_size):
- mpu.set_virtual_pipeline_model_parallel_rank(i)
- # Set pre_process and post_process only after virtual rank is set.
- pre_process = mpu.is_pipeline_first_stage()
- post_process = mpu.is_pipeline_last_stage()
- this_model = model_provider_func(
- pre_process=pre_process,
- post_process=post_process
- )
- this_model.model_type = model_type
- model.append(this_model)
- else:
- pre_process = mpu.is_pipeline_first_stage()
- post_process = mpu.is_pipeline_last_stage()
- add_encoder = True
- add_decoder = True
- if model_type == ModelType.encoder_and_decoder:
- if mpu.get_pipeline_model_parallel_world_size() > 1:
- assert args.pipeline_model_parallel_split_rank is not None, \
- "Split rank needs to be specified for model with both encoder and decoder"
- rank = mpu.get_pipeline_model_parallel_rank()
- split_rank = args.pipeline_model_parallel_split_rank
- world_size = mpu.get_pipeline_model_parallel_world_size()
- pre_process = rank == 0 or rank == split_rank
- post_process = (rank == (split_rank - 1)) or (
- rank == (world_size - 1))
- add_encoder = mpu.is_pipeline_stage_before_split()
- add_decoder = mpu.is_pipeline_stage_after_split()
- model = model_provider_func(
- pre_process=pre_process,
- post_process=post_process,
- add_encoder=add_encoder,
- add_decoder=add_decoder)
- else:
- model = model_provider_func(
- pre_process=pre_process,
- post_process=post_process
- )
- model.model_type = model_type
-
- if not isinstance(model, list):
- model = [model]
-
- # Disallow training and inference with Transformer Engine
- # for non-GPT models
- args.allow_transformer_engine = all([type(m) == GPTModel for m in model])
- # assert args.allow_transformer_engine or args.transformer_impl == 'local', \
- # 'Transformer Engine is only approved for GPT models'
-
- # Set tensor model parallel attributes if not set.
- # Only parameters that are already tensor model parallel have these
- # attributes set for them. We should make sure the default attributes
- # are set for all params so the optimizer can use them.
- for model_module in model:
- for param in model_module.parameters():
- tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
-
- # Print number of parameters.
- if mpu.get_data_parallel_rank() == 0:
- print(' > number of parameters on (tensor, pipeline) '
- 'model parallel rank ({}, {}): {}'.format(
- mpu.get_tensor_model_parallel_rank(),
- mpu.get_pipeline_model_parallel_rank(),
- sum([sum([p.nelement() for p in model_module.parameters()])
- for model_module in model])), flush=True)
-
- # GPU allocation.
- for model_module in model:
- model_module.cuda(torch.cuda.current_device())
-
- # Fp16 conversion.
- if args.fp16 or args.bf16:
- model = [Float16Module(model_module, args) for model_module in model]
-
- if wrap_with_ddp:
- config = get_model_config(model[0])
- model = [DDP(config,
- model_chunk,
- data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
- accumulate_allreduce_grads_in_fp32=args.accumulate_allreduce_grads_in_fp32,
- overlap_grad_reduce=args.overlap_grad_reduce,
- use_distributed_optimizer=args.use_distributed_optimizer,
- # Turn off bucketing for model_chunk 2 onwards, since communication for these
- # model chunks is overlapped with compute anyway.
- disable_bucketing=(model_chunk_idx > 0))
- for (model_chunk_idx, model_chunk) in enumerate(model)]
-
- # Broadcast params from data parallel src rank to other data parallel ranks.
- if args.data_parallel_random_init:
- for model_module in model:
- model_module.broadcast_params()
-
- return model
-
-
-def get_optimizer_param_scheduler(optimizer):
- """Build the learning rate scheduler."""
- args = get_args()
-
- # Iteration-based training.
- if args.train_iters:
- if args.lr_decay_iters is None:
- args.lr_decay_iters = args.train_iters
- lr_decay_steps = args.lr_decay_iters * args.global_batch_size
- wd_incr_steps = args.train_iters * args.global_batch_size
- if args.lr_warmup_fraction is not None:
- lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
- else:
- lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
- # Sample-based training.
- elif args.train_samples:
- # We need to set training iters for later use. Technically
- # we need to adjust the training samples too (due to last
- # batch being incomplete) but we leave it as is for now.
- update_train_iters(args)
- if args.lr_decay_samples is None:
- args.lr_decay_samples = args.train_samples
- lr_decay_steps = args.lr_decay_samples
- wd_incr_steps = args.train_samples
- if args.lr_warmup_fraction is not None:
- lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
- else:
- lr_warmup_steps = args.lr_warmup_samples
- else:
- raise Exception(
- 'either train-iters or train-samples should be provided.')
-
- opt_param_scheduler = OptimizerParamScheduler(
- optimizer,
- init_lr=args.lr_warmup_init,
- max_lr=args.lr,
- min_lr=args.min_lr,
- lr_warmup_steps=lr_warmup_steps,
- lr_decay_steps=lr_decay_steps,
- lr_decay_style=args.lr_decay_style,
- start_wd=args.start_weight_decay,
- end_wd=args.end_weight_decay,
- wd_incr_steps=wd_incr_steps,
- wd_incr_style=args.weight_decay_incr_style,
- use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
- override_opt_param_scheduler=args.override_opt_param_scheduler)
-
- return opt_param_scheduler
-
-
-def setup_model_and_optimizer(model_provider_func,
- model_type,
- no_wd_decay_cond=None,
- scale_lr_cond=None,
- lr_mult=1.0):
- """Setup model and optimizer."""
- args = get_args()
-
- model = get_model(model_provider_func, model_type)
- unwrapped_model = unwrap_model(model)
-
- optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
- scale_lr_cond, lr_mult)
- opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
-
- if args.load is not None:
- timers = get_timers()
- timers('load-checkpoint', log_level=0).start(barrier=True)
- args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
- timers('load-checkpoint').stop(barrier=True)
- timers.log(['load-checkpoint'])
- else:
- args.iteration = 0
-
- # get model without FP16 and/or DDP wrappers
- if args.iteration == 0 and len(unwrapped_model) == 1 \
- and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
- print_rank_0("Initializing ICT from pretrained BERT model")
- unwrapped_model[0].init_state_dict_from_bert()
- if args.fp16:
- optimizer.reload_model_params()
-
- return model, optimizer, opt_param_scheduler
-
-
-
-def train_step(forward_step_func, data_iterator,
- model, optimizer, opt_param_scheduler, config):
- """Single training step."""
- args = get_args()
- timers = get_timers()
-
- # Set grad to zero.
- for model_chunk in model:
- # If using distributed optimizer, don't zero buffer here; zeroing of buffer is
- # handled automatically by the optimizer after all-gathers finish.
- # Otherwise, zero the buffer.
- model_chunk.zero_grad_buffer(zero_buffer=(not args.use_distributed_optimizer))
- optimizer.zero_grad()
-
- # Forward pass.
- forward_backward_func = get_forward_backward_func()
- losses_reduced = forward_backward_func(
- forward_step_func=forward_step_func,
- data_iterator=data_iterator,
- model=model,
- num_microbatches=get_num_microbatches(),
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- decoder_seq_length=args.decoder_seq_length,
- forward_only=False)
-
- # Empty unused memory.
- if args.empty_unused_memory_level >= 1:
- torch.cuda.empty_cache()
-
- # Vision gradients.
- if args.vision_pretraining and args.vision_pretraining_type == "dino":
- unwrapped_model = unwrap_model(model[0])
- unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
-
- # Update parameters.
- timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
- update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
- timers('optimizer').stop()
-
- # Vision momentum.
- if args.vision_pretraining and args.vision_pretraining_type == "dino":
- unwrapped_model = unwrap_model(model[0])
- unwrapped_model.update_momentum(args.curr_iteration)
-
- # Update learning rate.
- if update_successful:
- increment = get_num_microbatches() * \
- args.micro_batch_size * \
- args.data_parallel_size
- opt_param_scheduler.step(increment=increment)
- skipped_iter = 0
- else:
- skipped_iter = 1
-
- # Empty unused memory.
- if args.empty_unused_memory_level >= 2:
- torch.cuda.empty_cache()
-
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # Average loss across microbatches.
- loss_reduced = {}
- for key in losses_reduced[0]:
- losses_reduced_for_key = [x[key] for x in losses_reduced]
- loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
- return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
- return {}, skipped_iter, grad_norm, num_zeros_in_grad
-
-
-def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter,
- grad_norm, params_norm, num_zeros_in_grad):
- """Log training information such as losses, timing, ...."""
- args = get_args()
- timers = get_timers()
- writer = get_tensorboard_writer()
- wandb_writer = get_wandb_writer()
-
- # Advanced, skipped, and Nan iterations.
- advanced_iters_key = 'advanced iterations'
- skipped_iters_key = 'skipped iterations'
- nan_iters_key = 'nan iterations'
- # Advanced iterations.
- if not skipped_iter:
- total_loss_dict[advanced_iters_key] = total_loss_dict.get(
- advanced_iters_key, 0) + 1
- else:
- if advanced_iters_key not in total_loss_dict:
- total_loss_dict[advanced_iters_key] = 0
- # Skipped iterations.
- total_loss_dict[skipped_iters_key] = total_loss_dict.get(
- skipped_iters_key, 0) + skipped_iter
- # Update losses and set nan iterations
- got_nan = False
- for key in loss_dict:
- if not skipped_iter:
- total_loss_dict[key] = total_loss_dict.get(
- key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
- else:
- value = loss_dict[key].float().sum().item()
- is_nan = value == float('inf') or \
- value == -float('inf') or \
- value != value
- got_nan = got_nan or is_nan
- total_loss_dict[nan_iters_key] = total_loss_dict.get(
- nan_iters_key, 0) + int(got_nan)
-
- # Logging.
- timers_to_log = [
- 'forward-backward',
- 'forward-compute',
- 'backward-compute',
- 'batch-generator',
- 'forward-recv',
- 'forward-send',
- 'backward-recv',
- 'backward-send',
- 'forward-send-forward-recv',
- 'forward-send-backward-recv',
- 'backward-send-forward-recv',
- 'backward-send-backward-recv',
- 'forward-backward-send-forward-backward-recv',
- 'layernorm-grads-all-reduce',
- 'embedding-grads-all-reduce',
- 'all-grads-sync',
- 'params-all-gather',
- 'optimizer-copy-to-main-grad',
- 'optimizer-unscale-and-check-inf',
- 'optimizer-clip-main-grad',
- 'optimizer-count-zeros',
- 'optimizer-inner-step',
- 'optimizer-copy-main-to-model-params',
- 'optimizer']
-
- # Calculate batch size.
- batch_size = args.micro_batch_size * args.data_parallel_size * \
- get_num_microbatches()
-
- total_iterations = total_loss_dict[advanced_iters_key] + \
- total_loss_dict[skipped_iters_key]
-
- # Tensorboard values.
- # Timer requires all the ranks to call.
- if args.log_timers_to_tensorboard and \
- (iteration % args.tensorboard_log_interval == 0):
- timers.write(timers_to_log, writer, iteration,
- normalizer=total_iterations)
- if writer and (iteration % args.tensorboard_log_interval == 0):
- if wandb_writer:
- wandb_writer.log({'samples vs steps': args.consumed_train_samples},
- iteration)
- if args.log_learning_rate_to_tensorboard:
- writer.add_scalar('learning-rate', learning_rate, iteration)
- writer.add_scalar('learning-rate vs samples', learning_rate,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'learning-rate': learning_rate}, iteration)
- if args.log_batch_size_to_tensorboard:
- writer.add_scalar('batch-size', batch_size, iteration)
- writer.add_scalar('batch-size vs samples', batch_size,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'batch-size': batch_size}, iteration)
- for key in loss_dict:
- writer.add_scalar(key , loss_dict[key], iteration)
- writer.add_scalar(key + ' vs samples', loss_dict[key],
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({key: loss_dict[key]}, iteration)
- if args.log_loss_scale_to_tensorboard:
- writer.add_scalar('loss-scale', loss_scale, iteration)
- writer.add_scalar('loss-scale vs samples', loss_scale,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'loss-scale': loss_scale}, iteration)
- if args.log_world_size_to_tensorboard:
- writer.add_scalar('world-size', args.world_size, iteration)
- writer.add_scalar('world-size vs samples', args.world_size,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'world-size': args.world_size}, iteration)
- if grad_norm is not None:
- writer.add_scalar('grad-norm', grad_norm, iteration)
- writer.add_scalar('grad-norm vs samples', grad_norm,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'grad-norm': grad_norm}, iteration)
- if num_zeros_in_grad is not None:
- writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
- writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
- if params_norm is not None:
- writer.add_scalar('params-norm', params_norm, iteration)
- writer.add_scalar('params-norm vs samples', params_norm,
- args.consumed_train_samples)
- if wandb_writer:
- wandb_writer.log({'params-norm': params_norm}, iteration)
- if args.log_memory_to_tensorboard:
- mem_stats = torch.cuda.memory_stats()
- writer.add_scalar(
- "mem-reserved-bytes",
- mem_stats["reserved_bytes.all.current"],
- iteration,
- )
- writer.add_scalar(
- "mem-allocated-bytes",
- mem_stats["allocated_bytes.all.current"],
- iteration,
- )
- writer.add_scalar(
- "mem-allocated-count",
- mem_stats["allocation.all.current"],
- iteration,
- )
-
- if iteration % args.log_interval == 0:
- elapsed_time = timers('interval-time').elapsed(barrier=True)
- elapsed_time_per_iteration = elapsed_time / total_iterations
- throughput = num_floating_point_operations(args, batch_size) / (
- elapsed_time_per_iteration * 10**12 * args.world_size)
- if args.log_timers_to_tensorboard:
- if writer:
- writer.add_scalar('iteration-time',
- elapsed_time_per_iteration, iteration)
- if wandb_writer:
- wandb_writer.log({'iteration-time': elapsed_time_per_iteration},
- iteration)
- log_string = ' iteration {:8d}/{:8d} |'.format(
- iteration, args.train_iters)
- log_string += ' consumed samples: {:12d} |'.format(
- args.consumed_train_samples)
- log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
- elapsed_time_per_iteration * 1000.0)
- if args.log_throughput:
- log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |'
- if args.log_timers_to_tensorboard:
- if writer:
- writer.add_scalar('throughput', throughput, iteration)
- if wandb_writer:
- wandb_writer.log({'throughput': throughput}, iteration)
- log_string += ' learning rate: {:.3E} |'.format(learning_rate)
- log_string += ' global batch size: {:5d} |'.format(batch_size)
- for key in total_loss_dict:
- if key not in [advanced_iters_key, skipped_iters_key,
- nan_iters_key]:
- avg = total_loss_dict[key].item() / \
- float(max(1, total_loss_dict[advanced_iters_key]))
- if avg > 0.0:
- log_string += ' {}: {:.6E} |'.format(key, avg)
- total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
- log_string += ' loss scale: {:.1f} |'.format(loss_scale)
- if grad_norm is not None:
- log_string += ' grad norm: {:.3f} |'.format(grad_norm)
- if num_zeros_in_grad is not None:
- log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
- if params_norm is not None:
- log_string += ' params norm: {:.3f} |'.format(params_norm)
- log_string += ' number of skipped iterations: {:3d} |'.format(
- total_loss_dict[skipped_iters_key])
- log_string += ' number of nan iterations: {:3d} |'.format(
- total_loss_dict[nan_iters_key])
- total_loss_dict[advanced_iters_key] = 0
- total_loss_dict[skipped_iters_key] = 0
- total_loss_dict[nan_iters_key] = 0
- print_rank_last(log_string)
- if report_memory_flag and learning_rate > 0.:
- # Report memory after optimizer state has been initialized.
- if torch.distributed.get_rank() == 0:
- num_microbatches = get_num_microbatches()
- report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
- report_memory('(after {} iterations)'.format(iteration))
- report_memory_flag = False
- timers.log(timers_to_log, normalizer=args.log_interval)
-
- return report_memory_flag
-
-
-def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
- timers = get_timers()
- # Extra barrier is added to make sure
- # all ranks report the max time.
- timers('save-checkpoint', log_level=0).start(barrier=True)
- save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
- timers('save-checkpoint').stop(barrier=True)
- timers.log(['save-checkpoint'])
-
-
-def train(forward_step_func, model, optimizer, opt_param_scheduler,
- train_data_iterator, valid_data_iterator,
- process_non_loss_data_func, config):
- """Train the model function."""
- args = get_args()
- timers = get_timers()
-
- # Write args to tensorboard
- write_args_to_tensorboard()
-
- # Turn on training mode which enables dropout.
- for model_module in model:
- model_module.train()
-
- # Tracking loss.
- total_loss_dict = {}
-
- # Iterations.
- iteration = args.iteration
-
- # Setup some training config params
- config.grad_scale_func = optimizer.scale_loss
- config.timers = timers
- if isinstance(model[0], DDP) and args.overlap_grad_reduce:
- assert config.no_sync_func is None, \
- ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
- 'a custom no_sync_func is not supported when overlapping grad-reduce')
- config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
- if len(model) == 1:
- config.no_sync_func = config.no_sync_func[0]
- if args.delay_grad_reduce:
- config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
- if len(model) == 1:
- config.grad_sync_func = config.grad_sync_func[0]
- if args.overlap_param_gather and args.delay_param_gather:
- config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x)
- for model_index in range(len(model))]
- if len(model) == 1:
- config.param_sync_func = config.param_sync_func[0]
- config.finalize_model_grads_func = finalize_model_grads
-
- timers('interval-time', log_level=0).start(barrier=True)
- print_datetime('before the start of training step')
- report_memory_flag = True
- exit = False
-
- if args.manual_gc:
- # Disable the default garbage collector and perform the collection manually.
- # This is to align the timing of garbage collection across ranks.
- assert args.manual_gc_interval >= 0, \
- 'Manual garbage collection interval should be laerger than or equal to 0.'
- gc.disable()
- gc.collect()
-
- while iteration < args.train_iters:
- if args.profile and \
- iteration == args.profile_step_start and \
- torch.distributed.get_rank() in args.profile_ranks:
- torch.cuda.cudart().cudaProfilerStart()
- torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
-
- update_num_microbatches(args.consumed_train_samples)
- args.curr_iteration = iteration
- loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
- train_step(forward_step_func,
- train_data_iterator,
- model,
- optimizer,
- opt_param_scheduler,
- config)
- iteration += 1
- args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
- args.micro_batch_size * \
- get_num_microbatches()
-
- # Logging.
- loss_scale = optimizer.get_loss_scale().item()
- params_norm = None
- if args.log_params_norm:
- params_norm = calc_params_l2_norm(model)
- report_memory_flag = training_log(loss_dict, total_loss_dict,
- optimizer.param_groups[0]['lr'],
- iteration, loss_scale,
- report_memory_flag, skipped_iter,
- grad_norm, params_norm, num_zeros_in_grad)
-
- # Autoresume
- if args.adlr_autoresume and \
- (iteration % args.adlr_autoresume_interval == 0):
- check_adlr_autoresume_termination(iteration, model, optimizer,
- opt_param_scheduler)
-
- # Evaluation
- if args.eval_interval and iteration % args.eval_interval == 0 and \
- args.do_valid:
- timers('interval-time').stop()
- if args.manual_gc and args.manual_gc_eval:
- # Collect all objects.
- gc.collect()
- prefix = 'iteration {}'.format(iteration)
- evaluate_and_print_results(prefix, forward_step_func,
- valid_data_iterator, model,
- iteration, process_non_loss_data_func,
- config, False)
- if args.manual_gc and args.manual_gc_eval:
- # Collect only the objects created and used in evaluation.
- gc.collect(generation=0)
- timers('interval-time', log_level=0).start(barrier=True)
-
- # Checkpointing
- saved_checkpoint = False
- if args.exit_signal_handler:
- signal_handler = get_signal_handler()
- if any(signal_handler.signals_received()):
- save_checkpoint_and_time(iteration, model, optimizer,
- opt_param_scheduler)
- print_datetime('exiting program after receiving SIGTERM.')
- exit = True
- break
-
- if args.save and args.save_interval and \
- iteration % args.save_interval == 0:
- timers('interval-time').stop()
- save_checkpoint_and_time(iteration, model, optimizer,
- opt_param_scheduler)
- saved_checkpoint = True
- timers('interval-time', log_level=0).start(barrier=True)
-
- # Exiting based on duration
- if args.exit_duration_in_mins:
- train_time = (time.time() - _TRAIN_START_TIME) / 60.0
- done_cuda = torch.cuda.IntTensor(
- [train_time > args.exit_duration_in_mins])
- torch.distributed.all_reduce(
- done_cuda, op=torch.distributed.ReduceOp.MAX)
- done = done_cuda.item()
- if done:
- if not saved_checkpoint:
- save_checkpoint_and_time(iteration, model, optimizer,
- opt_param_scheduler)
- print_datetime('exiting program after {} minutes'.format(train_time))
- exit = True
- break
-
- # Exiting based on iterations
- if args.exit_interval and iteration % args.exit_interval == 0:
- if args.save and not saved_checkpoint:
- save_checkpoint_and_time(iteration, model, optimizer,
- opt_param_scheduler)
- torch.distributed.barrier()
- print_datetime('exiting program at iteration {}'.format(iteration))
- exit = True
- break
-
- if args.profile and \
- iteration == args.profile_step_end and \
- torch.distributed.get_rank() in args.profile_ranks:
- torch.cuda.cudart().cudaProfilerStop()
-
- if args.manual_gc:
- if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0:
- gc.collect()
-
- # Flush TensorBoard and WandB writers.
- writer = get_tensorboard_writer()
- if writer:
- writer.flush()
- wandb_writer = get_wandb_writer()
- if wandb_writer:
- wandb_writer.finish()
-
- # If any exit conditions (signal handler, duration, iterations) have been reached, exit.
- if exit:
- sys.exit()
-
- return iteration
-
-
-def evaluate(forward_step_func,
- data_iterator,
- model,
- process_non_loss_data_func,
- config,
- verbose=False):
- """Evaluation."""
- args = get_args()
- timers = get_timers()
-
- timers('evaluate', log_level=0).start(barrier=True)
-
- if args.vision_pretraining and args.vision_pretraining_type == "dino":
- compute_feature_bank(model)
-
- # Turn on evaluation mode which disables dropout.
- for model_module in model:
- model_module.eval()
-
- total_loss_dict = {}
-
- # make validation batch size independent from training batch size
- eval_batch_size = args.global_batch_size
- eval_num_microbatches = eval_batch_size // \
- (args.micro_batch_size * args.data_parallel_size)
-
- with torch.no_grad():
- iteration = 0
- if verbose:
- print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples')
- while iteration < args.eval_iters:
- iteration += 1
- if verbose:
- print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}')
-
- forward_backward_func = get_forward_backward_func()
- # Don't care about timing during evaluation
- config.timers = None
- loss_dicts = forward_backward_func(
- forward_step_func=forward_step_func,
- data_iterator=data_iterator,
- model=model,
- num_microbatches=eval_num_microbatches,
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- decoder_seq_length=args.decoder_seq_length,
- forward_only=True)
- config.timers = get_timers()
-
- # Empty unused memory
- if args.empty_unused_memory_level >= 1:
- torch.cuda.empty_cache()
-
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # Reduce across processes.
- for loss_dict in loss_dicts:
- for key in loss_dict:
- total_loss_dict[key] = total_loss_dict.get(
- key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
-
- args.consumed_valid_samples += eval_batch_size
-
- if args.exit_duration_in_mins:
- train_time = (time.time() - _TRAIN_START_TIME) / 60.0
- done_cuda = torch.cuda.IntTensor(
- [train_time > args.exit_duration_in_mins])
- torch.distributed.all_reduce(
- done_cuda, op=torch.distributed.ReduceOp.MAX)
- done = done_cuda.item()
- if done:
- print_rank_0('Exiting during evaluation, timelimit reached')
- return None, None, True
-
- collected_non_loss_data = None
- if process_non_loss_data_func is not None and is_last_rank():
- collected_non_loss_data = forward_backward_func(
- forward_step_func=forward_step_func,
- data_iterator=data_iterator,
- model=model,
- num_microbatches=get_num_microbatches(),
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- decoder_seq_length=args.decoder_seq_length,
- forward_only=True,
- collect_non_loss_data=True)
-
- # Move model back to the train mode.
- for model_module in model:
- model_module.train()
-
- for key in total_loss_dict:
- total_loss_dict[key] /= args.eval_iters * eval_num_microbatches
-
- timers('evaluate').stop()
- timers.log(['evaluate'])
-
- return total_loss_dict, collected_non_loss_data, False
-
-def evaluate_and_print_results(prefix, forward_step_func,
- data_iterator, model,
- iteration, process_non_loss_data_func, config,
- verbose=False, write_to_tensorboard=True):
- """Helper function to evaluate and dump results on screen."""
- args = get_args()
- if write_to_tensorboard:
- writer = get_tensorboard_writer()
- else:
- writer = None
-
- wandb_writer = get_wandb_writer()
-
- total_loss_dict, collected_non_loss_data, timelimit = evaluate(
- forward_step_func, data_iterator, model,
- process_non_loss_data_func, config, verbose)
- # Timelimit hit during evaluation
- if timelimit:
- return
- string = ' validation loss at {} | '.format(prefix)
- for key in total_loss_dict:
- string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
- ppl = math.exp(min(20, total_loss_dict[key].item()))
- string += '{} PPL: {:.6E} | '.format(key, ppl)
- if writer:
- writer.add_scalar('{} validation'.format(key),
- total_loss_dict[key].item(),
- iteration)
- writer.add_scalar('{} validation vs samples'.format(key),
- total_loss_dict[key].item(),
- args.consumed_train_samples)
- if args.log_validation_ppl_to_tensorboard:
- writer.add_scalar('{} validation ppl'.format(key), ppl,
- iteration)
- writer.add_scalar('{} validation ppl vs samples'.format(key),
- ppl, args.consumed_train_samples)
- if wandb_writer and is_last_rank():
- wandb_writer.log({
- '{} validation'.format(key): total_loss_dict[key].item()},
- iteration)
-
- if process_non_loss_data_func is not None and writer and is_last_rank():
- process_non_loss_data_func(collected_non_loss_data, iteration, writer)
-
- length = len(string) + 1
- print_rank_last('-' * length)
- print_rank_last(string)
- print_rank_last('-' * length)
-
-
-def cyclic_iter(iter):
- while True:
- for x in iter:
- yield x
-
-
-def build_train_valid_test_datasets(build_train_valid_test_datasets_provider):
- """Build pretraining datasets."""
-
- args = get_args()
-
- # Number of train/valid/test samples.
- if args.train_samples:
- train_samples = args.train_samples
- else:
- train_samples = args.train_iters * args.global_batch_size
- eval_iters = (args.train_iters // args.eval_interval + 1) * \
- args.eval_iters
- test_iters = args.eval_iters
- train_val_test_num_samples = [train_samples,
- eval_iters * args.global_batch_size,
- test_iters * args.global_batch_size]
- print_rank_0(' > datasets target sizes (minimum size):')
- print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
- print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
- print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
-
- # Build the datasets.
- return build_train_valid_test_datasets_provider(train_val_test_num_samples)
-
-
-def build_train_valid_test_data_loaders(
- build_train_valid_test_datasets_provider):
- """Build pretraining data loaders."""
-
- args = get_args()
-
- (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
-
- print_rank_0('> building train, validation, and test datasets ...')
-
- # Backward compatibility, assume fixed batch size.
- if args.iteration > 0 and args.consumed_train_samples == 0:
- assert args.train_samples is None, \
- 'only backward compatiblity support for iteration-based training'
- args.consumed_train_samples = args.iteration * args.global_batch_size
- if args.iteration > 0 and args.consumed_valid_samples == 0:
- if args.train_samples is None:
- args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
- args.eval_iters * args.global_batch_size
-
- # Rely on distributed-aware core datasets, temporary
- is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False)
-
- # Construct the data pipeline
- if is_distributed or mpu.get_tensor_model_parallel_rank() == 0:
-
- # Build datasets.
- train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
- build_train_valid_test_datasets_provider)
- # Build dataloders.
- train_dataloader = build_pretraining_data_loader(
- train_ds, args.consumed_train_samples)
- if args.skip_train:
- valid_dataloader = build_pretraining_data_loader(valid_ds, 0)
- else:
- valid_dataloader = build_pretraining_data_loader(
- valid_ds, args.consumed_valid_samples)
- test_dataloader = build_pretraining_data_loader(test_ds, 0)
-
- # Flags to know if we need to do training/validation/testing.
- do_train = train_dataloader is not None and args.train_iters > 0
- do_valid = valid_dataloader is not None and args.eval_iters > 0
- do_test = test_dataloader is not None and args.eval_iters > 0
- flags = torch.cuda.LongTensor(
- [int(do_train), int(do_valid), int(do_test)])
- else:
- flags = torch.cuda.LongTensor([0, 0, 0])
-
- torch.distributed.broadcast(flags, 0)
-
- args.do_train = getattr(args, "do_train", False) or flags[0].item()
- args.do_valid = getattr(args, "do_valid", False) or flags[1].item()
- args.do_test = getattr(args, "do_test", False) or flags[2].item()
-
- return train_dataloader, valid_dataloader, test_dataloader
-
-
-def build_train_valid_test_data_iterators(
- build_train_valid_test_datasets_provider):
- """Build pretraining data iterators."""
-
- args = get_args()
-
- # Build loaders.
- train_dataloader, valid_dataloader, test_dataloader = \
- build_train_valid_test_data_loaders(
- build_train_valid_test_datasets_provider)
-
- # Build iterators.
- dl_type = args.dataloader_type
- assert dl_type in ['single', 'cyclic']
-
- if train_dataloader is not None:
- train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(train_dataloader))
- else:
- train_data_iterator = None
-
- if valid_dataloader is not None:
- valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(valid_dataloader))
- else:
- valid_data_iterator = None
-
- if test_dataloader is not None:
- test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(test_dataloader))
- else:
- test_data_iterator = None
-
- return train_data_iterator, valid_data_iterator, test_data_iterator
diff --git a/megatron/utils.py b/megatron/utils.py
deleted file mode 100644
index af9b4a07e08735dcb6b362a6cad4ee7b2a45a8db..0000000000000000000000000000000000000000
--- a/megatron/utils.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
-"""General utilities."""
-
-import sys
-
-import torch
-
-try:
- from apex.multi_tensor_apply import multi_tensor_applier
-except ImportError:
- multi_tensor_applier = None
-
-try:
- import amp_C
-except ImportError:
- amp_C = None
-
-from megatron import (
- get_args,
- get_adlr_autoresume,
-)
-from megatron.core import DistributedDataParallel as DDP
-from megatron.core import mpu
-from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
-from megatron.model import Float16Module
-from megatron.model.module import param_is_not_shared
-
-
-ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
-
-
-def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
- return_list = True
- if not isinstance(model, list):
- model = [model]
- return_list = False
- unwrapped_model = []
- for model_module in model:
- while isinstance(model_module, module_instances):
- model_module = model_module.module
- unwrapped_model.append(model_module)
- if not return_list:
- return unwrapped_model[0]
- return unwrapped_model
-
-
-def calc_params_l2_norm(model):
- """Calculate l2 norm of parameters """
- args = get_args()
- if not isinstance(model, list):
- model = [model]
- # Remove duplicate params.
- params_data = []
- for model_ in model:
- for param in model_.parameters():
- is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- if mpu.get_expert_model_parallel_rank() > 0:
- if not getattr(param, 'allreduce', True) and is_not_tp_duplicate:
- assert param_is_not_shared(param)
- params_data.append(param.data.float() if args.bf16 else param.data)
- else:
- is_not_shared = param_is_not_shared(param)
- if is_not_shared and is_not_tp_duplicate:
- params_data.append(param.data.float() if args.bf16 else param.data)
-
- # Check the availability of apex
- assert multi_tensor_applier is not None and amp_C is not None, \
- "apex is not available, please install it from https://github.com/NVIDIA/apex"
-
- # Calculate norm
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- norm, _ = multi_tensor_applier(
- amp_C.multi_tensor_l2norm,
- dummy_overflow_buf,
- [params_data],
- False # no per-parameter norm
- )
- norm_2 = norm * norm
- if mpu.get_expert_model_parallel_world_size() == 1:
- # Sum across all model-parallel GPUs(tensor + pipeline).
- torch.distributed.all_reduce(norm_2,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
- else:
- # Sum across tensor, pipeline and expert model-parallel GPUs.
- torch.distributed.all_reduce(norm_2,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_tensor_and_expert_parallel_group())
- torch.distributed.all_reduce(norm_2,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_pipeline_model_parallel_group())
- return norm_2.item() ** 0.5
-
-
-def average_losses_across_data_parallel_group(losses):
- """Reduce a tensor of losses across all GPUs."""
- averaged_losses = torch.cat(
- [loss.clone().detach().view(1) for loss in losses])
- torch.distributed.all_reduce(averaged_losses,
- group=mpu.get_data_parallel_group())
- averaged_losses = averaged_losses / \
- torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
-
- return averaged_losses
-
-
-def report_memory(name):
- """Simple GPU memory report."""
- mega_bytes = 1024.0 * 1024.0
- string = name + ' memory (MB)'
- string += ' | allocated: {}'.format(
- torch.cuda.memory_allocated() / mega_bytes)
- string += ' | max allocated: {}'.format(
- torch.cuda.max_memory_allocated() / mega_bytes)
- string += ' | reserved: {}'.format(
- torch.cuda.memory_reserved() / mega_bytes)
- string += ' | max reserved: {}'.format(
- torch.cuda.max_memory_reserved() / mega_bytes)
- if mpu.get_data_parallel_rank() == 0:
- print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
- flush=True)
-
-
-def print_params_min_max_norm(optimizer, iteration):
- """Print min, max, and norm of all parameters."""
- index = 0
- rank = torch.distributed.get_rank()
- string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
- optimizer_ = optimizer.optimizer
- for param_group in optimizer_.param_groups:
- for param in param_group['params']:
- index += 1
- min_ = param.data.min()
- max_ = param.data.max()
- norm = torch.linalg.norm(param.data)
- string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
- iteration, rank, index, int(param.tensor_model_parallel))
- string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
- print(string, flush=True)
-
-
-def check_adlr_autoresume_termination(iteration, model,
- optimizer, opt_param_scheduler):
- """Check for autoresume signal and exit if it is received."""
- from megatron.checkpointing import save_checkpoint
-
- args = get_args()
- autoresume = get_adlr_autoresume()
- # Add barrier to ensure consistnecy.
- torch.distributed.barrier()
- if autoresume.termination_requested():
- if args.save:
- save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
- print_rank_0(">>> autoresume termination request found!")
- if torch.distributed.get_rank() == 0:
- autoresume.request_resume()
- print_rank_0(">>> training terminated. Returning")
- sys.exit(0)
-
-
-def get_ltor_masks_and_position_ids(data,
- eod_token,
- reset_position_ids,
- reset_attention_mask,
- eod_mask_loss):
- """Build masks and position id for left to right model."""
-
- # Extract batch size and sequence length.
- micro_batch_size, seq_length = data.size()
-
- # Attention mask (lower triangular).
- if reset_attention_mask:
- att_mask_batch = micro_batch_size
- else:
- att_mask_batch = 1
- attention_mask = torch.tril(torch.ones(
- (att_mask_batch, seq_length, seq_length), device=data.device)).view(
- att_mask_batch, 1, seq_length, seq_length)
-
- # Loss mask.
- loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
- if eod_mask_loss:
- loss_mask[data == eod_token] = 0.0
-
- # Position ids.
- position_ids = torch.arange(seq_length, dtype=torch.long,
- device=data.device)
- position_ids = position_ids.unsqueeze(0).expand_as(data)
- # We need to clone as the ids will be modifed based on batch index.
- if reset_position_ids:
- position_ids = position_ids.clone()
-
- if reset_position_ids or reset_attention_mask:
- # Loop through the batches:
- for b in range(micro_batch_size):
-
- # Find indecies where EOD token is.
- eod_index = position_ids[b, data[b] == eod_token]
- # Detach indecies from positions if going to modify positions.
- if reset_position_ids:
- eod_index = eod_index.clone()
-
- # Loop through EOD indecies:
- prev_index = 0
- for j in range(eod_index.size()[0]):
- i = eod_index[j]
- # Mask attention loss.
- if reset_attention_mask:
- attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
- # Reset positions.
- if reset_position_ids:
- position_ids[b, (i + 1):] -= (i + 1 - prev_index)
- prev_index = i + 1
-
- # Convert attention mask to binary:
- attention_mask = (attention_mask < 0.5)
-
- return attention_mask, loss_mask, position_ids
-
-
-def get_batch_on_this_cp_rank(batch):
- """ Slice batch input along sequence dimension into multiple chunks,
- which are parallelized across GPUs in a context parallel group.
- """
-
- # With causal masking, each token only attends to its prior tokens. Simply split
- # sequence into CP chunks can result in severe load imbalance. That's to say, chunks
- # at the end of sequence have bigger workload than others. To address this issue,
- # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
- # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
- # that we can get balanced workload among GPUs in a context parallel group.
- args = get_args()
- cp_size = args.context_parallel_size
- if cp_size > 1:
- cp_rank = mpu.get_context_parallel_rank()
- for key, val in batch.items():
- seq_dim = 1 if key != 'attention_mask' else 2
- val = val.view(
- *val.shape[0:seq_dim],
- 2 * cp_size,
- val.shape[seq_dim] // (2 * cp_size),
- *val.shape[(seq_dim + 1) :],
- )
- index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
- val = val.index_select(seq_dim, index)
- val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
- batch[key] = val
-
- return batch
-
-
-def print_rank_0(message):
- """If distributed is initialized, print only on rank 0."""
- if torch.distributed.is_initialized():
- if torch.distributed.get_rank() == 0:
- print(message, flush=True)
- else:
- print(message, flush=True)
-
-def is_last_rank():
- return torch.distributed.get_rank() == (
- torch.distributed.get_world_size() - 1)
-
-def print_rank_last(message):
- """If distributed is initialized, print only on last rank."""
- if torch.distributed.is_initialized():
- if is_last_rank():
- print(message, flush=True)
- else:
- print(message, flush=True)