程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

segnet 迁移学习

发布于2019-08-20 18:37     阅读(992)     评论(0)     点赞(12)     收藏(2)


本文主要参考博客https://blog.csdn.net/u012426298/article/details/81386817

首先获取预训练模型,和相应的prototxt文件,连接就不上了,参考博客https://blog.csdn.net/u012426298/article/details/81386817

一、数据集

为了避免自己去标注太多的图片,所以采用了标注好了的nyu数据集。参考博客https://blog.csdn.net/weixin_43915709/article/details/88774325。对labels40.mat操作,得到 label 图。

  1. #-*- coding:UTF-8 -*-
  2. # 从mat文件提取labels
  3. # 需要注意这个文件里面的格式和官方有所不同,长宽需要互换,也就是进行转置
  4. import cv2
  5. import scipy.io as scio
  6. from PIL import Image
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. import os
  10. dataFile = './labels40.mat'
  11. data = scio.loadmat(dataFile)
  12. labels=np.array(data["labels40"])
  13. path_converted='./nyu_labels40'
  14. if not os.path.isdir(path_converted):
  15. os.makedirs(path_converted)
  16. labels_number=[]
  17. for i in range(1449):
  18. labels_number.append(labels[:,:,i].transpose((1, 0))) # 转置
  19. labels_0=np.array(labels_number[i])
  20. #print labels_0.shape
  21. print (type(labels_0))
  22. label_img=Image.fromarray(np.uint8(labels_number[i]))
  23. #label_img = label_img.rotate(270)
  24. label_img = label_img.transpose(Image.ROTATE_270)
  25. iconpath='./nyu_labels40/'+str('%06d'%(i+1))+'.png'
  26. label_img.save(iconpath, optimize=True)

注意:需要对数据集进行resize,不然很可能出现图片过大或过小的问题。参考博客https://www.cnblogs.com/SweetBeens/p/8572655.html 的MATLAB程序。

  1. ObjDir = 'F:\STUDY\CamVid\trainannot\';%将被改变的图像地址,称为目标地址
  2. OtpDir = 'F:\CamVid\trainannot\';%输出图像地址,称为输出地址
  3. for i = 1:1:1449%我的图像标号是00000001到00001449
  4. bgFile = [ObjDir,num2str(i,'%08d'),'.png'];%这句话读取目标地址里面的格式为png的图片
  5. %num2str是先把数字i转换成string然后补零直到八位
  6. %举个例子:i=13,num2str(i,'%08d)=00000013,类型是string
  7. bgFile = imread(bgFile);%把图片读成matlab认识的,类型为:图片
  8. img = imresize(bgFile,[360,480]);%调整大小到高360,长480
  9. filename=[num2str(i,'%08d'),'.png'];%输出的图片名称是00000001.png
  10. path=fullfile(OtpDir,filename);%输出的路径
  11. imwrite(img,path,'png');%以png格式输出出去
  12. end

 

二、 需要确定 label 图中的类别以及classd_weighting ,参考以下代码class_weight.py:

命令

python class_weight.py --dir  ./labels_deal  # label 图的路径

class_weight.py 

  1. import numpy as np
  2. import argparse
  3. import os
  4. from PIL import Image
  5. from os import listdir
  6. import sys
  7. import collections
  8. # Import arguments
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('--dir', type=str, help='Path to the folder containing the images with annotations')
  11. args = parser.parse_args()
  12. if args.dir:
  13. cwd = args.dir
  14. if not args.dir.endswith('/'): cwd = cwd + '/'
  15. else:
  16. cwd = os.getcwd() + '/'
  17. image_names = listdir(cwd)
  18. # Keep only images and append image_names to directory
  19. image_list = [cwd + s for s in image_names if s.lower().endswith(('.png', '.jpg', '.jpeg'))]
  20. print "Number of images:", len(image_list)
  21. def count_all_pixels(image_list):
  22. dic_class_imgcount = dict()
  23. overall_pixelcount = dict()
  24. result = dict()
  25. for img in image_list:
  26. sys.stdout.write('.')
  27. sys.stdout.flush()
  28. for key, value in get_class_per_image(img).items():
  29. # Sum up the number of classes returned from get_class_per_image function
  30. overall_pixelcount[key] = overall_pixelcount.get(key, 0) + value
  31. # If the class is present in the image, then increase the value by one
  32. # shows in how many images a particular class is present
  33. dic_class_imgcount[key] = dic_class_imgcount.get(key, 0) + 1
  34. print "Done"
  35. # Save above 2 variables in a list
  36. for (k, v), (k2, v2) in zip(overall_pixelcount.items(), dic_class_imgcount.items()):
  37. if k != k2: print ("This was impossible to happen, but somehow it did"); exit()
  38. result[k] = [v, v2]
  39. return result
  40. def get_class_per_image(img):
  41. dic_class_pixelcount = dict()
  42. im = Image.open(img)
  43. pix = im.load()
  44. for x in range(im.size[0]):
  45. for y in range(im.size[1]):
  46. dic_class_pixelcount[pix[x, y]] = dic_class_pixelcount.get(pix[x, y], 0) + 1
  47. #del dic_class_pixelcount[11]
  48. return dic_class_pixelcount
  49. def cal_class_weights(image_list):
  50. freq_images = dict()
  51. weights = collections.OrderedDict()
  52. # calculate freq per class
  53. for k, (v1, v2) in count_all_pixels(image_list).items():
  54. freq_images[k] = v1 / (v2 * 360 * 480 * 1.0)
  55. # calculate median of freqs
  56. median = np.median(freq_images.values())
  57. # calculate weights
  58. for k, v in freq_images.items():
  59. weights[k] = median / v
  60. return weights
  61. results = cal_class_weights(image_list)
  62. # Print the results
  63. for k, v in results.items():
  64. print " class", k, "weight:", round(v, 4)
  65. print "Copy this:"
  66. for k, v in results.items():
  67. print " class_weighting:", round(v, 4)

我一直以为我的 label 图是40类,https://blog.csdn.net/u012455577/article/details/86316996 。之前还凑合用别人的40类class_weighting: 结果一直出错

F0725 17:02:42.888584 17046 math_functions.cu:121] Check failed: status == CUBLAS_STATUS_SUCCESS (11 vs. 0)  CUBLAS_STATUS_MAPPING_ERROR

