【CVPR 2025】本文参考论文HVI: A New Color Space for Low-light Image Enhancement
论文地址:arxiv
源码地址:github
colab调试地址:稍后更新…


本人合集:

  • 2024ECCV 低光增强 Retinexformer(论文+代码讲解)
  • 2025CVPR 低光增强 HVI-CIDnet(论文+代码讲解)
  • 未完待续(争取一天一篇)

Part2:arch核心代码解读

我会按照以下顺序进行:

  1. HVI_transform.py:颜色空间转换,这是CID-Net处理图像的基础。
  2. transformer_utils.py:包含一些通用的工具模块,如层归一化和下采样/上采样模块。
  3. LCA.py:Lighten Cross-Attention (LCA) 模块的核心实现,包括CAB, IEL, 以及 HV_LCA 和 I_LCA。
  4. CIDNet.py:整体网络架构,如何将以上模块组织起来。

1. HVI_transform.py - 颜色空间转换

在这里插入图片描述

  • 用途:这个文件定义了RGB_HVI类,负责在RGB颜色空间和论文中提出的HVI(Hue, Custom Saturation-like components H and V, Intensity)颜色空间之间进行转换。HVI空间将强度(I)与颜色相关的分量(H, V)分离开,分别送入网络的I-branch和HV-branch。

  • 代码 (HVI_transform.py):

    import torch
    import torch.nn as nn
    
    pi = 3.141592653589793
    
    class RGB_HVI(nn.Module):
        def __init__(self):
            super(RGB_HVI, self).__init__()
            # k is reciprocal to the paper mentioned (论文中可能是1/k的形式)
            self.density_k = torch.nn.Parameter(torch.full([1],0.2))
            self.gated = False # 似乎是用于实验或控制某些行为的门控,默认关闭
            self.gated2= False
            self.alpha = 1.0 # 似乎是用于实验或控制某些行为的参数,默认值1.0
            self.alpha_s = 1.3
            self.this_k = 0 # 用于在forward过程中保存density_k的值
    
        def HVIT(self, img): # RGB to HVI transform
            eps = 1e-8 # 防止除零
            device = img.device
            dtypes = img.dtype
    
            # --- 标准HSV计算中的 H (Hue) 和 V (Value/Intensity) ---
            hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
            value = img.max(1)[0].to(dtypes) # V (Intensity) = max(R,G,B)
            img_min = img.min(1)[0].to(dtypes)
    
            # 计算 Hue (色调)
            # 这部分是标准的RGB转HSV中计算H的过程
            hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
            hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
            hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
    
            hue[img.min(1)[0]==value] = 0.0 # 如果max=min, 说明是灰度色, hue为0
            hue = hue/6.0 # 将hue归一化到[0,1)
    
            # --- 标准HSV计算中的 S (Saturation) ---
            saturation = (value - img_min ) / (value + eps ) # S = (V - min) / V
            saturation[value==0] = 0 # 如果V=0, 则S=0
    
            # 增加通道维度
            hue = hue.unsqueeze(1)
            saturation = saturation.unsqueeze(1)
            value = value.unsqueeze(1) # 这个value就是论文中的 I (Intensity)
    
            # --- HVI 空间中 H 和 V 分量的计算 ---
            k = self.density_k # 可学习参数k
            self.this_k = k.item() # 保存当前k值,用于反向转换
    
            # color_sensitive: 根据亮度(value)调整的颜色敏感度因子
            # 这个因子的设计引入了亮度对色度表示的影响
            color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
    
            # 将Hue转换为笛卡尔坐标系下的表示 (cos(angle), sin(angle))
            ch = (2.0 * pi * hue).cos() # hue已经归一化到[0,1), 2*pi*hue 得到角度
            cv = (2.0 * pi * hue).sin()
    
            # 计算HVI空间的H分量
            H = color_sensitive * saturation * ch
            # 计算HVI空间的V分量 (注意这里V不是HSV的Value,而是HVI的第二个色度分量)
            V = color_sensitive * saturation * cv
            I = value # HVI空间的I分量直接使用HSV的Value
    
            # 拼接H, V, I 三个通道
            xyz = torch.cat([H, V, I],dim=1)
            return xyz
    
        def PHVIT(self, img): # HVI to RGB transform (P for Pseudo-inverse)
            eps = 1e-8
            H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:] # 分离H, V, I通道
    
            # 将H, V, I的值限制在合理范围内
            H = torch.clamp(H,-1,1)
            V = torch.clamp(V,-1,1)
            I = torch.clamp(I,0,1)
    
            v = I # 即HSV中的Value
            k = self.this_k # 使用正向传播时保存的k值
            # 重新计算颜色敏感度因子
            color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
    
            # 从HVI的H,V分量反解出原始的饱和度加权色度分量
            # 这是 HVIT 中 H = color_sensitive * saturation * ch 的逆运算的一部分
            H_orig = (H) / (color_sensitive + eps)
            V_orig = (V) / (color_sensitive + eps)
            # 再次clamp,因为除法可能导致超出范围
            H_orig = torch.clamp(H_orig,-1,1)
            V_orig = torch.clamp(V_orig,-1,1)
    
            # 从反解的H_orig, V_orig 计算原始的hue (h) 和 saturation (s)
            h = torch.atan2(V_orig + eps, H_orig + eps) / (2*pi) # 计算角度,得到hue
            h = h%1 # 归一化到[0,1)
            s = torch.sqrt(H_orig**2 + V_orig**2 + eps) # 计算模长,得到saturation
    
            if self.gated: # 实验性门控
                s = s * self.alpha_s
    
            s = torch.clamp(s,0,1) # saturation 限制在[0,1]
            v = torch.clamp(v,0,1) # value 限制在[0,1]
    
            # --- 标准 HSV to RGB 转换逻辑 ---
            r = torch.zeros_like(h)
            g = torch.zeros_like(h)
            b = torch.zeros_like(h)
    
            hi = torch.floor(h * 6.0) # hue 所在的扇区 (0-5)
            f = h * 6.0 - hi         # 扇区内的偏移量
            p = v * (1. - s)
            q = v * (1. - (f * s))
            t = v * (1. - ((1. - f) * s))
    
            # 根据扇区索引 hi 将 p,q,t,v 赋值给 r,g,b
            hi0 = hi==0
            hi1 = hi==1
            hi2 = hi==2
            hi3 = hi==3
            hi4 = hi==4
            hi5 = hi==5
    
            r[hi0] = v[hi0]; g[hi0] = t[hi0]; b[hi0] = p[hi0]
            r[hi1] = q[hi1]; g[hi1] = v[hi1]; b[hi1] = p[hi1]
            r[hi2] = p[hi2]; g[hi2] = v[hi2]; b[hi2] = t[hi2]
            r[hi3] = p[hi3]; g[hi3] = q[hi3]; b[hi3] = v[hi3]
            r[hi4] = t[hi4]; g[hi4] = p[hi4]; b[hi4] = v[hi4]
            r[hi5] = v[hi5]; g[hi5] = p[hi5]; b[hi5] = q[hi5]
    
            r = r.unsqueeze(1)
            g = g.unsqueeze(1)
            b = b.unsqueeze(1)
            rgb = torch.cat([r, g, b], dim=1) # 拼接RGB通道
    
            if self.gated2: # 实验性门控
                rgb = rgb * self.alpha
            return rgb
    
  • 中文解释:

    • RGB_HVI 类实现了RGB和HVI颜色空间之间的相互转换。
    • __init__:初始化一个可学习的参数 density_k,这个参数会影响颜色分量H和V的计算。还包含一些门控变量和参数,可能用于实验。
    • HVIT(self, img) (RGB 到 HVI 转换):
      1. 首先,从输入的RGB图像 img 计算标准的色调(Hue)、饱和度(Saturation)和明度(Value)。这里的Value被直接用作HVI空间中的I (Intensity)分量。
      2. 然后,引入可学习参数 k 和一个基于明度 valuecolor_sensitive(颜色敏感度)因子。
      3. HVI空间的H和V分量是通过将原始色调Hue转换到笛卡尔坐标(ch, cv),然后乘以饱和度saturation和颜色敏感度因子color_sensitive得到的。这样做使得H和V分量不仅编码了颜色信息,还间接编码了与亮度相关的颜色感知特性。
      4. 最后,将计算得到的H, V, I三个分量在通道维度上拼接起来作为输出。这三个通道将分别或组合后送入网络的不同分支。
    • PHVIT(self, img) (HVI 到 RGB 转换):
      1. 接收H, V, I三个分量作为输入。
      2. 使用与HVIT中相同的k值(通过self.this_k获取)和I分量(即HSV的Value)来重新计算color_sensitive因子。
      3. 通过color_sensitive因子从HVI的H和V分量中反解出原始的、代表色度和饱和度的笛卡尔分量 (H_orig, V_orig)。
      4. H_origV_orig通过 atan2 和计算模长来恢复标准的色调(h)和饱和度(s)。
      5. 最后,使用标准的HSV到RGB转换算法,将恢复的h, s以及输入的I (作为v)转换回RGB图像。
  • 与论文对应: 这是网络处理的第一步和最后一步。论文中提到的HV-branch处理的是这里的H和V分量(或其组合),I-branch处理的是I分量。这种分解允许网络以不同的方式处理颜色/结构信息和亮度信息。


