Toybrick

标题: torch.jit.trace多输入模型的转换 [打印本页]

作者: zhw    时间: 2020-10-17 10:32
标题: torch.jit.trace多输入模型的转换
代码:weight_file = './Models/TSSTG/tsstg-model.pth'
graph_args = {'strategy': 'spatial'}
class_names = ['Standing', 'Walking', 'Sitting', 'Lying Down',
               'Stand up', 'Sit down', 'Fall Down']
num_class = len(class_names)
model = TwoStreamSpatialTemporalGraph(graph_args, num_class).to(device)
model.load_state_dict(torch.load(weight_file, map_location=device))
model.eval()
# torch.save(model, 'pose.pt')
input1 = torch.ones(30, 13, 3)
input2 = torch.ones(1, 2)
trace_model = torch.jit.trace(model, (input1, input2))
报错:[size=15.0667px]Traceback (most recent call last):  File "D:/Inspur/fall_down/Human-Falling-Detect-Tracks/pt2rknn.py", line 29, in <module>    trace_model = torch.jit.trace(model, (input1, input2))  File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 875, in trace    check_tolerance, _force_outplace, _module_class)  File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 1027, in trace_module    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)  File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 548, in __call__    result = self._slow_forward(*input, **kwargs)  File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 534, in _slow_forward    result = self.forward(*input, **kwargs)TypeError: forward() takes 2 positional arguments but 3 were given[size=15.0667px]求大神指点,这种多输入的模型怎么转换rknn格式

作者: jefferyzhang    时间: 2020-10-17 11:46
请问你这代码跟rknn有何关系么。。。。
作者: zhw    时间: 2020-10-17 14:11
jefferyzhang 发表于 2020-10-17 11:46
请问你这代码跟rknn有何关系么。。。。

torch转rknn不是应该先用torch.jit.trace保存模型,然后再转rknn吗?我torch.jit.trace保存模型的问题解决不了
作者: fly123    时间: 2021-1-28 19:12
zhw 发表于 2020-10-17 14:11
torch转rknn不是应该先用torch.jit.trace保存模型,然后再转rknn吗?我torch.jit.trace保存模型的问题解 ...

检查下网络的 def forward,需要两个输入




欢迎光临 Toybrick (https://t.rock-chips.com/) Powered by Discuz! X3.3