Toybrick

标题: 关于级联网络的推理中的格式转换问题 [打印本页]

作者: 萌虎龟来    时间: 2023-11-22 11:09
标题: 关于级联网络的推理中的格式转换问题
因为需要再RK3588部署一个单目标跟踪项目,众所周知,跟踪项目通常为两个模型,一个特征提取模型,一个特征融合模型
特征提取模型输入为cv Mat,输出为tensor,具体如下:
  1. [Instance branch]: INPUT_ATTR
  2. [TENSOR INDEX:            0]
  3. [TENSOR NAME:             input]
  4. [TENSOR N_DIMS:           4]
  5. [TENSOR DIMS:             [1, 255, 255, 3, ]]
  6. [TENSOR N_ELEMS:          195075]
  7. [TENSOR SIZE:             195075]
  8. [TENSOR W_STRIDE:         256]
  9. [TENSOR SIZE_WITH_STRIDE: 195840]
  10. [TENSOR FORMAT:           NHWC]
  11. [TENSOR TYPE:             INT8]
  12. [TENSOR QNT_TYPE:         AFFINE]
  13. [TENSOR ZP:               -128]
  14. [TENSOR SCALE:            1]

  15. [Instance branch]: OUTPUT_ATTR
  16. [TENSOR INDEX:            0]
  17. [TENSOR NAME:             output]
  18. [TENSOR N_DIMS:           4]
  19. [TENSOR DIMS:             [1, 48, 16, 16, ]]
  20. [TENSOR N_ELEMS:          12288]
  21. [TENSOR SIZE:             12288]
  22. [TENSOR W_STRIDE:         0]
  23. [TENSOR SIZE_WITH_STRIDE: 12288]
  24. [TENSOR FORMAT:           NCHW]
  25. [TENSOR TYPE:             INT8]
  26. [TENSOR QNT_TYPE:         AFFINE]
  27. [TENSOR ZP:               3]
  28. [TENSOR SCALE:            0.14404]
复制代码
特征融合模型输入为tensor, 输出为tensor,具体如下:
  1. [Correlation branch]: INPUT_ATTR
  2. [TENSOR INDEX:            0]
  3. [TENSOR NAME:             input1]
  4. [TENSOR N_DIMS:           4]
  5. [TENSOR DIMS:             [1, 8, 8, 48, ]]
  6. [TENSOR N_ELEMS:          3072]
  7. [TENSOR SIZE:             3072]
  8. [TENSOR W_STRIDE:         8]
  9. [TENSOR SIZE_WITH_STRIDE: 3072]
  10. [TENSOR FORMAT:           NHWC]
  11. [TENSOR TYPE:             INT8]
  12. [TENSOR QNT_TYPE:         AFFINE]
  13. [TENSOR ZP:               -22]
  14. [TENSOR SCALE:            0.124435]

  15. [Correlation branch]: INPUT_ATTR
  16. [TENSOR INDEX:            1]
  17. [TENSOR NAME:             input2]
  18. [TENSOR N_DIMS:           4]
  19. [TENSOR DIMS:             [1, 16, 16, 48, ]]
  20. [TENSOR N_ELEMS:          12288]
  21. [TENSOR SIZE:             12288]
  22. [TENSOR W_STRIDE:         16]
  23. [TENSOR SIZE_WITH_STRIDE: 12288]
  24. [TENSOR FORMAT:           NHWC]
  25. [TENSOR TYPE:             INT8]
  26. [TENSOR QNT_TYPE:         AFFINE]
  27. [TENSOR ZP:               -4]
  28. [TENSOR SCALE:            0.189816]

  29. [Correlation branch]: OUTPUT_ATTR
  30. [TENSOR INDEX:            0]
  31. [TENSOR NAME:             output1]
  32. [TENSOR N_DIMS:           4]
  33. [TENSOR DIMS:             [1, 2, 16, 16, ]]
  34. [TENSOR N_ELEMS:          512]
  35. [TENSOR SIZE:             512]
  36. [TENSOR W_STRIDE:         0]
  37. [TENSOR SIZE_WITH_STRIDE: 4096]
  38. [TENSOR FORMAT:           NCHW]
  39. [TENSOR TYPE:             INT8]
  40. [TENSOR QNT_TYPE:         AFFINE]
  41. [TENSOR ZP:               -5]
  42. [TENSOR SCALE:            0.0481221]

  43. [Correlation branch]: OUTPUT_ATTR
  44. [TENSOR INDEX:            1]
  45. [TENSOR NAME:             output2]
  46. [TENSOR N_DIMS:           4]
  47. [TENSOR DIMS:             [1, 4, 16, 16, ]]
  48. [TENSOR N_ELEMS:          1024]
  49. [TENSOR SIZE:             1024]
  50. [TENSOR W_STRIDE:         0]
  51. [TENSOR SIZE_WITH_STRIDE: 4096]
  52. [TENSOR FORMAT:           NCHW]
  53. [TENSOR TYPE:             INT8]
  54. [TENSOR QNT_TYPE:         AFFINE]
  55. [TENSOR ZP:               -128]
  56. [TENSOR SCALE:            0.315689]
复制代码
当前我采用的编程思路为特征提取模型采用零拷贝API
因为特征融合网络输入通道不为1,3,4,所以只能采用通用API
问题在于:
特征提取网络输出格式为NCHW,数据类型为INT8,特征融合网络输入NHWC,数据类型为INT8
我针对特征提取网络输出做了如下设置:
  1. attr_output_z.fmt = RKNN_TENSOR_NCHW;
  2.     attr_output_z.type = RKNN_TENSOR_FLOAT32;
  3.     ret = rknn_set_io_mem(ctx_z, mem_input_z, &attr_input_z);
复制代码
对于特征融合网络如下设置:
  1. rknn_input input;
  2.     for (int i = 0; i < io_num.n_input; i++) {
  3.         input.index = i;
  4.         input.type = RKNN_TENSOR_FLOAT32;
  5.         input.size = attr_inputs[i].size * sizeof(float);
  6.         input.fmt = RKNN_TENSOR_NCHW;
  7.         input.pass_through = 0;

  8.         inputs.push_back(input);
  9.     }
复制代码
两个模型之间的数据赋值,即特征提取网络结果作为特征融合模型输入代码如下:
  1. inputs[0].buf = mem_output_z->virt_addr;
复制代码


报错信息如下:
  1. E RKNN: [02:51:54.114] Meet unsupported src layout for normalize: 2
  2. E RKNN: [02:51:54.114] rknn_inputs_set, normalize error(-1) index=0
  3. E RKNN: [02:51:54.114] Meet unsupported src layout for normalize: 2
  4. E RKNN: [02:51:54.114] rknn_inputs_set, normalize error(-1) index=1
复制代码
请问是思路有问题,比如统一使用通用API?还是需要手动将NCHW的输入转换为NHWC,在进行操作?





欢迎光临 Toybrick (https://t.rock-chips.com/) Powered by Discuz! X3.3