|
关于转换pytorch中nn.ConvTranspose2d()模块的bug,rknn和pytorch输出的维度对不上
- import numpy as np
- from rknn.api import RKNN
- import torch
- def export_pytorch_model():
- class simplepose(nn.Module):
- def __init__(self):
- super(simplepose, self).__init__()
- self.preact = nn.ConvTranspose2d(3, 64, kernel_size=4,stride=2, padding=1)
- def forward(self, x):
- x = self.preact(x)
- return x
- net = simplepose()
- dummy_input = torch.randn([1, 3, 224, 224])
- trace_model = torch.jit.trace(net, dummy_input)
- pt_save_path = './simplepose.pt'
- trace_model.save(pt_save_path)
- return pt_save_path
- if __name__ == '__main__':
- pt_model = export_pytorch_model()
- input_size_list = [[3, 224, 224]]
- # Create RKNN object
- rknn = RKNN()
- # pre-process config
- print('--> config model')
- rknn.config(channel_mean_value='0 0 0 1', reorder_channel='0 1 2')
- print('done')
- # Load pytorch model
- print('--> Loading model')
- ret = rknn.load_pytorch(model=pt_model, input_size_list=input_size_list)
- if ret != 0:
- print('Load pytorch model failed!')
- exit(ret)
- print('done')
- # Build model
- print('--> Building model')
- ret = rknn.build(do_quantization=False, dataset='./dataset.txt')
- if ret != 0:
- print('Build pytorch failed!')
- exit(ret)
- print('done')
- # Set inputs
- img = np.random.randn(224, 224, 3).astype(np.float32)
- # init runtime environment
- print('--> Init runtime environment')
- ret = rknn.init_runtime()
- if ret != 0:
- print('Init runtime environment failed')
- exit(ret)
- print('done')
- # Inference
- print('--> Running model')
- outputs = rknn.inference(inputs=[img])
- print('rknn output:')
- print(outputs[0].shape)
- # pytorch inference
- pt_i = np.expand_dims(img, axis=0)
- pt_i = np.transpose(pt_i, [0, 3, 1, 2])
- pt_i = torch.from_numpy(pt_i)
- model = torch.jit.load(pt_model)
- pt_o = model(pt_i).detach().numpy()
- print('pytorch output:')
- print(pt_o.shape)
- rknn.release()
|
本帖子中包含更多资源
您需要 登录 才可以下载或查看,没有帐号?立即注册
x
|