- """ torch -> onnx -> rknn """
- import torch
- import numpy as np
- from torch import nn
- model_name = "little_model_func_conv"
- ONNX_MODEL = model_name + '.onnx'
- RKNN_MODEL = model_name + '_from_onnx' + '.rknn'
- image_size_h = 2
- image_size_w =1
- num_channel =3
- # === torch 模型初始化 ===
- class ToyModel(nn.Module):
- def __init__(self, ):
- super(ToyModel, self).__init__()
- self.op = nn.ReLU()
- def forward(self, x):
- x = self.op(x)
- return x
- net = ToyModel()
- print("==== network ====")
- print(net)
- net.eval()
- # === 转化1: torch2onnx ===
- print("--> torch model inference result")
- input_tensor = torch.Tensor(np.arange(num_channel * image_size_h * image_size_w).reshape(1, num_channel, image_size_h, image_size_w))
- torch_out = torch.onnx._export(net, input_tensor, ONNX_MODEL, export_params=True)
- # === 转化2: onnx2rknn & rknn inference ===
- from rknn.api import RKNN
- rknn = RKNN()
- print('--> Loading model')
- ret = rknn.load_onnx(model=ONNX_MODEL)
- if ret != 0:
- print('Load resnet50v2 failed!')
- exit(ret)
- print('done')
- # Build model
- print('--> Building model')
- ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
- if ret != 0:
- print('Build resnet50 failed!')
- exit(ret)
- print('done')
- # Export rknn model
- print('--> Export RKNN model')
- ret = rknn.export_rknn(RKNN_MODEL)
- if ret != 0:
- print('Export resnet50v2.rknn failed!')
- exit(ret)
- print('done')
- # === rknn inference ===
- # init runtime environment
- print("--> Init runtime environment")
- ret = rknn.init_runtime()
- if ret != 0:
- print("Init runtime environment failed")
- exit(ret)
- print('done')
- # inference
- print("--> Running rknn model")
- rknn_input = input_tensor.numpy()[0].transpose(1,2,0)
- print('----> rknn input')
- print(rknn_input)
- rknn_outputs = rknn.inference(inputs=[rknn_input])[0][0] #[::-1]
- # === torch inference ===
- print('----> torch input')
- print(input_tensor)
- torch_inference_result = net(input_tensor)[0].detach().cpu().numpy()
- # === onnx inference ===
- import onnxruntime
- def to_numpy(tensor):
- return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
- ort_session = onnxruntime.InferenceSession(ONNX_MODEL)
- ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_tensor)}
- ort_outs = ort_session.run([x.name for x in ort_session.get_outputs()], ort_inputs)
- # === compare & show results ===
- print("~~~~~~ torch model infer output ~~~~~~")
- print(torch_inference_result)
- print("~~~~~~ onnx model infer output ~~~~~~")
- print(ort_outs)
- print("~~~~~~ rknn model infer output ~~~~~~")
- print(rknn_outputs)
复制代码
- ----> rknn input
- [[[0. 2. 4.]]
- [[1. 3. 5.]]]
- ----> torch input
- tensor([[[[0.],
- [1.]],
- [[2.],
- [3.]],
- [[4.],
- [5.]]]])
- ~~~~~~ torch model infer output ~~~~~~
- [[[0.]
- [1.]]
- [[2.]
- [3.]]
- [[4.]
- [5.]]]
- ~~~~~~ onnx model infer output ~~~~~~
- [array([[[[0.],
- [1.]],
- [[2.],
- [3.]],
- [[4.],
- [5.]]]], dtype=float32)]
- ~~~~~~ rknn model infer output ~~~~~~
- [[[4.]
- [5.]]
- [[2.]
- [3.]]
- [[0.]
- [1.]]]
复制代码
leok 发表于 2020-4-27 17:45
请确认rknn toolkit的版本,1.0.0以前版本输出是NHWC,1.0.0之后版本输出和原始模型一致。 ...
leok 发表于 2020-4-27 17:45
请确认rknn toolkit的版本,1.0.0以前版本输出是NHWC,1.0.0之后版本输出和原始模型一致。 ...
kkkaaa 发表于 2020-4-27 17:49
我好像发现问题所在了
rknn_input = input_tensor.numpy()[0][::-1].transpose(1,2,0)
如果事先把 rknn_in ...
leok 发表于 2020-4-27 20:37
rknn toolkit默认inference的data_format是nhwc
- rknn.config(reorder_channel="0 1 2")
复制代码
欢迎光临 Toybrick (https://t.rock-chips.com/) | Powered by Discuz! X3.3 |