程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

暂无数据

Pytorch 版YOLOV3训练自己的数据集

发布于2019-08-15 11:59     阅读(1101)     评论(0)     点赞(5)     收藏(1)


版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_34795071/article/details/90769094

   数据是深度学习的灵魂所在,默认你已经准备好数据,哈哈

1、环境搭建

git clone https://github.com/ultralytics/yolov3.git

2、安装需要的软件

pip install -r requirements.txt

3、准备数据

 

在data文件下建立上面三个文件(Annotations、images与ImageSets,labels后续我们脚本生成)其中Annotations存放xml文件,images图像,ImageSets新建Main文件存放train与test文件(脚本生成),labeles是标签文件

划分训练集与测试集(为了充分利用数据集我们只划分这两个),生成的在ImageSets / Main文件下

  1. import os
  2. import random
  3. trainval_percent = 0.2 #可自行进行调节
  4. train_percent = 1
  5. xmlfilepath = 'Annotations'
  6. txtsavepath = 'ImageSets\Main'
  7. total_xml = os.listdir(xmlfilepath)
  8. num = len(total_xml)
  9. list = range(num)
  10. tv = int(num * trainval_percent)
  11. tr = int(tv * train_percent)
  12. trainval = random.sample(list, tv)
  13. train = random.sample(trainval, tr)
  14. #ftrainval = open('ImageSets/Main/trainval.txt', 'w')
  15. ftest = open('ImageSets/Main/test.txt', 'w')
  16. ftrain = open('ImageSets/Main/train.txt', 'w')
  17. #fval = open('ImageSets/Main/val.txt', 'w')
  18. for i in list:
  19. name = total_xml[i][:-4] + '\n'
  20. if i in trainval:
  21. #ftrainval.write(name)
  22. if i in train:
  23. ftest.write(name)
  24. #else:
  25. #fval.write(name)
  26. else:
  27. ftrain.write(name)
  28. #ftrainval.close()
  29. ftrain.close()
  30. #fval.close()
  31. ftest.close()

建立voc_labels文件生成labels标签文件

  1. import xml.etree.ElementTree as ET
  2. import pickle
  3. import os
  4. from os import listdir, getcwd
  5. from os.path import join
  6. sets = ['train', 'test']
  7. classes = ['apple','orange'] #自己训练的类别
  8. def convert(size, box):
  9. dw = 1. / size[0]
  10. dh = 1. / size[1]
  11. x = (box[0] + box[1]) / 2.0
  12. y = (box[2] + box[3]) / 2.0
  13. w = box[1] - box[0]
  14. h = box[3] - box[2]
  15. x = x * dw
  16. w = w * dw
  17. y = y * dh
  18. h = h * dh
  19. return (x, y, w, h)
  20. def convert_annotation(image_id):
  21. in_file = open('data/Annotations/%s.xml' % (image_id))
  22. out_file = open('data/labels/%s.txt' % (image_id), 'w')
  23. tree = ET.parse(in_file)
  24. root = tree.getroot()
  25. size = root.find('size')
  26. w = int(size.find('width').text)
  27. h = int(size.find('height').text)
  28. for obj in root.iter('object'):
  29. difficult = obj.find('difficult').text
  30. cls = obj.find('name').text
  31. if cls not in classes or int(difficult) == 1:
  32. continue
  33. cls_id = classes.index(cls)
  34. xmlbox = obj.find('bndbox')
  35. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
  36. float(xmlbox.find('ymax').text))
  37. bb = convert((w, h), b)
  38. out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
  39. wd = getcwd()
  40. for image_set in sets:
  41. if not os.path.exists('data/labels/'):
  42. os.makedirs('data/labels/')
  43. image_ids = open('data/ImageSets/Main/%s.txt' % (image_set)).read().strip().split()
  44. list_file = open('data/%s.txt' % (image_set), 'w')
  45. for image_id in image_ids:
  46. list_file.write('data/images/%s.jpg\n' % (image_id))
  47. convert_annotation(image_id)
  48. list_file.close()

四、配置训练文件

