Toybrick

标题: Tensorflow转换RKNN模型求助 [打印本页]

作者: 86667780    时间: 2020-5-4 20:33
标题: Tensorflow转换RKNN模型求助
  1. from rknn.api import RKNN
复制代码
转化代码如上,查看tool使用说明书,因为输入的是4096个空间点数据,每个点有x,y,z的坐标。因此config未配置归一化。另外就是有个疑问,我使用的PB模型是自己利用开源算法训练好的CKPT经行固化的,只指定了输出节点,对于输入节点我一直有疑问不知道是否是Placeholder
网络结构如下
[attach]1055[/attach][attach]1056[/attach]

源码model和train文件如下

Model
  1. import tensorflow as tf
  2. import math
  3. import time
  4. import numpy as np
  5. import os
  6. import sys
  7. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  8. ROOT_DIR = os.path.dirname(BASE_DIR)
  9. sys.path.append(os.path.join(ROOT_DIR, 'utils'))
  10. import tf_util

  11. def placeholder_inputs(batch_size, num_point):
  12.     pointclouds_pl = tf.placeholder(tf.float32,
  13.                                      shape=(batch_size, num_point, 9))
  14.     labels_pl = tf.placeholder(tf.int32,
  15.                                 shape=(batch_size, num_point))
  16.     return pointclouds_pl, labels_pl

  17. def get_model(point_cloud, is_training, bn_decay=None):
  18.     """ ConvNet baseline, input is BxNx3 gray image """
  19.     batch_size = point_cloud.get_shape()[0].value
  20.     num_point = point_cloud.get_shape()[1].value

  21.     input_image = tf.expand_dims(point_cloud, -1)
  22.     # CONV
  23.     net = tf_util.conv2d(input_image, 64, [1,9], padding='VALID', stride=[1,1],
  24.                          bn=True, is_training=is_training, scope='conv1', bn_decay=bn_decay)
  25.     net = tf_util.conv2d(net, 64, [1,1], padding='VALID', stride=[1,1],
  26.                          bn=True, is_training=is_training, scope='conv2', bn_decay=bn_decay)
  27.     net = tf_util.conv2d(net, 64, [1,1], padding='VALID', stride=[1,1],
  28.                          bn=True, is_training=is_training, scope='conv3', bn_decay=bn_decay)
  29.     net = tf_util.conv2d(net, 128, [1,1], padding='VALID', stride=[1,1],
  30.                          bn=True, is_training=is_training, scope='conv4', bn_decay=bn_decay)
  31.     points_feat1 = tf_util.conv2d(net, 1024, [1,1], padding='VALID', stride=[1,1],
  32.                          bn=True, is_training=is_training, scope='conv5', bn_decay=bn_decay)
  33.     # MAX
  34.     pc_feat1 = tf_util.max_pool2d(points_feat1, [num_point,1], padding='VALID', scope='maxpool1')
  35.     # FC
  36.     pc_feat1 = tf.reshape(pc_feat1, [batch_size, -1])
  37.     pc_feat1 = tf_util.fully_connected(pc_feat1, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay)
  38.     pc_feat1 = tf_util.fully_connected(pc_feat1, 128, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay)
  39.     print(pc_feat1)
  40.    
  41.     # CONCAT
  42.     pc_feat1_expand = tf.tile(tf.reshape(pc_feat1, [batch_size, 1, 1, -1]), [1, num_point, 1, 1])
  43.     points_feat1_concat = tf.concat(axis=3, values=[points_feat1, pc_feat1_expand])
  44.    
  45.     # CONV
  46.     net = tf_util.conv2d(points_feat1_concat, 512, [1,1], padding='VALID', stride=[1,1],
  47.                          bn=True, is_training=is_training, scope='conv6')
  48.     net = tf_util.conv2d(net, 256, [1,1], padding='VALID', stride=[1,1],
  49.                          bn=True, is_training=is_training, scope='conv7')
  50.     net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='dp1')
  51.     net = tf_util.conv2d(net, 13, [1,1], padding='VALID', stride=[1,1],
  52.                          activation_fn=None, scope='conv8')
  53.     net = tf.squeeze(net, [2])

  54.     return net

  55. def get_loss(pred, label):
  56.     """ pred: B,N,13
  57.         label: B,N """
  58.     loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label)
  59.     return tf.reduce_mean(loss)

  60. if __name__ == "__main__":
  61.     with tf.Graph().as_default():
  62.         a = tf.placeholder(tf.float32, shape=(32,4096,9))
  63.         net = get_model(a, tf.constant(True))
  64.         with tf.Session() as sess:
  65.             init = tf.global_variables_initializer()
  66.             sess.run(init)
  67.             start = time.time()
  68.             for i in range(100):
  69.                 print(i)
  70.                 sess.run(net, feed_dict={a:np.random.rand(32,4096,9)})
  71.             print(time.time() - start)
