|
新手上路,用pytorch做了mnist手写体数字识别,在转换模型的时候报错
网络结构如下:
- class Net(nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
- self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
- self.conv2_drop = nn.Dropout2d()
- self.fc1 = nn.Linear(320, 50)
- self.fc2 = nn.Linear(50, 10)
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
- x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
- x = x.view(-1, 320)
- x = F.relu(self.fc1(x))
- x = F.dropout(x, training=self.training)
- x = self.fc2(x)
- return F.log_softmax(x, dim=1)
经训练后保存为mnist.py,然后将其转换成torchscript文件- def save_traced_model():
- model = Net()
- model.load_state_dict(torch.load('mnist.pt'))
- model.eval()
- example = torch.rand(64, 1, 28, 28)
- traced_script_model = torch.jit.trace(model, example)
- traced_script_model.save("traced_model.pt")
然后进行模型转化- if __name__ == '__main__':
- model = './traced_model.pt'
- input_size_list = [[1, 28, 28]]
- rknn = RKNN()
- # pre-process config
- print('--> config model')
- rknn.config(channel_mean_value='0 255', reorder_channel='0 1 2')
- print('done')
- # Load pytorch model
- print('--> Loading model')
- ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
- if ret != 0:
- print('Load pytorch model failed!')
- exit(ret)
- print('done')
- # Build model
- print('--> Building model')
- ret = rknn.build(do_quantization=False, dataset='./1.txt')
- if ret != 0:
- print('Build pytorch failed!')
- exit(ret)
- print('done')
- # Export rknn model
- print('--> Export RKNN model')
- ret = rknn.export_rknn('./mnist.rknn')
- if ret != 0:
- print('Export mnist.rknn failed!')
- exit(ret)
- print('done')
- ret = rknn.load_rknn('./mnist.rknn')
- # init runtime environment
- print('--> Init runtime environment')
- ret = rknn.init_runtime()
- if ret != 0:
- print('Init runtime environment failed')
- exit(ret)
- print('done')
结果运行到init runtime的时候报错segmentation fault(core dumped)
请问可能是什么原因呢?手写体网络本身可以运行得到测试结果,应该没什么问题吧。。
|
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有帐号?立即注册
x
|