U-Net: Convolutional Networks for Biomedical Images Segmentation
paper:U-Net: Convolutional Networks for Biomedical Image Segmentation以MMSegmentation中unet的实现为例,假设batch_size=4,输入shape为(4, 3, 480, 480)。
·
paper: U-Net: Convolutional Networks for Biomedical Image Segmentation
创新点
- 提出了U型encoder-decoder的网络结构,通过skip-connection操作更好的融合浅层的位置信息和深层的语义信息。U-Net借鉴FCN采用全卷积的结构,相比于FCN一个重要的改变是在上采样部分也有大量的特征通道,这允许网络将上下文信息传播到更高分辨率的层。
- 医疗图像分割的任务,训练数据非常少,作者通过应用弹性形变做了大量的数据增强。
- 提出使用加权损失。
一些需要注意的实现细节
- 原论文实现中没有使用padding,因此输出feature map的分辨率逐渐减小,在下面介绍的mmsegmentation的实现中采用了padding,因此当stride=1时输出特征图的分辨率不变。
- FCN中skip-connection融合浅层信息与深层信息是通过add的方式,而U-Net中是通过concatenate的方式.
实现细节解析
以MMSegmentation中unet的实现为例,假设batch_size=4,输入shape为(4, 3, 480, 480)。
Backbone
- encode阶段共5个stage,每个stage中有一个ConvBlock,ConvBlock由2个Conv-BN-Relu组成。除了第1个stage,后4个stage在ConvBlock前都有1个2x2-s2的maxpool。每个stage的第1个conv的输出通道x2。因此encode阶段每个stage的输出shape分别为(4, 64, 480, 480)、(4, 128, 240, 240)、(4, 256, 120, 120)、(4, 512, 60, 60)、(4, 1024, 30, 30)。
- decode阶段共4个stage,和encode后4个降采样的stage对应。每个stage分为upsample、concatenate、conv三个步骤。upsample由一个scale_factor=2的bilinear插值和1个Conv-BN-Relu组成,其中的conv是1x1-s1通道数减半的卷积。第二步concatenate将upsample的输出与encode阶段分辨率大小相同的输出沿通道方向拼接到一起。第三步是一个ConvBlock,和encode阶段一样,这里的ConvBlock也由两个Conv-BN-Relu组成,因为upsample后通道数减半,但和encode对应输出拼接后通道数又还原回去了,这里的ConvBlock中的第一个conv再将输出通道数减半。因此decode阶段每个stage的输出shape分别为(4, 1024, 30, 30)、(4, 512, 60, 60)、(4, 256, 120, 120)、(4, 128 , 240, 240)、(4, 64, 480, 480)。注意decode共4个stage,因此实际的输出是后4个,第一个输出就是encode最后一个stage的输出。
FCN Head
- backbone中decode阶段的最后一个stage的输出(4, 64, 480, 480)作为head的输入。首先经过一个3x3-s1的conv-bn-relu,通道数不变。然后经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
Loss
- loss采用cross-entropy loss
Auxiliary Head
- backbone中decode阶段的倒数第二个stage的输出(4, 128, 240, 240)作为auxiliary head的输入。经过一个3x3-s1的conv-bn-relu,输出通道数减半为64。经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
- 辅助分支的Loss也是cross-entropy loss,注意这个分支的最终输出分辨率为原始gt的一半,因此在计算loss时需要先通过双线性插值上采样。
模型的完整结构
EncoderDecoder(
(backbone): UNet(
(encoder): ModuleList(
(0): Sequential(
(0): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(4): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
(decoder): ModuleList(
(0): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(1): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(2): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
(3): UpConvBlock(
(conv_block): BasicConvBlock(
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
(1): ConvModule(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
(upsample): InterpConv(
(interp_upsample): Sequential(
(0): Upsample()
(1): ConvModule(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
)
)
)
init_cfg=[{'type': 'Kaiming', 'layer': 'Conv2d'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]
(decode_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
(auxiliary_head): FCNHead(
input_transform=None, ignore_index=255, align_corners=False
(loss_decode): CrossEntropyLoss(avg_non_ignore=False)
(conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
(dropout): Dropout2d(p=0.1, inplace=False)
(convs): Sequential(
(0): ConvModule(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activate): ReLU(inplace=True)
)
)
)
init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
)
更多推荐
所有评论(0)