2. transformer_utils.py - 通用工具模块
在这里插入图片描述

  • 用途: 提供一些在Transformer类网络结构中常用的模块,如层归一化(LayerNorm)和带有归一化的下采样/上采样模块。

  • 代码 (transformer_utils.py):

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class LayerNorm(nn.Module):
        r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
        The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
        shape (batch_size, height, width, channels) while channels_first corresponds to inputs
        with shape (batch_size, channels, height, width).
        """
        def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
            super().__init__()
            self.weight = nn.Parameter(torch.ones(normalized_shape)) # 可学习的缩放因子 gamma
            self.bias = nn.Parameter(torch.zeros(normalized_shape))  # 可学习的平移因子 beta
            self.eps = eps # 防止除零
            self.data_format = data_format
            if self.data_format not in ["channels_last", "channels_first"]:
                raise NotImplementedError
            self.normalized_shape = (normalized_shape, )
    
        def forward(self, x):
            if self.data_format == "channels_last":
                # (batch_size, height, width, channels)
                # 这种格式下,normalized_shape 通常是 (channels,)
                # F.layer_norm 会在最后一个维度上进行归一化
                return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
            elif self.data_format == "channels_first":
                # (batch_size, channels, height, width)
                # 这种格式下,normalized_shape 通常是 (channels,)
                # 需要手动在 channel 维度上计算均值和方差进行归一化
                u = x.mean(1, keepdim=True) # 沿通道维度计算均值
                s = (x - u).pow(2).mean(1, keepdim=True) # 沿通道维度计算方差
                x = (x - u) / torch.sqrt(s + self.eps) # 归一化
                # 应用可学习的缩放和平移,注意weight和bias的维度需要匹配
                # self.weight (channels,) -> (channels, 1, 1)
                x = self.weight[:, None, None] * x + self.bias[:, None, None]
                return x
    
    class NormDownsample(nn.Module): # 带可选归一化的下采样模块
        def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
            super(NormDownsample, self).__init__()
            self.use_norm=use_norm
            if self.use_norm:
                self.norm=LayerNorm(out_ch) # 如果使用归一化,则初始化LayerNorm
            self.prelu = nn.PReLU() # PReLU激活函数
            self.down = nn.Sequential(
                # 卷积核大小为3, 步长为1, padding为1 (保持空间尺寸不变)
                nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
                # 使用双线性插值进行下采样,scale_factor=0.5即尺寸减半
                nn.UpsamplingBilinear2d(scale_factor=scale)
            )
        def forward(self, x):
            x = self.down(x)
            x = self.prelu(x)
            if self.use_norm:
                x = self.norm(x)
            return x
    
    class NormUpsample(nn.Module): # 带可选归一化和跳跃连接的上采样模块
        def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
            super(NormUpsample, self).__init__()
            self.use_norm=use_norm
            if self.use_norm:
                self.norm=LayerNorm(out_ch) # 如果使用归一化,则初始化LayerNorm
            self.prelu = nn.PReLU() # PReLU激活函数
            self.up_scale = nn.Sequential(
                # 卷积核大小为3, 步长为1, padding为1
                nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
                # 使用双线性插值进行上采样,scale_factor=2即尺寸加倍
                nn.UpsamplingBilinear2d(scale_factor=scale)
            )
            # 用于融合上采样特征和跳跃连接特征的1x1卷积
            self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
    
        def forward(self, x,y): # x是来自解码器前一层的特征, y是来自编码器的跳跃连接特征
            x = self.up_scale(x) # 对x进行上采样
            x = torch.cat([x, y],dim=1) # 将上采样后的x与跳跃连接y在通道维度拼接
            x = self.up(x) # 通过1x1卷积融合特征并调整通道数
            x = self.prelu(x)
            if self.use_norm:
                x = self.norm(x)
            return x
    
  • 中文解释:

    • LayerNorm: 实现了层归一化。它支持两种数据格式:channels_first (PyTorch中卷积层默认的 B, C, H, W 格式) 和 channels_last (B, H, W, C 格式,常见于TensorFlow或某些Transformer变体)。对于 channels_first,它会沿着通道维度(dim=1)计算均值和方差来进行归一化,并应用可学习的缩放参数 weight 和平移参数 bias
    • NormDownsample: 一个下采样模块。它首先通过一个 3 × 3 3 \times 3 3×3卷积(不改变空间尺寸,但可能改变通道数),然后使用nn.UpsamplingBilinear2d配合 scale_factor=0.5 来实现空间尺寸减半的下采样。之后是PReLU激活,并可选择是否应用LayerNorm
    • NormUpsample: 一个上采样模块,设计用于U-Net结构的解码器部分。它接收两路输入:x(来自解码器前一层,需要被上采样)和 y(来自编码器对应层级的跳跃连接)。x 首先通过一个 3 × 3 3 \times 3 3×3卷积和nn.UpsamplingBilinear2dscale_factor=2)进行上采样和通道调整。然后,上采样后的 x 与跳跃连接 y 在通道维度上拼接。最后,一个 1 × 1 1 \times 1 1×1卷积用于融合拼接后的特征并将通道数调整为期望的 out_ch。同样,之后是PReLU激活和可选的LayerNorm
  • 与论文对应: 这些是构成CID-Net中U-Net骨架(编码器和解码器路径)的基本组件。NormDownsample用于编码器中的HVE_blockIE_blockNormUpsample用于解码器中的HVD_blockID_blockLayerNorm则用于LCA模块内部以及这些采样模块中,以稳定训练和改善性能,这在Transformer类的结构中很常见。


3. LCA.py - Lighten Cross-Attention (LCA) 模块
在这里插入图片描述

  • 用途: 实现论文中的核心模块LCA,包括其子模块:Cross Attention Block (CAB) 和 Intensity Enhance Layer (IEL)。CDL(Color Denoise Layer)与IEL共享相同的结构。HV_LCAI_LCA 是将这些子模块组合起来用于HV分支和I分支的完整LCA块。

  • 代码 (LCA.py):

    import torch
    import torch.nn as nn
    from einops import rearrange # einops是一个强大的张量操作库,简化reshape, transpose等
    from net.transformer_utils import LayerNorm # 从之前的文件导入LayerNorm
    
    # Cross Attention Block (CAB)
    class CAB(nn.Module):
        def __init__(self, dim, num_heads, bias):
            super(CAB, self).__init__()
            self.num_heads = num_heads # 多头注意力的头数
            # 可学习的温度参数,用于缩放注意力分数,每个头一个
            self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
    
            # --- Q, K, V 的生成网络 ---
            # 论文中提到: "feature embedding convolution layers contains a 1x1 depth-wise convolution and a 3x3 group convolution."
            # 代码实现: 1x1卷积 -> 3x3深度卷积(groups=dim 或 dim*2)
    
            # Query (Q) 生成路径
            self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) # 1x1 卷积
            self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) # 3x3 深度卷积
    
            # Key (K) 和 Value (V) 生成路径 (一起计算然后分割)
            self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) # 1x1 卷积, 输出通道为 dim*2
            self.kv_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) # 3x3 深度卷积
    
            # 输出投影层
            self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) # 1x1 卷积
    
        def forward(self, x, y): # x 通常是当前分支的特征 (用于生成Q), y 是另一分支的特征 (用于生成K,V)
            b, c, h, w = x.shape # batch, channels, height, width
    
            # 生成 Query (Q) from x
            q = self.q_dwconv(self.q(x))
    
            # 生成 Key (K) 和 Value (V) from y
            kv = self.kv_dwconv(self.kv(y))
            k, v = kv.chunk(2, dim=1) # 将通道维度一分为二,得到K和V
    
            # --- 多头注意力机制 ---
            # 使用 einops.rearrange 进行维度重排以支持多头
            # 'b (head c_per_head) h w -> b head c_per_head (h w)'
            #  c = head * c_per_head (每个头的通道数是 c // num_heads)
            q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
            k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
            v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
    
            # 归一化 Q 和 K (在最后一个维度上,即特征维度)
            q = torch.nn.functional.normalize(q, dim=-1)
            k = torch.nn.functional.normalize(k, dim=-1)
    
            # 计算注意力分数: (Q * K^T) / sqrt(d_k) * temperature
            # 这里直接用 Q @ K.transpose(-2, -1) (点积)
            # 然后乘以可学习的 temperature (论文中的 alpha_H 类似)
            attn = (q @ k.transpose(-2, -1)) * self.temperature
            attn = nn.functional.softmax(attn, dim=-1) # 对注意力分数应用Softmax
    
            # 使用注意力分数加权 V
            out = (attn @ v)
    
            # 将输出维度重排回原始图像特征格式
            # 'b head c (h w) -> b (head c) h w'
            out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
    
            out = self.project_out(out) # 最终的1x1卷积投影
            return out
    
    
    # Intensity Enhancement Layer (IEL)
    # 论文中提到 IEL 和 CDL 结构相同
    class IEL(nn.Module):
        def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
            super(IEL, self).__init__()
    
            hidden_features = int(dim * ffn_expansion_factor) # 中间隐藏层的特征维度
    
            # 对应论文图14(1)+(2)的一部分:输入投影和初步分解
            # project_in 将输入维度从 dim 扩展到 hidden_features*2
            self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) # 1x1 卷积
    
            # 对应论文图14(2)的深度卷积分解
            # dwconv 对扩展后的特征进行3x3深度卷积,然后分割成x1, x2
            self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=bias)
    
            # 对应论文图14(3)的∆W, ∆S计算部分(或IEL中的 YI 和 YR 的增强部分)
            # dwconv1 和 dwconv2 分别处理 x1 和 x2
            self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias) # 3x3 深度卷积
            self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias) # 3x3 深度卷积
    
            # 对应论文图14(4)的输出投影
            self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) # 1x1 卷积
    
            self.Tanh = nn.Tanh() # Tanh激活函数
    
        def forward(self, x):
            x_projected = self.project_in(x) # 1x1卷积扩展维度
            # 3x3深度卷积后,沿通道维度分割为x1, x2
            # 这对应论文IEL的 Y_I = W(I)Y_hat_I 和 Y_R = W(R)Y_hat_I
            # 或CDL的 W (Wavelength) 和 S (Saturation)
            x_dw = self.dwconv(x_projected)
            x1, x2 = x_dw.chunk(2, dim=1)
    
            # 对应论文 Eq.15: (tanh(WsYI) + YI) 和 (tanh(WsYR) + YR)
            # 或 CDL 的 (S + ∆S) 和 (W + ∆W),其中 ∆W = tanh(DWConv3×3(W))
            # 这里的 x1 和 x2 是 dwconv 的输出块,而不是原始输入 x 分解后的 YI/YR。
            # 实际上是:
            #   delta_x1 = self.Tanh(self.dwconv1(x1))
            #   enhanced_x1 = delta_x1 + x1 (原始论文公式是 Ws作用在分解后的YI上,这里是作用在dwconv的输出块上)
            #   然而,代码中直接对 project_in 和 dwconv 后的 x1, x2 进行操作。
            #   更准确地说,这里的 x1, x2 已经经过了 project_in 和 dwconv,
            #   所以 self.dwconv1(x1) 对应于论文中的 WsYI 或 DWConv3x3(W)
            #   而 + x1 则是残差连接。
    
            x1_enhanced = self.Tanh(self.dwconv1(x1)) + x1
            x2_enhanced = self.Tanh(self.dwconv2(x2)) + x2
    
            # 对应论文 Eq.15 的逐元素乘法 ⊙
            # 或 CDL 的 (S + ∆S) ⊙ (W + ∆W)
            x_merged = x1_enhanced * x2_enhanced
    
            x_output = self.project_out(x_merged) # 1x1卷积输出
            return x_output
    
    
    # Lightweight Cross Attention for HV-branch (Color Denoise Layer - CDL path)
    class HV_LCA(nn.Module):
        def __init__(self, dim,num_heads, bias=False):
            super(HV_LCA, self).__init__()
            # gdfn (Gated Denoising Feedforward Network) 使用 IEL 的结构作为 CDL
            self.gdfn = IEL(dim) # 论文中说明 IEL 和 CDL 结构相同
            self.norm = LayerNorm(dim) # 层归一化
            self.ffn = CAB(dim, num_heads, bias) # Cross Attention Block
    
        def forward(self, x, y): # x: HV-branch features, y: I-branch features
            # Cross Attention: x 作为Q的来源,y 作为K,V的来源 (通过norm(x)和norm(y))
            # x = x_current_hv + CAB(norm(x_current_hv), norm(y_current_i))
            x = x + self.ffn(self.norm(x), self.norm(y))
            # CDL (using IEL structure)
            # x = CDL(norm(x_after_cab))
            x = self.gdfn(self.norm(x)) # 注意这里没有残差连接 x = x + self.gdfn(...)
            return x
    
    # Lightweight Cross Attention for I-branch (Intensity Enhancement Layer - IEL path)
    class I_LCA(nn.Module):
        def __init__(self, dim,num_heads, bias=False):
            super(I_LCA, self).__init__()
            self.norm = LayerNorm(dim) # 层归一化
            # gdfn 使用 IEL 的结构作为 IEL
            self.gdfn = IEL(dim)
            self.ffn = CAB(dim, num_heads, bias=bias) # Cross Attention Block
    
        def forward(self, x, y): # x: I-branch features, y: HV-branch features
            # Cross Attention: x 作为Q的来源,y 作为K,V的来源
            # x = x_current_i + CAB(norm(x_current_i), norm(y_current_hv))
            x = x + self.ffn(self.norm(x), self.norm(y))
            # IEL
            # 论文中提到 IEL 的输出会添加残差: x = x_after_cab + IEL(norm(x_after_cab))
            x = x + self.gdfn(self.norm(x)) # 注意这里有残差连接
            return x
    
  • 中文解释:

    • CAB (Cross Attention Block):
      • __init__: 初始化Q, K, V的生成网络。这些网络由一个 1 × 1 1 \times 1 1×1卷积和一个 3 × 3 3 \times 3 3×3深度卷积(groups=dimdim*2,非常接近论文中描述的 1 × 1 1 \times 1 1×1 深度卷积和 3 × 3 3 \times 3 3×3 组卷积的组合思想,深度卷积是组卷积的特例)构成。还初始化了一个可学习的temperature参数用于缩放注意力分数,以及一个输出投影用的 1 × 1 1 \times 1 1×1卷积。
      • forward(self, x, y): x 是当前分支的特征(用于生成Q),y 是另一分支的特征(用于生成K,V)。
        1. 通过各自的网络从x生成Q,从y生成K和V。
        2. 使用einops.rearrange将Q, K, V重排以支持多头注意力。
        3. 对Q和K进行L2归一化。
        4. 计算注意力分数:attn = (Q @ K.transpose) * temperature,然后应用Softmax。这与论文公式(14) S o f t m a x ( Q ⊗ K / α H ) Softmax(Q \otimes K / \alpha_H) Softmax(QK/αH) 的精神一致。
        5. 注意力分数用于加权V:out = attn @ V
        6. 将结果out重排回图像特征的形状,并通过一个 1 × 1 1 \times 1 1×1卷积进行最终投影。
    • IEL (Intensity Enhance Layer):
      • __init__: 包含一个输入投影卷积 (project_in),一个中间的深度卷积 (dwconv),两个并行的深度卷积 (dwconv1, dwconv2) 用于处理分解后的特征,以及一个输出投影卷积 (project_out)。Tanh激活函数也被初始化。
      • forward(self, x):
        1. 输入x首先通过project_in进行通道扩展,然后通过dwconv进行深度卷积。
        2. dwconv的输出在通道维度被chunk成两部分 x1x2。这对应论文中IEL将特征分解为 Y I Y_I YI Y R Y_R YR(或CDL中分解为W和S)的步骤,尽管这里的x1,x2是经过初步卷积变换后的。
        3. x1x2分别应用:enhanced_feature = Tanh(dwconv_path(feature_chunk)) + feature_chunk。这精确匹配论文公式(15) 中 ( tanh ⁡ ( W s Y I ) + Y I ) (\tanh(W_sY_I) + Y_I) (tanh(WsYI)+YI) 的形式,以及CDL中 ( S + Δ S ) (S + \Delta S) (S+ΔS) 的形式(其中 Δ S = tanh ⁡ ( D W C o n v ( S ) ) \Delta S = \tanh(DWConv(S)) ΔS=tanh(DWConv(S)))。dwconv1dwconv2扮演了 W s W_s Ws D W C o n v DWConv DWConv 的角色。
        4. 增强后的x1_enhancedx2_enhanced逐元素相乘( ⊙ \odot ),如论文公式(15)和CDL理论所述。
        5. 最后通过project_out输出。
      • 与论文图14对应:
        • project_indwconv + chunk 对应图14的(1)初步光度分解和(2)分解为两个组分(如Wavelength和Saturation,或Illumination和Reflectance)。
        • Tanh(dwconv1(x1)) + x1 对应图14的(3)寻找 Δ \Delta Δ并与原分量结合。
        • x1_enhanced * x2_enhanced 对应图14的(4)重组(元素乘法)。
        • project_out 对应图14的(4)最终的Point-wise Conv。
    • HV_LCA (用于HV分支的LCA):
      • __init__: 包含一个IEL实例(作为CDL使用,因为论文指出CDL和IEL结构相同),一个LayerNorm,和一个CAB
      • forward(self, x, y): x是HV分支的特征,y是I分支的特征。
        1. 首先进行交叉注意力:x = x + self.ffn(self.norm(x), self.norm(y)),其中ffn是CAB。x从当前HV分支来(生成Q),y从I分支来(生成K,V)。结果通过残差连接加回x
        2. 然后,经过CAB增强的x通过self.norm(x)归一化后,送入self.gdfn(即CDL)。注意:这里的CDL输出直接成为新的x没有像I_LCA那样再进行一次外部的残差连接 x = x + self.gdfn(...)
    • I_LCA (用于I分支的LCA):
      • __init__: 类似HV_LCA,但IEL实例在这里就作为IEL使用。
      • forward(self, x, y): x是I分支的特征,y是HV分支的特征。
        1. 交叉注意力步骤同HV_LCAx = x + self.ffn(self.norm(x), self.norm(y))
        2. 然后,经过CAB增强的x通过self.norm(x)归一化后,送入self.gdfn(即IEL)。注意:根据论文对IEL的描述(“the output of IEL adds the residuals”),这里IEL的输出通过残差连接加回xx = x + self.gdfn(self.norm(x))。这与HV_LCA中的CDL处理方式不同。
  • 与论文对应:

    • CAB的实现紧密遵循论文中交叉注意力的描述,Q来自一个分支,K,V来自另一个。卷积层的选择(1x1 + 3x3深度卷积)是论文中“特征嵌入卷积层”的具体实现。
    • IEL的结构和计算流程与论文公式(15)以及图14的步骤高度吻合。
    • HV_LCAI_LCA将CAB和IEL/CDL组合起来,形成了论文图13所示的LCA模块。它们之间的主要区别在于gdfn(IEL/CDL部分)输出后的残差连接方式,这与论文中对IEL特别提到的残差连接相符。

4. CIDNet.py - 整体网络架构
在这里插入图片描述

  • 用途: 定义了CIDNet的整体U-Net形状的架构,并将之前定义的RGB_HVI转换和LCA模块(HV_LCA, I_LCA)以及下采样/上采样模块组织起来。

  • 代码 (CIDNet.py):

    import torch
    import torch.nn as nn
    from net.HVI_transform import RGB_HVI
    from net.transformer_utils import NormDownsample, NormUpsample # 只导入需要的
    from net.LCA import HV_LCA, I_LCA # 只导入需要的
    # from huggingface_hub import PyTorchModelHubMixin # 如果要用huggingface hub则取消注释
    
    class CIDNet(nn.Module): # 如果不用huggingface_hub, 可以去掉 PyTorchModelHubMixin
        def __init__(self,
                     channels=[36, 36, 72, 144], # 不同阶段的通道数 U-Net的宽度
                     heads=[1, 2, 4, 8],       # 不同阶段CAB的头数
                     norm=False # 是否在NormDownsample/NormUpsample中使用LayerNorm
            ):
            super(CIDNet, self).__init__()
    
            [ch1, ch2, ch3, ch4] = channels # ch1是最浅层的通道数, ch4是最深层的
            [head1, head2, head3, head4] = heads # head1通常对应ch2的LCA (因为LCA在下采样后)
    
            # --- HV-branch (处理H,V分量,即颜色/结构信息) ---
            # Encoder (HVE: HV Encoder)
            self.HVE_block0 = nn.Sequential( # 初始卷积块
                nn.ReplicationPad2d(1), # 边缘复制填充,避免卷积后的尺寸缩小
                nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False) # 输入3通道(H,V,I from HVI) 输出ch1
            )
            self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm) # ch1 -> ch2, 尺寸减半
            self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm) # ch2 -> ch3, 尺寸减半
            self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm) # ch3 -> ch4, 尺寸减半
    
            # Decoder (HVD: HV Decoder)
            self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)   # ch4 -> ch3, 尺寸加倍
            self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)   # ch3 -> ch2, 尺寸加倍
            self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)   # ch2 -> ch1, 尺寸加倍
            self.HVD_block0 = nn.Sequential( # 最终输出卷积块
                nn.ReplicationPad2d(1),
                nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False) # 输出2通道 (H', V')
            )
    
    
            # --- I-branch (处理I分量,即亮度信息) ---
            # Encoder (IE: Intensity Encoder)
            self.IE_block0 = nn.Sequential( # 初始卷积块
                nn.ReplicationPad2d(1),
                nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False), # 输入1通道 (I from HVI) 输出ch1
            )
            self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
            self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
            self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
    
            # Decoder (ID: Intensity Decoder)
            self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
            self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
            self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
            self.ID_block0 =  nn.Sequential( # 最终输出卷积块
                nn.ReplicationPad2d(1),
                nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False), # 输出1通道 (I')
            )
    
            # --- LCA 模块实例化 ---
            # 注意LCA模块用在下采样/上采样模块之间或之后
            # heads参数列表长度应与LCA模块数量匹配,这里有6对LCA
            # heads=[head_LCA1, head_LCA2, head_LCA3_enc, head_LCA3_bot, head_LCA2_dec, head_LCA1_dec]
            # 论文中 channels=[36,36,72,144], heads=[1,2,4,8]
            # LCA1 应用于 ch2 (36通道), head 应该是 head1 (论文中是1,这里代码heads[0]是1,但LCA1对应ch2,所以是heads[0]或heads[1])
            # 实际上代码中LCA1用ch2和head2, LCA2用ch3和head3, LCA3/4用ch4和head4
    
            # 如果 heads = [h_ch2, h_ch3, h_ch4_enc, h_ch4_bottleneck, h_ch3_dec, h_ch2_dec]
            # ch1=channels[0], ch2=channels[1], ch3=channels[2], ch4=channels[3]
            # head1=heads[0],  head2=heads[1],  head3=heads[2],   head4=heads[3]
    
            # LCA for features with ch2 channels (after first downsample)
            self.HV_LCA1 = HV_LCA(ch2, heads[1]) # channels[1]用heads[1]
            self.I_LCA1  = I_LCA(ch2, heads[1])
    
            # LCA for features with ch3 channels (after second downsample)
            self.HV_LCA2 = HV_LCA(ch3, heads[2]) # channels[2]用heads[2]
            self.I_LCA2  = I_LCA(ch3, heads[2])
    
            # LCA for features with ch4 channels (encoder side, before bottleneck)
            self.HV_LCA3 = HV_LCA(ch4, heads[3]) # channels[3]用heads[3]
            self.I_LCA3  = I_LCA(ch4, heads[3])
    
            # LCA for features with ch4 channels (bottleneck/decoder side)
            self.HV_LCA4 = HV_LCA(ch4, heads[3]) # channels[3]用heads[3]
            self.I_LCA4  = I_LCA(ch4, heads[3])
    
            # LCA for features with ch3 channels (decoder side)
            self.HV_LCA5 = HV_LCA(ch3, heads[2]) # channels[2]用heads[2]
            self.I_LCA5  = I_LCA(ch3, heads[2])
    
            # LCA for features with ch2 channels (decoder side)
            self.HV_LCA6 = HV_LCA(ch2, heads[1]) # channels[1]用heads[1]
            self.I_LCA6  = I_LCA(ch2, heads[1])
    
            self.trans = RGB_HVI() # HVI转换模块实例
    
        def forward(self, x): # x是输入的RGB图像
            dtypes = x.dtype
            hvi_original = self.trans.HVIT(x) # 步骤1: RGB -> HVI
            
            # 分离 I 通道 和 HV 通道
            # I-branch input (1 channel)
            i_branch_input = hvi_original[:,2,:,:].unsqueeze(1).to(dtypes)
            # HV-branch input (HVE_block0 接收原始的3通道HVI)
            hv_branch_input = hvi_original # (H, V, I)
    
            # --- Encoder Path ---
            # Level 0 (Initial Convolution)
            i_enc0 = self.IE_block0(i_branch_input) # (B, ch1, H, W)
            hv_enc0 = self.HVE_block0(hv_branch_input) # (B, ch1, H, W)
            # Skip connections for U-Net
            i_jump0 = i_enc0
            hv_jump0 = hv_enc0
    
            # Level 1 (Downsample + LCA)
            i_enc1 = self.IE_block1(i_enc0)   # (B, ch2, H/2, W/2)
            hv_enc1 = self.HVE_block1(hv_enc0) # (B, ch2, H/2, W/2)
    
            i_processed_lca1 = self.I_LCA1(i_enc1, hv_enc1)   # I-branch uses HV features
            hv_processed_lca1 = self.HV_LCA1(hv_enc1, i_enc1) # HV-branch uses I features
            # Skip connections
            i_jump1 = i_processed_lca1
            hv_jump1 = hv_processed_lca1
    
            # Level 2 (Downsample + LCA)
            i_enc2 = self.IE_block2(i_processed_lca1)   # (B, ch3, H/4, W/4)
            hv_enc2 = self.HVE_block2(hv_processed_lca1) # (B, ch3, H/4, W/4)
    
            i_processed_lca2 = self.I_LCA2(i_enc2, hv_enc2)
            hv_processed_lca2 = self.HV_LCA2(hv_enc2, i_enc2)
            # Skip connections
            i_jump2 = i_processed_lca2
            hv_jump2 = hv_processed_lca2
    
            # Level 3 (Downsample + LCA) - Deepest encoder part before bottleneck
            i_enc3 = self.IE_block3(i_processed_lca2)   # (B, ch4, H/8, W/8)
            hv_enc3 = self.HVE_block3(hv_processed_lca2) # (B, ch4, H/8, W/8)
    
            i_processed_lca3 = self.I_LCA3(i_enc3, hv_enc3)
            hv_processed_lca3 = self.HV_LCA3(hv_enc3, i_enc3)
    
            # --- Bottleneck LCA ---
            # (Operating on ch4 features)
            i_bottleneck = self.I_LCA4(i_processed_lca3, hv_processed_lca3)
            hv_bottleneck = self.HV_LCA4(hv_processed_lca3, i_processed_lca3) # Note: hv_4 in paper, using hv_bottleneck for clarity
    
            # --- Decoder Path ---
            # Level 3 (Upsample + LCA)
            # Upsample using NormUpsample, which takes (x_to_upsample, skip_connection_feature)
            hv_dec3_upsampled = self.HVD_block3(hv_bottleneck, hv_jump2) # (B, ch3, H/4, W/4)
            i_dec3_upsampled = self.ID_block3(i_bottleneck, i_jump2)    # (B, ch3, H/4, W/4)
    
            i_processed_lca5 = self.I_LCA5(i_dec3_upsampled, hv_dec3_upsampled)
            hv_processed_lca5 = self.HV_LCA5(hv_dec3_upsampled, i_dec3_upsampled)
    
            # Level 2 (Upsample + LCA)
            hv_dec2_upsampled = self.HVD_block2(hv_processed_lca5, hv_jump1) # (B, ch2, H/2, W/2)
            # 在原代码中 i_dec3_upsampled (应该是i_processed_lca5)被直接送入ID_block2,这可能是一个小错误,
            # 通常应该是LCA处理后的结果送入下一级。这里遵循原代码逻辑。
            # i_dec2_upsampled = self.ID_block2(i_dec3_upsampled, i_jump1) # This seems like a bug, should be i_processed_lca5
            i_dec2_upsampled = self.ID_block2(i_processed_lca5, i_jump1)    # Corrected based on typical U-Net flow and var names
    
            i_processed_lca6 = self.I_LCA6(i_dec2_upsampled, hv_dec2_upsampled)
            hv_processed_lca6 = self.HV_LCA6(hv_dec2_upsampled, i_dec2_upsampled)
    
            # Level 1 (Upsample) - Final upsampling before output convolution
            i_dec1 = self.ID_block1(i_processed_lca6, i_jump0) # (B, ch1, H, W)
            hv_dec1 = self.HVD_block1(hv_processed_lca6, hv_jump0) # (B, ch1, H, W)
    
            # Level 0 (Output Convolution)
            i_out = self.ID_block0(i_dec1)   # (B, 1, H, W) - Enhanced I'
            hv_out = self.HVD_block0(hv_dec1) # (B, 2, H, W) - Enhanced H', V'
    
            # Combine H', V', I' and add global residual from original HVI
            output_hvi_enhanced = torch.cat([hv_out, i_out], dim=1) # (B, 3, H, W)
            output_hvi_final = output_hvi_enhanced + hvi_original # Global residual connection
    
            # Convert enhanced HVI back to RGB
            output_rgb = self.trans.PHVIT(output_hvi_final) # 步骤 N: HVI -> RGB
    
            return output_rgb
    
        # Helper function if needed outside forward pass for just HVI conversion
        def HVIT(self,x):
            hvi = self.trans.HVIT(x)
            return hvi
    
  • 中文解释:

    • __init__(self, ...):
      • 初始化函数接收通道数列表 channels 和注意力头数列表 heads 作为参数。
      • HV-branch (HV通道)I-branch (I通道) 都构建了一个U-Net的编码器-解码器结构。
        • 编码器路径包含初始卷积块 (HVE_block0, IE_block0) 和一系列 NormDownsample 模块 (HVE_block1/2/3, IE_block1/2/3) 进行特征提取和空间下采样。
        • 解码器路径包含一系列 NormUpsample 模块 (HVD_block1/2/3, ID_block1/2/3) 进行特征恢复和空间上采样,并融合来自编码器的跳跃连接。最后是输出卷积块 (HVD_block0, ID_block0) 将特征映射回所需的输出通道数(HV为2,I为1)。
      • LCA模块实例化: 在U-Net的编码器、解码器以及瓶颈位置,为HV分支和I分支分别实例化了多对LCA模块 (HV_LCA1HV_LCA6, I_LCA1I_LCA6)。这些LCA模块在不同尺度上执行交叉注意力和特征增强/去噪。参数channels[i]heads[i] 用于配置对应层级LCA模块的维度和头数。
      • 实例化 RGB_HVI 转换模块。
    • forward(self, x):
      1. 输入转换: 输入的RGB图像 x 首先通过 self.trans.HVIT(x) 转换为HVI颜色空间的 hvi_original
      2. 分支输入分离: 从 hvi_original 中分离出I通道 (i_branch_input) 和HV通道 (hv_branch_input,这里代码显示 HVE_block0 接收的是完整的3通道HVI,而 IE_block0 接收的是分离后的单通道I。这与论文中严格分离处理H,V和I的描述略有不同,HV分支的初始卷积也看到了I分量)。
      3. 编码器路径:
        • i_branch_inputhv_branch_input 分别通过各自的初始卷积块 (IE_block0, HVE_block0)。结果 (i_enc0, hv_enc0) 被保存为跳跃连接 (i_jump0, hv_jump0)。
        • 然后是一个重复的模式:下采样 -> LCA处理。
          • 例如,i_enc0hv_enc0 分别经过 IE_block1HVE_block1 下采样得到 i_enc1hv_enc1
          • i_enc1hv_enc1 被送入 I_LCA1HV_LCA1 进行交叉注意和增强,得到 i_processed_lca1hv_processed_lca1。这些处理后的特征也可能被保存为下一级的跳跃连接 (i_jump1, hv_jump1)。
        • 这个下采样+LCA的模式在编码器中会重复多次 (LCA1, LCA2, LCA3),直到最深的特征层。
      4. 瓶颈LCA: 在U-Net的最深层(瓶颈处),编码器输出的特征 (i_processed_lca3, hv_processed_lca3) 会再经过一对LCA模块 (I_LCA4, HV_LCA4) 进行处理,得到 i_bottleneckhv_bottleneck
      5. 解码器路径:
        • 这是一个与编码器路径对称的重复模式:上采样(并融合跳跃连接)-> LCA处理。
          • 例如,hv_bottlenecki_bottleneck 分别与来自编码器的跳跃连接 hv_jump2i_jump2 一起送入 HVD_block3ID_block3 进行上采样,得到 hv_dec3_upsampledi_dec3_upsampled
          • 然后,这些上采样后的特征送入 HV_LCA5I_LCA5 进行LCA处理。
        • 这个上采样+LCA的模式在解码器中重复多次 (LCA5, LCA6)。
      6. 输出合并与转换:
        • 解码器最终输出的I分支特征 (i_dec1) 和HV分支特征 (hv_dec1) 分别经过它们各自的最终输出卷积 (ID_block0, HVD_block0),得到增强后的单通道I’ (i_out) 和双通道H’V’ (hv_out)。
        • hv_outi_out 在通道维度上拼接成增强后的HVI图像 output_hvi_enhanced
        • 全局残差连接: 增强后的HVI与网络输入端的原始HVI (hvi_original) 相加,实现全局残差学习。
        • 最后,通过 self.trans.PHVIT() 将最终的HVI图像转换回RGB格式 output_rgb
    • HVIT(self,x): 一个辅助方法,如果需要在外部仅执行RGB到HVI的转换。
  • 与论文对应:

    • 整体架构是一个带有跳跃连接的U-Net。
    • 核心创新在于其双分支设计(I-branch和HV-branch)以及在U-Net的编码器、瓶颈和解码器的多个尺度上密集地应用LCA模块。这使得网络能够在不同抽象层次上持续地交互和优化亮度和颜色信息。
    • HVE_block0 接收3通道HVI作为输入是代码中的一个细节,可能与论文中纯粹分离H/V和I的描述略有不同,但后续的LCA交互严格遵循了分支间的交叉。
    • 最后的全局残差连接 (+ hvi_original) 是一个常见的技巧,让网络学习增强的残差量 ( Δ H , Δ V , Δ I \Delta H, \Delta V, \Delta I ΔH,ΔV,ΔI),通常能使训练更容易。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