+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2019-08(101)

2019-09(114)

《My Deep Diary》之《Kaggle Histopathologic Cancer Detection比赛之Tensorflow2.0/Keras Eager Execution实现》

发布于2019-08-07 14:34     阅读(82)     评论(0)     点赞(5)     收藏(0)


Kaggle项目地址:https://www.kaggle.com/c/histopathologic-cancer-detection/overview

本文记录了一个使用Tensorflow2.0/Keras Eager Execution的实现,数据预处理采用了Tensorflow标准的Dataset的方式:

 

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. AUTOTUNE = tf.data.experimental.AUTOTUNE
  4. # tf.enable_eager_execution()
  5. import numpy as np
  6. import os,sys,csv
  7. import cv2 as cv
  8. import matplotlib.pyplot as plt
  9. import pandas as pd
  10. from sklearn.model_selection import train_test_split
  11. from sklearn.utils import shuffle
  12. import myimageutil as iu
  13. """
  14. ====================================================================================
  15. <<1.初步了解掌握数据的情况>>
  16. ====================================================================================
  17. 用pandas简单处理一下CSV并画出来看一下
  18. 这里我借用了kaggle的这篇kernel里的plot的代码,有兴趣的童鞋可以读一下,
  19. https://www.kaggle.com/qitvision/a-complete-ml-pipeline-fast-ai
  20. """
  21. ROOT_PATH = 'D:/ai_data/histopathologic-cancer-detection'
  22. CSV_PATH = 'D:/ai_data/histopathologic-cancer-detection/train_labels.csv'
  23. TRAIN_PATH = 'D:/ai_data/histopathologic-cancer-detection/train'
  24. TEST_PATH = 'D:/ai_data/histopathologic-cancer-detection/test'
  25. print(">>>看一下根目录下有哪些东西:")
  26. print(os.listdir(ROOT_PATH))
  27. df = pd.read_csv(CSV_PATH) #pandas里的数据集叫dataframe,和scala里的一样,我们简称df
  28. # 接下来我们来看一下数据的情况
  29. print(">>>这个数据集的大小:")
  30. print(df.shape)
  31. print(">>>这个数据集的样本分布:")
  32. print(df['label'].value_counts())
  33. print(">>>看一下数据:")
  34. print(df.head())
  35. # 这边我想说明一下,之前我们的第一篇walkthrough里是直接从csv中获得文件列表的,这边最好检查一下列表里的文件和文件夹里的是不是一一对应
  36. print(">>>list一下训练图片文件夹里的图片:")
  37. from glob import glob
  38. train_file_paths = glob(TRAIN_PATH + '/*.tif')
  39. test_file_paths = glob(TEST_PATH + '/*.tif')
  40. print("train_file_paths size:", len(train_file_paths))
  41. print("test_file_paths size:", len(test_file_paths))
  42. import re
  43. def check_valid():
  44. assert len(train_file_paths) == len(df['id']),'图片数量不一致'
  45. ids_from_filepath = list(map(lambda filepath:''.join(re.findall(r'[a-z0-9]{40}',filepath)), train_file_paths))
  46. dif = list(set(ids_from_filepath)^set(df['id'])) #求两个list的差集,如果差集为0,那说明两个list相等
  47. if len(dif) == 0:
  48. print("文件名匹配正常")
  49. else:
  50. print("匹配异常,下列文件名有差异:")
  51. print(dif)
  52. exit()
  53. check_valid()
  54. # print(">>>数据没问题的话接下来看一下正负数据样例的图片:")
  55. # iu.plotSamples(df,TRAIN_PATH) #要注意本次的图片数据是使用中间32X32像素的内容为基准进行标注的,所以画图把中间一块标注出来了,但实际分类的时候不一定要把中间裁出来
  56. # print(">>>进入正题,我们拆分一下数据,把训练数据分成训练和测试2部分,比例为9:1")
  57. train, val = train_test_split(train_file_paths, test_size=0.1, shuffle=True)
  58. id_label_map = {k:v for k,v in zip(df.id.values, df.label.values)}
  59. def get_paths_labels(pathlist):
  60. ids = []
  61. labels = []
  62. for item in pathlist:
  63. id = ''.join(re.findall(r'[a-z0-9]{40}',item))
  64. label = id_label_map[id]
  65. ids.append(item)
  66. labels.append(label)
  67. return ids,labels
  68. train_paths,train_labels = get_paths_labels(train)
  69. val_paths,val_labels = get_paths_labels(val)
  70. # exit()
  71. """
  72. ====================================================================================
  73. <<2.图片处理和扩增>>
  74. ====================================================================================
  75. 图片处理主要是要匹配CNN的输入大小,扩增是为了降低过拟合风险
  76. 无论是图片处理还是扩增都有太多方法了,比较常用的imageaug或者tf.image进行数据扩增,其实openCV什么都能干
  77. imgaug堪称python里最强图片扩增工具,方法多,叠加方便,一个图像数据扩增100倍轻轻松松:
  78. https://github.com/aleju/imgaug
  79. 使用tensorflow自带的tf.image进行augmentation,特点是能结合tf.dataset无缝使用:
  80. http://androidkt.com/tensorflow-image-augmentation-using-tf-image/
  81. 这边我们使用imgaug进行处理,最后生成tf.dataset进行训练
  82. """
  83. BATCH_SIZE = 32
  84. #我们还是使用之前的方法读取tif文件,tensorflow本身不支持读取tif,所以只能用py_func调用外部函数来读取
  85. def image_aug_cv(filepath,label):
  86. image_decoded = cv.imread(filepath.numpy().decode(), 1)
  87. image_resized = tf.image.resize(image_decoded, [224, 224])
  88. return aug_image(image_resized), label
  89. def aug_image(image):
  90. return image / 255.0
  91. def prepare_train_ds(filepaths,labels):
  92. global BATCH_SIZE
  93. paths_ds = tf.data.Dataset.from_tensor_slices(filepaths)
  94. labels_ds = tf.data.Dataset.from_tensor_slices(labels)
  95. paths_labels_ds = tf.data.Dataset.zip((paths_ds,labels_ds))
  96. images_labels_ds = paths_labels_ds.shuffle(buffer_size=300000)
  97. images_labels_ds = images_labels_ds.map(lambda filename,label : tf.py_function( func=image_aug_cv,
  98. inp=[filename,label],
  99. Tout=[tf.float32,tf.float32]),
  100. num_parallel_calls=AUTOTUNE)
  101. # images_labels_ds = images_labels_ds.repeat()
  102. images_labels_ds = images_labels_ds.batch(BATCH_SIZE)
  103. images_labels_ds = images_labels_ds.prefetch(buffer_size = 200)
  104. return images_labels_ds
  105. train_ds = prepare_train_ds(train_paths,np.asarray(train_labels).astype('float32').reshape((-1,1)))
  106. val_ds = prepare_train_ds(val_paths,np.asarray(val_labels).astype('float32').reshape((-1,1)))
  107. """
  108. ====================================================================================
  109. <<3.建模>>
  110. ====================================================================================
  111. 使用keras和比较新的NASnet来建立模型,方法和walkthrough里的一摸一样
  112. """
  113. from tensorflow.keras.layers import concatenate, Activation, GlobalAveragePooling2D, Flatten
  114. from tensorflow.keras.layers import Dense, Input, Dropout, MaxPooling2D, Concatenate, GlobalMaxPooling2D, GlobalAveragePooling2D
  115. from tensorflow.keras.models import Model
  116. from tensorflow.keras.applications.nasnet import NASNetMobile
  117. # from tensorflow.keras.optimizers import Adam
  118. nasnet = NASNetMobile(include_top=False, input_shape=(224, 224, 3))
  119. x1 = GlobalMaxPooling2D()(nasnet.output)
  120. x2 = GlobalAveragePooling2D()(nasnet.output)
  121. x3 = Flatten()(nasnet.output)
  122. out = Concatenate(axis=-1)([x1, x2, x3])
  123. out = Dropout(0.5)(out)
  124. predictions = Dense(1, activation="sigmoid",name = 'predictions')(out)
  125. model = Model(inputs=nasnet.input, outputs=predictions)
  126. model.trainable = True
  127. # for layer in model.layers[:-3]:
  128. # layer.trainable = False
  129. optimizer = tf.keras.optimizers.Adam(lr = 0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
  130. loss_func = tf.keras.losses.BinaryCrossentropy()
  131. # model.summary()
  132. train_loss = tf.keras.metrics.Mean(name='train_loss')
  133. train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
  134. val_loss = tf.keras.metrics.Mean(name='val_loss')
  135. val_accuracy = tf.keras.metrics.BinaryAccuracy(name='val_accuracy')
  136. """
  137. ====================================================================================
  138. <<3.训练>>
  139. ====================================================================================
  140. 这边使用官方标准的tensorflow 2.0 Eager Execution的训练方法来训练网络
  141. """
  142. # @tf.function
  143. def train_step(images, labels):
  144. with tf.GradientTape() as tape:
  145. predictions = model(images)
  146. loss = loss_func(labels, predictions)
  147. # print("train loss:"+str(loss.numpy()))
  148. gradients = tape.gradient(loss, model.trainable_variables)
  149. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  150. train_loss(loss)
  151. train_accuracy(labels, predictions)
  152. # @tf.function
  153. def val_step(images, labels):
  154. predictions = model(images)
  155. loss = loss_func(labels, predictions)
  156. # print("val loss:"+str(loss.numpy()))
  157. val_loss(loss)
  158. val_accuracy(labels, predictions)
  159. EPOCHS = 20
  160. import datetime
  161. for epoch in range(EPOCHS):
  162. for images, labels in train_ds:
  163. train_step(images, labels)
  164. for val_images, val_labels in val_ds:
  165. val_step(val_images, val_labels)
  166. template = 'Epoch {}, Loss: {}, Accuracy: {}, val Loss: {}, val Accuracy: {}'
  167. print (template.format(epoch+1,
  168. train_loss.result(),
  169. train_accuracy.result()*100,
  170. val_loss.result(),
  171. val_accuracy.result()*100))
  172. print(datetime.datetime.now())

 



所属网站分类: 技术文章 > python文章

作者:倒车请注意

链接: http://www.pythonheidong.com/blog/article/11353/

来源:python黑洞网 www.pythonheidong.com

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

5 0

赞一赞 or 踩一踩

收藏该文
已收藏

评论内容:(最多支持255个字符)

相似文章

  Timezone offset does not match system offset: 0 != -32400. Please, check your config files

  Think Python : How to Think Like a Computer Scientist pdf下载

  tensorflow遇到ImportError: Could not find 'cudart64_100.dll'错误解决

  解决python中的Non-UTF-8 code starting with ‘\xbs4’ in file错误

  Python Algorithms : Mastering Basic Algorithms in the Python Language pdf下载

  Python for Data Analysis. Data Wrangling with Pandas, NumPy, and IPython pdf下载

  Learn Python the Hard Way : A Very Simple Introduction to the Terrifyingly Beautiful World o pdf下载

  VS CODE Python setting and launch.json

  Caused by SSLError("Can’t connect to HTTPS URL because the SSL module is not available)

  手把手教用matlab做无人驾驶(十)--纯跟踪算法(pure control)的补充l---python与matlab/simulink两种语言的编程实现

优质资源排行榜

 python经典电子书大合集下载 下载次数 8104

 零基础java开发工程师视频教程全套,基础+进阶+项目实战(152G) 下载次数 7543

 零基础前端开发工程师视频教程全套,基础+进阶+项目实战(共120G) 下载次数 7438

 零基础大数据全套视频400G 下载次数 7001

 零基础php开发工程师视频教程全套,基础+进阶+项目实战(80G) 下载次数 6891

 零基础软件测试全套系统教程 下载次数 6501

 全套人工智能视频+pdf 下载次数 6436

 IOS全套视频教程 基础班+就业班 下载次数 4679

 编程小白的第一本python入门书(高清版)PDF下载 下载次数 3076

10  effective python编写高质量Python代码的59个有效方法 pdf下载 下载次数 3047

11  Python深度学习 pdf下载 下载次数 3035

12  使用python+pygame开发的小游戏《嗷大喵快跑》源码下载 下载次数 2998

13  python项目开发视频 下载次数 2996

14  python从入门到精通视频(全60集)python视频教程下载 下载次数 2993

15  黑马2017年java就业班全套视频教程 下载次数 2992

16  python实战项目 平铺图像板系统源码下载,适用于想要保存,标记和共享图像,视频和网页的用户 下载次数 2987

17  利用python实现程序内存监控脚本 下载次数 2986

18  老男孩python自动化视频 下载次数 2979

19  树莓派Python编程指南 pdf下载 下载次数 2976

20  尚硅谷Go学科全套视频 下载次数 2972

21  老王python基础+进阶+项目视频教程 下载次数 2971

22  某硅谷Python项目+AI课程+核心基础视频教程 下载次数 2966

23  Web前端实战精品课程 下载次数 2966

24  Python基础教程 pdf下载 下载次数 2962

25  tron python小游戏 下载次数 2962

26  [小甲鱼]零基础入门学习Python 下载次数 2959

27  老男孩python全栈开发15期 下载次数 2958

28  2017最新web前端开发完整视频教程附源码 下载次数 2948

29  最新全套完整JAVAWEB2018开发视频 下载次数 2926

30  Python算法教程_中文版 pdf下载 下载次数 2910

31  Spring boot实战视频6套下载 下载次数 2909

32  python全套视频十五期(116G) 下载次数 2901

33  Python项目实战 下载次数 2882

34  python全自动抢火车票教程-python视频教程下载 下载次数 2882

35  30个小时搞定Python网络爬虫 含源码 下载次数 2881

36  尚硅谷大数据之Hadoop视频 下载次数 2876

37  简明python教程 (A Byte of Python)pdf下载 下载次数 2870

38  Python A~B~C~ python视频教程下载 下载次数 2864

39  数据结构与算法视频(小甲鱼讲解-全) 下载次数 2863

40  web小程序表白天数倒计时源码下载 下载次数 2862

41  python基础视频教程 下载次数 2862

42  Python高性能编程 pdf下载 下载次数 2858

43  Python Cookbook第三版中文PDF下载高清完整扫描原版 下载次数 2856

44  go语言全套视频 下载次数 2852

45  清华学霸尹成Python爬虫视频-ok 下载次数 2845

46  黑马前端36期最全视频和代码 下载次数 2841

47  2018最新全套web前端视频教程+源码下载 下载次数 2839

48  利用Python进行数据分析 pdf下载 下载次数 2834

49  老男孩Python自动化开发12期 老男孩最强一期python高级运维开发课程 第二部分 70GB 下载次数 2832

50  python视频 神经网络 Tensorflow 下载次数 2827