Toybrick

多输入时,通道顺序为何会乱? Python接口怎么用多输入?

jefferyzhang

版主

积分
12975
楼主
发表于 2020-2-23 11:04:38 | 显示全部楼层
为啥不直接用pytorch模型转rknn试试?
回复

使用道具 举报

jefferyzhang

版主

积分
12975
沙发
发表于 2020-2-25 08:48:05 | 显示全部楼层
1. 顺序不一致要再看下,我们还在确认这个问题中
2. name可以通过Netron看
3. 多输入python请参看教程有写,文档也有写
4. 不明白你说的多输入形状不一致是什么意思,如果是多个input节点当然是可以的,如果是多batch当然是不行的
回复

使用道具 举报

jefferyzhang

版主

积分
12975
板凳
发表于 2020-2-26 09:50:54 | 显示全部楼层
NPU部门回复:
1.支持多输入多输出,多个输入时shape可以不一样,但是输入要按照nchw的格式,具体可以参考my_multiple_input_test.py;
2.目前pytorch和onnx暂不支持lstm,tensorflow有些可以支持;

my_multiple_input_test.py:
  1. import torch
  2. from rknn.api import RKNN
  3. import numpy as np

  4. class Net(torch.nn.Module):

  5.     def __init__(self):
  6.         super(Net, self).__init__()

  7.         self.conv1 = torch.nn.Conv2d(1,6,3)
  8.         self.conv2 = torch.nn.Conv2d(2,12,3)
  9.         self.conv3 = torch.nn.Conv2d(3,24,3)

  10.     def forward(self, x, y, z):
  11.         x = self.conv1(x)
  12.         y = self.conv2(y)
  13.         z = self.conv3(z)

  14.         return x,y,z

  15. def E_D(vector1, vector2):
  16.     print(np.linalg.norm(vector1 - vector2))

  17. def cos_d(vector1, vector2):
  18.     d = np.dot(vector1, vector2) / (np.linalg.norm(vector1) * (np.linalg.norm(vector2)))
  19.     print(d)


  20. if __name__ == '__main__':

  21.     net = Net()
  22.     i1 = torch.rand(1,1,5,5)
  23.     i2 = torch.rand(1,2,7,7)
  24.     i3 = torch.rand(1,3,9,9)
  25.     trace_model = torch.jit.trace(net, (i1,i2,i3))
  26.     trace_model.save('test.pt')


  27.     rknn = RKNN(verbose=True)
  28.     rknn.config(batch_size=1,
  29.                 channel_mean_value='0 1#0 0 1#0 0 0 1',
  30.                 reorder_channel='0 1 2#0 1 2#0 1 2',
  31.                 epochs=1)

  32.     ret = rknn.load_pytorch(model='test.pt', input_size_list=[[1,5,5],[2,7,7],[3,9,9]])
  33.     # ret = rknn.load_onnx(model='lstm{128x64}.pt.onnx')
  34.     if ret != 0:
  35.         print('Load pytorch model failed!')
  36.         exit(ret)

  37.     ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
  38.     if ret != 0:
  39.         print('Build pytorch failed!')
  40.         exit(ret)

  41.     ret = rknn.init_runtime(target='rk1808',device_id='7e9f3eb02ede60e8')
  42.     if ret != 0:
  43.         print('Init runtime environment failed')
  44.         exit(ret)


  45.     rknn_r = rknn.inference(inputs=[i1.numpy(), i2.numpy(), i3.numpy()],
  46.                             data_type='float32',
  47.                             data_format='nchw')
  48.     pytorch_r = net(i1,i2,i3)


  49.     for d1, d2 in zip(rknn_r, pytorch_r):
  50.         d1 = d1.ravel()
  51.         d2 = d2.detach().numpy().ravel()

  52.         E_D(d1, d2)
  53.         cos_d(d1, d2)

  54.         print()


复制代码

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表