Toybrick

torch.jit.trace多输入模型的转换

zhw

注册会员

积分
107
楼主
发表于 2020-10-17 10:32:23    查看: 13151|回复: 3 | [复制链接]    打印 | 只看该作者
代码:weight_file = './Models/TSSTG/tsstg-model.pth'
graph_args = {'strategy': 'spatial'}
class_names = ['Standing', 'Walking', 'Sitting', 'Lying Down',
               'Stand up', 'Sit down', 'Fall Down']
num_class = len(class_names)
model = TwoStreamSpatialTemporalGraph(graph_args, num_class).to(device)
model.load_state_dict(torch.load(weight_file, map_location=device))
model.eval()
# torch.save(model, 'pose.pt')
input1 = torch.ones(30, 13, 3)
input2 = torch.ones(1, 2)
trace_model = torch.jit.trace(model, (input1, input2))
报错:[size=15.0667px]Traceback (most recent call last):  File "D:/Inspur/fall_down/Human-Falling-Detect-Tracks/pt2rknn.py", line 29, in <module>    trace_model = torch.jit.trace(model, (input1, input2))  File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 875, in trace    check_tolerance, _force_outplace, _module_class)  File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 1027, in trace_module    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)  File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 548, in __call__    result = self._slow_forward(*input, **kwargs)  File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 534, in _slow_forward    result = self.forward(*input, **kwargs)TypeError: forward() takes 2 positional arguments but 3 were given[size=15.0667px]求大神指点,这种多输入的模型怎么转换rknn格式
回复

使用道具 举报

jefferyzhang

版主

积分
13580
沙发
发表于 2020-10-17 11:46:34 | 只看该作者
请问你这代码跟rknn有何关系么。。。。
回复

使用道具 举报

zhw

注册会员

积分
107
板凳
 楼主| 发表于 2020-10-17 14:11:07 | 只看该作者
jefferyzhang 发表于 2020-10-17 11:46
请问你这代码跟rknn有何关系么。。。。

torch转rknn不是应该先用torch.jit.trace保存模型,然后再转rknn吗?我torch.jit.trace保存模型的问题解决不了
回复

使用道具 举报

fly123

注册会员

积分
60
地板
发表于 2021-1-28 19:12:05 | 只看该作者
zhw 发表于 2020-10-17 14:11
torch转rknn不是应该先用torch.jit.trace保存模型,然后再转rknn吗?我torch.jit.trace保存模型的问题解 ...

检查下网络的 def forward,需要两个输入
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表