Toybrick

RK3399Pro入门教程(9)MNIST RKNN量化教程

peng

注册会员

积分
169
楼主
发表于 2019-9-20 08:34:19    查看: 43355|回复: 11 | [复制链接]    打印 | 显示全部楼层

一.搭建网络
    model.py,同tensorflow官网mnist的例子差不多,不懂的可以去官网看下官网代码解析
  1. import tensorflow as tf

  2. #这里输入采用28x28方便之后进行rknn量化
  3. x = tf.placeholder("float", [None, 28,28],name='x')
  4. y_ = tf.placeholder("float", [None,10],name='y_')
  5. keep_prob = tf.placeholder("float", name='keep_prob')
  6. def weight_variable(shape,name):
  7.     initial = tf.truncated_normal(shape, stddev=0.1)
  8.     return tf.Variable(initial,name=name)

  9. def bias_variable(shape,name):
  10.     initial = tf.constant(0.1, shape=shape)
  11.     return tf.Variable(initial,name=name)

  12. def conv2d(x, W):
  13.     return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

  14. def max_pool_2x2(x):
  15.     return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  16.                         strides=[1, 2, 2, 1], padding='SAME')

  17. # convolution layer
  18. def lenet5_layer(input, weight, bias,weight_name,bias_name):
  19.     W_conv = weight_variable(weight,weight_name)
  20.     b_conv = bias_variable(bias,bias_name)
  21.     h_conv = tf.nn.relu(conv2d(input, W_conv) + b_conv)
  22.     return max_pool_2x2(h_conv)
  23. # connected layer
  24. def dense_layer(layer, weight, bias,weight_name,bias_name):
  25.     W_fc = weight_variable(weight,weight_name)
  26.     b_fc = bias_variable(bias,bias_name)
  27.     return tf.nn.relu(tf.matmul(layer, W_fc) + b_fc)

  28. def build_model(is_training):
  29.     #first conv
  30.     x_image = tf.reshape(x, [-1,28,28,1])
  31.     W_conv1 = [5, 5, 1, 32]
  32.     b_conv1 = [32]
  33.     layer1 = lenet5_layer(x_image,W_conv1,b_conv1,'W_conv1','b_conv1')
  34.     #second conv
  35.     W_conv2 = [5, 5, 32, 64]
  36.     b_conv2 = [64]
  37.     layer2 = lenet5_layer(layer1,W_conv2,b_conv2,'W_conv2','b_conv2')
  38.     #third conv
  39.     W_fc1 = [7 * 7 * 64, 1024]
  40.     b_fc1 = [1024]
  41.     layer2_flat = tf.reshape(layer2, [-1, 7*7*64])
  42.     layer3 = dense_layer(layer2_flat,W_fc1,b_fc1,'W_fc1','b_fc1')
  43.     #softmax
  44.     W_fc2 = weight_variable([1024, 10],'W_fc2')
  45.     b_fc2 = bias_variable([10],'b_fc2')
  46.     if is_training:
  47.         #dropout
  48.         h_fc1_drop = tf.nn.dropout(layer3, keep_prob)
  49.         finaloutput=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2,name="y_conv")
  50.     else:
  51.         finaloutput=tf.nn.softmax(tf.matmul(layer3, W_fc2) + b_fc2,name="y_conv")
  52.     print('finaloutput:', finaloutput)
  53.     return finaloutput
