发布于2019-08-20 18:37 阅读(1175) 评论(0) 点赞(12) 收藏(2)
本文主要参考博客https://blog.csdn.net/u012426298/article/details/81386817。
首先获取预训练模型,和相应的prototxt文件,连接就不上了,参考博客https://blog.csdn.net/u012426298/article/details/81386817
一、数据集
为了避免自己去标注太多的图片,所以采用了标注好了的nyu数据集。参考博客https://blog.csdn.net/weixin_43915709/article/details/88774325。对labels40.mat操作,得到 label 图。
- #-*- coding:UTF-8 -*-
- # 从mat文件提取labels
- # 需要注意这个文件里面的格式和官方有所不同,长宽需要互换,也就是进行转置
- import cv2
- import scipy.io as scio
- from PIL import Image
- import numpy as np
- import matplotlib.pyplot as plt
- import os
-
- dataFile = './labels40.mat'
- data = scio.loadmat(dataFile)
- labels=np.array(data["labels40"])
-
- path_converted='./nyu_labels40'
- if not os.path.isdir(path_converted):
- os.makedirs(path_converted)
-
- labels_number=[]
- for i in range(1449):
- labels_number.append(labels[:,:,i].transpose((1, 0))) # 转置
- labels_0=np.array(labels_number[i])
- #print labels_0.shape
- print (type(labels_0))
- label_img=Image.fromarray(np.uint8(labels_number[i]))
- #label_img = label_img.rotate(270)
- label_img = label_img.transpose(Image.ROTATE_270)
-
- iconpath='./nyu_labels40/'+str('%06d'%(i+1))+'.png'
- label_img.save(iconpath, optimize=True)
注意:需要对数据集进行resize,不然很可能出现图片过大或过小的问题。参考博客https://www.cnblogs.com/SweetBeens/p/8572655.html 的MATLAB程序。
- ObjDir = 'F:\STUDY\CamVid\trainannot\';%将被改变的图像地址,称为目标地址
- OtpDir = 'F:\CamVid\trainannot\';%输出图像地址,称为输出地址
- for i = 1:1:1449%我的图像标号是00000001到00001449
- bgFile = [ObjDir,num2str(i,'%08d'),'.png'];%这句话读取目标地址里面的格式为png的图片
- %num2str是先把数字i转换成string然后补零直到八位
- %举个例子:i=13,num2str(i,'%08d)=00000013,类型是string
- bgFile = imread(bgFile);%把图片读成matlab认识的,类型为:图片
- img = imresize(bgFile,[360,480]);%调整大小到高360,长480
- filename=[num2str(i,'%08d'),'.png'];%输出的图片名称是00000001.png
- path=fullfile(OtpDir,filename);%输出的路径
- imwrite(img,path,'png');%以png格式输出出去
- end
二、 需要确定 label 图中的类别以及classd_weighting ,参考以下代码class_weight.py:
命令
python class_weight.py --dir ./labels_deal # label 图的路径
class_weight.py
- import numpy as np
- import argparse
- import os
- from PIL import Image
- from os import listdir
- import sys
- import collections
-
- # Import arguments
- parser = argparse.ArgumentParser()
- parser.add_argument('--dir', type=str, help='Path to the folder containing the images with annotations')
- args = parser.parse_args()
-
- if args.dir:
- cwd = args.dir
- if not args.dir.endswith('/'): cwd = cwd + '/'
- else:
- cwd = os.getcwd() + '/'
-
- image_names = listdir(cwd)
- # Keep only images and append image_names to directory
- image_list = [cwd + s for s in image_names if s.lower().endswith(('.png', '.jpg', '.jpeg'))]
-
- print "Number of images:", len(image_list)
-
- def count_all_pixels(image_list):
- dic_class_imgcount = dict()
- overall_pixelcount = dict()
- result = dict()
- for img in image_list:
- sys.stdout.write('.')
- sys.stdout.flush()
- for key, value in get_class_per_image(img).items():
- # Sum up the number of classes returned from get_class_per_image function
- overall_pixelcount[key] = overall_pixelcount.get(key, 0) + value
- # If the class is present in the image, then increase the value by one
- # shows in how many images a particular class is present
- dic_class_imgcount[key] = dic_class_imgcount.get(key, 0) + 1
- print "Done"
- # Save above 2 variables in a list
- for (k, v), (k2, v2) in zip(overall_pixelcount.items(), dic_class_imgcount.items()):
- if k != k2: print ("This was impossible to happen, but somehow it did"); exit()
- result[k] = [v, v2]
- return result
-
-
- def get_class_per_image(img):
- dic_class_pixelcount = dict()
- im = Image.open(img)
- pix = im.load()
- for x in range(im.size[0]):
- for y in range(im.size[1]):
- dic_class_pixelcount[pix[x, y]] = dic_class_pixelcount.get(pix[x, y], 0) + 1
- #del dic_class_pixelcount[11]
- return dic_class_pixelcount
-
-
- def cal_class_weights(image_list):
- freq_images = dict()
- weights = collections.OrderedDict()
- # calculate freq per class
- for k, (v1, v2) in count_all_pixels(image_list).items():
- freq_images[k] = v1 / (v2 * 360 * 480 * 1.0)
- # calculate median of freqs
- median = np.median(freq_images.values())
- # calculate weights
- for k, v in freq_images.items():
- weights[k] = median / v
- return weights
-
- results = cal_class_weights(image_list)
-
- # Print the results
- for k, v in results.items():
- print " class", k, "weight:", round(v, 4)
-
- print "Copy this:"
- for k, v in results.items():
- print " class_weighting:", round(v, 4)
我一直以为我的 label 图是40类,https://blog.csdn.net/u012455577/article/details/86316996 。之前还凑合用别人的40类class_weighting: 结果一直出错
F0725 17:02:42.888584 17046 math_functions.cu:121] Check failed: status == CUBLAS_STATUS_SUCCESS (11 vs. 0) CUBLAS_STATUS_MAPPING_ERROR
真的是蠢死。。。用了class_weight.py 后才发现有48类。
三、制作 train.txt以及test.txt文件
txtfile.sh
- #!/usr/bin/env sh
- DATA_train=/home/zml/data/nyu/40_label/images_deal
- MASK_train=/home/zml/data/nyu/40_label/labels_deal
- DATA_test=/home/zml/data/nyu/40_label/images_deal
- MASK_test=/home/zml/data/nyu/40_label/labels_deal
-
- MY=/home/zml/temp/transferlearn/
-
- ################################################
- rm -rf $MY/train.txt
-
- echo "Create train.txt"
- find $DATA_train/ -name "*.png">>$MY/img.txt
- find $MASK_train/ -name "*.png">>$MY/mask.txt
- paste -d " " $MY/img.txt $MY/mask.txt>$MY/train.txt
-
- rm -rf $MY/img.txt
- rm -rf $MY/mask.txt
-
- ##################################################
- rm -rf $MY/test.txt
-
- echo "Create test.txt"
- find $DATA_test/ -name "*.png">>$MY/img.txt
- find $MASK_test/ -name "*.png">>$MY/mask.txt
- paste -d " " $MY/img.txt $MY/mask.txt>$MY/test.txt
-
- rm -rf $MY/img.txt
- rm -rf $MY/mask.txt
用命令 sh txtfile.sh 可得到train.txt和test.txt
四、修改segnet_train.prototxt
最后一个num_output 修改为自己数据集label图的类别总数:
修改ignore_label,修改为自己的类别数。并根据之前的class_weight.py得到的class_weighting修改文件。
五、修改segnet_solver.prototxt,修改里面的学习率等,这个文件比较多,所以就不细细讲了。
六、运行
/home/zml/caffe/caffe-segnet-cudnn5/build/tools/caffe train -solver /home/zml/temp/transferlearn/file/segnet_solver.prototxt -weights -solver /home/zml/temp/transferlearn/file/segnet_pascal.caffemodel -gpu 0
作者:83748wuw
链接:https://www.pythonheidong.com/blog/article/49514/79eee7d619889956bc02/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!