温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

pytorch模型怎么转onnx模型

发布时间:2022-08-30 14:11:48 来源:亿速云 阅读:186 作者:iii 栏目:开发技术

这篇文章主要介绍“pytorch模型怎么转onnx模型”,在日常操作中,相信很多人在pytorch模型怎么转onnx模型问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”pytorch模型怎么转onnx模型”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

    学习内容

    前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装

    1 . pytorch 转 onnx

    pytorch 转 onnx 只需要一个函数 torch.onnx.export

    torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

    参数说明:

    • model——需要导出的pytorch模型

    • args——模型的输入参数,满足输入层的shape正确即可。

    • path——输出的onnx模型的位置。例如‘yolov5.onnx’。

    • export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。

    • verbose——是否打印模型转换信息。default=False。

    • input_names——输入节点名称。default=None。

    • output_names——输出节点名称。default=None。

    • do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。

    • dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。
      格式如下 :
      1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
      2)仅dict<int, string> dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:{0:&lsquo;batch&rsquo;,1:&lsquo;c&rsquo;}}
      3)mixed dynamic_axes={&lsquo;input&rsquo;:{0:&lsquo;batch&rsquo;,2:&lsquo;height&rsquo;,3:&lsquo;width&rsquo;},&lsquo;output&rsquo;:[0,1]}

    • opset_version&mdash;&mdash;opset的版本,低版本不支持upsample等操作。

    import torch
    import torch.nn
    import onnx
    
    model = torch.load('best.pt')
    model.eval()
    
    input_names = ['input']
    output_names = ['output']
    
    x = torch.randn(1,3,32,32,requires_grad=True)
    
    torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

    2 . 运行onnx模型

    检查onnx模型,并使用onnxruntime运行。

    import onnx
    import onnxruntime as ort
    
    model = onnx.load('best.onnx')
    onnx.checker.check_model(model)
    
    session = ort.InferenceSession('best.onnx')
    x=np.random.randn(1,3,32,32).astype(np.float32)  # 注意输入type一定要np.float32!!!!!
    # x= torch.randn(batch_size,chancel,h,w)
    
    
    outputs = session.run(None,input = { 'input' : x })

    参数说明:

    • output_names: default=None
      用来指定输出哪些,以及顺序
      若为None,则按序输出所有的output,即返回[output_0,output_1]
      若为[&lsquo;output_1&rsquo;,&lsquo;output_0&rsquo;],则返回[output_1,output_0]
      若为[&lsquo;output_0&rsquo;],则仅返回[output_0:tensor]

    • input:dict
      可以通过session.get_inputs().name获得名称
      其中key值要求与torch.onnx.export中设定的一致

    3.onnx模型输出与pytorch模型比对

    import numpy as np
    np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)

    如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。

    此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。

    到此,关于“pytorch模型怎么转onnx模型”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注亿速云网站,小编会继续努力为大家带来更多实用的文章!

    向AI问一下细节

    免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

    AI