程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2023-05(1)

2023-06(3)

tvm学习笔记 (三):载入onnx格式模型

发布于2019-08-17 20:10     阅读(4502)     评论(0)     点赞(4)     收藏(4)


1、模型转换

  1. import onnx
  2. import numpy as np
  3. import tvm
  4. import tvm.relay as relay
  5. onnx_model = onnx.load('test.onnx')
  6. target = tvm.target.create('llvm')
  7. input_name = '0' # change '1' to '0'
  8. shape_dict = {input_name: (1, 3, 224, 224)}
  9. sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
  10. with relay.build_config(opt_level=2):
  11. graph, lib, params = relay.build_module.build(sym, target, params=params)
  12. dtype = 'float32'
  13. from tvm.contrib import graph_runtime
  14. print("Output model files")
  15. libpath = "./test.so"
  16. lib.export_library(libpath)
  17. graph_json_path = "./test.json"
  18. with open(graph_json_path, 'w') as fo:
  19. fo.write(graph)
  20. param_path = "./test.params"
  21. with open(param_path, 'wb') as fo:
  22. fo.write(relay.save_param_dict(params))

2、模型部署

  1. import numpy as np
  2. import tvm
  3. import tvm.relay as relay
  4. from tvm.contrib import graph_runtime
  5. import cv2 as cv
  6. test_json = 'test.json'
  7. test_lib = 'test.so'
  8. test_param = 'test.params'
  9. loaded_json = open(test_json).read()
  10. loaded_lib = tvm.module.load(test_lib)
  11. loaded_params = bytearray(open(test_param, "rb").read())
  12. def preprocess(img_src):
  13. img_src= cv.cvtColor(img_src, cv.COLOR_BGR2RGB)
  14. img_src= cv.resize(img_src, (224, 224))
  15. input_data = np.array(img_src).astype(np.float32)
  16. input_data = input_data / 255.0
  17. input_data = np.transpose(input_data, (2, 0, 1))
  18. input_data[0] = (input_data[0] - 0.485)/ 0.229
  19. input_data[1] = (input_data[1] - 0.456)/ 0.224
  20. input_data[2] = (input_data[2] - 0.406)/ 0.225
  21. input_data = input_data[np.newaxis, :].copy()
  22. return input_data
  23. img = cv.imread("29.jpg")
  24. img_input = preprocess(img)
  25. ctx = tvm.cpu(0)
  26. module = graph_runtime.create(loaded_json, loaded_lib, ctx)
  27. module.load_params(loaded_params)
  28. # run the module
  29. module.set_input("0", img_input)
  30. module.run()
  31. out_deploy = module.get_output(0).asnumpy()
  32. print(classes[np.argmax(out_deploy)])

3、遇到问题

转Mobilenet-SSD的onnx模型时遇到问题:

  1. /* an internal invariant was violated while typechecking your program [23:48:07]
  2. /home/mirror/workspace/tvm/src/relay/op/tensor/transform.cc:204: Check failed:
  3. e_dtype == dtype (int64 vs. int32) : relay.concatenate requires all tensors have
  4. the same dtype; */

在讨论区找到一个讨论帖子:

https://discuss.tvm.ai/t/relay-onnx-load-resnet-onnx-to-relay-failed/2411

尝试使用onnx-simplifier工具:

git 地址:

https://github.com/daquexian/onnx-simplifier.git

安装使用:

  1. >> pip3 install onnx-simplifier
  2. >> python3 -m onnxsim input_model output_model

然后再进行模型载入编译就搞定了,感谢大佬们提供的工具~

参考资料:

https://docs.tvm.ai/tutorials/frontend/from_onnx.html#sphx-glr-tutorials-frontend-from-onnx-py



所属网站分类: 技术文章 > 博客

作者:集天地之正气

链接:https://www.pythonheidong.com/blog/article/48407/31b71faefe0b0dfb022d/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

4 0
收藏该文
已收藏

评论内容:(最多支持255个字符)