代码拉取完成,页面将自动刷新
import torch
import matplotlib.pyplot as plt
import numpy as np
def generate_positional_encoding(max_len, d_model, device='cpu'):
pe = torch.zeros(max_len, d_model, device=device)
position = torch.arange(0, max_len, device=device).unsqueeze(1) # shape: (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2, device=device) * (-np.log(10000.0) / d_model)) # shape: (d_model//2,)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数索引:正弦
pe[:, 1::2] = torch.cos(position * div_term) # 奇数索引:余弦
return pe
# 设置参数(例如:序列长度50,特征维度128)
max_len = 50 # 序列长度(位置数量)
d_model = 128 # 特征维度
pe = generate_positional_encoding(max_len, d_model)
# 转换为numpy数组(便于绘图)
pe_np = pe.cpu().numpy()
# 创建画布
plt.figure(figsize=(12, 6))
# for i in range(0, 16, 2): # 每隔2取一个正弦维度(0, 2, 4...)
# plt.plot(pe_np[:, i], label=f'Sin Dim {i}', linestyle='-', alpha=0.7)
plt.plot(pe_np[:, 20], label=f'Sin Dim {20}', linestyle='-', alpha=0.7)
# for i in range(1, 17, 2): # 每隔2取一个余弦维度(1, 3, 5...)
# plt.plot(pe_np[:, i], label=f'Cos Dim {i}', linestyle='--', alpha=0.7)
plt.plot(pe_np[:, 21], label=f'Cos Dim {21}', linestyle='--', alpha=0.7)
# 添加标题和标签
plt.title('Position Encode Cos Sin', fontsize=14)
plt.xlabel('position index', fontsize=12)
plt.ylabel('encode value', fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # 图例放在右侧
plt.grid(alpha=0.3)
plt.tight_layout() # 自动调整布局
plt.savefig('positional_encoding.png', # 文件名
dpi=300, # 分辨率(300dpi清晰)
bbox_inches='tight') # 确保图例不被截断
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。