- class ToyModel(nn.Module):
- def __init__(self, ):
- super(ToyModel, self).__init__()
- self.op = nn.ReLU()
- def forward(self, x):
- x = self.op(x)
- return x
复制代码
- """ torch -> onnx -> rknn """
- import torch
- import numpy as np
- from torch import nn
- model_name = "little_model_func_conv"
- ONNX_MODEL = model_name + '.onnx'
- RKNN_MODEL = model_name + '_from_onnx' + '.rknn'
- image_size_h = 2
- image_size_w =1
- num_channel =3
- # === torch 模型初始化 ===
- class ToyModel(nn.Module):
- def __init__(self, ):
- super(ToyModel, self).__init__()
- self.op = nn.ReLU()
- def forward(self, x):
- x = self.op(x)
- return x
- net = ToyModel()
- print("==== network ====")
- print(net)
- net.eval()
- # === 转化1: torch2onnx ===
- print("--> torch model inference result")
- input_tensor = torch.Tensor(np.arange(num_channel * image_size_h * image_size_w).reshape(1, num_channel, image_size_h, image_size_w))
- torch_out = torch.onnx._export(net, input_tensor, ONNX_MODEL, export_params=True)
- # === 转化2: onnx2rknn & rknn inference ===
- from rknn.api import RKNN
- rknn = RKNN()
- print('--> Loading model')
- ret = rknn.load_onnx(model=ONNX_MODEL)
- if ret != 0:
- print('Load resnet50v2 failed!')
- exit(ret)
- print('done')
- # Build model
- print('--> Building model')
- ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
- if ret != 0:
- print('Build resnet50 failed!')
- exit(ret)
- print('done')
- # Export rknn model
- print('--> Export RKNN model')
- ret = rknn.export_rknn(RKNN_MODEL)
- if ret != 0:
- print('Export resnet50v2.rknn failed!')
- exit(ret)
- print('done')
- # === rknn inference ===
- # init runtime environment
- print("--> Init runtime environment")
- ret = rknn.init_runtime()
- if ret != 0:
- print("Init runtime environment failed")
- exit(ret)
- print('done')
- # inference
- print("--> Running rknn model")
- rknn_input = input_tensor.numpy()[0].transpose(1,2,0)
- print('----> rknn input')
- print(rknn_input)
- rknn_outputs = rknn.inference(inputs=[rknn_input])[0][0] #[::-1]
- # === torch inference ===
- print('----> torch input')
- print(input_tensor)
- torch_inference_result = net(input_tensor)[0].detach().cpu().numpy()
- # === onnx inference ===
- import onnxruntime
- def to_numpy(tensor):
- return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
- ort_session = onnxruntime.InferenceSession(ONNX_MODEL)
- ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_tensor)}
- ort_outs = ort_session.run([x.name for x in ort_session.get_outputs()], ort_inputs)
- # === compare & show results ===
- print("~~~~~~ torch model infer output ~~~~~~")
- print(torch_inference_result)
- print("~~~~~~ onnx model infer output ~~~~~~")
- print(ort_outs)
- print("~~~~~~ rknn model infer output ~~~~~~")
- print(rknn_outputs)
复制代码
这个问题困扰我好久了,请看看是不是有 bug, 谢谢jefferyzhang 发表于 2020-5-1 22:24
1. 试下1.3.2正式版本,
2. 文档有介绍,nchw都是要显式申明设置给rknn的,不写的话结果可能不符合预期。具 ...
kkkaaa 发表于 2020-5-8 14:23
这个问题挺困扰我的,虽然找到可以保持输出一致的方法,但是总觉得很奇怪
也不知道是不是我哪里做错了,不 ...
欢迎光临 Toybrick (https://t.rock-chips.com/) | Powered by Discuz! X3.3 |