+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2019-08(82)

2019-09(116)

2019-10(2)

Mask RCNN训练自己的数据-笔记

发布于2019-09-11 14:51     阅读(845)     评论(0)     点赞(6)     收藏(2)


Mask RCNN 训练自己的数据-笔记

最近在学习CV一些的知识,老师需要我们做一个物体分割和识别的小项目以做学习。在疯狂看完RCNN、Fast RCNN、Faster RCNN等文章后决定使用Mask RCNN作为主要的框架,也跟着很多的文章和官网做了demo和shape训练的内容,就准备开始使用自己的数据进行训练,有一些心得想要写一下。
具体的代码可以参考:https://github.com/Kingqibin/EWATT2
mask rcnn GitHub:https://github.com/matterport/Mask_RCNN
由于我也是入门阶段,所以有写的不准确的地方请见谅。

环境

  1. ubuntu 16.04
  2. tensorflow
  3. mask rcnn等

代码前的准备

1. 数据标记

感觉数据标记是一个大学问,我之前也是疯狂Google一些使用mask rcnn训练自己数据的一些文章用作参考,在它们中有很多都是用labelme标记数据,我第一次也用的这个,但是,训练后的模型就是没法儿正常用,可能是我有什么地方写的不对?反正是没法儿用。因此我就又去找mask rcnn GitHub上给出来的示例,发现,我如果想要做的话应该根据balloon那个模型进行变化,在balloon中使用的数据是用VIA标记的,因此抱着试一试的心态,又把100个左右的图片标记了一下,这一次很快就训练出了模型,而且效果也很好。这也是为什么我的项目名称后面带2的原因。废话不多说了,推荐使用VIA进行标记,遇到同样问题的小伙伴们可以戳我一下,哈哈。

这是VIA的官网,把工具下载一下: http://www.robots.ox.ac.uk/~vgg/software/via/
在这里插入图片描述
下载之后不需要安装,直接用浏览器打开就行。具体怎么使用可以自行去Google,挺简单的。因为我这个项目中只需要识别一个物体,因此只需要标记处来就好,不需要加什么属性标签。

2. 存放数据

上一步标记完成的数据会导出一个via_region_data.json 的文件,我建议可以自己查看一下这个文件下所有的数据保存结构,保存了哪些数据,对后面源码理解会有帮助。
将训练集和验证集的图片分成两个文件夹存放,并将它们各自标记的数据文件存放在各自的目录下。

3. 环境配置

去mask rcnn的GitHub上把源码下载下来,使用pycharm创建个虚拟环境,把用到的一些包都安装一下比如tensorflow啥的。然后去Google下mask_rcnn_coco.h5(一个预加载模型,凡是用过mask rcnn的同学都知道)。然后就可以开始写代码了!

写代码中

1. 看看balloon模型是怎么写的

https://github.com/matterport/Mask_RCNN/blob/master/samples/balloon/balloon.py 打开放在一边
可以看到

  1. 开始导了一些包,定义了一些常量
  2. 定义BallonConfig,即定义了训练时的一些参数
  3. 定义BalloonDataset,即定义了数据模型
  4. train函数
  5. 后面是一些与检测相关的函数,以及程序的命令行运行选项,都不是核心,直接略过

也就是说,3、4步是咱们的核心

2. 开始写

  1. 首先将定义自己的config,参考BallonConfig
from mrcnn.config import Config
class EwattConfig(Config):
    NAME = 'ewatt'
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    # 类型数目,本项目只需要识别一个天线,因此,为BG + antenna = 2
    NUM_CLASSES = 1 + 1
    # 图片的大小
    IMAGE_MIN_DIM = 640
    IMAGE_MAX_DIM = 1024
    # 候选区域大小
    RPN_ANCHOR_SCALES = (16 * 16, 32 * 128, 64 * 128, 32 * 256, 64 * 512)
    # 每张图片的ROI数目
    TRAIN_ROIS_PER_IMAGE = 32
    # 每个epoch的迭代数目
    STEPS_PER_EPOCH = 100
    VALIDATION_STEPS = 20
    DETECTION_MIN_CONFIDENCE = 0.9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

其它的一些参数,我也不是太明白,有兴趣的可以去找下相关文档或者直接看config文件里面的说明。

  1. 定义自己的dataset,参考BallonDataset
# -*- coding:utf8 -*-

