1、环境配置

conda create -n mambast python=3.9
conda activate mambast
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt

报错:

Building wheels for collected packages: mamba_ssm
  Building wheel for mamba_ssm (setup.py) ... error
  error: subprocess-exited-with-error
  
  × python setup.py bdist_wheel did not run successfully.
  │ exit code: 1
  ╰─> [9 lines of output]
      No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
      
      
      torch.__version__  = 2.3.1+cu121
      
      
      running bdist_wheel
      Guessing wheel URL:  https://github.com/state-spaces/mamba/releases/download/v2.1.0/mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
      error: <urlopen error retrieval incomplete: got only 20971520 out of 323800418 bytes>
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for mamba_ssm
  Running setup.py clean for mamba_ssm
Failed to build mamba_ssm
ERROR: Failed to build installable wheels for some pyproject.toml based projects (mamba_ssm)

 手动点击下面的链接下载:

https://github.com/state-spaces/mamba/releases/download/v2.1.0/mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

2、train

下载 VGG 模型

https://drive.google.com/drive/folders/1pVhJFwk2f3arP7zUDFAe5_PJrPSG1gc2

修改 scripts/train.sh 文件中的数据集路径、修改VGG 模型路径

修改 main.py 中的参数,例如最大迭代步数:

import argparse  # 用于解析命令行参数
import torch  # 导入 PyTorch 库
import random  # 用于设置随机种子,确保实验可复现
from eval import eval  # 导入评估函数
from test import test  # 导入测试函数
from train import train  # 导入训练函数


if __name__ == '__main__':
    # 创建 ArgumentParser 对象,用于解析命令行输入的参数
    parser = argparse.ArgumentParser()

    # 基本选项:处理输入内容和风格图像的路径
    parser.add_argument('--content', type=str,  # 单张内容图像路径
                    help='File path to the content image')
    parser.add_argument('--style', type=str,  # 单张风格图像路径,或多个风格图像路径,用逗号分隔,进行风格插值或空间控制
                        help='File path to the style image, or multiple style \
                        images separated by commas if you want to do style \
                        interpolation or spatial control')
    parser.add_argument('--content_dir', default='./datasets/train2014', type=str,  # 内容图像的目录路径
                        help='Directory path to a batch of content images')
    parser.add_argument('--style_dir', default='./datasets/Images', type=str,  # 风格图像的目录路径
                        help='Directory path to a batch of style images')
    parser.add_argument('--content_test_dir', default='./datasets/train2014', type=str,  # 测试集内容图像路径
                        help='Directory path to a batch of content images')
    parser.add_argument('--style_test_dir', default='./datasets/Images', type=str,  # 测试集风格图像路径
                        help='Directory path to a batch of style images')
    parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')  # 预训练的 VGG 网络路径,需下载

    # 训练选项:训练过程中保存模型和结果的路径
    parser.add_argument('--checkpoints_dir', default='./experiments',  # 模型保存路径
                        help='Directory to save the model')
    parser.add_argument('--results_dir', default='./experiments',  # 结果保存路径
                        help='Directory to save the results')
    parser.add_argument('--log_dir', default='./logs',  # 日志保存路径
                        help='Directory to save the log')
    parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')  # 解码器模型路径
    parser.add_argument('--mamba_path', type=str, default='experiments/transformer_iter_160000.pth')  # Mamba 网络路径
    parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')  # 嵌入模型路径
    parser.add_argument('--continue_train', action='store_true')  # 是否继续训练的标志
    parser.add_argument('--resume_iter', type=int, default=0)  # 继续训练时恢复的迭代步数
    parser.add_argument('--lr', type=float, default=5e-4)  # 学习率
    parser.add_argument('--lr_decay', type=float, default=1e-3)  # 学习率衰减系数
    parser.add_argument('--max_iter', type=int, default=160000)  # 最大训练迭代次数
    parser.add_argument('--print_every', type=int, default=1000)  # 每多少步打印一次训练信息
    parser.add_argument('--eval_every', type=int, default=10000)  # 每多少步进行一次评估
    parser.add_argument('--batch_size', type=int, default=8)  # 批次大小
    parser.add_argument('--style_weight', type=float, default=10.0)  # 风格损失的权重
    parser.add_argument('--content_weight', type=float, default=7.0)  # 内容损失的权重
    parser.add_argument('--l1_weight', type=float, default=70.0)  # L1 损失的权重
    parser.add_argument('--l2_weight', type=float, default=1.0)  # L2 损失的权重
    parser.add_argument('--n_threads', type=int, default=16)  # 数据加载的线程数
    parser.add_argument('--save_model_interval', type=int, default=10000)  # 每多少步保存一次模型
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                            help="Type of positional embedding to use on top of the image features")  # 位置嵌入类型
    parser.add_argument('--hidden_dim', default=512, type=int,  # Transformer 隐藏层的维度
                            help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--seed', default=777, type=int,  # 随机种子,确保实验可复现
                            help="Seed for reproducibility")
    parser.add_argument('--model_name', type=str, default='Mamba-ST')  # 模型名称
    parser.add_argument('--use_pos_embed', action='store_true')  # 是否使用位置嵌入
    parser.add_argument('--rnd_style', action='store_true')  # 是否使用随机风格
    parser.add_argument('--output_dir', type=str, default='',  # 输出图像保存路径
                    help='Directory to save the output image(s)')
    parser.add_argument('--img_size', type=int, default=256)  # 输入图像的尺寸
    parser.add_argument('--mode', type=str, required=True,  # 训练模式,必须指定 ('train', 'eval', 'test')
                        choices=['train', 'eval', 'test'])
    parser.add_argument('--d_state', type=int, default=16, help='Mamba hidden state dimension')  # Mamba 隐藏状态维度

    # 解析命令行参数
    args = parser.parse_args()

    # 设置随机种子,确保实验可复现
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    
    # 根据不同模式执行不同的任务
    if args.mode == 'train':  # 如果是训练模式,调用训练函数
        print(args)
        train(args)
    elif args.mode == 'eval':  # 如果是评估模式,调用评估函数
        print(args)
        eval(args)
    elif args.mode == 'test':  # 如果是测试模式,调用测试函数
        test(args)
    else:  # 如果模式无效,打印错误信息
        print("Specify valid mode")

 开始训练:

sh scripts/train.sh

3、test

sh scripts/test.sh

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