❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦


🚀 快速阅读

  1. 量化压缩:将扩散模型的权重和激活值量化到4位,减少模型大小和内存占用。
  2. 加速推理:通过量化减少计算复杂度,提高模型在GPU上的推理速度。
  3. 低秩分支:引入低秩分支处理量化中的异常值,减少量化误差,提升图像质量。

正文(附运行示例)

SVDQuant 是什么

公众号: 蚝油菜花 - nunchaku

SVDQuant是由MIT研究团队推出的后训练量化技术,专门针对扩散模型进行优化。该技术通过将模型的权重和激活值量化至4位,显著减少了内存占用,并加速了推理过程。SVDQuant引入了一个高精度的低秩分支,用于吸收量化过程中的异常值,从而在保持图像质量的同时,实现了在16GB 4090 GPU上3.5倍的显存优化和8.7倍的延迟减少。

SVDQuant支持DiT和UNet架构,并能无缝集成现成的低秩适配器(LoRAs),无需重新量化。这为在资源受限的设备上部署大型扩散模型提供了有效的解决方案。

SVDQuant 的主要功能

  • 量化压缩:将扩散模型的权重和激活值量化到4位,减少模型大小,降低内存占用。
  • 加速推理:量化减少计算复杂度,提高模型在GPU上的推理速度。
  • 低秩分支吸收异常值:引入低秩分支处理量化中的异常值,减少量化误差。
  • 内核融合:设计推理引擎Nunchaku,基于内核融合减少内存访问,进一步提升推理效率。
  • 支持多种架构:兼容DiT和UNet架构的扩散模型。
  • LoRA集成:无缝集成低秩适配器(LoRAs),无需重新量化。

SVDQuant 的技术原理

  • 量化处理:对模型的权重和激活值进行4位量化,对保持模型性能构成挑战。
  • 异常值处理:用平滑技术将激活值中的异常值转移到权重上,基于SVD分解权重,将权重分解为低秩分量和残差。
  • 低秩分支:引入16位精度的低秩分支处理权重中的异常值,将残差量化到4位,降低量化难度。
  • Eckart-Young-Mirsky定理:移除权重中的主导奇异值,大幅减小权重的幅度和异常值。
  • 推理引擎Nunchaku:设计推理引擎,基于融合低秩分支和低比特分支的内核,减少内存访问和内核调用次数,降低延迟。

如何运行 SVDQuant

安装依赖

首先,创建并激活一个conda环境,然后安装所需的依赖包:

conda create -n nunchaku python=3.11
conda activate nunchaku
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf
pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil

安装 nunchaku

确保你已经安装了gcc/g++>=11。如果没有,可以通过Conda安装:

conda install -c conda-forge gxx=11 gcc=11

然后从源码构建并安装nunchaku包:

git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku
git submodule init
git submodule update
pip install -e .

使用示例

example.py中,提供了一个运行INT4 FLUX.1-schnell模型的最小脚本:

import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel

transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")

资源


❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