Ai
1 Star 1 Fork 0

LEVSONGSW/DeepLearnLog

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
TensorParallel.py 5.33 KB
一键复制 编辑 原始数据 按行查看 历史
LEVSONGSW 提交于 2025-07-25 17:09 +08:00 . Distributed Parallel Learn Log
# %%
from turtle import Vec2D
from sympy import sec
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
# %%
# ! 矩阵/Tensor并行
# ! 学习资料:https://github.com/chunhuizhang/pytorch_distribute_tutorials/blob/main/tutorials/tensor_parallel.ipynb
# %%
# * 模块简单测试
X = np.random.randn(100, 200)
A = np.random.randn(200, 300)
B = np.random.randn(300, 400)
# %%
def split_columnwise(A, num_splits):
return np.split(A, num_splits, axis=1)
def split_rowwise(A, num_splits):
return np.split(A, num_splits, axis=0)
# %%
def normal_forward_pass(X, A, B, f):
Y = f(np.dot(X, A))
Z = np.dot(Y, B)
return Z
# %%
def tensor_parallel_forward_pass(X, A, B, f):
A1, A2 = split_columnwise(A, 2)
B1, B2 = split_rowwise(B, 2)
Y1 = f(np.dot(X, A1))
Y2 = f(np.dot(X, A2))
Z1 = np.dot(Y1, B1)
Z2 = np.dot(Y2, B2)
return Z1 + Z2
# %%
Z_normal = normal_forward_pass(X, A, B, np.tanh)
Z_tensor = tensor_parallel_forward_pass(X, A, B, np.tanh)
# %%
np.allclose(Z_normal, Z_tensor)
# %%
# ! 模型
input = torch.randn(size=(1, 5, 10), dtype=torch.float32)
embedding_dim = input.size(2)
embedding_dim
# %%
dense_h_to_4h = nn.Linear(embedding_dim, embedding_dim*4, bias=False)
output1 = dense_h_to_4h(input)
output1.shape
# %%
dense_4h_to_h = nn.Linear(embedding_dim*4, embedding_dim, bias=False)
output2 = dense_4h_to_h(output1)
output2.shape
# %%
n_device = 2
half_4h = embedding_dim * 4 // n_device
# %%
dense_h_to_4h_parallel = nn.Linear(embedding_dim, half_4h, bias=False)
dense_h_to_4h_parallel.weight.data = dense_h_to_4h.weight[:half_4h, :]
first_h_to_4h = dense_h_to_4h_parallel(input)
first_h_to_4h.shape
# %%
dense_h_to_4h_parallel = nn.Linear(embedding_dim, half_4h, bias=False)
dense_h_to_4h_parallel.weight.data = dense_h_to_4h.weight[half_4h:, :]
second_h_to_4h = dense_h_to_4h_parallel(input)
second_h_to_4h.shape
# %%
dense_4h_to_h_parallel = nn.Linear(embedding_dim*4, embedding_dim, bias=False)
dense_4h_to_h_parallel.weight.data = dense_4h_to_h.weight[:, :half_4h]
first_4h_to_h = dense_4h_to_h_parallel(first_h_to_4h)
first_4h_to_h.shape
# %%
dense_4h_to_h_parallel = nn.Linear(embedding_dim*4, embedding_dim, bias=False)
dense_4h_to_h_parallel.weight.data = dense_4h_to_h.weight[:, half_4h:]
second_4h_to_h = dense_4h_to_h_parallel(second_h_to_4h)
second_4h_to_h.shape
# %%
first_4h_to_h + second_4h_to_h
# %%
output2
# %%
# ! Attention 模块
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)
bsz, seq_len, hidden_size = input.size()
hidden_size
# %%
num_heads = 4
head_dim = hidden_size // num_heads
head_dim
# %%
Wq = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
Wk = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
Wv = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
Wo = nn.Linear(head_dim * num_heads, hidden_size, bias=False)
# %%
q = Wq(input)
k = Wk(input)
v = Wv(input)
q = q.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
v = v.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
v.shape
# %%
attn_weight = torch.matmul(q, k.transpose(2, 3))/math.sqrt(head_dim)
attn_weight.shape
# %%
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output = torch.matmul(attn_weight, v)
attn_output.shape
# %%
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, seq_len, num_heads * head_dim)
attn_output.shape
# %%
attn_output_non_tp = Wo(attn_output)
attn_output_non_tp.shape
# %%
attn_output_non_tp
# %%
# ! Attention 中 张量并行 模块
n_devices = 2
num_heads = num_heads // n_device
# %%
Wq.weight.shape
# %%
Wq_blocks = Wq.weight.split(num_heads * head_dim, dim=0)
print(Wq_blocks[0].shape)
print(Wq_blocks[1].shape)
# %%
Wk_blocks = Wk.weight.split(num_heads * head_dim, dim=0)
Wv_blocks = Wv.weight.split(head_dim * num_heads, dim=0)
# %%
Wo_blocks = Wo.weight.split(num_heads * head_dim, dim=1)
print(Wo_blocks[0].shape)
print(Wo_blocks[1].shape)
# %%
q1 = F.linear(input, Wq_blocks[0])
k1 = F.linear(input, Wk_blocks[0])
v1 = F.linear(input, Wv_blocks[0])
q1.shape
# %%
q1 = q1.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
k1 = k1.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
v1 = v1.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
# %%
attn_weight = torch.matmul(q1, k1.transpose(2, 3)) / math.sqrt(head_dim)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output1 = torch.matmul(attn_weight, v1)
attn_output1 = attn_output1.transpose(1, 2).contiguous()
attn_output1.shape
# %%
attn_output1 = attn_output1.reshape(bsz, seq_len, num_heads * head_dim)
attn_output1.shape
# %%
attn_output1 = F.linear(attn_output1, Wo_blocks[0])
attn_output1.shape
# %%
q2 = F.linear(input, Wq_blocks[1])
k2 = F.linear(input, Wk_blocks[1])
v2 = F.linear(input, Wv_blocks[1])
q2 = q2.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
v2 = v2.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
k2 = k2.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2)
attn_weight = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(head_dim)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output2 = torch.matmul(attn_weight, v2)
attn_output2 = attn_output2.transpose(1, 2).contiguous()
attn_output2 = attn_output2.reshape(bsz, seq_len, num_heads * head_dim)
attn_output2 = F.linear(attn_output2, Wo_blocks[1])
attn_output2.shape
# %%
attn_output1 + attn_output2
# %%
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/levsongsw/deep-learn-log.git
git@gitee.com:levsongsw/deep-learn-log.git
levsongsw
deep-learn-log
DeepLearnLog
master

搜索帮助