程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

【目标检测】训练技巧:加速,改成多线程转VOC矩形框和生成tfrecord

发布于2019-11-07 10:50     阅读(1403)     评论(0)     点赞(4)     收藏(4)


改生成VOC2007矩形框:

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 19 14:51:00 2019

@author: Andrea
"""

import os
import numpy as np
import codecs
import json
from glob import glob
import cv2
import shutil
from sklearn.model_selection import train_test_split
import threading

#1.标签路径
labelme_path = "I:\\biaozhutuxiang\\fangdichan1106-banannanan"              #原始labelme标注数据路径
saved_path = "I:\\biaozhutuxiang\\VOC2007-fangdichan1106-banannanan\\"                #保存路径

#2.创建要求文件夹
if not os.path.exists(saved_path + "Annotations"):
    os.makedirs(saved_path + "Annotations")
if not os.path.exists(saved_path + "JPEGImages/"):
    os.makedirs(saved_path + "JPEGImages/")
if not os.path.exists(saved_path + "ImageSets/Main/"):
    os.makedirs(saved_path + "ImageSets/Main/")
    
    
"""重新定义带返回值的线程类----民国档案------"""
class LoadThread(threading.Thread):
#class LoadThread_rep:
    def __init__(self, json_file_):
        super(LoadThread, self).__init__()
        self.json_file_ = json_file_
    def run(self):
        if('.json' not in self.json_file_):
            return self.json_file_ 
        else:
            json_file_ = self.json_file_.split('.json')[0]
        print(json_file_)
        json_filename = os.path.join(labelme_path , json_file_ + ".json")
        print(json_filename)
        json_file = json.load(open(json_filename,"r",encoding="utf-8"))
        print(os.path.join(labelme_path , json_file_ +".jpg"))
        height, width, channels = cv2.imread(os.path.join(labelme_path , json_file_ +".jpg")).shape
        with codecs.open(saved_path + "Annotations/"+json_file_ + ".xml","w","utf-8") as xml:
            xml.write('<annotation>\n')
            xml.write('\t<folder>' + 'UAV_data' + '</folder>\n')
            xml.write('\t<filename>' + json_file_ + ".jpg" + '</filename>\n')
            xml.write('\t<source>\n')
            xml.write('\t\t<database>The UAV autolanding</database>\n')
            xml.write('\t\t<annotation>UAV AutoLanding</annotation>\n')
            xml.write('\t\t<image>flickr</image>\n')
            xml.write('\t\t<flickrid>NULL</flickrid>\n')
            xml.write('\t</source>\n')
            xml.write('\t<owner>\n')
            xml.write('\t\t<flickrid>NULL</flickrid>\n')
            xml.write('\t\t<name>Yuanyiqin</name>\n')
            xml.write('\t</owner>\n')
            xml.write('\t<size>\n')
            xml.write('\t\t<width>'+ str(width) + '</width>\n')
            xml.write('\t\t<height>'+ str(height) + '</height>\n')
            xml.write('\t\t<depth>' + str(channels) + '</depth>\n')
            xml.write('\t</size>\n')
            xml.write('\t\t<segmented>0</segmented>\n')
            for multi in json_file["shapes"]:
                points = np.array(multi["points"])
                xmin = min(points[:,0])
                xmax = max(points[:,0])
                ymin = min(points[:,1])
                ymax = max(points[:,1])
                label = multi["label"]
                if xmax <= xmin:
                    pass
                elif ymax <= ymin:
                    pass
                else:
                    xml.write('\t<object>\n')
                    xml.write('\t\t<name>'+str(label)+'</name>\n')
                    xml.write('\t\t<pose>Unspecified</pose>\n')
                    xml.write('\t\t<truncated>1</truncated>\n')
                    xml.write('\t\t<difficult>0</difficult>\n')
                    xml.write('\t\t<bndbox>\n')
                    xml.write('\t\t\t<xmin>' + str(xmin) + '</xmin>\n')
                    xml.write('\t\t\t<ymin>' + str(ymin) + '</ymin>\n')
                    xml.write('\t\t\t<xmax>' + str(xmax) + '</xmax>\n')
                    xml.write('\t\t\t<ymax>' + str(ymax) + '</ymax>\n')
                    xml.write('\t\t</bndbox>\n')
                    xml.write('\t</object>\n')
                    print(json_filename,xmin,ymin,xmax,ymax,label)
            xml.write('</annotation>')
            self.json_file_ 
    def get_result(self):
        return self.json_file_
    
##3.获取待处理文件
#files = glob(labelme_path + "*.json")
#print(files)
#files = [i.split("/")[-1].split(".json")[0] for i in files]

#4.读取标注信息并写入 xml
threadnum = 64
if __name__ == '__main__':
#        for json_file_ in os.listdir(labelme_path):
    img_list = os.listdir(labelme_path)
    img_length = len(img_list)
#                threadnum = 4
    for i in range(0,int(img_length/threadnum)+1):
#                for i in range(int(img_length/threadnum)+1):
        print('i,int(img_length/threadnum):',i,int(img_length/threadnum))
        li = []
        for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):
#                    for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):                 
            json_file_ = img_list[j]    
            print('json_file_:',json_file_)
            thread = LoadThread(json_file_)
            li.append(thread)
            thread.start()
        for thread in li:
            thread.join()  # 一定要join,不然主线程比子线程跑的快,会拿不到结果
        json_file_ = thread.get_result()  
        print('Down json_file_:',json_file_)

          
    #5.复制图片到 VOC2007/JPEGImages/下
    image_files = glob(labelme_path + "*.jpg")
    print("copy image files to VOC007/JPEGImages/")
    for image in image_files:
        shutil.copy(image,saved_path +"JPEGImages/")
        
    #6.split files for txt
    txtsavepath = saved_path + "ImageSets/Main/"
    ftrainval = open(txtsavepath+'/trainval.txt', 'w')
    ftest = open(txtsavepath+'/test.txt', 'w')
    ftrain = open(txtsavepath+'/train.txt', 'w')
    fval = open(txtsavepath+'/val.txt', 'w')
    total_files = glob("./VOC2007/Annotations/*.xml")
    total_files = [i.split("/")[-1].split(".xml")[0] for i in total_files]
    #test_filepath = ""
    for file in total_files:
        ftrainval.write(file + "\n")
    #test
    #for file in os.listdir(test_filepath):
    #    ftest.write(file.split(".jpg")[0] + "\n")
    #split
    train_files,val_files = train_test_split(total_files,test_size=0.15,random_state=42)
    #train
    for file in train_files:
        ftrain.write(file + "\n")
    #val
    for file in val_files:
        fval.write(file + "\n")
    
    ftrainval.close()
    ftrain.close()
    fval.close()
    #ftest.close()

改成多线程生成tfrecord:

# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import sys
sys.path.append('../../')
import xml.etree.cElementTree as ET
import numpy as np
import tensorflow as tf
import math
import glob
import cv2
from libs.label_name_dict.label_dict import *
from help_utils.tools import *
import threading
import random

tf.app.flags.DEFINE_string('VOC_dir', '/home/yuanyq/Detect_DL/FPN_Tensorflow/data/io/VOC2007/', 'Voc dir')
tf.app.flags.DEFINE_string('xml_dir', 'Annotations', 'xml dir')
tf.app.flags.DEFINE_string('image_dir', 'JPEGImages', 'image dir')
tf.app.flags.DEFINE_string('save_name', 'train', 'save name')
tf.app.flags.DEFINE_string('save_dir', '../tfrecord/', 'save name')
tf.app.flags.DEFINE_string('img_format', '.jpg', 'format of image')
tf.app.flags.DEFINE_string('dataset', 'pascal', 'dataset')
FLAGS = tf.app.flags.FLAGS

threadnum = 128
global count
count = 0
class LoadThread(threading.Thread):
    def __init__(self,xml,image_path,xml_path,writer):
        super(LoadThread,self).__init__()
        self.xml = xml
        self.image_path = image_path
        self.xml_path = xml_path
        self.writer = writer
    def run(self):
     # to avoid path error in different development platform
        xml = self.xml.replace('\\', '/')
        img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format
        img_path = self.image_path + '/' + img_name
        print('xml:',xml)
        if not os.path.exists(img_path):
            print('{} is not exist!'.format(img_path))
            #return self.xml

        img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)

        # img = np.array(Image.open(img_path))
        img = cv2.imread(img_path)[:, :, ::-1]

        feature = tf.train.Features(feature={
            # do not need encode() in linux
            'img_name': _bytes_feature(img_name.encode()),
            # 'img_name': _bytes_feature(img_name),
            'img_height': _int64_feature(img_height),
            'img_width': _int64_feature(img_width),
            'img': _bytes_feature(img.tostring()),
            'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
            'num_objects': _int64_feature(gtbox_label.shape[0])
        })

        example = tf.train.Example(features=feature)

        self.writer.write(example.SerializeToString())

        #view_bar('Conversion progress', count + 1, len(glob.glob(self.xml_path + '/*.xml')))
        return self.xml
    def get_result(self):
        print(self.xml)
        return self.xml

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def read_xml_gtbox_and_label(xml_path):
    """
    :param xml_path: the path of voc xml
    :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
           and has [xmin, ymin, xmax, ymax, label] in a per row
    """

    tree = ET.parse(xml_path)
    root = tree.getroot()
    img_width = None
    img_height = None
    box_list = []
    for child_of_root in root:
        # if child_of_root.tag == 'filename':
        #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
        #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'

        if child_of_root.tag == 'size':
            for child_item in child_of_root:
                if child_item.tag == 'width':
                    img_width = int(child_item.text)
                if child_item.tag == 'height':
                    img_height = int(child_item.text)

        if child_of_root.tag == 'object':
            label = None
            for child_item in child_of_root:
#                print('child_item.tag:',child_item.tag)
#                print('child_item.text:',child_item.text)
#                print('NAME_LABEL_MAP:',NAME_LABEL_MAP)
                if child_item.tag == 'name':
                    if(child_item.text == '0002X'):
                        child_item.text = '0002'
                    if(child_item.text == 'X0002'):
                        child_item.text = '0002'
                    if(child_item.text =='000Z1'):
                        child_item.text = '0001'
                    if(child_item.text =='A0001'):
                        child_item.text = '0001'
                    if(child_item.text =='c0002'):
                        child_item.text = '0002'
                    if(child_item.text !='0001' and child_item.text !='0002' and child_item.text !='0003'):
                        label = 1
                    else:
                        label = NAME_LABEL_MAP[child_item.text]
                if child_item.tag == 'bndbox':
                    tmp_box = []
                    for node in child_item:
                        tmp_box.append(math.ceil(float(node.text)))
                    assert label is not None, 'label is none, error'
                    tmp_box.append(label)
                    box_list.append(tmp_box)

    gtbox_label = np.array(box_list, dtype=np.int32)

    return img_height, img_width, gtbox_label
def convert_pascal_to_tfrecord():
    xml_path = FLAGS.VOC_dir + FLAGS.xml_dir
    image_path = FLAGS.VOC_dir + FLAGS.image_dir
    save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
    mkdir(FLAGS.save_dir)
   # print('xml_path:',xml_path)
   # print('save_path:',save_path)
   # print('image_path:',image_path)
    # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
    writer = tf.python_io.TFRecordWriter(path=save_path)
    img_list = os.listdir(xml_path)
    random.shuffle(img_list)
    img_length = len(img_list)
    for i in range(0,int(img_length/threadnum)+1):
         li = []
         for j in range(i*threadnum,min(i*threadnum+threadnum,img_length)):
             xml = os.path.join(xml_path,img_list[j])
             thread = LoadThread(xml,image_path,xml_path,writer)
             thread.daemon = True
             li.append(thread)
             thread.start()
         for thread in li:
             thread.join()  # 一定要join,不然主线程比子线程跑的快,会拿不到结果
             xml = thread.get_result()
             print('img_name done:',xml)
        # to avoid path error in different development platform

    print('\nConversion is complete!')


if __name__ == '__main__':
    # xml_path = '../data/dataset/VOCdevkit/VOC2007/Annotations/000005.xml'
    # read_xml_gtbox_and_label(xml_path)

    convert_pascal_to_tfrecord()

 



所属网站分类: 技术文章 > 博客

作者:头疼不是病

链接:https://www.pythonheidong.com/blog/article/148231/41c8822bbc90c0e00669/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

4 0
收藏该文
已收藏

评论内容:(最多支持255个字符)