Toybrick

LSTM尝试了pytorch, tensorflw都不行

xsky

中级会员

积分
388
楼主
发表于 2020-2-28 16:17:17    查看: 8683|回复: 3 | [复制链接]    打印 | 只看该作者
本帖最后由 xsky 于 2020-2-28 16:23 编辑

rknn     1.3
D RKNNAPI:   API: 1.3.0
D RKNNAPI:   DRV: 1.3.0

torch                1.2.0
tensorflow         1.14.0
onnx                 1.4.1
onnx-tf              1.2.1


目的是导出并测试一个LSTM运算

完整代码:
其中文件名
xxx-pt为pytorch建立模型,xxx-tf为tensorflow
xxx-run 为去掉模型创建转换,只加载.rknn并运行
op-check-xx  为单个运算符测试


其中pytorch,模型定义代码
  1. import platform
  2. import os
  3. import torch
  4. import numpy as np

  5. from rknn.api import RKNN

  6. import onnx
  7. from onnx_tf.backend import prepare


  8. class LstmNode(torch.nn.Module):
  9.     def __init__(self, input_size, hidden_size):
  10.         super(LstmNode, self).__init__()
  11.         self._fc_x = torch.nn.Linear(input_size, hidden_size)
  12.         self._fc_hc = torch.nn.Linear(hidden_size, hidden_size)

  13.     def forward(self, x, hc0):
  14.         a = self._fc_x(x)
  15.         b = self._fc_hc(hc0)
  16.         return a + b


  17. class HardTanh(torch.nn.Module):
  18.     def __init__(self):
  19.         super(HardTanh, self).__init__()

  20.     def forward(self, x):
  21.         return (x * 0.5 + 0.5)


  22. class PicewiseLinear(torch.nn.Module):
  23.     def __init__(self, in_size, out_size):
  24.         super(PicewiseLinear, self).__init__()
  25.         self._linear = torch.nn.Linear(in_size, out_size)

  26.     def forward(self, x):
  27.         ox = self._linear(x)
  28.         ox = ox.clamp_(0, 1)
  29.         pass


  30. class LstmUnit(torch.nn.Module):
  31.     def __init__(self, input_size, hidden_size):
  32.         super(LstmUnit, self).__init__()
  33.         self._tanh = torch.nn.Hardtanh()  # torch.nn.Tanh()
  34.         self._sigmoid = torch.nn.Sigmoid()
  35.         self._fc_it = LstmNode(input_size, hidden_size)
  36.         self._fc_ft = LstmNode(input_size, hidden_size)
  37.         self._fc_gt = LstmNode(input_size, hidden_size)
  38.         self._fc_ot = LstmNode(input_size, hidden_size)
  39.         pass

  40.     def forward(self, x, h0, c0):
  41.         # # _tanh替换为_sigmoid, onnx加载计算错误;  load_torch加载可转换但模型初始化失败
  42.         # it = self._sigmoid(self._fc_it(x, h0))
  43.         # ft = self._sigmoid(self._fc_ft(x, h0))
  44.         # gt = self._sigmoid(self._fc_gt(x, h0))      # self._tanh
  45.         # ot = self._sigmoid(self._fc_ot(x, h0))
  46.         # ct = ft * c0 + it * gt
  47.         # ht = ot * self._sigmoid(ct)                 # ot * self._tanh(ct)

  48.         #  去掉sigmoid/tanh只剩矩阵乘加, onnx加载计算结果错误, load_pytorch结果正确
  49.         it = self._fc_it(x, h0)
  50.         ft = self._fc_ft(x, h0)
  51.         gt = self._fc_gt(x, h0)  # self._tanh
  52.         ot = self._fc_ot(x, h0)
  53.         ct = ft * c0 + it * gt
  54.         ht = ot * ct  # ot * self._tanh(ct)
  55.         return ot, ht, ct


  56. class LSTM(torch.nn.Module):

  57.     def __init__(self, seq_len, input_size, hidden_size):
  58.         super(LSTM, self).__init__()

  59.         self._seq_len = seq_len
  60.         self._input_size = input_size
  61.         self._hidden_size = hidden_size
  62.         # self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
  63.         # self._fc = torch.nn.Linear(input_size, hidden_size)

  64.         self._lstm_unit = LstmUnit(input_size, hidden_size)

  65.     def forward(self, x, h0, c0):
  66.         ix = x.view(1, -1)
  67.         ih = h0.view(1, -1)
  68.         ic = c0.view(1, -1)
  69.         y, oh, oc = self._lstm_unit(ix, ih, ic)
  70.         oy = y.view(1, seq_len, -1)
  71.         oh = oh.view(1, 1, -1)
  72.         oc = ic.view(1, 1, -1)

  73.         # 此段循环会生成 select算子,rknn.load_pytorch时报错
  74.         # ix = x.view(1, self._seq_len, -1)
  75.         # ih = h0.view(1, -1)
  76.         # ic = c0.view(1, -1)
  77.         # y = []
  78.         # for i in range(self._seq_len):
  79.         #     xt = ix[0][i].view(1, -1)
  80.         #     yt, ih, ic = self._lstm_unit(xt, ih, ic)
  81.         #     y.append(yt)
  82.         # oy = torch.cat(y)
  83.         # oh = ih.view(1, 1, -1)
  84.         # oc = ic.view(1, 1, -1)

  85.         # Pytorch LSTM
  86.         # ix = x.view(1, seq_len, -1)
  87.         # ih = h0.view(1, 1, -1)
  88.         # ic = c0.view(1, 1, -1)
  89.         # oy, (oh, oc) = self._lstm(ix, (ih, ic))

  90.         return oy, oh, oc

  91.         # return self._sigmoid(ix), self._sigmoid(ih), self._sigmoid(ic)