复制代码


Train
  1. import argparse
  2. import math
  3. import h5py
  4. import numpy as np
  5. import tensorflow as tf
  6. import socket

  7. import os
  8. import sys
  9. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  10. ROOT_DIR = os.path.dirname(BASE_DIR)
  11. sys.path.append(BASE_DIR)
  12. sys.path.append(ROOT_DIR)
  13. sys.path.append(os.path.join(ROOT_DIR, 'utils'))
  14. import provider
  15. import tf_util
  16. from model import *


  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
  19. parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
  20. parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]')
  21. parser.add_argument('--max_epoch', type=int, default=50, help='Epoch to run [default: 50]')
  22. parser.add_argument('--batch_size', type=int, default=24, help='Batch Size during training [default: 24]')
  23. parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
  24. parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
  25. parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
  26. parser.add_argument('--decay_step', type=int, default=300000, help='Decay step for lr decay [default: 300000]')
  27. parser.add_argument('--decay_rate', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]')
  28. parser.add_argument('--test_area', type=int, default=6, help='Which area to use for test, option: 1-6 [default: 6]')
  29. FLAGS = parser.parse_args()


  30. BATCH_SIZE = FLAGS.batch_size
  31. NUM_POINT = FLAGS.num_point
  32. MAX_EPOCH = FLAGS.max_epoch
  33. NUM_POINT = FLAGS.num_point
  34. BASE_LEARNING_RATE = FLAGS.learning_rate
  35. GPU_INDEX = FLAGS.gpu
  36. MOMENTUM = FLAGS.momentum
  37. OPTIMIZER = FLAGS.optimizer
  38. DECAY_STEP = FLAGS.decay_step
  39. DECAY_RATE = FLAGS.decay_rate

  40. LOG_DIR = FLAGS.log_dir
  41. if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
  42. os.system('cp model.py %s' % (LOG_DIR)) # bkp of model def
  43. os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure
  44. LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
  45. LOG_FOUT.write(str(FLAGS)+'\n')

  46. MAX_NUM_POINT = 4096
  47. NUM_CLASSES = 13

  48. BN_INIT_DECAY = 0.5
  49. BN_DECAY_DECAY_RATE = 0.5
  50. #BN_DECAY_DECAY_STEP = float(DECAY_STEP * 2)
  51. BN_DECAY_DECAY_STEP = float(DECAY_STEP)
  52. BN_DECAY_CLIP = 0.99

  53. HOSTNAME = socket.gethostname()

  54. ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt')
  55. room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')]

  56. # Load ALL data
  57. data_batch_list = []
  58. label_batch_list = []
  59. for h5_filename in ALL_FILES:
  60.     data_batch, label_batch = provider.loadDataFile(h5_filename)
  61.     data_batch_list.append(data_batch)
  62.     label_batch_list.append(label_batch)
  63. data_batches = np.concatenate(data_batch_list, 0)
  64. label_batches = np.concatenate(label_batch_list, 0)
  65. print(data_batches.shape)
  66. print(label_batches.shape)

  67. test_area = 'Area_'+str(FLAGS.test_area)
  68. train_idxs = []
  69. test_idxs = []
  70. for i,room_name in enumerate(room_filelist):
  71.     if test_area in room_name:
  72.         test_idxs.append(i)
  73.     else:
  74.         train_idxs.append(i)

  75. train_data = data_batches[train_idxs,...]
  76. train_label = label_batches[train_idxs]
  77. test_data = data_batches[test_idxs,...]
  78. test_label = label_batches[test_idxs]
  79. print(train_data.shape, train_label.shape)
  80. print(test_data.shape, test_label.shape)




  81. def log_string(out_str):
  82.     LOG_FOUT.write(out_str+'\n')
  83.     LOG_FOUT.flush()
  84.     print(out_str)


  85. def get_learning_rate(batch):
  86.     learning_rate = tf.train.exponential_decay(
  87.                         BASE_LEARNING_RATE,  # Base learning rate.
  88.                         batch * BATCH_SIZE,  # Current index into the dataset.
  89.                         DECAY_STEP,          # Decay step.
  90.                         DECAY_RATE,          # Decay rate.
  91.                         staircase=True)
  92.     learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!!
  93.     return learning_rate        

  94. def get_bn_decay(batch):
  95.     bn_momentum = tf.train.exponential_decay(
  96.                       BN_INIT_DECAY,
  97.                       batch*BATCH_SIZE,
  98.                       BN_DECAY_DECAY_STEP,
  99.                       BN_DECAY_DECAY_RATE,
  100.                       staircase=True)
  101.     bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum)
  102.     return bn_decay

  103. def train():
  104.     with tf.Graph().as_default():
  105.         with tf.device('/gpu:'+str(GPU_INDEX)):
  106.             pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT)
  107.             is_training_pl = tf.placeholder(tf.bool, shape=())
  108.             
  109.             # Note the global_step=batch parameter to minimize.
  110.             # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
  111.             batch = tf.Variable(0)
  112.             bn_decay = get_bn_decay(batch)
  113.             tf.summary.scalar('bn_decay', bn_decay)

  114.             # Get model and loss
  115.             pred = get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
  116.             loss = get_loss(pred, labels_pl)
  117.             tf.summary.scalar('loss', loss)

  118.             correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_pl))
  119.             accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT)
  120.             tf.summary.scalar('accuracy', accuracy)

  121.             # Get training operator
  122.             learning_rate = get_learning_rate(batch)
  123.             tf.summary.scalar('learning_rate', learning_rate)
  124.             if OPTIMIZER == 'momentum':
  125.                 optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
  126.             elif OPTIMIZER == 'adam':
  127.                 optimizer = tf.train.AdamOptimizer(learning_rate)
  128.             train_op = optimizer.minimize(loss, global_step=batch)
  129.             
  130.             # Add ops to save and restore all the variables.
  131.             saver = tf.train.Saver()
  132.             
  133.         # Create a session
  134.         config = tf.ConfigProto()
  135.         config.gpu_options.allow_growth = True
  136.         config.allow_soft_placement = True
  137.         config.log_device_placement = True
  138.         sess = tf.Session(config=config)

  139.         # Add summary writers
  140.         merged = tf.summary.merge_all()
  141.         train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
  142.                                   sess.graph)
  143.         test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))

  144.         # Init variables
  145.         init = tf.global_variables_initializer()
  146.         sess.run(init, {is_training_pl:True})

  147.         ops = {'pointclouds_pl': pointclouds_pl,
  148.                'labels_pl': labels_pl,
  149.                'is_training_pl': is_training_pl,
  150.                'pred': pred,
  151.                'loss': loss,
  152.                'train_op': train_op,
  153.                'merged': merged,
  154.                'step': batch}

  155.         for epoch in range(MAX_EPOCH):
  156.             log_string('**** EPOCH %03d ****' % (epoch))
  157.             sys.stdout.flush()
  158.             
  159.             train_one_epoch(sess, ops, train_writer)
  160.             eval_one_epoch(sess, ops, test_writer)
  161.             
  162.             # Save the variables to disk.
  163.             if epoch % 10 == 0:
  164.                 save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
  165.                 log_string("Model saved in file: %s" % save_path)




  166. def train_one_epoch(sess, ops, train_writer):
  167.     """ ops: dict mapping from string to tf ops """
  168.     is_training = True
  169.    
  170.     log_string('----')
  171.     current_data, current_label, _ = provider.shuffle_data(train_data[:,0:NUM_POINT,:], train_label)
  172.    
  173.     file_size = current_data.shape[0]
  174.     num_batches = file_size // BATCH_SIZE
  175.    
  176.     total_correct = 0
  177.     total_seen = 0
  178.     loss_sum = 0
  179.    
  180.     for batch_idx in range(num_batches):
  181.         if batch_idx % 100 == 0:
  182.             print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches))
  183.         start_idx = batch_idx * BATCH_SIZE
  184.         end_idx = (batch_idx+1) * BATCH_SIZE
  185.         
  186.         feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :],
  187.                      ops['labels_pl']: current_label[start_idx:end_idx],
  188.                      ops['is_training_pl']: is_training,}
  189.         summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']],
  190.                                          feed_dict=feed_dict)
  191.         train_writer.add_summary(summary, step)
  192.         pred_val = np.argmax(pred_val, 2)
  193.         correct = np.sum(pred_val == current_label[start_idx:end_idx])
  194.         total_correct += correct
  195.         total_seen += (BATCH_SIZE*NUM_POINT)
  196.         loss_sum += loss_val
  197.    
  198.     log_string('mean loss: %f' % (loss_sum / float(num_batches)))
  199.     log_string('accuracy: %f' % (total_correct / float(total_seen)))

  200.         
  201. def eval_one_epoch(sess, ops, test_writer):
  202.     """ ops: dict mapping from string to tf ops """
  203.     is_training = False
  204.     total_correct = 0
  205.     total_seen = 0
  206.     loss_sum = 0
  207.     total_seen_class = [0 for _ in range(NUM_CLASSES)]
  208.     total_correct_class = [0 for _ in range(NUM_CLASSES)]
  209.    
  210.     log_string('----')
  211.     current_data = test_data[:,0:NUM_POINT,:]
  212.     current_label = np.squeeze(test_label)
  213.    
  214.     file_size = current_data.shape[0]
  215.     num_batches = file_size // BATCH_SIZE
  216.    
  217.     for batch_idx in range(num_batches):
  218.         start_idx = batch_idx * BATCH_SIZE
  219.         end_idx = (batch_idx+1) * BATCH_SIZE

  220.         feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :],
  221.                      ops['labels_pl']: current_label[start_idx:end_idx],
  222.                      ops['is_training_pl']: is_training}
  223.         summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['loss'], ops['pred']],
  224.                                       feed_dict=feed_dict)
  225.         test_writer.add_summary(summary, step)
  226.         pred_val = np.argmax(pred_val, 2)
  227.         correct = np.sum(pred_val == current_label[start_idx:end_idx])
  228.         total_correct += correct
  229.         total_seen += (BATCH_SIZE*NUM_POINT)
  230.         loss_sum += (loss_val*BATCH_SIZE)
  231.         for i in range(start_idx, end_idx):
  232.             for j in range(NUM_POINT):
  233.                 l = current_label[i, j]
  234.                 total_seen_class[l] += 1
  235.                 total_correct_class[l] += (pred_val[i-start_idx, j] == l)
  236.             
  237.     log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT)))
  238.     log_string('eval accuracy: %f'% (total_correct / float(total_seen)))
  239.     log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))))
  240.          


  241. if __name__ == "__main__":
  242.     train()
  243.     LOG_FOUT.close()
