多尺度模块特征融合模块学习1
CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation学习
·
模块来源
[2405.10530] CM-UNet: Hybrid CNN-Mamba UNet for Remote Sensing Image Semantic Segmentation
模块名称
CM-UNet
模块作用
提取多尺度特征信息
模块结构
模块代码
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels, reduction=4):
super(ChannelAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttentionModule(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttentionModule, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class FusionConv(nn.Module):
def __init__(self, in_channels, out_channels, factor=4.0):
super(FusionConv, self).__init__()
dim = int(out_channels // factor)
self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)
self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)
self.spatial_attention = SpatialAttentionModule()
self.channel_attention = ChannelAttentionModule(dim)
self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)
self.down_2 = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
def forward(self, x1, x2, x4):
x_fused = torch.cat([x1, x2, x4], dim=1)
x_fused = self.down(x_fused)
x_fused_c = x_fused * self.channel_attention(x_fused)
x_3x3 = self.conv_3x3(x_fused)
x_5x5 = self.conv_5x5(x_fused)
x_7x7 = self.conv_7x7(x_fused)
x_fused_s = x_3x3 + x_5x5 + x_7x7
x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)
x_out = self.up(x_fused_s + x_fused_c)
return x_out
class DownFusion(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownFusion, self).__init__()
self.fusion_conv = FusionConv(in_channels, out_channels)
self.CAM = ChannelAttentionModule(out_channels)
def forward(self, x1, x2):
x_fused = torch.cat([x1, x2], dim=1)
x_fused = self.fusion_conv(x_fused)
x_fused = + x_fused
return x_fused
class MSAA(nn.Module):
def __init__(self, in_channels, out_channels):
super(MSAA, self).__init__()
self.fusion_conv = FusionConv(in_channels, out_channels)
def forward(self, x1, x2, x4, last=False):
# # x2 是从低到高,x4是从高到低的设计,x2传递语义信息,x4传递边缘问题特征补充
# x_1_2_fusion = self.fusion_1x2(x1, x2)
# x_1_4_fusion = self.fusion_1x4(x1, x4)
# x_fused = x_1_2_fusion + x_1_4_fusion
x_fused = self.fusion_conv(x1, x2, x4)
return x_fused
总结
CM-UNet,这是一种利用最新的 Mamba 架构进行 RS 语义分割的高效框架。这个设计通过采用新颖的 UNet 形结构来解决大规模 RS 图像中显著的目标变化问题。编码器利用 ResNet 提取文本信息,而解码器则使用 CSMamba 块来有效捕获全局长距离依赖关系。此外,CM-UNet还集成了多尺度注意力聚合 (MSAA) 模块和多输出增强功能,以进一步支持多尺度特征学习。CM-UNet 已在三个 RS 语义分割数据集中得到验证,实验结果证明了此方法的优越性。
更多推荐
所有评论(0)