复制代码

实验:
1、直接使用Pytorch 的LSTM导出:.load_pytorch时正常,.build时报错2、Tensorflow单层动态LSTM导出:  不支持的算子,TensorArrayGatherV3;如果把输出节点上移,报错:E AttributeError: 'NoneType' object has no attribute 'op'
3、使用Pytorch的底层运算符构造LSTM单元及迭代,即上面贴出的代码:
     (1) 其中tanh报错: E KeyError: 'aten::tanh';
     (2) 如果去掉tanh的调用:load_pytorch可以转换为rknn,但inin_runtime时报错。
     (3)如果再把sigmoid也去掉,即只剩  Linear0 + Linear1,以及向量乘、向量加的操作: load_pytorch->rknn计算结果正确
     以上步骤均可成功通过onnx->rknn,  但多输入的顺序会乱序;  如果按乱序后的输入调整输入,onnx运行结果不正确。
    onnx->rknn 其中onnx模型是输入输出是有名字的,但转为rknn之后,C API运行查询到的输入名字是空的。
     (4)在(3)的基础上,只有乘法和加法,针对LSTM的基本单元,使用for循环在序列数上面迭代:
               
load_pytorch报错:E KeyError: 'aten::select';
                load_onnx,build报错:E ValueError: Try match Gather_26ut0 failed, catch exception!

4、对单个运算的测试
    load_tensorflow->RKNN均支持tanh, sigmoid,计算结果正确。


请问下:
1、LSTM除了写自定义操作符还有别的办法么?  
2、可以帮反馈下一版本支持一下双层双向动态LSTM么?  其主要的运算是Linear,应该NPU还是会有加速效果的。  但是用自定义运算符编写,其中应该会增加调度开销。底层驱动实现算子应该效率会高些。
3、使用自定义运算符,能没有instructions,类似NEON.h,将NPU支持加速的功能封装。
4、自定义运算符文档中提到的PPU模块全称是?有详细资料么? 这个模块可以针对哪些操作进行加速?





回复

使用道具 举报

xsky

中级会员

积分
388
沙发
 楼主| 发表于 2020-3-4 12:05:39 | 只看该作者
LSTM,有人试过么?
回复

使用道具 举报

zht

注册会员

积分
74
板凳
发表于 2020-3-5 16:30:34 | 只看该作者
1. lstm可以参考: https://github.com/MaybeShewill-CV/CRNN_Tensorflow
2. PPU是一个可编程的模块,模块函数参考:https://www.khronos.org/openvx/
回复

使用道具 举报

xsky

中级会员

积分
388
地板
 楼主| 发表于 2020-3-6 09:44:44 | 只看该作者
zht 发表于 2020-3-5 16:30
1. lstm可以参考: https://github.com/MaybeShewill-CV/CRNN_Tensorflow
2. PPU是一个可编程的模块,模块函 ...

嗯,谢谢
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表