1 Star 0 Fork 0

wangc_coder/resume

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 1.38 KB
一键复制 编辑 原始数据 按行查看 历史
wangc_coder 提交于 2022-06-20 17:36 +08:00 . 1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
lstm = nn.LSTM(3, 3) # 一个词的 input_size, hidden_state_size
inputs = [torch.randn(1, 3) for _ in range(5)] # 定义LSTM的输入数据,此处不是mini batch
hidden = (torch.randn(1, 1, 3), # h_0(initial hidden state) of shape (num_layers * num_directions, batch, hidden_size)
torch.randn(1, 1, 3)) # c_0(initial cell state) of shape (num_layers * num_directions, batch, hidden_size)
for i in inputs:
# Step through the sequence one element at a time: 此处一个sequence中实际只有一个word
# out shape (seq_len, batch, num_directions * hidden_size): return (h_t) from the last layer of the LSTM, for each t
# hidden=(hn,cn) when t = seq_len
# h_n of shape (num_layers * num_directions, batch, hidden_size), c_n of shape (num_layers * num_directions, batch, hidden_size)
out, hidden = lstm(i.view(1, 1, -1), hidden)
print("i={},out={},hidden={}".format(i,out,hidden))
# 接下来,我们把5个单词全部放在一个sequence中进行处理
inputs = torch.cat(inputs).view(len(inputs), 1, -1) # 先转为ndarray,把二维张量转为三维张量
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) # clean out h0,c0
out, hn_cn = lstm(inputs, hidden)
print("\nout={},hn_cn={}".format(out,hn_cn))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Java
1
https://gitee.com/wangc_coder/resume.git
git@gitee.com:wangc_coder/resume.git
wangc_coder
resume
resume
master

搜索帮助