|
本帖最后由 DDX 于 2021-3-19 16:21 编辑
前提:1)由于pytorch训练的模型市pth格式(torch.save(model,path)用的保存网络格式和参数的方式),不想再训练一遍,所以转换成onnx,然后再用rknn-toolkit转换成rknn格式
2)rknn-toolkit==1.3
3)加速棒是tb-rk1808-s0,固件市最原始的版本
(1) pth-->>onnx代码:
'''
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("./logs/1.pth") # pytorch模型加载
batch_size = 4 #批处理大小
input_shape = (3, 600, 600) #输入数据,改成自己的输入shape
# #set the model to inference mode
model.eval()
x = torch.randn(batch_size, *input_shape) # 生成张量
x = x.to(device)
export_onnx_file = "./test.onnx" # 目的ONNX文件名
torch.onnx.export(model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
"output":{0:"batch_size"}})
'''
(2)然后onnx转换rknn出错,看不懂里面的内容市什么出错,请帮忙答疑一下
--> config model
done
--> Loading model
/home/jf/anaconda3/lib/python3.6/site-packages/onnx_tf/common/__init__.py:87: UserWarning: FrontendHandler.get_outputs_names is deprecated. It will be removed in future release.. Use node.outputs instead.
warnings.warn(message)
E Catch exception when loading onnx model: ./test.onnx!
E Traceback (most recent call last):
E File "rknn/api/rknn_base.py", line 469, in rknn.api.rknn_base.RKNNBase.load_onnx
E File "rknn/base/RKNNlib/converter/convert_onnx.py", line 494, in rknn.base.RKNNlib.converter.convert_onnx.convert_onnx.__init__
E File "rknn/base/RKNNlib/converter/convert_onnx.py", line 497, in rknn.base.RKNNlib.converter.convert_onnx.convert_onnx.__init__
E File "/home/jf/anaconda3/lib/python3.6/site-packages/onnx/checker.py", line 86, in check_model
E C.check_model(model.SerializeToString())
E onnx.onnx_cpp2py_export.checker.ValidationError: Unrecognized attribute: ceil_mode for operator MaxPool
E ==> Context: Bad node spec: input: "357" output: "358" op_type: "MaxPool" attribute { name: "ceil_mode" i: 1 type: INT } attribute { name: "kernel_shape" ints: 3 ints: 3 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 2 ints: 2 type: INTS }
Load test failed!
|
|