代码拉取完成,页面将自动刷新
import numpy as np
import torch
from torch import nn
from torch.nn import init
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)
class ChannelAttention(nn.Module):
def __init__(self, channel, reduction=16, num_layers=3):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
gate_channels = [channel]
gate_channels += [channel // reduction] * num_layers
gate_channels += [channel]
self.ca = nn.Sequential()
self.ca.add_module('flatten', Flatten())
for i in range(len(gate_channels) - 2):
self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))
self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))
self.ca.add_module('relu%d' % i, nn.ReLU())
self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))
def forward(self, x):
res = self.avgpool(x)
res = self.ca(res)
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
return res
class SpatialAttention(nn.Module):
def __init__(self, channel, reduction=16, num_layers=3, dia_val=2):
super().__init__()
self.sa = nn.Sequential()
self.sa.add_module('conv_reduce1',
nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=channel // reduction))
self.sa.add_module('bn_reduce1', nn.BatchNorm2d(channel // reduction))
self.sa.add_module('relu_reduce1', nn.ReLU())
for i in range(num_layers):
self.sa.add_module('conv_%d' % i, nn.Conv2d(kernel_size=3, in_channels=channel // reduction,
out_channels=channel // reduction, padding=autopad(3, None, dia_val), dilation=dia_val))
self.sa.add_module('bn_%d' % i, nn.BatchNorm2d(channel // reduction))
self.sa.add_module('relu_%d' % i, nn.ReLU())
self.sa.add_module('last_conv', nn.Conv2d(channel // reduction, 1, kernel_size=1))
def forward(self, x):
res = self.sa(x)
res = res.expand_as(x)
return res
class BAMBlock(nn.Module):
def __init__(self, channel=512, reduction=16, dia_val=2):
super().__init__()
self.ca = ChannelAttention(channel=channel, reduction=reduction)
self.sa = SpatialAttention(channel=channel, reduction=reduction, dia_val=dia_val)
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size()
sa_out = self.sa(x)
ca_out = self.ca(x)
weight = self.sigmoid(sa_out + ca_out)
out = (1 + weight) * x
return out
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
bam = BAMBlock(channel=512, reduction=16, dia_val=2)
output = bam(input)
print(output.shape)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。