复制代码


运行PB转换RKNN时报错信息如下
  1. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/ceil.py:10: The name tf.ceil is deprecated. Please use tf.math.ceil instead.

  2. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/depth_to_space.py:12: The name tf.depth_to_space is deprecated. Please use tf.compat.v1.depth_to_space instead.

  3. W:tensorflow:
  4. The TensorFlow contrib module will not be included in TensorFlow 2.0.
  5. For more information, please see:
  6.   * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  7.   * https://github.com/tensorflow/addons
  8.   * https://github.com/tensorflow/io (for I/O related ops)
  9. If you depend on functionality not listed there, please file an issue.

  10. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/log.py:10: The name tf.log is deprecated. Please use tf.math.log instead.

  11. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/random_normal.py:9: The name tf.random_normal is deprecated. Please use tf.random.normal instead.

  12. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/random_uniform.py:9: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

  13. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/handlers/backend/upsample.py:13: The name tf.image.resize_images is deprecated. Please use tf.image.resize instead.

  14. /home/toybrick/.local/lib/python3.7/site-packages/onnx_tf/common/__init__.py:87: UserWarning: FrontendHandler.get_outputs_names is deprecated. It will be removed in future release.. Use node.outputs instead.
  15.   warnings.warn(message)
  16. W:tensorflow:From /home/toybrick/.local/lib/python3.7/site-packages/rknn/api/rknn.py:67: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
  17. Instructions for updating:
  18. Use `tf.compat.v1.graph_util.extract_sub_graph`
  19. E Catch exception when loading tensorflow model: pointnet_model_final.pb!
  20. E Traceback (most recent call last):
  21. E   File "rknn/api/rknn_base.py", line 215, in rknn.api.rknn_base.RKNNBase.load_tensorflow
  22. E   File "rknn/base/RKNNlib/converter/convert_tf.py", line 527, in rknn.base.RKNNlib.converter.convert_tf.convert_tf.pre_process
  23. E   File "rknn/base/RKNNlib/converter/tensorflowloader.py", line 77, in rknn.base.RKNNlib.converter.tensorflowloader.TF_Graph_Preprocess.pre_proces
  24. E   File "rknn/base/RKNNlib/converter/tensorflowloader.py", line 520, in rknn.base.RKNNlib.converter.tensorflowloader.TF_Graph_Preprocess.freeze_switch_path_v3
  25. E   File "rknn/base/RKNNlib/converter/tensorflowloader.py", line 436, in rknn.base.RKNNlib.converter.tensorflowloader.TF_Graph_Preprocess.freeze_switch_path_v3.fix_select_branch
  26. E   File "rknn/base/RKNNlib/converter/tf_util.py", line 198, in rknn.base.RKNNlib.converter.tf_util.TFProto_Util.change_input
  27. E   File "/home/toybrick/.local/lib/python3.7/site-packages/google/protobuf/internal/containers.py", line 204, in __getitem__
  28. E     return self._values[key]
  29. E IndexError: list index out of range
  30. done
  31. --> Building model
  32. Traceback (most recent call last):
  33.   File "rknn_transfer.py", line 26, in <module>
  34.     rknn.build()
  35.   File "/home/toybrick/.local/lib/python3.7/site-packages/rknn/api/rknn.py", line 222, in build
  36.     inputs = self.rknn_base.net.get_input_layers()
  37. AttributeError: 'NoneType' object has no attribute 'get_input_layers'
