1 Star 0 Fork 2

wonder/ConvLSTM-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ConvRNN.py 4.53 KB
一键复制 编辑 原始数据 按行查看 历史
jhhuang96 提交于 5年前 . modify header
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : ConvRNN.py
@Time : 2020/03/09
@Author : jhhuang96
@Mail : hjh096@126.com
@Version : 1.0
@Description: convrnn cell
'''
import torch
import torch.nn as nn
class CGRU_cell(nn.Module):
"""
ConvGRU Cell
"""
def __init__(self, shape, input_channels, filter_size, num_features):
super(CGRU_cell, self).__init__()
self.shape = shape
self.input_channels = input_channels
# kernel_size of input_to_state equals state_to_state
self.filter_size = filter_size
self.num_features = num_features
self.padding = (filter_size - 1) // 2
self.conv1 = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
2 * self.num_features, self.filter_size, 1,
self.padding),
nn.GroupNorm(2 * self.num_features // 32, 2 * self.num_features))
self.conv2 = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
self.num_features, self.filter_size, 1, self.padding),
nn.GroupNorm(self.num_features // 32, self.num_features))
def forward(self, inputs=None, hidden_state=None, seq_len=10):
# seq_len=10 for moving_mnist
if hidden_state is None:
htprev = torch.zeros(inputs.size(1), self.num_features,
self.shape[0], self.shape[1]).cuda()
else:
htprev = hidden_state
output_inner = []
for index in range(seq_len):
if inputs is None:
x = torch.zeros(htprev.size(0), self.input_channels,
self.shape[0], self.shape[1]).cuda()
else:
x = inputs[index, ...]
combined_1 = torch.cat((x, htprev), 1) # X_t + H_t-1
gates = self.conv1(combined_1) # W * (X_t + H_t-1)
zgate, rgate = torch.split(gates, self.num_features, dim=1)
# zgate, rgate = gates.chunk(2, 1)
z = torch.sigmoid(zgate)
r = torch.sigmoid(rgate)
combined_2 = torch.cat((x, r * htprev),
1) # h' = tanh(W*(x+r*H_t-1))
ht = self.conv2(combined_2)
ht = torch.tanh(ht)
htnext = (1 - z) * htprev + z * ht
output_inner.append(htnext)
htprev = htnext
return torch.stack(output_inner), htnext
class CLSTM_cell(nn.Module):
"""ConvLSTMCell
"""
def __init__(self, shape, input_channels, filter_size, num_features):
super(CLSTM_cell, self).__init__()
self.shape = shape # H, W
self.input_channels = input_channels
self.filter_size = filter_size
self.num_features = num_features
# in this way the output has the same size
self.padding = (filter_size - 1) // 2
self.conv = nn.Sequential(
nn.Conv2d(self.input_channels + self.num_features,
4 * self.num_features, self.filter_size, 1,
self.padding),
nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features))
def forward(self, inputs=None, hidden_state=None, seq_len=10):
# seq_len=10 for moving_mnist
if hidden_state is None:
hx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
self.shape[1]).cuda()
cx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
self.shape[1]).cuda()
else:
hx, cx = hidden_state
output_inner = []
for index in range(seq_len):
if inputs is None:
x = torch.zeros(hx.size(0), self.input_channels, self.shape[0],
self.shape[1]).cuda()
else:
x = inputs[index, ...]
combined = torch.cat((x, hx), 1)
gates = self.conv(combined) # gates: S, num_features*4, H, W
# it should return 4 tensors: i,f,g,o
ingate, forgetgate, cellgate, outgate = torch.split(
gates, self.num_features, dim=1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
output_inner.append(hy)
hx = hy
cx = cy
return torch.stack(output_inner), (hy, cy)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wonderif/ConvLSTM-PyTorch.git
git@gitee.com:wonderif/ConvLSTM-PyTorch.git
wonderif
ConvLSTM-PyTorch
ConvLSTM-PyTorch
master

搜索帮助