复制代码
二.训练网络
    train.py,这里代码兼容了tf伪量化的代码,这里我们把create_training_graph()传入的参数is_quantify设为False就可以了,由于mnist拿到的train和test数据shape都是(784,),这里定义了一个reshape_batch函数把train时的batch以及test时的输入都reshape成(28,28),具体代码如下:
  1. # -*- coding=utf-8 -*-
  2. import tensorflow as tf
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. from model import build_model, x, y_, keep_prob

  5. mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
  6. def create_training_graph(is_quantify):
  7.     #创建训练图,加入create_training_graph:
  8.     g = tf.get_default_graph()   # 给create_training_graph的参数,默认图
  9.     #调用网络定义,也就是拿到输出
  10.     y_conv = build_model(is_training=True)    #这里的is_training设置为True,因为前面模型定义写了训练时要用到dropout
  11.     #损失函数
  12.     cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
  13.     print('cost:', cross_entropy)
  14.     if is_quantify:
  15.         # 加入 create_training_graph函数,注意位置要在loss之后, optimize之前
  16.         tf.contrib.quantize.create_training_graph(input_graph=g, quant_delay=0)
  17.     #  optimize
  18.     optimize = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  19.     #计算准确率
  20.     correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
  21.     # 给出识别准确率[这会返回我们一个布尔值的列表.为了确定哪些部分是正确的,我们要把它转换成浮点值,然后再示均值。 比如, [True, False, True, True] 会转换成 [1,0,1,1] ,从而它的准确率就是0.75.]   
  22.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

  23.     # 返回所需数据,供训练使用
  24.     return dict(
  25.         x=x,
  26.         y=y_,
  27.         keep_prob=keep_prob,
  28.         optimize=optimize,
  29.         cost=cross_entropy,
  30.         correct_prediction=correct_prediction,
  31.         accuracy=accuracy,
  32.     )

  33. def reshape_batch(batch):
  34.     rebatch = []
  35.     for item in batch:
  36.         b = item.reshape(28,28)
  37.         rebatch.append(b)
  38.     return rebatch
  39. #开始训练
  40. def train_network(graph,ckpt,point_dir,pbtxt):
  41.     # 初始化
  42.     init = tf.global_variables_initializer()
  43.     # 调用Saver函数保存所需文件
  44.     saver = tf.train.Saver()
  45.     # 创建上下文,开始训练sess.run(init)
  46.     with tf.Session() as sess:
  47.         sess.run(init)
  48.         # 一共训练两万次,准确率达到百分99以上
  49.         for i in range(20000):
  50.         # 每次处理50张图片
  51.             batch = mnist.train.next_batch(50)
  52.             # 每100次保存并打印一次准确率等
  53.             if i % 100 == 0:
  54.             # feed_dict喂数据,数据全reshape成28x28
  55.                 train_accuracy = sess.run([graph['accuracy']], feed_dict={
  56.                                                                            graph['x']:reshape_batch(batch[0]),    # batch[0]存的图片数据
  57.                                                                            graph['y']:batch[1],    # batch[1]存的标签
  58.                                                                            graph['keep_prob']: 1.0})
  59.                 print("step %d, training accuracy %g"%(i, train_accuracy[0]))
  60.             sess.run([graph['optimize']], feed_dict={
  61.                                                        graph['x']:reshape_batch(batch[0]),
  62.                                                        graph['y']:batch[1],
  63.                                                        graph['keep_prob']:0.5})
  64.         test_accuracy = sess.run([graph['accuracy']], feed_dict={
  65.                                                                   graph['x']: reshape_batch(mnist.test.images),
  66.                                                                   graph['y']: mnist.test.labels,
  67.                                                                   graph['keep_prob']: 1.0})
  68.         print("Test accuracy %g" % test_accuracy[0])
  69.         # 保存ckpt(checkpoint)和pbtxt。记得把路径改成自己的路径
  70.         saver.save(sess, ckpt)
  71.         tf.train.write_graph(sess.graph_def,point_dir,pbtxt, True)
  72.         print(tf.trainable_variables())
  73.         print(tf.get_variable('W_fc2',[1024, 10]).value)


  74. if __name__ == "__main__":
  75.     ckpt = './checkpoint/mnist.ckpt'
  76.     point_dir = './checkpoint'
  77.     pbtxt = 'mnist.pbtxt'
  78.     g1 = create_training_graph(False)
  79.     train_network(g1,ckpt,point_dir,pbtxt)