复制代码



作者: 86667780    时间: 2020-5-4 20:35
  1. from rknn.api import RKNN  


  2. if __name__ == '__main__':
  3.     # 创建RKNN执行对象
  4.     rknn = RKNN()

  5.     rknn.config()

  6. # 加载TensorFlow模型
  7. # tf_pb='digital_gesture.pb'指定待转换的TensorFlow模型
  8. # inputs指定模型中的输入节点
  9. # outputs指定模型中输出节点
  10. # input_size_list指定模型输入的大小
  11.     print('--> Loading model')
  12.     rknn.load_tensorflow(tf_pb='pointnet_model_final.pb',
  13.                          inputs=['Placeholder'],
  14.                          outputs=['Squeeze'],
  15.                          input_size_list=[[4096, 3]])
  16.     print('done')

  17. # 创建解析pb模型
  18. # do_quantization=False指定不进行量化
  19. # 量化会减小模型的体积和提升运算速度,但是会有精度的丢失
  20.     print('--> Building model')
  21.     rknn.build(do_quantization=False)
  22.     print('done')

  23.     # 导出保存rknn模型文件
  24.     print('--> Exporting model')
  25.     rknn.export_rknn('./digital_gesture.rknn')
  26.     print('done')

  27.     # Release RKNN Context
  28.     print('--> Release model')
  29.     rknn.release()
  30.     print('done')
