Toybrick

pytorch->onnx-rknn 输入channel是否等于3决定推理结果是否正确

kkkaaa

中级会员

积分
203
楼主
发表于 2020-4-30 11:26:31    查看: 6480|回复: 5 | [复制链接]    打印 | 只看该作者
本帖最后由 kkkaaa 于 2020-4-30 14:43 编辑

不知道是不是一个 bug
rknn 版本:1.3.1b1
输入给 rknn 的 tensor 已经做了 transpose, 满足 data_format = 'nhwc'
输出没有做任何处理,保持模型返回的结果


我测试的 toy model 定义如下
  1. class ToyModel(nn.Module):
  2.     def __init__(self, ):
  3.         super(ToyModel, self).__init__()
  4.         self.op = nn.ReLU()

  5.     def forward(self, x):
  6.         x = self.op(x)
  7.         return x
复制代码

op 替换成其他简单 op 效果一样,为了明显看出规律,使用了 relu.


当输入 tensor 的形状为 (1,3, H, W) 时,结果的 channel 维颠倒了
当输入 tensor 的形状为 (1,C, H, W) and C!=3 时, 结果是正确的

以下是一个完整的可打印出 torcu, onnx, rknn 模型输入和推理结果的脚本

  1. """ torch -> onnx -> rknn """

  2. import torch
  3. import numpy as np
  4. from torch import nn


  5. model_name = "little_model_func_conv"
  6. ONNX_MODEL = model_name + '.onnx'
  7. RKNN_MODEL = model_name + '_from_onnx' + '.rknn'

  8. image_size_h = 2
  9. image_size_w =1
  10. num_channel =3

  11. # === torch 模型初始化 ===


  12. class ToyModel(nn.Module):
  13.     def __init__(self, ):
  14.         super(ToyModel, self).__init__()
  15.         self.op = nn.ReLU()

  16.     def forward(self, x):
  17.         x = self.op(x)
  18.         return x

  19. net = ToyModel()

  20. print("==== network ====")
  21. print(net)
  22. net.eval()

  23. # === 转化1: torch2onnx ===
  24. print("--> torch model inference result")

  25. input_tensor = torch.Tensor(np.arange(num_channel * image_size_h * image_size_w).reshape(1, num_channel, image_size_h, image_size_w))
  26. torch_out = torch.onnx._export(net, input_tensor, ONNX_MODEL, export_params=True)

  27. # === 转化2: onnx2rknn & rknn inference ===
  28. from rknn.api import RKNN
  29. rknn = RKNN()

  30. print('--> Loading model')
  31. ret = rknn.load_onnx(model=ONNX_MODEL)
  32. if ret != 0:
  33.     print('Load resnet50v2 failed!')
  34.     exit(ret)
  35. print('done')

  36. # Build model
  37. print('--> Building model')
  38. ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
  39. if ret != 0:
  40.     print('Build resnet50 failed!')
  41.     exit(ret)
  42. print('done')

  43. # Export rknn model
  44. print('--> Export RKNN model')
  45. ret = rknn.export_rknn(RKNN_MODEL)
  46. if ret != 0:
  47.     print('Export resnet50v2.rknn failed!')
  48.     exit(ret)
  49. print('done')

  50. # === rknn inference ===
  51. # init runtime environment
  52. print("--> Init runtime environment")
  53. ret = rknn.init_runtime()
  54. if ret != 0:
  55.     print("Init runtime environment failed")
  56.     exit(ret)
  57. print('done')

  58. # inference
  59. print("--> Running rknn model")
  60. rknn_input = input_tensor.numpy()[0].transpose(1,2,0)
  61. print('----> rknn input')
  62. print(rknn_input)
  63. rknn_outputs = rknn.inference(inputs=[rknn_input])[0][0]  #[::-1]

  64. # === torch inference ===
  65. print('----> torch input')
  66. print(input_tensor)
  67. torch_inference_result = net(input_tensor)[0].detach().cpu().numpy()

  68. # === onnx inference ===
  69. import onnxruntime
  70. def to_numpy(tensor):
  71.     return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

  72. ort_session = onnxruntime.InferenceSession(ONNX_MODEL)
  73. ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_tensor)}
  74. ort_outs = ort_session.run([x.name for x in ort_session.get_outputs()], ort_inputs)


  75. # === compare & show results ===
  76. print("~~~~~~ torch model infer output ~~~~~~")
  77. print(torch_inference_result)
  78. print("~~~~~~ onnx model infer output ~~~~~~")
  79. print(ort_outs)
  80. print("~~~~~~ rknn model infer output ~~~~~~")
  81. print(rknn_outputs)
复制代码
这个问题困扰我好久了,请看看是不是有 bug, 谢谢
回复

使用道具 举报

jefferyzhang

版主

积分
12953
沙发
发表于 2020-5-1 22:24:09 | 只看该作者
1. 试下1.3.2正式版本,
2. 文档有介绍,nchw都是要显式申明设置给rknn的,不写的话结果可能不符合预期。具体看下文档,我自己没玩过pytorch。
回复

使用道具 举报

kkkaaa

中级会员

积分
203
板凳
 楼主| 发表于 2020-5-6 16:40:54 | 只看该作者
jefferyzhang 发表于 2020-5-1 22:24
1. 试下1.3.2正式版本,
2. 文档有介绍,nchw都是要显式申明设置给rknn的,不写的话结果可能不符合预期。具 ...

我试了 1.3.2, 而且显式指定 data_format = 'nhwc', 结果还是一样。
我觉得这本质上是一个从 onnx 到 rknn 转换的问题。

我把我观察到的现象描述的再具体一点。
当torch/onnx的输入是
tensor([[[[0.],
          [1.]],

         [[2.],
          [3.]]]])
即 rknn 的输入是
[[[0. 2.]]

[[1. 3.]]]时

rknn 输出是
[[[0.]
  [1.]]

[[2.]
  [3.]]]
----------
当 pytorch/onnx 输入是
tensor([[[[0.],
          [1.]],

         [[2.],
          [3.]],

         [[4.],
          [5.]]]])
即 rknn 输入是
[[[0. 2. 4.]]

[[1. 3. 5.]]] 时,
rknn 输出是
[[[4.]
  [5.]]

[[2.]
  [3.]]

[[0.]
  [1.]]]
回复

使用道具 举报

kkkaaa

中级会员

积分
203
地板
 楼主| 发表于 2020-5-8 14:23:15 | 只看该作者
这个问题挺困扰我的,虽然找到可以保持输出一致的方法,但是总觉得很奇怪

也不知道是不是我哪里做错了,不过我真的已经检查过很多次了
回复

使用道具 举报

leok

版主

积分
894
5#
发表于 2020-5-12 17:56:23 | 只看该作者
kkkaaa 发表于 2020-5-8 14:23
这个问题挺困扰我的,虽然找到可以保持输出一致的方法,但是总觉得很奇怪

也不知道是不是我哪里做错了,不 ...

显示指定NHWC或者NCHW,输入的数据也要是这个格式。如果不是的话,可以可以用np.transpose接口去转换。
回复

使用道具 举报

kkkaaa

中级会员

积分
203
6#
 楼主| 发表于 2020-5-14 21:53:37 | 只看该作者
找到原因了,和帖子 http://t.rock-chips.com/forum.php?mod=viewthread&tid=1542&extra= 是同一个问题
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表