- import platform
- import os
- import torch
- import numpy as np
- from rknn.api import RKNN
- import onnx
- from onnx_tf.backend import prepare
- class LstmNode(torch.nn.Module):
- def __init__(self, input_size, hidden_size):
- super(LstmNode, self).__init__()
- self._fc_x = torch.nn.Linear(input_size, hidden_size)
- self._fc_hc = torch.nn.Linear(hidden_size, hidden_size)
- def forward(self, x, hc0):
- a = self._fc_x(x)
- b = self._fc_hc(hc0)
- return a + b
- class HardTanh(torch.nn.Module):
- def __init__(self):
- super(HardTanh, self).__init__()
- def forward(self, x):
- return (x * 0.5 + 0.5)
- class PicewiseLinear(torch.nn.Module):
- def __init__(self, in_size, out_size):
- super(PicewiseLinear, self).__init__()
- self._linear = torch.nn.Linear(in_size, out_size)
- def forward(self, x):
- ox = self._linear(x)
- ox = ox.clamp_(0, 1)
- pass
- class LstmUnit(torch.nn.Module):
- def __init__(self, input_size, hidden_size):
- super(LstmUnit, self).__init__()
- self._tanh = torch.nn.Hardtanh() # torch.nn.Tanh()
- self._sigmoid = torch.nn.Sigmoid()
- self._fc_it = LstmNode(input_size, hidden_size)
- self._fc_ft = LstmNode(input_size, hidden_size)
- self._fc_gt = LstmNode(input_size, hidden_size)
- self._fc_ot = LstmNode(input_size, hidden_size)
- pass
- def forward(self, x, h0, c0):
- # # _tanh替换为_sigmoid, onnx加载计算错误; load_torch加载可转换但模型初始化失败
- # it = self._sigmoid(self._fc_it(x, h0))
- # ft = self._sigmoid(self._fc_ft(x, h0))
- # gt = self._sigmoid(self._fc_gt(x, h0)) # self._tanh
- # ot = self._sigmoid(self._fc_ot(x, h0))
- # ct = ft * c0 + it * gt
- # ht = ot * self._sigmoid(ct) # ot * self._tanh(ct)
- # 去掉sigmoid/tanh只剩矩阵乘加, onnx加载计算结果错误, load_pytorch结果正确
- it = self._fc_it(x, h0)
- ft = self._fc_ft(x, h0)
- gt = self._fc_gt(x, h0) # self._tanh
- ot = self._fc_ot(x, h0)
- ct = ft * c0 + it * gt
- ht = ot * ct # ot * self._tanh(ct)
- return ot, ht, ct
- class LSTM(torch.nn.Module):
- def __init__(self, seq_len, input_size, hidden_size):
- super(LSTM, self).__init__()
- self._seq_len = seq_len
- self._input_size = input_size
- self._hidden_size = hidden_size
- # self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
- # self._fc = torch.nn.Linear(input_size, hidden_size)
- self._lstm_unit = LstmUnit(input_size, hidden_size)
- def forward(self, x, h0, c0):
- ix = x.view(1, -1)
- ih = h0.view(1, -1)
- ic = c0.view(1, -1)
- y, oh, oc = self._lstm_unit(ix, ih, ic)
- oy = y.view(1, seq_len, -1)
- oh = oh.view(1, 1, -1)
- oc = ic.view(1, 1, -1)
- # 此段循环会生成 select算子,rknn.load_pytorch时报错
- # ix = x.view(1, self._seq_len, -1)
- # ih = h0.view(1, -1)
- # ic = c0.view(1, -1)
- # y = []
- # for i in range(self._seq_len):
- # xt = ix[0][i].view(1, -1)
- # yt, ih, ic = self._lstm_unit(xt, ih, ic)
- # y.append(yt)
- # oy = torch.cat(y)
- # oh = ih.view(1, 1, -1)
- # oc = ic.view(1, 1, -1)
- # Pytorch LSTM
- # ix = x.view(1, seq_len, -1)
- # ih = h0.view(1, 1, -1)
- # ic = c0.view(1, 1, -1)
- # oy, (oh, oc) = self._lstm(ix, (ih, ic))
- return oy, oh, oc
- # return self._sigmoid(ix), self._sigmoid(ih), self._sigmoid(ic)
复制代码
zht 发表于 2020-3-5 16:30
1. lstm可以参考: https://github.com/MaybeShewill-CV/CRNN_Tensorflow
2. PPU是一个可编程的模块,模块函 ...
欢迎光临 Toybrick (https://t.rock-chips.com/) | Powered by Discuz! X3.3 |