复制代码


额,转换的代码不小心没打全
作者: leok    时间: 2020-5-5 11:29
pb模型文件发出来
作者: jefferyzhang    时间: 2020-5-5 13:26
E   File "/home/toybrick/.local/lib/python3.7/site-packages/google/protobuf/internal/containers.py", line 204, in __getitem__
E     return self._values[key]
E IndexError: list index out of range

1. 确认并提供下tensorflow版本号。
2 .确认并提供rknn-toolkit版本号。
3. 确认这个pb文件在转换环境下是可以被tensorflow加载并推理的。
作者: 86667780    时间: 2020-5-5 14:03
leok 发表于 2020-5-5 11:29
pb模型文件发出来

PB模型在压缩包中,等级比较低没办法上传原始的CKPT文件(受大小限制)。

作者: 86667780    时间: 2020-5-5 14:27
jefferyzhang 发表于 2020-5-5 13:26
E   File "/home/toybrick/.local/lib/python3.7/site-packages/google/protobuf/internal/containers.py", ...

您好
1.tensorflow版本查询为1.14.0
2.运行官方rknn/mobilenet-ssd,返回版本
    RKNNAPI:   API:1.3.2
    RKNNAPI:   DRV:1.3.1
    rknn-toolkit:  1.3.2??(这个的版本不知道咋看,使用pip下载显示的是rknn-1.3.2的,我是按照rknn-toolkit && rknn-api for Toybrick 那个帖子安装的,应该都是最新的)
3.pb模型我自己用netron查看觉得网络没啥问题没进行测试(入门小白,时间比较紧,当时想着跳一下直接去推理rknn的),我现在去试一下。
作者: leok    时间: 2020-5-5 16:47
86667780 发表于 2020-5-5 14:03
PB模型在压缩包中,等级比较低没办法上传原始的CKPT文件(受大小限制)。
...

conv1/Conv2D
你用这个节点作为input去转,前面2个节点在前处理的时候处理下。




欢迎光临 Toybrick (https://t.rock-chips.com/) Powered by Discuz! X3.3