from mrcnn import utils
import os
import json
import numpy as np
import skimage.draw
from skimage import io
class EwattDataset(utils.Dataset):
   # 加载图片
   def load_antenna(self,dataset_dir,subset):
        """Load a subset of the Balloon dataset.
        dataset_dir: Root directory of the dataset.
        subset: Subset to load: train or val
        """
        # 添加类别,第一个参数是大类别名称,第二个参数是序号,第三个是小类别名称
        self.add_class("antenna",1,"antenna")
        # 如:
        # self.add_class("antenna",2,"antenna2")
        # 建议使用一个大类别名称,具体的用处,可以自己试一下
        
        # 数据子集,不用也行,可以直接指定路径
        assert subset in ["train","val"]
        dataset_dir = os.path.join(dataset_dir,subset)
        # 获取标记
        annotations = json.load(open(os.path.join(dataset_dir, "via_region_data.json")))
        annotations = list(annotations.values())
        annotations = [a for a in annotations if a['regions']]
###############################################################
#上面这些就是从文件中加载一些数据,之前让看json文件的作用就在这儿体现,你可以每一步print看一下,在这里我就不演示了#
###############################################################
        for a in annotations:
        	# 获取一些定义的属性(attribute)值比如region(标记区域)等,也可以获取一些自己定义的属性,并在后面做相应的修改
            if type(a['regions']) is dict:
                polygons = [r['shape_attributes'] for r in a['regions'].values()]
            else:
                polygons = [r['shape_attributes'] for r in a['regions']]
            
            image_path = os.path.join(dataset_dir, a['filename'])
            image = io.imread(image_path)
            height, width = image.shape[:2]
            self.add_image(
                "antenna", # 大类名称
                image_id=a['filename'],  # use file name as a unique image id
                path=image_path,
                # 如果又多个名称,请加上class_id,这个就是上面add_class 时的序号,我这里没有
                # 如:
                # class_id = 1
                width=width, height=height,
                polygons=polygons)

# load_mask 是继承过来的方法,因此,不要修改参数和名称,这个方法在load_antenna方法之后调用
   def load_mask(self, image_id):
   		# 获取到图片的信息(信息是在load_antenna中添加的)
        image_info = self.image_info[image_id]
        if image_info["source"] != "antenna":
            return super(self.__class__,self).load_mask(image_id)

        info = self.image_info[image_id]
        # 加载mask矩阵
        mask = np.zeros([info["height"],info["width"],len(info["polygons"])],dtype=np.uint8)
        for i, p in enumerate(info['polygons']):
        	# 加载每一个点
            rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
            mask[rr, cc, i] = 1

        return mask.astype(np.bool),np.ones([mask.shape[-1]],dtype=np.int32)
        
# 这个函数我也不知道干啥的,修改下就完事儿了,好像是出错的时候报错的
   def image_reference(self, image_id):
       info = self.image_info[image_id]
       if info["source"] != "antenna":
           return info["path"]
       else:
           super(self.__class__,self).image_reference(image_id)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  1. 修改train文件
# -*- coding:utf8 -*-

import os
import tensorflow as tf
import mrcnn.model as modellib
import warnings
from EwattDataset import EwattDataset
from EwattConfig import EwattConfig

# 屏蔽一些不重要的warning(强迫症必备)
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)

# 获取根目录
ROOT_DIR = os.getcwd()
# print(ROOT_DIR)
# 训练后模型的存储文件夹
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")
# coco 模型
COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

dataset_path = os.path.join(ROOT_DIR,"data")
config = EwattConfig()
def train():
	# train数据
    dataset_train = EwattDataset()
    dataset_train.load_antenna(dataset_path, "train")
    dataset_train.prepare()
    # Val数据
    dataset_val = EwattDataset()
    dataset_val.load_antenna(dataset_path, "val")
    dataset_val.prepare()
	# 定义model
    model = modellib.MaskRCNN(mode='training',config=config,model_dir=DEFAULT_LOGS_DIR)
    # 加载coco模型
    model.load_weights(COCO_WEIGHTS_PATH, by_name=True, exclude=[
        "mrcnn_class_logits", "mrcnn_bbox_fc",
        "mrcnn_bbox", "mrcnn_mask"])
    # 开始训练
    model.train(dataset_train, dataset_val,
                learning_rate=config.LEARNING_RATE,
                epochs=20,
                layers='all')
train()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

训练

这一步只需要点一下运行按钮就行,不需要做什么,可以看一下中间输出的数据,如果发现又不是太合理的地方,就果断停止,找一下原因吧。

训练后的检测

检测就很简单了,根据demo https://github.com/matterport/Mask_RCNN/blob/master/samples/demo.ipynb 源码改一下就好了,没有什么难度。在这里就不写了。

我的结果

在这里插入图片描述


嘿嘿,结果还是比较让人满意的,因为只是学习下使用,因此也没有对一些东西做进一步的优化,如果有兴趣的同学可以调整一些参数做进一步的优化哟。到这里就结束了。如果大家有什么问题或者我有写的不合理的地方,欢迎在评论区留言。最近正在看YOLOv3的使用,等出成果了再写笔记,大家一起分享下。

