|
代码:weight_file = './Models/TSSTG/tsstg-model.pth'
graph_args = {'strategy': 'spatial'}
class_names = ['Standing', 'Walking', 'Sitting', 'Lying Down',
'Stand up', 'Sit down', 'Fall Down']
num_class = len(class_names)
model = TwoStreamSpatialTemporalGraph(graph_args, num_class).to(device)
model.load_state_dict(torch.load(weight_file, map_location=device))
model.eval()
# torch.save(model, 'pose.pt')
input1 = torch.ones(30, 13, 3)
input2 = torch.ones(1, 2)
trace_model = torch.jit.trace(model, (input1, input2))
报错:[size=15.0667px]Traceback (most recent call last): File "D:/Inspur/fall_down/Human-Falling-Detect-Tracks/pt2rknn.py", line 29, in <module> trace_model = torch.jit.trace(model, (input1, input2)) File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 875, in trace check_tolerance, _force_outplace, _module_class) File "D:\Software\Python36\lib\site-packages\torch\jit\__init__.py", line 1027, in trace_module module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace) File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 548, in __call__ result = self._slow_forward(*input, **kwargs) File "D:\Software\Python36\lib\site-packages\torch\nn\modules\module.py", line 534, in _slow_forward result = self.forward(*input, **kwargs)TypeError: forward() takes 2 positional arguments but 3 were given[size=15.0667px]求大神指点,这种多输入的模型怎么转换rknn格式
|
|