Toybrick

调用librknn_api接口移植pytorch模型

Leonod

新手上路

积分
19
楼主
发表于 2020-12-3 15:37:49    查看: 8026|回复: 0 | [复制链接]    打印 | 只看该作者
楼主调用以下脚本:______________________________

import numpy as np
import cv2
from rknn.api import RKNN
import torchvision.models as models
import torch


def export_pytorch_model():
    net = models.resnet18(pretrained=True)
    net.eval()
    trace_model = torch.jit.trace(net, torch.Tensor(1,3,224,224))
    trace_model.save('./resnet18.pt')


def show_outputs(output):
    output_sorted = sorted(output, reverse=True)
    top5_str = '\n-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)


def show_perfs(perfs):
    perfs = 'perfs: {}\n'.format(perfs)
    print(perfs)


def softmax(x):
    return np.exp(x)/sum(np.exp(x))


if __name__ == '__main__':

    export_pytorch_model()
    model = './resnet18.pt'
    input_size_list = [[3,224,224]]
    rknn = RKNN()
    print('--> config model')
    rknn.config(channel_mean_value='123.675 116.28 103.53 58.395', reorder_channel='0 1 2')
    print('done')
    print('--> Loading model')
    ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
    if ret != 0:
        print('Load pytorch model failed!')
        exit(ret)
    print('done')
    print('--> Building model')
    ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
    if ret != 0:
        print('Build pytorch failed!')
        exit(ret)
    print('done')
    print('--> Export RKNN model')
    ret = rknn.export_rknn('./resnet_18.rknn')
    if ret != 0:
        print('Export resnet_18.rknn failed!')
        exit(ret)
    print('done')

    ret = rknn.load_rknn('./resnet_18.rknn')

    # Set inputs
    img = cv2.imread('./space_shuttle_224.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # init runtime environment
    print('--> Init runtime environment')
    ret = rknn.init_runtime()
    if ret != 0:
        print('Init runtime environment failed')
        exit(ret)
    print('done')
    print('--> Running model')
    outputs = rknn.inference(inputs=[img])
    show_outputs(softmax(np.array(outputs[0][0])))
    print('done')
    rknn.release()


————————————————————————————————————————————
可以将pytorch的模型转换为rknn的模型,并且能够输出推断结果。

现在楼主想在调用rknn提供的c语言接口(库为librknn_api.so),实现加载模型(pytorch模型转rknn模型的模型),推断输出目标的类型,坐标位置,置信度信息。请问有相关的示例代码吗?
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

产品中心 购买渠道 开源社区 Wiki教程 资料下载 关于Toybrick


快速回复 返回顶部 返回列表