真的是蠢死。。。用了class_weight.py 后才发现有48类。 

三、制作 train.txt以及test.txt文件

txtfile.sh

  1. #!/usr/bin/env sh
  2. DATA_train=/home/zml/data/nyu/40_label/images_deal
  3. MASK_train=/home/zml/data/nyu/40_label/labels_deal
  4. DATA_test=/home/zml/data/nyu/40_label/images_deal
  5. MASK_test=/home/zml/data/nyu/40_label/labels_deal
  6. MY=/home/zml/temp/transferlearn/
  7. ################################################
  8. rm -rf $MY/train.txt
  9. echo "Create train.txt"
  10. find $DATA_train/ -name "*.png">>$MY/img.txt
  11. find $MASK_train/ -name "*.png">>$MY/mask.txt
  12. paste -d " " $MY/img.txt $MY/mask.txt>$MY/train.txt
  13. rm -rf $MY/img.txt
  14. rm -rf $MY/mask.txt
  15. ##################################################
  16. rm -rf $MY/test.txt
  17. echo "Create test.txt"
  18. find $DATA_test/ -name "*.png">>$MY/img.txt
  19. find $MASK_test/ -name "*.png">>$MY/mask.txt
  20. paste -d " " $MY/img.txt $MY/mask.txt>$MY/test.txt
  21. rm -rf $MY/img.txt
  22. rm -rf $MY/mask.txt

 用命令 sh  txtfile.sh 可得到train.txt和test.txt

 四、修改segnet_train.prototxt 

最后一个num_output 修改为自己数据集label图的类别总数: 

修改ignore_label,修改为自己的类别数。并根据之前的class_weight.py得到的class_weighting修改文件。 

五、修改segnet_solver.prototxt,修改里面的学习率等,这个文件比较多,所以就不细细讲了。

六、运行

/home/zml/caffe/caffe-segnet-cudnn5/build/tools/caffe  train -solver /home/zml/temp/transferlearn/file/segnet_solver.prototxt -weights  -solver /home/zml/temp/transferlearn/file/segnet_pascal.caffemodel -gpu 0

 



所属网站分类: 技术文章 > 博客

作者:83748wuw

链接:https://www.pythonheidong.com/blog/article/49514/79eee7d619889956bc02/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

12 0
收藏该文
已收藏

评论内容:(最多支持255个字符)