在data目录下新建fruit.data,配置训练的数据,fruit.name预测的类别名字

  1. classes=2
  2. train=data/train.txt
  3. valid=data/test.txt
  4. names=data/fruit.names

   具体cfg参数的意义可以参考我的博客

  1. [net]
  2. # Testing
  3. #batch=1
  4. #subdivisions=1
  5. # Training
  6. batch=64
  7. subdivisions=16
  8. width=416
  9. height=416
  10. channels=3
  11. momentum=0.9
  12. decay=0.0005
  13. angle=0
  14. saturation = 1.5
  15. exposure = 1.5
  16. hue=.1
  17. learning_rate=0.001
  18. burn_in=1000
  19. max_batches = 50000
  20. policy=steps
  21. steps=4000,45000
  22. scales=.1,.1
  23. [convolutional]
  24. batch_normalize=1
  25. filters=32
  26. size=3
  27. stride=1
  28. pad=1
  29. activation=leaky
  30. # Downsample
  31. [convolutional]
  32. batch_normalize=1
  33. filters=64
  34. size=3
  35. stride=2
  36. pad=1
  37. activation=leaky
  38. [convolutional]
  39. batch_normalize=1
  40. filters=32
  41. size=1
  42. stride=1
  43. pad=1
  44. activation=leaky
  45. 。。。。。
  46. [convolutional]
  47. size=1
  48. stride=1
  49. pad=1
  50. filters=21 #3*(类别数+4+1)
  51. activation=linear
  52. [yolo]
  53. mask = 6,7,8
  54. anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
  55. classes=2 #类别数
  56. num=9
  57. jitter=.3
  58. ignore_thresh = .7
  59. truth_thresh = 1
  60. random=1
  61. [route]
  62. layers = -4
  63. [convolutional]
  64. batch_normalize=1
  65. filters=256
  66. size=1
  67. stride=1
  68. pad=1
  69. activation=leaky
  70. [upsample]
  71. stride=2
  72. [route]
  73. layers = -1, 61
  74. [convolutional]
  75. batch_normalize=1
  76. filters=256
  77. size=1
  78. stride=1
  79. pad=1
  80. activation=leaky
  81. [convolutional]
  82. batch_normalize=1
  83. size=3
  84. stride=1
  85. pad=1
  86. filters=512
  87. activation=leaky
  88. [convolutional]
  89. batch_normalize=1
  90. filters=256
  91. size=1
  92. stride=1
  93. pad=1
  94. activation=leaky
  95. [convolutional]
  96. batch_normalize=1
  97. size=3
  98. stride=1
  99. pad=1
  100. filters=512
  101. activation=leaky
  102. [convolutional]
  103. batch_normalize=1
  104. filters=256
  105. size=1
  106. stride=1
  107. pad=1
  108. activation=leaky
  109. [convolutional]
  110. batch_normalize=1
  111. size=3
  112. stride=1
  113. pad=1
  114. filters=512
  115. activation=leaky
  116. [convolutional]
  117. size=1
  118. stride=1
  119. pad=1
  120. filters=21 #3*(类别数+4+1)
  121. activation=linear
  122. [yolo]
  123. mask = 3,4,5
  124. anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
  125. classes=2 #类别数
  126. num=9
  127. jitter=.3
  128. ignore_thresh = .7
  129. truth_thresh = 1
  130. random=1
  131. [route]
  132. layers = -4
  133. [convolutional]
  134. batch_normalize=1
  135. filters=128
  136. size=1
  137. stride=1
  138. pad=1
  139. activation=leaky
  140. [upsample]
  141. stride=2
  142. [route]
  143. layers = -1, 36
  144. [convolutional]
  145. batch_normalize=1
  146. filters=128
  147. size=1
  148. stride=1
  149. pad=1
  150. activation=leaky
  151. [convolutional]
  152. batch_normalize=1
  153. size=3
  154. stride=1
  155. pad=1
  156. filters=256
  157. activation=leaky
  158. [convolutional]
  159. batch_normalize=1
  160. filters=128
  161. size=1
  162. stride=1
  163. pad=1
  164. activation=leaky
  165. [convolutional]
  166. batch_normalize=1
  167. size=3
  168. stride=1
  169. pad=1
  170. filters=256
  171. activation=leaky
  172. [convolutional]
  173. batch_normalize=1
  174. filters=128
  175. size=1
  176. stride=1
  177. pad=1
  178. activation=leaky
  179. [convolutional]
  180. batch_normalize=1
  181. size=3
  182. stride=1
  183. pad=1
  184. filters=256
  185. activation=leaky
  186. [convolutional]
  187. size=1
  188. stride=1
  189. pad=1
  190. filters=21 # 3*(类别数+4+1)
  191. activation=linear
  192. [yolo]
  193. mask = 0,1,2
  194. anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
  195. classes=2 #自己的类别
  196. num=9
  197. jitter=.3
  198. ignore_thresh = .7
  199. truth_thresh = 1
  200. random=1

五、训练

python train.py --data data/fruit.data --cfg cfg/yolov3.cfg --epochs 10 #后面的epochs自行更改

 

六、测试

python detect.py --data-cfg data/fruit.data --cfg cfg/yolov3.cfg --weights weights/best.pt

7、评估模型

python test.py  --data data/fruit.data --cfg cfg/yolov3.cfg  --weights weights/latest.pt

8、可视化图像

python -c "from utils import utils; utils.plot_results()"

emmm.... 接下来.... 没有了...官方指南,值得拥有.... 溜了溜了

                                    

 

未完待续。。。



所属网站分类: 技术文章 > 博客

作者:378273283782232

链接:https://www.pythonheidong.com/blog/article/36016/501d003e23c9b9d2f325/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

5 0
收藏该文
已收藏

评论内容:(最多支持255个字符)