复制代码
三.保存网络参数
    freese.py,将网络中参数和变量从ckpt中读出来,保存为pb文件,同上一步一样,将frozen函数的is_quantify设为False就可以了:
  1. import tensorflow as tf
  2. import os.path
  3. from model import build_model
  4. from tensorflow.python.framework import graph_util

  5. # 创建推理图
  6. def create_inference_graph():
  7.     """Build the mnist model for evaluation."""
  8. # 调用网络,Create an output to use for inference.
  9.     logits = build_model(is_training=False)
  10.     return logits
  11.     # # 得到分类输出  
  12.     # tf.nn.softmax(logits, name='output')
  13. def load_variables_from_checkpoint(sess, start_checkpoint):
  14.     """Utility function to centralize checkpoint restoration.
  15.     Args:
  16.       sess: TensorFlow session.
  17.       start_checkpoint: Path to saved checkpoint on disk.
  18.     """
  19.     saver = tf.train.Saver(tf.global_variables())
  20.     saver.restore(sess, start_checkpoint)

  21. def frozen(is_quantify,ckpt,pbtxt):
  22.     # Create the model and load its weights.
  23.     init = tf.global_variables_initializer()
  24.     with tf.Session() as sess:
  25.         sess.run(init)
  26. # 推理图
  27.         logits = create_inference_graph()  
  28. # 加入create_eval_graph(),转化为tflite可接受的格式。以下语句中有路径的,记得改路径。
  29.         if is_quantify:
  30.             tf.contrib.quantize.create_eval_graph()
  31.         load_variables_from_checkpoint(sess, ckpt)
  32.         # Turn all the variables into inline constants inside the graph and save it.
  33. # 固化 frozen:ckpt + pbtxt
  34.         frozen_graph_def = graph_util.convert_variables_to_constants(
  35.             sess, sess.graph_def, ['y_conv'])
  36. # 保存最终的pb模型
  37.         tf.train.write_graph(
  38.             frozen_graph_def,
  39.             os.path.dirname(pbtxt),
  40.             os.path.basename(pbtxt),
  41.             as_text=False)
  42.         tf.logging.info('Saved frozen graph to %s', pbtxt)

  43. if __name__ == "__main__":
  44.     ckpt = './checkpoint/mnist.ckpt'
  45.     pbtxt = 'mnist_frozen_graph.pb'
  46.     frozen(False,ckpt,pbtxt)
  47.     #is_quantify False   mnist_frozen_graph_not_28x28.pb
  48.     # ckpt = './checkpoint_not/mnist.ckpt'
  49.     # pbtxt = 'mnist_frozen_graph_not.pb'
  50.     # frozen(False,ckpt,pbtxt)
  51.     # ckpt = './test/mnist.ckpt'
  52.     # pbtxt = 'test.pb'
  53.     # frozen(False,ckpt,pbtxt)
复制代码
四.将pb模型转为rknn
     由于量化rknn模型需要相应图片集,因此我们先要获取相应的数据集进入mnist数据目录下,解压t10k-images-idx3-ubyte.gz,然后运行get_image.py,将原先压缩的数据转为图片,同时得到量化需要的dataset.txt文件。
    get_image.py
  1. import struct
  2. import numpy as np
  3. #import matplotlib.pyplot as plt
  4. import PIL.Image
  5. from PIL import Image
  6. import os

  7. os.system("mkdir ../MNIST_data/mnist_test")
  8. filename='../MNIST_data/t10k-images.idx3-ubyte'
  9. dataset = './dataset.txt'
  10. binfile=open(filename,'rb')
  11. buf=binfile.read()
  12. index=0
  13. data_list = []
  14. magic,numImages,numRows,numColumns=struct.unpack_from('>IIII',buf,index)
  15. index+=struct.calcsize('>IIII')
  16. for image in range(0,numImages):
  17.     im=struct.unpack_from('>784B',buf,index)
  18.     index+=struct.calcsize('>784B')
  19.     im=np.array(im,dtype='uint8')
  20.     im=im.reshape(28,28)
  21.     im=Image.fromarray(im)
  22.     im.save('../MNIST_data/mnist_test/test_%s.jpg'%image,'jpeg')
  23.     data_list.append('../MNIST_data/mnist_test/test_%s.jpg\n'%image)
  24. with open(dataset,'w+') as ff:
  25.     ff.writelines(data_list)
