发布于2019-08-22 19:34 阅读(2139) 评论(0) 点赞(30) 收藏(3)
本文属于课程笔记,源自曹健老师的”人工智能实践:Tensorflow笔记”(侵删):https://www.icourse163.org/learn/PKU-1002536002#/learn/announce
tfrecords是一种二进制文件格式,理论上它可以保存任何格式的信息,可将图片和标签制作该改格式文件,使用tfrecords进行存储,可提高内存利用率。
其中:tf.train.Example用来存储训练数据 。 训练数据的特征用键值对的形式表示。
如“ img_raw ” :值, ”label ”: 值 。值的参数分别是 BytesList/FloatList/Int64List,别对应于取值为二进制数,浮点数,整数特征。SerializeToString( ) 把数据序列化成字符串存储。
下列代码为制作,获取,使用tfrecords格式数据集,其中输入图像为28*28的手写数字图像,输出为0-9的数字标签
#coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import os
image_train_path='./mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path='./mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train='./data/mnist_train.tfrecords'
image_test_path='./mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path='./mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test='./data/mnist_test.tfrecords'
data_path='./data'
resize_height = 28
resize_width = 28
def write_tfRecord(tfRecordName, image_path, label_path):
#接收路径/文件名,图像路径,标签路径
#创建一个新的writer(实例化)
writer = tf.python_io.TFRecordWriter(tfRecordName)
#图片数量以显示进度
num_pic = 0
#打开标签文件(txt文件,格式:图片名(空格)标签),读取内容
f = open(label_path, 'r')
contents = f.readlines()
f.close()
for content in contents:
#分隔每行内容
value = content.split()
#图片路径:图片路径+图片名
img_path = image_path + value[0]
img = Image.open(img_path)
#转换为二进制数据
img_raw = img.tobytes()
#初始化,并将标签位赋值为1
labels = [0] * 10
labels[int(value[1])] = 1
# 把每张图片和标签封装到example中 (img_raw与labels)
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
#将example序列化
writer.write(example.SerializeToString())
num_pic += 1
print ("the number of picture:", num_pic)
writer.close()
print("write tfrecord successful")
def generate_tfRecord():
#判断保存路径是否存在,不存在则创建路径,存在则打印已存在
isExists = os.path.exists(data_path)
if not isExists:
os.makedirs(data_path)
print 'The directory was created successfully'
else:
print 'directory already exists'
#调用write_tfRecord将训练集和验证集的图片和标签写成tfrecords文件。
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
def read_tfRecord(tfRecord_path):
#新建文件名队列,告知文件名队列包括那些文件
#tf.train.string_input_producer(string_tensor,num_epochs=None,shuffle=True,seed=None,capacity=32,#shared_name=None,name=None,cancel_op=None)
#该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据。
#string_tensor: 存储图像和标签信息的 TFRecord 文件名列表,
#num_epochs: 循环读取的轮数(可选),shuffle :布尔值(可选),如果为 True ,则在每轮随机打乱读取顺序,
#seed; 随机读取时设置的种子(可选),capacity :队列容量
#shared_name :(可选 如果设置,该队列将在多个会话中以给定名称共享。所有具有此队列的设备都可以通过 shared_name 访问它。在分布式设置中使用这种方法意味着每个名称只能被访问此操作的其中一个会话看到。
#name :操作的名称(可选),cancel_op :取消队列 None
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
reader = tf.TFRecordReader()
#将读出的每个样本保存到serialized_example中,进行解序列化
_, serialized_example = reader.read(filename_queue)
#标签要给出实际分类数
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img.set_shape([784])
img = tf.cast(img, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.float32)
return img, label
def get_tfrecord(num, isTrain=True):
#获取tfrecords文件,num每次获取的数据量,isTrain--训练集--True,测试集--False
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
#这个函数随机读取一个batch的数据 。
#从总样本中顺序取出capacity组数据打乱顺序,每次输出batch_size组
#如果少于min_after_dequeue,会从总样中取数据填满capacity,共使用2个线程
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size = num,
num_threads = 2,
capacity = 1000,
min_after_dequeue = 700)
return img_batch, label_batch
def main():
generate_tfRecord()
if __name__ == '__main__':
main()
除此之外,还可以在方向传播过程(文件)中利用多线程提高图片和标签的批获取效率。
方法:将批获取的操作放到线程协调器开启和关闭之间
开启线程协调器:
coord = tf.train.Coordinator( )tf.train.Coordinator( )
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
关闭线程协调器:
coord.request_stop( )
coord.join(threads)
作者:44344df
链接:https://www.pythonheidong.com/blog/article/53438/a4095bdf797a9f34da9f/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!