Toybrick

标题: 关于onnx模型转换问题 [打印本页]

作者: DDX    时间: 2021-3-19 16:18
标题: 关于onnx模型转换问题
本帖最后由 DDX 于 2021-3-19 16:21 编辑

前提:1)由于pytorch训练的模型市pth格式(torch.save(model,path)用的保存网络格式和参数的方式),不想再训练一遍,所以转换成onnx,然后再用rknn-toolkit转换成rknn格式
           2)rknn-toolkit==1.3
           3)加速棒是tb-rk1808-s0,固件市最原始的版本

(1) pth-->>onnx代码:

'''

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.load("./logs/1.pth") # pytorch模型加载
batch_size = 4  #批处理大小
input_shape = (3, 600, 600)   #输入数据,改成自己的输入shape

# #set the model to inference mode
model.eval()

x = torch.randn(batch_size, *input_shape)   # 生成张量
x = x.to(device)
export_onnx_file = "./test.onnx"        # 目的ONNX文件名
torch.onnx.export(model,
                    x,
                    export_onnx_file,
                    opset_version=10,
                    do_constant_folding=True,    # 是否执行常量折叠优化
                    input_names=["input"],    # 输入名
                    output_names=["output"],    # 输出名
                    dynamic_axes={"input":{0:"batch_size"},  # 批处理变量
                                    "output":{0:"batch_size"}})
'''


(2)然后onnx转换rknn出错,看不懂里面的内容市什么出错,请帮忙答疑一下

--> config model
done
--> Loading model
/home/jf/anaconda3/lib/python3.6/site-packages/onnx_tf/common/__init__.py:87: UserWarning: FrontendHandler.get_outputs_names is deprecated. It will be removed in future release.. Use node.outputs instead.
  warnings.warn(message)
E Catch exception when loading onnx model: ./test.onnx!
E Traceback (most recent call last):
E   File "rknn/api/rknn_base.py", line 469, in rknn.api.rknn_base.RKNNBase.load_onnx
E   File "rknn/base/RKNNlib/converter/convert_onnx.py", line 494, in rknn.base.RKNNlib.converter.convert_onnx.convert_onnx.__init__
E   File "rknn/base/RKNNlib/converter/convert_onnx.py", line 497, in rknn.base.RKNNlib.converter.convert_onnx.convert_onnx.__init__
E   File "/home/jf/anaconda3/lib/python3.6/site-packages/onnx/checker.py", line 86, in check_model
E     C.check_model(model.SerializeToString())
E onnx.onnx_cpp2py_export.checker.ValidationError: Unrecognized attribute: ceil_mode for operator MaxPool
E ==> Context: Bad node spec: input: "357" output: "358" op_type: "MaxPool" attribute { name: "ceil_mode" i: 1 type: INT } attribute { name: "kernel_shape" ints: 3 ints: 3 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 2 ints: 2 type: INTS }
Load test failed!

作者: bobby_jiang    时间: 2021-3-30 15:50
转onnx的时候把opset_version设成9以后再试下。你现在生成的onnx模型版本太高,解析失败了。
作者: DDX    时间: 2021-11-15 11:12
bobby_jiang 发表于 2021-3-30 15:50
转onnx的时候把opset_version设成9以后再试下。你现在生成的onnx模型版本太高,解析失败了。 ...

谢谢您的答复




欢迎光临 Toybrick (https://t.rock-chips.com/) Powered by Discuz! X3.3