复制代码
rknn_transfer.py:
  1. from rknn.api import RKNN

  2. def common_transfer(pb_name,export_name):
  3.         ret = 0
  4.         #看具体log 传入verbose=True
  5.         rknn = RKNN()
  6.         #灰度图无需此步操作
  7.         # rknn.config(channel_mean_value='', reorder_channel='')
  8.         print('--> Loading model')

  9.         ret = rknn.load_tensorflow(
  10.                 tf_pb='./mnist_frozen_graph.pb',
  11.                 inputs=['x'],
  12.                 outputs=['y_conv'],
  13.                 input_size_list=[[28,28,1]])
  14.         if ret != 0:
  15.                 print('load_tensorflow error')
  16.                 rknn.release()
  17.                 return ret
  18.         print('done')
  19.         print('--> Building model')
  20.         rknn.build(do_quantization=False)
  21.         print('done')
  22.         # 导出保存rknn模型文件
  23.         rknn.export_rknn('./mnist.rknn')
  24.         # Release RKNN Context
  25.         rknn.release()
  26.         return ret

  27. def quantify_transfer(pb_name,dataset_name,export_name):
  28.         ret = 0
  29.         print(pb_name,dataset_name,export_name)
  30.         rknn = RKNN()
  31.         rknn.config(channel_mean_value='', reorder_channel='',quantized_dtype='dynamic_fixed_point-8')
  32.         print('--> Loading model')
  33.         ret = rknn.load_tensorflow(
  34.                 tf_pb=pb_name,
  35.                 inputs=['x'],
  36.                 outputs=['y_conv'],
  37.                 input_size_list=[[28,28,1]])
  38.         if ret != 0:
  39.                 print('load_tensorflow error')
  40.                 rknn.release()
  41.                 return ret
  42.         print('done')
  43.         print('--> Building model')
  44.         rknn.build(do_quantization=True,dataset=dataset_name)
  45.         print('done')
  46.         # 导出保存rknn模型文件
  47.         rknn.export_rknn(export_name)
  48.         # Release RKNN Context
  49.         rknn.release()
  50.         return ret
  51. if __name__ == '__main__':
  52.         #pb转化为rknn模型
  53.         pb_name = './mnist_frozen_graph.pb'
  54.         export_name = './mnist.rknn'
  55.         ret = common_transfer(pb_name,export_name)
  56.         if ret != 0:
  57.                 print('======common transfer error !!===========')
  58.         else:
  59.                 print('======common transfer ok !!===========')
  60.         dataset_name = './dataset.txt'
  61.         export_name = './mnist_quantization.rknn'
  62.         #pb转化为量化的rknn模型
  63.         quantify_transfer(pb_name,dataset_name,export_name)
  64.         if ret != 0:
  65.                 print('======quantization transfer 10000 error !!===========')
  66.         else:
  67.                 print('======quantization transfer 10000 ok !!===========')

复制代码
五.对比pb和rknn的推理结果,比较他们的准确度
     分别运行tf_predict.py,rknn_predict.py得到tf模型,rknn模型,量化的rknn模型的运行结果:
     tf_predict.py
  1. #! -*- coding: utf-8 -*-
  2. from __future__ import absolute_import, unicode_literals
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. import tensorflow as tf

  5. mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
  6. origin_test = mnist.test.images
  7. reshape_test = []
  8. for t in origin_test:
  9.     b = t.reshape(28,28)
  10.     reshape_test.append(b)
  11. for length in [100,500,1000,10000]:
  12.     with tf.Graph().as_default():
  13.         output_graph_def = tf.GraphDef()
  14.         output_graph_path = './mnist_frozen_graph.pb'

  15.         with open(output_graph_path, 'rb') as f:
  16.             output_graph_def.ParseFromString(f.read())
  17.             _ = tf.import_graph_def(output_graph_def, name="")
  18.      
  19.         with tf.Session() as sess:
  20.             sess.run(tf.global_variables_initializer())
  21.             input = sess.graph.get_tensor_by_name("x:0")
  22.             output = sess.graph.get_tensor_by_name("y_conv:0")
  23.             y_conv_2 = sess.run(output, feed_dict={input:reshape_test[0:length]})
  24.             y_2 = mnist.test.labels[0:length]
  25.             print("first image:",y_conv_2[0])
  26.             correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y_2, 1))
  27.             accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))
  28.             print('%d:'%length,"check accuracy %g" % sess.run(accuracy_2))