转载请声明处处,谢谢



所属网站分类: 技术文章 > python文章

作者:23dh

链接: http://www.pythonheidong.com/blog/article/107339/

来源:python黑洞网 www.pythonheidong.com

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

6 0

赞一赞 or 踩一踩

收藏该文
已收藏

评论内容:(最多支持255个字符)

相似文章

  MaskRCNN-Benchmark训练自己的数据集

  python 通过Sybase ASE ODBC Driver访问sybase数据库,无需配置DSN【自己整理的】

  The New And Improved Flask - 2018 pdf 下载

  全套尚硅谷大数据视频 hadoop spark kafka flume

  解决python中的Non-UTF-8 code starting with ‘\xbs4’ in file错误

  手把手教用matlab做无人驾驶(十)--纯跟踪算法(pure control)的补充l---python与matlab/simulink两种语言的编程实现

  【SimpleITK】胸部CT数据3D space归一化,以及3D plot

  Pandas数据分析①——数据读取(CSV/TXT/JSON)

  复写 keras_resnet 在cifar10数据集上分类

  mysql-python 安装错误: Cannot open include file: 'config-win.h': No such file or directory

优质资源排行榜

 python经典电子书大合集下载 下载次数 8108

 零基础java开发工程师视频教程全套,基础+进阶+项目实战(152G) 下载次数 7545

 零基础前端开发工程师视频教程全套,基础+进阶+项目实战(共120G) 下载次数 7439

 零基础大数据全套视频400G 下载次数 7002

 零基础php开发工程师视频教程全套,基础+进阶+项目实战(80G) 下载次数 6891

 零基础软件测试全套系统教程 下载次数 6502

 全套人工智能视频+pdf 下载次数 6436

 IOS全套视频教程 基础班+就业班 下载次数 4679

 编程小白的第一本python入门书(高清版)PDF下载 下载次数 3100

10  effective python编写高质量Python代码的59个有效方法 pdf下载 下载次数 3065

11  Python深度学习 pdf下载 下载次数 3044

12  使用python+pygame开发的小游戏《嗷大喵快跑》源码下载 下载次数 2998

13  python项目开发视频 下载次数 2997

14  python从入门到精通视频(全60集)python视频教程下载 下载次数 2994

15  黑马2017年java就业班全套视频教程 下载次数 2992

16  python实战项目 平铺图像板系统源码下载,适用于想要保存,标记和共享图像,视频和网页的用户 下载次数 2987

17  利用python实现程序内存监控脚本 下载次数 2986

18  树莓派Python编程指南 pdf下载 下载次数 2979

19  老男孩python自动化视频 下载次数 2979

20  老王python基础+进阶+项目视频教程 下载次数 2972

21  尚硅谷Go学科全套视频 下载次数 2972

22  某硅谷Python项目+AI课程+核心基础视频教程 下载次数 2967

23  Web前端实战精品课程 下载次数 2966

24  Python基础教程 pdf下载 下载次数 2966

25  tron python小游戏 下载次数 2962

26  [小甲鱼]零基础入门学习Python 下载次数 2959

27  老男孩python全栈开发15期 下载次数 2958

28  2017最新web前端开发完整视频教程附源码 下载次数 2948

29  最新全套完整JAVAWEB2018开发视频 下载次数 2926

30  Python算法教程_中文版 pdf下载 下载次数 2913

31  Spring boot实战视频6套下载 下载次数 2909

32  python全套视频十五期(116G) 下载次数 2901

33  Python项目实战 下载次数 2882

34  python全自动抢火车票教程-python视频教程下载 下载次数 2882

35  30个小时搞定Python网络爬虫 含源码 下载次数 2881

36  尚硅谷大数据之Hadoop视频 下载次数 2876

37  简明python教程 (A Byte of Python)pdf下载 下载次数 2873

38  Python A~B~C~ python视频教程下载 下载次数 2864

39  数据结构与算法视频(小甲鱼讲解-全) 下载次数 2863

40  web小程序表白天数倒计时源码下载 下载次数 2862

41  Python Cookbook第三版中文PDF下载高清完整扫描原版 下载次数 2862

42  python基础视频教程 下载次数 2862

43  Python高性能编程 pdf下载 下载次数 2862

44  go语言全套视频 下载次数 2853

45  利用Python进行数据分析 pdf下载 下载次数 2850

46  清华学霸尹成Python爬虫视频-ok 下载次数 2845

47  黑马前端36期最全视频和代码 下载次数 2841

48  2018最新全套web前端视频教程+源码下载 下载次数 2840

49  老男孩Python自动化开发12期 老男孩最强一期python高级运维开发课程 第二部分 70GB 下载次数 2832

50  流畅的Python PDF下载高清完整扫描原版 下载次数 2828