Mamba-ST 代码复现
MambaST 代码复现
·
1、环境配置
conda create -n mambast python=3.9
conda activate mambast
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
报错:
Building wheels for collected packages: mamba_ssm
Building wheel for mamba_ssm (setup.py) ... error
error: subprocess-exited-with-error
× python setup.py bdist_wheel did not run successfully.
│ exit code: 1
╰─> [9 lines of output]
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
torch.__version__ = 2.3.1+cu121
running bdist_wheel
Guessing wheel URL: https://github.com/state-spaces/mamba/releases/download/v2.1.0/mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
error: <urlopen error retrieval incomplete: got only 20971520 out of 323800418 bytes>
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for mamba_ssm
Running setup.py clean for mamba_ssm
Failed to build mamba_ssm
ERROR: Failed to build installable wheels for some pyproject.toml based projects (mamba_ssm)
手动点击下面的链接下载:
https://github.com/state-spaces/mamba/releases/download/v2.1.0/mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install mamba_ssm-2.1.0+cu122torch2.3cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
2、train
下载 VGG 模型
https://drive.google.com/drive/folders/1pVhJFwk2f3arP7zUDFAe5_PJrPSG1gc2
修改 scripts/train.sh 文件中的数据集路径、修改VGG 模型路径
修改 main.py 中的参数,例如最大迭代步数:
import argparse # 用于解析命令行参数
import torch # 导入 PyTorch 库
import random # 用于设置随机种子,确保实验可复现
from eval import eval # 导入评估函数
from test import test # 导入测试函数
from train import train # 导入训练函数
if __name__ == '__main__':
# 创建 ArgumentParser 对象,用于解析命令行输入的参数
parser = argparse.ArgumentParser()
# 基本选项:处理输入内容和风格图像的路径
parser.add_argument('--content', type=str, # 单张内容图像路径
help='File path to the content image')
parser.add_argument('--style', type=str, # 单张风格图像路径,或多个风格图像路径,用逗号分隔,进行风格插值或空间控制
help='File path to the style image, or multiple style \
images separated by commas if you want to do style \
interpolation or spatial control')
parser.add_argument('--content_dir', default='./datasets/train2014', type=str, # 内容图像的目录路径
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', default='./datasets/Images', type=str, # 风格图像的目录路径
help='Directory path to a batch of style images')
parser.add_argument('--content_test_dir', default='./datasets/train2014', type=str, # 测试集内容图像路径
help='Directory path to a batch of content images')
parser.add_argument('--style_test_dir', default='./datasets/Images', type=str, # 测试集风格图像路径
help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth') # 预训练的 VGG 网络路径,需下载
# 训练选项:训练过程中保存模型和结果的路径
parser.add_argument('--checkpoints_dir', default='./experiments', # 模型保存路径
help='Directory to save the model')
parser.add_argument('--results_dir', default='./experiments', # 结果保存路径
help='Directory to save the results')
parser.add_argument('--log_dir', default='./logs', # 日志保存路径
help='Directory to save the log')
parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth') # 解码器模型路径
parser.add_argument('--mamba_path', type=str, default='experiments/transformer_iter_160000.pth') # Mamba 网络路径
parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth') # 嵌入模型路径
parser.add_argument('--continue_train', action='store_true') # 是否继续训练的标志
parser.add_argument('--resume_iter', type=int, default=0) # 继续训练时恢复的迭代步数
parser.add_argument('--lr', type=float, default=5e-4) # 学习率
parser.add_argument('--lr_decay', type=float, default=1e-3) # 学习率衰减系数
parser.add_argument('--max_iter', type=int, default=160000) # 最大训练迭代次数
parser.add_argument('--print_every', type=int, default=1000) # 每多少步打印一次训练信息
parser.add_argument('--eval_every', type=int, default=10000) # 每多少步进行一次评估
parser.add_argument('--batch_size', type=int, default=8) # 批次大小
parser.add_argument('--style_weight', type=float, default=10.0) # 风格损失的权重
parser.add_argument('--content_weight', type=float, default=7.0) # 内容损失的权重
parser.add_argument('--l1_weight', type=float, default=70.0) # L1 损失的权重
parser.add_argument('--l2_weight', type=float, default=1.0) # L2 损失的权重
parser.add_argument('--n_threads', type=int, default=16) # 数据加载的线程数
parser.add_argument('--save_model_interval', type=int, default=10000) # 每多少步保存一次模型
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features") # 位置嵌入类型
parser.add_argument('--hidden_dim', default=512, type=int, # Transformer 隐藏层的维度
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--seed', default=777, type=int, # 随机种子,确保实验可复现
help="Seed for reproducibility")
parser.add_argument('--model_name', type=str, default='Mamba-ST') # 模型名称
parser.add_argument('--use_pos_embed', action='store_true') # 是否使用位置嵌入
parser.add_argument('--rnd_style', action='store_true') # 是否使用随机风格
parser.add_argument('--output_dir', type=str, default='', # 输出图像保存路径
help='Directory to save the output image(s)')
parser.add_argument('--img_size', type=int, default=256) # 输入图像的尺寸
parser.add_argument('--mode', type=str, required=True, # 训练模式,必须指定 ('train', 'eval', 'test')
choices=['train', 'eval', 'test'])
parser.add_argument('--d_state', type=int, default=16, help='Mamba hidden state dimension') # Mamba 隐藏状态维度
# 解析命令行参数
args = parser.parse_args()
# 设置随机种子,确保实验可复现
torch.manual_seed(args.seed)
random.seed(args.seed)
# 根据不同模式执行不同的任务
if args.mode == 'train': # 如果是训练模式,调用训练函数
print(args)
train(args)
elif args.mode == 'eval': # 如果是评估模式,调用评估函数
print(args)
eval(args)
elif args.mode == 'test': # 如果是测试模式,调用测试函数
test(args)
else: # 如果模式无效,打印错误信息
print("Specify valid mode")
开始训练:
sh scripts/train.sh
3、test
sh scripts/test.sh
更多推荐
所有评论(0)