Ai
1 Star 1 Fork 0

LEVSONGSW/DeepLearnLog

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
plotPositionEmbedding.py 1.77 KB
一键复制 编辑 原始数据 按行查看 历史
LEVSONGSW 提交于 2025-08-27 19:27 +08:00 . Plot Position Encode
import torch
import matplotlib.pyplot as plt
import numpy as np
def generate_positional_encoding(max_len, d_model, device='cpu'):
pe = torch.zeros(max_len, d_model, device=device)
position = torch.arange(0, max_len, device=device).unsqueeze(1) # shape: (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2, device=device) * (-np.log(10000.0) / d_model)) # shape: (d_model//2,)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数索引:正弦
pe[:, 1::2] = torch.cos(position * div_term) # 奇数索引:余弦
return pe
# 设置参数(例如:序列长度50,特征维度128)
max_len = 50 # 序列长度(位置数量)
d_model = 128 # 特征维度
pe = generate_positional_encoding(max_len, d_model)
# 转换为numpy数组(便于绘图)
pe_np = pe.cpu().numpy()
# 创建画布
plt.figure(figsize=(12, 6))
# for i in range(0, 16, 2): # 每隔2取一个正弦维度(0, 2, 4...)
# plt.plot(pe_np[:, i], label=f'Sin Dim {i}', linestyle='-', alpha=0.7)
plt.plot(pe_np[:, 20], label=f'Sin Dim {20}', linestyle='-', alpha=0.7)
# for i in range(1, 17, 2): # 每隔2取一个余弦维度(1, 3, 5...)
# plt.plot(pe_np[:, i], label=f'Cos Dim {i}', linestyle='--', alpha=0.7)
plt.plot(pe_np[:, 21], label=f'Cos Dim {21}', linestyle='--', alpha=0.7)
# 添加标题和标签
plt.title('Position Encode Cos Sin', fontsize=14)
plt.xlabel('position index', fontsize=12)
plt.ylabel('encode value', fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # 图例放在右侧
plt.grid(alpha=0.3)
plt.tight_layout() # 自动调整布局
plt.savefig('positional_encoding.png', # 文件名
dpi=300, # 分辨率(300dpi清晰)
bbox_inches='tight') # 确保图例不被截断
plt.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/levsongsw/deep-learn-log.git
git@gitee.com:levsongsw/deep-learn-log.git
levsongsw
deep-learn-log
DeepLearnLog
master

搜索帮助