手撕VQVAE(向量量化变分自编码器) – Day2 – VAVAE.py

Vector Quantized Variational AutoEncoder (VQVAE) 模型原理

VQVAE 网络结构图

在这里插入图片描述

在这里插入图片描述

向量量化变分自编码器 网络结构

和AE以及VAE的区别

  • Part1:隐空间(latent)不是直接利用神经网络生成的,也不是是利用神经网络生成均值和方差来采样生成隐空间,而是先利用编码器进行图像编码,得到隐藏间的向量(用于查询),然后到一个向量字典(图中的codebook)中进行查询与其最接近的向量,作为最终的编码向量

    • 例如:原本是直接生成(1,10)的向量,或者是生成10个均值 u u u和10个标准差 σ \sigma σ(满足正态分布),然后通过采样生成10个向量值,但是现在是利用生成的10维向量,到字典中查询到最接近的那个10维向量,作为目标的隐空间编码向量
    • 存在的问题查询,并取最接近的向量的操作是无法进行梯度反向传播的,因此需要利用直通量化操作(Straight-Through Estimator (STE))进行反向传播,来进行编码器的参数更新。
  • Part2直通量化技巧:设编码器编码得到的隐空间向量为encode_z,查询得到的最接近的隐空间向量为z_q,那么利用公式如下,即可完成梯度的反向传播,更新编码器参数。

    • z q = e n c o d e z + ( z q − e n c o d e z ) . d e t a c h ( ) z_q=encode_z+(z_q-encode_z).detach() zq=encodez+(zqencodez).detach()

    • 这里将得到的 z q z_q zq作为编码器的输入,那么这里就将后续解码器的梯度利用这个公式以梯度为1,传递到编码器

  • Part3:损失函数去掉了KL散度(因为没有正态分布的采样了,因此VQVAE其实也失去了生成的能力,主要起编码的作用),这里添加了两个损失,分别为量化损失承诺损失,两者的共同点是都是衡量编码器和查询后得到的向量的距离,差别是更新谁的参数,量化损失更新查询字典的参数承诺损失更新编码器的参数。(个人记法就是,承诺,意味金标准,因此是以查询的那个字典作为标准来更新编码器的参数而量化则意味着图像量化为向量作为标准,那就更新字典的参数

    • VAE(变分自编码器)的损失函数由两部分组成:重构损失KL 散度

      1. 重构损失(Reconstruction Loss)

      重构损失衡量了生成的数据 ( \hat{x} ) 与输入数据 ( x ) 之间的差异,通常使用 均方误差(MSE)交叉熵 来衡量。对于连续数据,常用 MSE,对于离散数据,常用交叉熵。
      L reconstruction = E q ( z ∣ x ) [ ∥ x − x ^ ∥ 2 ] \mathcal{L}_{\text{reconstruction}} = \mathbb{E}_{q(z|x)}[\| x - \hat{x} \|^2] Lreconstruction=Eq(zx)[xx^2]
      这里:

      • ( x ^ \hat{x} x^ = p(x|z) ) 是解码器生成的样本。
      • ( x ) 是原始输入数据。
      • ( ∥ x − x ^ ∥ 2 \| x - \hat{x} \|^2 xx^2 ) 是输入数据与生成数据之间的均方误差。

      2. 向量量化损失(Vector Quantization Loss)- 更新字典参数

      • 向量量化损失用于衡量编码器输出的特征与字典向量之间的差异。其目标是最小化每个编码的特征与最接近的字典向量之间的欧几里得距离。该损失由以下公式表示:
        L v q = ∥ sg ( e ( x ) ) − z e ∥ 2 L_{vq} = \| \text{sg}(e(x)) - z_e \|^2 Lvq=sg(e(x))ze2
        这里:

        • e ( x ) e(x) e(x) 是编码器输出的特征(编码后的表示)
        • z e z_e ze 是量化后的向量
        • ( sg ( ⋅ ) \text{sg}(\cdot) sg() ) 表示 stop gradient 操作,确保梯度不通过量化操作传播。

      3. 承诺损失(Commitment Loss)- 更新编码器参数

      • 承诺损失用于鼓励编码器的输出尽量接近其最近的字典向量。其目标是最小化编码器输出与其最接近字典向量之间的差距。该损失由以下公式表示:
        L c o m m i t = β ∥ e ( x ) − s g ( z e ) ∥ 2 L_{commit} = \beta \| e(x) - sg(z_e) \|^2 Lcommit=βe(x)sg(ze)2
        这里:

        • β \beta β 是一个超参数,通常为一个较小的值,控制承诺损失的权重,
        • e ( x ) e(x) e(x) 是编码器输出的特征,
        • z e z_e ze 是量化后的字典向量。

VQVAE.py代码 - 向量量化变分自编码器模块

VQVAE代码

Part1 库函数

# 该模块主要是为了实现VAE模型的,

'''
# Part1 引入相关的库函数
'''
import torch
from torch import nn
from dataset import Mnist_dataset

Part2 初始化一个简化的残差类 - 作为图像编码和解码的一部分

'''
# Part2 设计残差的类函数
'''
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        tmp = self.relu(x)
        tmp = self.conv1(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        return x + tmp

Part3 VQVAE类

'''
# Part3 设计VQVAE的类函数
'''
class VQVAE(nn.Module):
    def __init__(self, img_channel, encode_f1_channel, latent_size, num_emd):
        super().__init__()
        # VQVAE的编码器,一般是先经过几个卷积,然后最后通过reshape或者view来铺平。
        self.encode = nn.Sequential(
            nn.Conv2d(in_channels=img_channel, out_channels=encode_f1_channel, kernel_size=4, stride=2, padding=1),
            # 卷积层
            nn.ReLU(),  # ReLU激活函数
            nn.Conv2d(in_channels=encode_f1_channel, out_channels=encode_f1_channel, kernel_size=4, stride=2,
                      padding=1),  # 卷积层
            nn.ReLU(),  # ReLU激活函数
            nn.Conv2d(in_channels=encode_f1_channel, out_channels=encode_f1_channel, kernel_size=3, stride=1,
                      padding=1),
            # # 再来两个Resblock保持原图像不用变的
            ResidualBlock(encode_f1_channel),
            ResidualBlock(encode_f1_channel),
        )

        self.emd_dic = nn.Embedding(num_embeddings=num_emd, embedding_dim=latent_size)
        # 进行归一化
        self.emd_dic.weight.data.uniform_(-1.0 / num_emd,
                                          1.0 / num_emd),

        self.decode = nn.Sequential(
            nn.Conv2d(latent_size, encode_f1_channel, kernel_size=3, stride=1, padding=1),  # 卷积层,用于恢复到输入通道数
            ResidualBlock(encode_f1_channel),
            ResidualBlock(encode_f1_channel),
            # nn.ReLU(),
            nn.ConvTranspose2d(encode_f1_channel, encode_f1_channel, kernel_size=4, stride=2, padding=1),  # 反卷积层
            nn.ReLU(),  # ReLU激活函数
            nn.ConvTranspose2d(encode_f1_channel, img_channel, kernel_size=4, stride=2, padding=1),  # 反卷积层
            # # 再来两个Resblock保持原图像不用变的

        )

    def forward(self, x):
        # part1 编码部分,得到编码后的向量
        encode_z = self.encode(x)  # (batch,encode_f1_channel,img_size,img_size) # (b,encode_f1_channel,7,7)

        emd_dic_data = self.emd_dic.weight.data  # (num_emd,latent_size)
        # 获取两者的维度大小,来进行下面的操作
        num_emd, latent_size = emd_dic_data.shape

        # part2 计算距离部分:然后需要利用通道计算和字典向量的距离,先扩展,便于广播计算
        '''
        # 为什么这样扩展
        # 1. 为了能够计算每张图像和每个向量之间的距离,将扩展后每张图像对应的维度和所有向量对应的维度进行对应。也就是图像batch后面的所有维度,需要和整个字典向量对应,因此
        # 2. encode_z应当会比总的字典向量的维度多一个batch维度,但是为了统一,所以需要把字典向量前面添加一个维度。
        # 3. 然后要保证每个图像要和每个向量进行对应,所以num_emd,也应该比一张图像的(C,H,W)前面一个维度,所以图像前面要插入一个维度。最终形成下面的局面,五维向量。
        '''

        encode_z_broadcast = encode_z.unsqueeze(1)  # (b, 1, encode_f1_channel, 7, 7)
        emd_dic_data_broad_cast = emd_dic_data.reshape(1, num_emd, latent_size, 1, 1)  # (1, num_emd, latent_size, 1, 1)

        # 开始计算距离
        dist = torch.sum((encode_z_broadcast - emd_dic_data_broad_cast) ** 2, dim=2)  # 从每张图像对应的维度开始进行距离结果(b, num_emd, 7, 7)

        # 取出num_emd个距离中最小的那个# (batch,img_size*img_size)
        min_dist_index = torch.argmin(dist, dim=1)  # (b, 7, 7)
        finnel_latent = self.emd_dic(min_dist_index).permute(0, 3, 2, 1)  # (b, 7, 7, 32) - > (b, 32, 7, 7)

        # part3 变换得到最终的潜层向量
        # (batch, latent_size, img_size, img_size)

        decoder_input = encode_z + (finnel_latent-encode_z).detach()


        # part4 解码部分
        x_hat = self.decode(decoder_input)  # (batch,channel*img_size*img_size)

        return x_hat, encode_z, finnel_latent  # 后面俩个用于计算损失

Part4 测试

'''
# Part4 开始测试
'''
if __name__ == '__main__':
    img, label = Mnist_dataset[0]
    vqvae = VQVAE(img_channel=1, encode_f1_channel=64, num_emd=512, latent_size=64)
    result, encode_z, finnel_latent = vqvae(img.unsqueeze(0))
    print(result.size())

参考

视频讲解:DALL·E 2(内含扩散模型介绍)【论文精读】_哔哩哔哩_bilibili

模型原理讲解:自学资料 - Dalle2模型 - 文生图技术-CSDN博客

github资料:YanxinTong/VQVAE_Pytorch: 利用 Pytorch 手撕 VQVAE 模型

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