- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from rknn.api import RKNN
- class siamrpn(nn.Module):
- def __init__(self):
- super(siamrpn, self).__init__()
- def forward(self, x, w):
- return F.conv2d(x, w)
- if __name__ == "__main__":
- net = siamrpn()
- net.eval()
- torch_model = 'siamrpn.pt'
- rknn_model = 'siamrpn.rknn'
- # export pt
- trace_model = torch.jit.trace(
- net, (torch.Tensor(1, 3, 224, 224), torch.Tensor(1, 3, 4, 4)))
- print(trace_model.code)
- trace_model.save(torch_model)
- # export rknn
- rknn = RKNN(verbose=True)
- # pre-process config
- print('--> config model')
- rknn.config(target_platform='rk3399pro')
- print('done')
- # Load pytorch model
- print('--> Loading model')
- ret = rknn.load_pytorch(model=torch_model, input_size_list=[
- [3, 224, 224], [3, 4, 4]])
- if ret != 0:
- print('Load pytorch model failed!')
- exit(ret)
- print('done')
复制代码
pfwhnudhwq 发表于 2020-7-15 12:08
请问:
1.不能使用pytorch的函数接口吗?就是代码里的F.conv2d(nn.Conv2d可以正确加载)
2.有提供3399pro,d ...
jefferyzhang 发表于 2020-7-16 09:08
1.能不能用接口函数请参看文档,貌似按其他客户经验确实很多这种F.xxx函数无法正确转换。
2.只要是toybrick ...
- model = 'pt/feature.pt'
- input_size_list = [[3, 127, 127]]
- inp_z = np.ones(shape=[127, 127, 3], dtype=np.uint8)*255
- # rknn inference
- rknn = RKNN()
- rknn.config(reorder_channel='0 1 2')
- ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
- ret = rknn.build(do_quantization=False)
- ret = rknn.init_runtime()
- outputs = rknn.inference(inputs=[inp_z])
- rknn.release()
- # pytorch inference
- inp_z = np.ones([3, 127, 127], dtype=np.uint8)
- inp_z = torch.Tensor(inp_z).unsqueeze(0)
- module = torch.jit.load(model)
- out = module.forward(inp_z)
- out = out.data.cpu().numpy()
- # difference between rknn and pytorch
- print(np.sum(np.abs(out-outputs)))
- WARNING: Token 'COMMENT' defined, but not used
- WARNING: There is 1 unused token
- W The target_platform is not set in config, using default target platform rk1808.
- W [set_chip_platform_env:187]evaluate model on RK1808
- 1384.8684
复制代码
pfwhnudhwq 发表于 2020-7-16 11:38
谢谢,能否帮我看看下面这段代码,用rknn.load_pytorch和pytorch的inference结果差这么多
...
欢迎光临 Toybrick (https://t.rock-chips.com/) | Powered by Discuz! X3.3 |