The index selection operator (especially, in the forward path) is slow for large sizes on Ascend.
We wrote the standalone test to prove that. We use the following parameters:
max_size = 600000
choice_size = max_size // 2
We select random choice_size rows from a tensor of the shape (max_size, 128).
In these settings, ASCEND forward path is 16 times slower than the backward path (15 seconds / 0.9 seconds). That’s weird.
If we compare total time to NVIDIA A100, ASCEND is 50 times slower here.
It would be great to understand the reason for such behavior. Is it inefficient torch operator implementation, or some hardware limitations?
Ascend 910
Model forward path total time: 15.6363
Model forward path mean time: 0.7818
Model backward path total time: 0.9307
Model backward path mean time: 0.0465
Total time: 16.5843
NVIDIA A100
Model forward path total time: 0.0791
Model forward path mean time: 0.0040
Model backward path total time: 0.1502
Model backward path mean time: 0.0075
Total time: 0.2789
Below is the code of the test.
To run the test, use the following command line:
python3 torch_large_index_select_test.py -g 0
from functools import wraps
import time
import torch
import argparse
from tqdm import tqdm
max_size = 600000
choice_size = max_size // 2
def synchronize(device):
if device.type == "npu":
torch.npu.synchronize(device)
if device.type == "cuda":
torch.cuda.synchronize(device)
def model_test(model_init_func):
@wraps(model_init_func)
def model_test_wrapper(device, num_batches, init_blank_shot=True):
print(
f"!!!!!!!!!!------------{model_init_func.__name__} TEST BEGAN ------------!!!!!!!!!!"
)
model_init_start_time = time.perf_counter()
model = model_init_func(device)
model_init_total_time = time.perf_counter() - model_init_start_time
opt = torch.optim.Adam(model.parameters())
print(f"Model init time:\t{model_init_total_time:.4f}")
inputs = torch.randn((max_size, 128), dtype=torch.float32, device=device)
forward_path_times = []
backward_path_times = []
total_batch_times = []
if init_blank_shot:
# Need for model initializations
temp_results = model(inputs)
loss = 1.0 - temp_results.sum()
opt.zero_grad()
loss.backward()
opt.step()
synchronize(device)
for _ in tqdm(range(num_batches)):
synchronize(device)
fp_start_time = time.perf_counter()
temp_results = model(inputs)
synchronize(device)
forward_path_times.append(time.perf_counter() - fp_start_time)
loss = 1.0 - temp_results.sum()
opt.zero_grad()
bp_start_time = time.perf_counter()
loss.backward()
synchronize(device)
backward_path_times.append(time.perf_counter() - bp_start_time)
opt.step()
synchronize(device)
total_batch_times.append(time.perf_counter() - fp_start_time)
print(f"Model forward path total time:\t{sum(forward_path_times):.4f}")
print(
f"Model forward path mean time:\t{sum(forward_path_times) / num_batches:.4f}"
)
print(f"Model backward path total time:\t{sum(backward_path_times):.4f}")
print(
f"Model backward path mean time:\t{sum(backward_path_times) / num_batches:.4f}"
)
print(f"Total time:\t{sum(total_batch_times):.4f}")
print(f"Average model time:\t{sum(total_batch_times) / num_batches:.4f}")
print(
f"!!!!!!!!!!------------{model_init_func.__name__} TEST ENDED ------------!!!!!!!!!!"
)
return model_test_wrapper
class SelectionModelClass(torch.nn.Module):
def __init__(self, device):
super(SelectionModelClass, self).__init__()
self.dense_layer = torch.nn.Linear(in_features=128, out_features=128).to(device)
self.idx_list = torch.randint(
0, max_size, (choice_size,), dtype=torch.long, device=device
)
def forward(self, x):
x = self.dense_layer(x)
x = torch.relu(x)
# !!!! This is main cause of slowness !!!!
x = x[self.idx_list]
return x
@model_test
def SelectionModelTest(device):
class TempClass(SelectionModelClass):
def __init__(self, device):
super(TempClass, self).__init__(device)
return TempClass(device=device)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", "-g", default=-1, type=int, help="-1 means cpu")
parser.add_argument("--batches-num", "-b", default=20, type=int, help="batches num")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
device = None
if args.gpu == -1:
device = torch.device("cpu")
else:
try:
import torch_npu
if not torch.npu.is_available():
print("NPU is not available, if\n")
else:
# Without this line, torch sends tensors to "npu:0" anyway. :(
torch.npu.set_device(args.gpu)
device = torch.device(f"npu:{args.gpu}")
except:
print("NPU is not available\n")
if torch.cuda.is_available():
device = torch.device(f"cuda:{args.gpu}")
else:
print("Cuda is not available\n")
assert device is not None, "Can't instantiate device"
print(f"Current device: {device}")
SelectionModelTest(device, args.batches_num)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
Please add labels , also you can visit https://gitee.com/ascend/community/blob/master/labels.md to find more.
为了让代码尽快被审核,请您为Issue打上标签,打上标签的Issue可以直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/ascend/community/blob/master/labels.md
以模型训练相关代码提交为例,如果你提交的是模型训练代码,你可以这样评论:
//train/model
另外你还可以给这个Issue标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!
Can you provide a copy of the profiling data?
Logs has been provided in a private message.
登录 后才可以发表评论