代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
lstm = nn.LSTM(3, 3) # 一个词的 input_size, hidden_state_size
inputs = [torch.randn(1, 3) for _ in range(5)] # 定义LSTM的输入数据,此处不是mini batch
hidden = (torch.randn(1, 1, 3), # h_0(initial hidden state) of shape (num_layers * num_directions, batch, hidden_size)
torch.randn(1, 1, 3)) # c_0(initial cell state) of shape (num_layers * num_directions, batch, hidden_size)
for i in inputs:
# Step through the sequence one element at a time: 此处一个sequence中实际只有一个word
# out shape (seq_len, batch, num_directions * hidden_size): return (h_t) from the last layer of the LSTM, for each t
# hidden=(hn,cn) when t = seq_len
# h_n of shape (num_layers * num_directions, batch, hidden_size), c_n of shape (num_layers * num_directions, batch, hidden_size)
out, hidden = lstm(i.view(1, 1, -1), hidden)
print("i={},out={},hidden={}".format(i,out,hidden))
# 接下来,我们把5个单词全部放在一个sequence中进行处理
inputs = torch.cat(inputs).view(len(inputs), 1, -1) # 先转为ndarray,把二维张量转为三维张量
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out h0,c0
out, hn_cn = lstm(inputs, hidden)
print("\nout={},hn_cn={}".format(out,hn_cn))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。