PyTorch量化报错后端不匹配
环境:PyTorch-1.7.1错误描述:使用PyTorch Quantization包进行量化感知训练(QAT)时,最后一步convert报错:Traceback (most recent call last):File "train.py", line 136, in <module>main()File "train.py", line 126, in mainquantized
·
环境:PyTorch-1.7.1
错误描述:使用PyTorch Quantization包进行量化感知训练(QAT)时,最后一步convert报错:
Traceback (most recent call last):
File "train.py", line 136, in <module>
main()
File "train.py", line 126, in main
quantized_model = torch.quantization.convert(model.eval(), inplace=False)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 414, in convert
_convert(module, mapping, inplace=True)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 458, in _convert
_convert(mod, mapping, inplace=True)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 459, in _convert
reassign[name] = swap_module(mod, mapping)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 485, in swap_module
new_mod = mapping[type(mod)].from_float(mod)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 368, in from_float
return cls.get_qconv(mod, activation_post_process, weight_post_process)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 157, in get_qconv
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/utils.py", line 16, in _quantize_weight
wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
RuntimeError: Could not run 'aten::quantize_per_channel' with arguments from the 'CUDA' backend. 'aten::quantize_per_channel' is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].
CPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/TraceType_2.cpp:9654 [kernel]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
(pytorch-1.7.1) ➜ CIFAR-10 python train.py
Files already downloaded and verified
Files already downloaded and verified
/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
reduce_range will be deprecated in a future release of PyTorch."
解决方案:我的模型训练过程在cuda上完成,而量化支持的是cpu后端,因此需要先将模型转到cpu上再量化:
quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)
更多推荐
所有评论(0)