复制代码
rknn_predict.py
  1. import numpy as np
  2. from PIL import Image
  3. from rknn.api import RKNN
  4. import cv2
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. import tensorflow as tf

  7. mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
  8. print(mnist.test.images[0].shape)
  9. # 解析模型的输出,获得概率最大的手势和对应的概率
  10. def get_predict(probability):
  11.     data = probability[0][0]
  12.     data = data.tolist()
  13.     max_prob = max(data)
  14.     return data.index(max_prob), max_prob
  15. # return data.index(max_prob), max_prob;
  16. def load_model(model_name):
  17.     # 创建RKNN对象
  18.     rknn = RKNN()
  19.     # 载入RKNN模型
  20.     print('-->loading model')
  21.     rknn.load_rknn(model_name)
  22.     print('loading model done')
  23.     # 初始化RKNN运行环境
  24.     print('--> Init runtime environment')
  25.     ret = rknn.init_runtime()
  26.     if ret != 0:
  27.        print('Init runtime environment failed')
  28.        exit(ret)
  29.     print('done')
  30.     return rknn
  31. def predict(rknn,length):
  32.     acc_count = 0
  33.     for i in range(length):
  34.         # im = mnist.test.images[i]
  35.         im = Image.open("../MNIST_data/mnist_test/test_%d.jpg"%i)   # 加载图片
  36.         im = im.resize((28,28),Image.ANTIALIAS)
  37.         im = np.asarray(im)
  38.         outputs = rknn.inference(inputs=[im])
  39.         pred, prob = get_predict(outputs)
  40.         if i ==0:
  41.             print(outputs)
  42.             print(prob)
  43.             print(pred)
  44.         if i ==100 or i ==500 or i ==1000 or i ==10000:
  45.             result = float(acc_count)/i
  46.             print('result%d:'%i,result)
  47.         if list(mnist.test.labels[i]).index(1) == pred:
  48.             acc_count += 1
  49.     result = float(acc_count)/length
  50.     print('result:',result)
  51.     # acc_count = 0
  52.     # length = len(mnist.test.images)
  53.     # for i in range(length):
  54.         # im = mnist.test.images[i]# 加载图片
  55.         # outputs = rknn.inference(inputs=[im])   # 运行推理,得到推理结果
  56.         # pred, prob = get_predict(outputs)     # 将推理结果转化为可视信息
  57.         # if i%100 == 0:
  58.             # print(prob)
  59.             # print(pred)
  60.             # print(acc_count)
  61.             # print(list(mnist.test.labels[i]).index(1))
  62.         # if list(mnist.test.labels[i]).index(1) == pred:
  63.             # acc_count += 1
  64.     # result = float(acc_count)/length
  65.     # print('result:',result)
  66. if __name__=="__main__":
  67.     #此处要改成相应的量化或者非量化rknn模型
  68.     model_name = './mnist.rknn'
  69.     length = 10000
  70.     rknn = load_model(model_name)
  71.     predict(rknn,length)

  72.     rknn.release()
复制代码
得到最终的结果对比图表如下:


本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
回复

使用道具 举报

peng

注册会员

积分
169
沙发
 楼主| 发表于 2019-9-20 08:43:10 | 显示全部楼层
本帖最后由 peng 于 2019-9-27 14:50 编辑

《量化训练》关键的两个函数为:tf.contrib.quantize.create_training_graph,tf.contrib.quantize.create_eval_graph(input_graph=g),在上一节中代码中已经提供了,具体可看以及几个步骤:第一步,修改train.py,将create_training_graph()的is_quantify设为True.然后重新训练模型。
  1. if __name__ == "__main__":
  2. ckpt = './checkpoint_fake/mnist.ckpt'
  3. point_dir = './checkpoint_fake'
  4. pbtxt = 'mnist_fake.pbtxt'
  5. g1 = create_training_graph(True)
  6. train_network(g1,ckpt,point_dir,pbtxt)
复制代码
第二步,修改freese.py,将frozen()的is_quantify设为True.重新固化成pb文件,
  1. if __name__ == "__main__":
  2. ckpt = './checkpoint_fake/mnist.ckpt'
  3. pbtxt = 'mnist_frozen_graph_fake.pb'
  4. frozen(True,ckpt,pbtxt)
复制代码

第三步,使用toco工具将pb模型转化为全量化的tflite模型(输入,输出和权值都变成int8类型)

  1. #!/bin/sh

  2. toco \
  3. --graph_def_file=mnist_frozen_graph_fake.pb \
  4. --output_file=mnist_fakequantize.tflite \
  5. --output_format=TFLITE \
  6. --inference_type=QUANTIZED_UINT8 \
  7. --input_shapes=1,28,28 \
  8. --input_arrays=x \
  9. --output_arrays=y_conv \
  10. --mean_values=0 \
  11. --std_dev_values=256 \
  12. --change_concat_input_ranges=false --allow_custom_ops
复制代码


