Toybrick

pytorch init runtime 报错

SongJ

注册会员

积分
87
楼主
发表于 2020-6-1 11:22:30    查看: 6022|回复: 4 | [复制链接]    打印 | 只看该作者
新手上路,用pytorch做了mnist手写体数字识别,在转换模型的时候报错
网络结构如下:
  1. class Net(nn.Module):
  2.     def __init__(self):
  3.         super(Net, self).__init__()
  4.         self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  5.         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  6.         self.conv2_drop = nn.Dropout2d()
  7.         self.fc1 = nn.Linear(320, 50)
  8.         self.fc2 = nn.Linear(50, 10)

  9.     def forward(self, x):
  10.         x = F.relu(F.max_pool2d(self.conv1(x), 2))
  11.         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  12.         x = x.view(-1, 320)
  13.         x = F.relu(self.fc1(x))
  14.         x = F.dropout(x, training=self.training)
  15.         x = self.fc2(x)
  16.         return F.log_softmax(x, dim=1)
复制代码
经训练后保存为mnist.py,然后将其转换成torchscript文件
  1. def save_traced_model():
  2.     model = Net()
  3.     model.load_state_dict(torch.load('mnist.pt'))
  4.     model.eval()
  5.     example = torch.rand(64, 1, 28, 28)
  6.     traced_script_model = torch.jit.trace(model, example)
  7.     traced_script_model.save("traced_model.pt")
复制代码
然后进行模型转化
  1. if __name__ == '__main__':

  2.     model = './traced_model.pt'
  3.     input_size_list = [[1, 28, 28]]

  4.     rknn = RKNN()

  5.     # pre-process config
  6.     print('--> config model')
  7.     rknn.config(channel_mean_value='0 255', reorder_channel='0 1 2')
  8.     print('done')

  9.     # Load pytorch model
  10.     print('--> Loading model')
  11.     ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
  12.     if ret != 0:
  13.         print('Load pytorch model failed!')
  14.         exit(ret)
  15.     print('done')

  16.     # Build model
  17.     print('--> Building model')
  18.     ret = rknn.build(do_quantization=False, dataset='./1.txt')
  19.     if ret != 0:
  20.         print('Build pytorch failed!')
  21.         exit(ret)
  22.     print('done')

  23.     # Export rknn model
  24.     print('--> Export RKNN model')
  25.     ret = rknn.export_rknn('./mnist.rknn')
  26.     if ret != 0:
  27.         print('Export mnist.rknn failed!')
  28.         exit(ret)
  29.     print('done')

  30.     ret = rknn.load_rknn('./mnist.rknn')

  31.     # init runtime environment
  32.     print('--> Init runtime environment')
  33.     ret = rknn.init_runtime()
  34.     if ret != 0:
  35.         print('Init runtime environment failed')
  36.         exit(ret)
  37.     print('done')
复制代码
结果运行到init runtime的时候报错segmentation fault(core dumped)


请问可能是什么原因呢?手写体网络本身可以运行得到测试结果,应该没什么问题吧。。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
回复

使用道具 举报

jefferyzhang

版主

积分
12944
沙发
发表于 2020-6-1 14:12:07 | 只看该作者
1. 转换看过去没有问题。
2. W the target_platform is not set in config , 你是在哪里跑的
回复

使用道具 举报

SongJ

注册会员

积分
87
板凳
 楼主| 发表于 2020-6-1 14:54:47 | 只看该作者
jefferyzhang 发表于 2020-6-1 14:12
1. 转换看过去没有问题。
2. W the target_platform is not set in config , 你是在哪里跑的 ...

在pc上跑的
回复

使用道具 举报

jefferyzhang

版主

积分
12944
地板
发表于 2020-6-1 15:05:39 | 只看该作者

PC跑仿真?
回复

使用道具 举报

SongJ

注册会员

积分
87
5#
 楼主| 发表于 2020-6-1 15:38:23 | 只看该作者

对的,在PC上跑的仿真
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表