1. 引言

卷积神经网络(CNN)在计算机视觉领域取得了巨大成功,广泛应用于图像分类、目标检测、语义分割等任务。然而,CNN 存在天然的局限性:

  • 局部感受野:CNN 只能捕获局部特征,对长距离依赖关系建模能力不足。
  • 不变性不足:旋转、缩放、角度变化可能会影响模型的表现。
  • 计算资源消耗大:深度 CNN(如 ResNet-101、EfficientNet)参数量巨大,计算需求高。

解决方案?
近年来,研究人员提出了一系列方法来改进 CNN,包括:

  1. 注意力机制(Attention Mechanism):如 SE(Squeeze-and-Excitation)、CBAM(Convolutional Block Attention Module)、Transformer 结构等,使 CNN 具备全局感知能力。
  2. 混合架构(Hybrid Models):结合 CNN 和 Transformer 提取特征,提高任务表现。
  3. 轻量化 CNN 设计:如 MobileNetV3、EfficientNet,优化计算效率。

本教程将介绍如何结合注意力机制Transformer 结构,突破 CNN 的局限性,并提供代码示例。


2. 传统 CNN 的局限性

2.1 局部感受野问题

CNN 依靠卷积核(kernel)对局部区域进行特征提取,而卷积层的堆叠仅能逐层扩大感受野。例如:

import torch
import torch.nn as nn

conv_layer = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)

此处 kernel_size=3,表示每个卷积核仅能看到 3x3 范围的像素,尽管深度 CNN 通过多层卷积扩大感受野,但仍然缺乏全局信息的建模能力。


3. 注意力机制如何提升 CNN?

3.1 SE(Squeeze-and-Excitation)注意力机制

SE 模块能够提升 CNN 对重要特征的关注度,通过全局池化计算通道间关系。其核心结构:

  1. Squeeze(全局平均池化):将特征图压缩成通道级别的向量。
  2. Excitation(通道权重调整):利用 MLP 学习每个通道的重要性。
  3. Recalibration(特征重新加权):调整特征图权重,增强重要特征。

代码实现 SE 模块:

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.global_avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
  • global_avg_pool(x) 提取全局信息。
  • fc(y) 通过 MLP 计算通道注意力权重。
  • x * y.expand_as(x) 重新加权特征图。

SE 模块可直接嵌入 CNN 结构,如 ResNet:

class SE_ResNet(nn.Module):
    def __init__(self):
        super(SE_ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.se = SEBlock(64)  # 添加 SE 注意力模块

    def forward(self, x):
        x = self.conv1(x)
        x = self.se(x)
        return x

3.2 CBAM(Convolutional Block Attention Module)注意力机制

CBAM 结合了 通道注意力(Channel Attention)空间注意力(Spatial Attention)

  1. 通道注意力:学习不同通道的重要性。
  2. 空间注意力:学习特征图中不同位置的重要性。

CBAM 代码实现:

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x).view(x.size(0), -1))
        max_out = self.fc(self.max_pool(x).view(x.size(0), -1))
        out = avg_out + max_out
        return x * out.view(x.size(0), x.size(1), 1, 1)

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv(out))
        return x * out
  • ChannelAttention 通过通道池化(avg_pool & max_pool)计算重要性权重。
  • SpatialAttention 通过 7x7 卷积计算空间特征权重。

CBAM 适用于多个 CNN 结构,如 ResNet, EfficientNet

class CBAM_ResNet(nn.Module):
    def __init__(self):
        super(CBAM_ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.ca = ChannelAttention(64)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.conv1(x)
        x = self.ca(x)
        x = self.sa(x)
        return x

4. Transformer 如何进一步提升 CNN?

CNN 在提取局部特征时表现良好,而 Transformer 适用于全局信息建模。近年来,一些混合架构(Hybrid Models)将 CNN 与 Transformer 结合,例如:

  • Swin Transformer:使用滑动窗口(Shifted Window)机制,兼具 CNN 和 Transformer 优势。
  • ConvNeXt:基于 CNN 但借鉴 Transformer 设计,提升全局特征建模能力。

如何替换 CNN Backbone 为 Swin Transformer?

from timm.models.swin_transformer import swin_tiny_patch4_window7_224

class SwinCNN(nn.Module):
    def __init__(self):
        super(SwinCNN, self).__init__()
        self.backbone = swin_tiny_patch4_window7_224(pretrained=True)

    def forward(self, x):
        return self.backbone(x)
  • swin_tiny_patch4_window7_224(pretrained=True) 预训练的 Swin Transformer,可用于分类和检测任务。

5. 结论

  • CNN 仍是计算机视觉任务的主流,但局限性明显(局部感受野、计算成本高)
  • 结合注意力机制(SE、CBAM)可有效提升 CNN 的特征建模能力
  • 混合架构(CNN+Transformer)是未来趋势,提高长距离依赖建模能力

🚀 希望本教程对你有所帮助!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