第四步,将tflite模型转为rknn,修改rknn_transfer.py:

  1. def tflite_transfer():
  2. rknn = RKNN()
  3. print('--> Loading model')
  4. ret = rknn.load_tflite(model = './mnist_fakequantize.tflite')
  5. print('done')
  6. print('--> Building model')
  7. rknn.build(do_quantization=False)
  8. print('done')
  9. # 导出保存rknn模型文件
  10. rknn.export_rknn('./mnist_quantization_fake.rknn')
  11. # Release RKNN Context
  12. rknn.release()
复制代码
  1. if __name__ == '__main__':
  2. tflite_transfer()
复制代码
第五步,分别运行tflite_predict.py,和rknn_predict.py,对比tflite和rknn推理的结果和准确度:
tflite_predict.py
  1. mport tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. from PIL import Image

  4. mnist = input_data.read_data_sets("../../MNIST_data/", one_hot=True)
  5. length = 100
  6. # 加载模型并分配张量
  7. interpreter = tf.lite.Interpreter(model_path="mnist_fakequantize.tflite")
  8. interpreter.allocate_tensors()
  9. # 获取输入输出张量
  10. input_details = interpreter.get_input_details()
  11. print(input_details)
  12. output_details = interpreter.get_output_details()
  13. print(output_details)



  14. acc_count = 0
  15. for i in range(length):
  16. #im = mnist.test.images[i]
  17. im = Image.open("../../MNIST_data/mnist_test/test_%d.jpg"%i) # 加载图片
  18. im = im.resize((28,28),Image.ANTIALIAS)
  19. im = np.asarray(im)
  20. # print(im.dtype)
  21. input_shape = input_details[0]['shape']
  22. input_data = im.reshape(1,28,28)
  23. # print(input_data.dtype)
  24. interpreter.set_tensor(input_details[0]['index'], input_data)

  25. interpreter.invoke()
  26. output_data = interpreter.get_tensor(output_details[0]['index'])
  27. # print(output_data)
  28. # print(output_data.argmax(axis=1)[0])
  29. if i <3:
  30. print(output_data)

  31. if list(mnist.test.labels[i]).index(1) == output_data.argmax(axis=1)[0]:
  32. acc_count += 1
  33. result = float(acc_count)/length
  34. print('result:',result)
复制代码
rknn_predict.py
  1. import numpy as np
  2. from PIL import Image
  3. from rknn.api import RKNN
  4. import cv2
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. import tensorflow as tf

  7. mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)
  8. print(mnist.test.images[0].shape)
  9. # 解析模型的输出,获得概率最大的手势和对应的概率
  10. def get_predict(probability):
  11. data = probability[0][0]
  12. data = data.tolist()
  13. max_prob = max(data)
  14. return data.index(max_prob), max_prob
  15. # return data.index(max_prob), max_prob;
  16. def load_model(model_name):
  17. # 创建RKNN对象
  18. rknn = RKNN()
  19. # 载入RKNN模型
  20. print('-->loading model')
  21. rknn.load_rknn(model_name)
  22. print('loading model done')
  23. # 初始化RKNN运行环境
  24. print('--> Init runtime environment')
  25. ret = rknn.init_runtime()
  26. if ret != 0:
  27. print('Init runtime environment failed')
  28. exit(ret)
  29. print('done')
  30. return rknn
  31. def predict(rknn,length):
  32. acc_count = 0
  33. for i in range(length):
  34. im = mnist.test.images[i]
  35. # im = Image.open("../MNIST_data/mnist_test/test_%d.jpg"%i) # 加载图片
  36. # im = im.resize((28,28),Image.ANTIALIAS)
  37. # im = np.asarray(im)
  38. im = im.reshape(1,28,28)
  39. outputs = rknn.inference(inputs=[im])
  40. pred, prob = get_predict(outputs)
  41. if i ==0:
  42. print(outputs)
  43. print(prob)
  44. print(pred)
  45. if i ==100 or i ==500 or i ==1000 or i ==10000:
  46. result = float(acc_count)/i
  47. print('result%d:'%i,result)
  48. if list(mnist.test.labels[i]).index(1) == pred:
  49. acc_count += 1
  50. result = float(acc_count)/length
  51. print('result:',result)
  52. if __name__=="__main__":
  53. #此处要改成相应的量化或者非量化rknn模型
  54. model_name = './mnist_quantization_fake.rknn'
  55. length = 10000
  56. rknn = load_model(model_name)
  57. predict(rknn,length)

  58. rknn.release()
复制代码
最终运行结果如下表所示:

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表