代码拉取完成,页面将自动刷新
# %%
import os
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data import Dataset, DataLoader
from torch.distributed._tensor import DeviceMesh
from torch.amp.autocast_mode import autocast
# %%
def init_dist():
init_process_group('nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
# %%
class nets(nn.Module):
def __init__(self, inputdim: int, hiddendim: int, outdim: int):
super(nets, self).__init__()
self.input = nn.Linear(inputdim, hiddendim)
self.output = nn.Linear(hiddendim, outdim)
def forward(self, x: torch.Tensor):
print(f"Rank:{os.environ['LOCAL_RANK']}, Input Weight:{self.input.weight.shape}, Output Weight:{self.output.weight.shape}")
print(f"Rank:{os.environ['LOCAL_RANK']}, Weight:{self.output.weight}, Output Weight:{self.output.weight.shape}") # 查看输出权重值维度变成一半,但是查看输出权重的shape依然是没有拆分的shape,说明我们查看权重的shape时,Pytorch的张量并行会一直记录未拆分的shape
x = self.input(x)
o= self.output(x)
return o
# %%
class datasets(Dataset):
def __init__(self, batchsize: int, inputdim: int):
# self.data = torch.randn((batchsize, inputdim))
self.data = torch.tensor([[-1.2685, -0.3359, 0.3753, 0.8139, -0.8377, 0.5173, -2.3060, -1.2877,
0.5419, 0.2264, 0.3391, 1.0687, -0.4949, -0.3194, -0.2784, -0.3434],
[ 0.6138, -0.6626, -0.8777, -0.6654, 1.0892, -1.2870, 0.9698, -0.7955,
-1.2492, 1.4883, 0.1950, 1.1504, 1.7145, -1.3542, -0.1323, 0.9527],
[-0.6840, 1.1146, 1.2526, -1.6137, -0.4373, 0.0589, 1.5469, 2.1319,
1.3031, 1.5617, 2.3328, 0.0339, -0.4445, 0.2684, 1.5338, 0.2438],
[-0.6055, 0.2712, -0.8205, 0.9398, 0.2239, -0.7185, 1.1680, 0.7111,
-0.9498, 0.2168, -0.6744, -0.5974, -1.3757, 0.9873, -0.4443, 1.2712],
[-1.4351, -0.2610, -0.9145, -1.4046, -0.1739, -0.1877, -0.0097, -0.9801,
0.2072, 0.8775, 0.6582, 0.5227, -1.2462, -0.5198, -0.3489, 0.6775],
[ 0.3895, 0.2157, 0.9020, -0.9164, 1.0149, -0.1645, 0.1544, 0.4764,
0.8894, 0.0314, -0.0906, -0.6251, -1.7766, 0.3059, -0.3503, -0.8400],
[-1.1152, -0.4049, 0.9521, 1.1675, 1.5218, -1.6813, 0.4359, -0.7398,
-1.5968, -0.6112, -0.0541, 0.2037, -1.0194, -1.7481, 0.0972, 1.3010],
[-0.3545, 0.3862, 0.3221, -1.4349, 1.3182, -1.2108, 0.7134, 1.1861,
-0.1904, 0.5095, 0.2032, 0.1373, -2.0926, -1.8493, -0.8085, -0.4283],
[-0.0496, -0.8880, -0.0983, 0.1265, -0.0858, -1.3326, 0.3974, 1.2103,
0.3597, -1.5291, 0.6048, -0.0075, 0.0552, -0.4060, 0.0734, 0.2467],
[ 0.8863, 0.7070, 0.3548, 0.7311, 0.0273, -0.4060, -0.9918, 1.1632,
0.6291, 0.5502, 0.4488, -1.3875, 0.4756, -1.2840, -0.8738, 0.2462],
[-1.3289, -0.9069, -1.2382, 0.3249, 0.1817, -0.0952, 1.8952, 1.1803,
-0.2355, 0.2777, 0.5185, 0.6320, 0.1538, 0.1887, 1.0202, -0.0270],
[-0.8395, -0.3305, -0.4228, -0.7803, -0.5505, -1.1946, -1.6778, 1.0611,
-0.8079, 0.0617, 0.9839, -0.2509, -0.5673, 0.1756, 1.7068, -1.2233],
[-1.9708, 0.0852, 0.4051, 0.3808, 1.5265, 0.1554, -0.1910, 2.1397,
-0.7290, 0.1410, -0.3454, -0.5503, -0.4366, 1.2320, -0.4126, -1.0403],
[ 0.5622, -1.8529, -2.4480, 0.7545, 0.2815, -0.7455, 0.5176, 1.0525,
-1.2969, 0.6270, 0.6824, 1.2537, -0.4509, -0.7034, -0.5586, -0.6927],
[-1.8240, -1.4849, 0.9169, 0.5198, 1.2471, 1.5980, 0.8311, -0.2132,
-0.5307, -0.0044, 1.4256, 0.6691, 1.3700, -0.4051, -1.4883, -0.1453],
[-0.6203, 0.3816, -2.3648, -1.7413, -0.5032, 0.2807, -0.0148, -1.0793,
0.5438, -0.8159, 1.3272, 1.1375, 1.5259, 0.8258, -0.5577, 1.1011]])
# self.label = torch.randint(0, 2,(batchsize,))
self.label = torch.tensor([1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1])
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
return self.data[idx], self.label[idx]
def __len__(self):
return len(self.label)
# %%
init_dist()
if int(os.environ["LOCAL_RANK"]) != 0:
torch.distributed.barrier()
# 依然会多次调用,所以采取数据集固定
train_dataset = datasets(16, 16)
if int(os.environ['LOCAL_RANK']) == 0:
torch.distributed.barrier()
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, sampler=DistributedSampler(train_dataset))
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
device = int(os.environ['LOCAL_RANK']) if torch.cuda.is_available() else 'cpu'
net = nets(16, 32, 1)
net.to(device)
parallel_model = parallelize_module(
module=net,
device_mesh=DeviceMesh('cuda',[0, 1]),
parallelize_plan=PairwiseParallel()
)
optimizer = optim.Adam(parallel_model.parameters(), lr=0.001)
loss_fun = torch.nn.BCEWithLogitsLoss()
loss_item = 0
for idx, data in enumerate(train_dataloader):
optimizer.zero_grad()
input_data = data[0]
input_data = input_data.to(device)
label_data = data[1].to(torch.float16)
# print(f"Data-0:{data[0]}")
# print(f"Data-1:{data[1]}")
label_data = label_data.to(device)
with autocast(input_data.device.type):
out = parallel_model(input_data)
# ! 会自动合并,可以查看输出值:out,进行对比
print(f"Out Shape:{out.shape}")
# print(f"Out: {out}")
# print(f"Device Mesh:{out.device_mesh}")
loss = loss_fun(out, label_data.unsqueeze(-1))
# print(f"Label: {label_data}")
# print(f"Loss: {loss}")
loss.backward()
optimizer.step()
loss_item += loss.to('cpu').item()
destroy_process_group()
# %%
print(f"Loss:{loss_item / 16}")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。