程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

关于opencv中 tf_text_graph_ssd.py生成pbtxt报错的解决方案

发布于2019-08-19 10:26     阅读(1155)     评论(0)     点赞(0)     收藏(4)


opencv中的dnn模块很强大,可以加载现在主流的深度学习模型,darknet、caffe、tensorflow等,我本人最近在使用tensorflow进行目标检测,用的是ssdlite-mobilenet-v2。tensorflow的object_detection api使用起来很方便,而且大佬研究出来的模型也都没毛病,效果很好,换换数据集,调调参数就ok。训练过程不再本篇博客的讨论范围内,假定现在已经训练完毕,并且生成了pb文件。

一般我们恢复模型都是要用到tensorflow,大致代码如下所示:

  1. import numpy as np
  2. import tensorflow as tf
  3. import cv2 as cv
  4. # Read the graph.
  5. with tf.gfile.FastGFile('frozen_inference_graph.pb', 'rb') as f:
  6. graph_def = tf.GraphDef()
  7. graph_def.ParseFromString(f.read())
  8. with tf.Session() as sess:
  9. # Restore session
  10. sess.graph.as_default()
  11. tf.import_graph_def(graph_def, name='')
  12. # Read and preprocess an image.
  13. img = cv.imread('example.jpg')
  14. rows = img.shape[0]
  15. cols = img.shape[1]
  16. inp = cv.resize(img, (300, 300))
  17. inp = inp[:, :, [2, 1, 0]] # BGR2RGB
  18. # Run the model
  19. out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
  20. sess.graph.get_tensor_by_name('detection_scores:0'),
  21. sess.graph.get_tensor_by_name('detection_boxes:0'),
  22. sess.graph.get_tensor_by_name('detection_classes:0')],
  23. feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
  24. # Visualize detected bounding boxes.
  25. num_detections = int(out[0][0])
  26. for i in range(num_detections):
  27. classId = int(out[3][0][i])
  28. score = float(out[1][0][i])
  29. bbox = [float(v) for v in out[2][0][i]]
  30. if score > 0.3:
  31. x = bbox[1] * cols
  32. y = bbox[0] * rows
  33. right = bbox[3] * cols
  34. bottom = bbox[2] * rows
  35. cv.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2)
  36. cv.imshow('TensorFlow MobileNet-SSD', img)
  37. cv.waitKey()

如果电脑没有tensorflow想要恢复模型怎么办?那就可以考虑下opencv了,opencv加载tensorflow模型需要pb文件和pbtxt文件,pbtxt是可以根据pb文件生成的,在opencv的源代码中进入sample/dnn 文件夹中,由于我们的模型是ssd因此找到以下两个文件

tf_text_graph_ssd.py

tf_text_graph_common.py

 

在命令行执行

python tf_text_graph_ssd.py --input ssdlite.pb --config pipeline.config --output ssdlite.pbtxt

ssdlite.pb是自己训练的pb模型,pipeline是训练时的配置文件,生成pb时会有,ssdlite.pbtxt就是生成的文件用于opencv恢复模型

但是我本人在这里运行时报错了,如下图所示:

在github中有人提到训练完的模型需要转换一下

  1. import tensorflow as tf
  2. from tensorflow.tools.graph_transforms import TransformGraph
  3. with tf.gfile.FastGFile('ssdlite.pb', 'rb') as f:
  4. graph_def = tf.GraphDef()
  5. graph_def.ParseFromString(f.read())
  6. graph_def = TransformGraph(graph_def, ['image_tensor'], ['detection_boxes', 'detection_classes', 'detection_scores', 'num_detections'], ['sort_by_execution_order'])
  7. with tf.gfile.FastGFile('ssdlite_new.pb', 'wb') as f:
  8. f.write(graph_def.SerializeToString())#保存新的模型

转化完pb模型之后再运行上面的代码,成功生成pbtxt文件

然后用opencv恢复模型

net = cv2.dnn.readNetFromTensorflow(weightsPath,configPath)

大致代码如下:

  1. import cv2 as cv
  2. cvNet = cv.dnn.readNetFromTensorflow('frozen_inference_graph.pb', 'graph.pbtxt')
  3. img = cv.imread('test.jpg')
  4. rows = img.shape[0]
  5. cols = img.shape[1]
  6. cvNet.setInput(cv.dnn.blobFromImage(img, size=(300, 300), swapRB=True, crop=False))
  7. cvOut = cvNet.forward()
  8. for detection in cvOut[0,0,:,:]:
  9. score = float(detection[2])
  10. if score > 0.3:
  11. left = detection[3] * cols
  12. top = detection[4] * rows
  13. right = detection[5] * cols
  14. bottom = detection[6] * rows
  15. cv.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (23, 230, 210), thickness=2)
  16. cv.imshow('img', img)
  17. cv.waitKey()

在此需要感谢dkurt大神的转换代码,顺便提一下ssdlite的速度是真的快,pc端cpu下大约60ms一帧,如果需要移植到arm平台上,可借助腾讯的ncnn,在移动端加速效果也是很棒的。

参考的一些资料:

https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API

https://github.com/opencv/opencv/issues/11560



所属网站分类: 技术文章 > 博客

作者:爸爸去挣钱我去幼儿园

链接:https://www.pythonheidong.com/blog/article/48827/6e671ad30b2f7d55fc2f/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

0 0
收藏该文
已收藏

评论内容:(最多支持255个字符)