TFrecord 使用技术
Table of Contents
1 如何生成 TFrecord
这里仍然给使用已经 pretrained 的模型来训练我们自己的任务, 也就是我们只训练最后一层.
google tensorflow 提供的 slim 可以用一种更加强大且自由的方式去定义自己的任务, 你需要编写下载数据集的代码, 转换为tfrecord的代码, 注册入 dataset 的代码, 以及 tfrecord读入内存的代码. 但幸运的是这些代码的模式都很固定, 且其文件夹下有实例可以参考.
…../models/research/slim/
我们主要使用, slim 文件夹下的 train_image_classifier.py
,
- 首先做图片预处理, 把图片生成 .tfrecord 文件, 该文件类型底层使用 protobuffer — google 提供的 二进制 文件存储方式, 传输和运算效率非常高. 在进行模型训练的时候使用 .tfrecord 作为数据输入格式.
import math import os import random import sys import tensorflow as tf # 验证集数量 _NUM_TEST = 500 # random seed _RANDOM_SEED = 0 # 数据块数量 _NUM_SHARDS = 5 # 数据集路径 DATASET_DIR = "lec_8_2_data/train/" # 生成的标签文件. 注意这里'生成'的意思, 数据图片是使用各自所在文件夹作为自己的 # 标签, '生成'的意思是把文件夹名字映射为数字. LABELS_FILENAME = "lec8_2_produced_labels/labels.txt" #定义tfrecord文件的路径+名字 def _get_dataset_filename(dataset_dir, split_name, shard_id): output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename) def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id): #abstract base class for protocol message. return tf.train.Example( features=tf.train.Features( feature={ #可自己定义 如果是string/image => bytes_feature #------------- : ------------------------ 'image/encoded': bytes_feature(image_data), 'iamge/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), })) # 把数据转为 tfrecord 格式 def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): assert split_name in ['train', 'test'] #计算每个数据块有多少数据 num_per_shard = int(len(filenames) / _NUM_SHARDS) with tf.Graph().as_default(): with tf.Session() as sess: for shard_id in range(_NUM_SHARDS): #定义tfrecord文件的路径+名字 output_filename = _get_dataset_filename( dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter( output_filename) as tfrecord_writer: # 固定套路 #每一个数据块的开始位置 start_ndx = shard_id * num_per_shard #每一个数据块的最后位置 end_ndx = min((shard_id + 1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): try: #如果遇到损坏的图片文件, 则直接跳过不做处理 sys.stdout.write( '\r>> Convert image %d/%d shard %d' % (i + 1, len(filenames), shard_id)) sys.stdout.flush() #读取图片 image_data = tf.gfile.FastGFile( filenames[i], 'rb').read() #获得图片的类别名称 class_name = os.path.basename( os.path.dirname(filenames[i])) #找到类别名称对应的id class_id = class_names_to_ids[class_name] #生成tfrecord文件 example = image_to_tfexample( image_data, b'jpg', class_id) tfrecord_writer.write(example.SerializeToString()) except IOError as e: print('Could not read:', filenames[i]) print('Error:', e) print('Skip it\n') sys.stdout.write('\n') sys.stdout.flush() # 判断tfrecord文件是否存在 def _dataset_exists(dataset_dir): for split_name in ['train', 'test']: for shard_id in range(_NUM_SHARDS): #定义tfrecord文件的路径+名字 output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id) if not tf.gfile.Exists(output_filename): return False return True def write_label_file(labels_to_class_names, dataset_dir, filename=LABELS_FILENAME): labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.writer('%d:%s\n' % (label, class_name)) #获取所有文件以及分类 def _get_dataset_filenames_and_classes(dataset_dir): #数据目录 directories = [] #分类名称 class_names = [] for filename in os.listdir(dataset_dir): #合并文件路径 path = os.path.join(dataset_dir, filename) #判断该路径是否为目录 if os.path.isdir(path): #加入数据目录 directories.append(path) #加入类别名称, 文件夹名就是类型名 class_names.append(filename) photo_filenames = [] #循环每个分类的文件夹 for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) #把图片加入图片列表 photo_filenames.append(path) return photo_filenames, class_names if __name__ == '__main__': # 判断tfrecord文件是否存在, 如果存在就不用预处理数据集图片, 直接跳过预处理 # 阶段. if _dataset_exists(DATASET_DIR): print('tfrecord文件已存在') else: #获得所有图片及分类 photo_filenames, class_names = _get_dataset_filenames_and_classes( DATASET_DIR) #把分类转为字典格式, 类似于{'house':0, 'flower':1, 'plane':2} class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集 random.seed(_RANDOM_SEED) random.shuffle(photo_filenames) # shuffle 会把list中的数据打乱 training_filenames = photo_filenames[_NUM_TEST:] testing_filenames = photo_filenames[:_NUM_TEST] #数据转换 _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR) _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR)
如果数据集比较小, 只需要存放到一个 tfrecord 即可, 但是当你数据量较大,比如500个G, 这时候可以做数据集切分.
数据块开始与最后位置的示意图
每个点一个图片, 当我们切分的时候, 每一块 shard 的开始位置就是 shard_id * num_per_shard, 这个公式对所有 shard 都有效, 但是每一块 shard 的末尾位置就需要考虑整个数据集的图片数量: min((shard_id+1) * num_per_shard, len(filename)) /--- 1200 .......................................................... +---------+----------+---------+----------+---------+----------+ | 0 | 1 | 2 | 3 | 4 | 5 | +---------+----------+---------+----------+---------+----------+ \ 300 / \-- 1500
注意你读取文件的方式, 好几次错误都处在这里了
# RIGHT image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # WRONG image_data = tf.gfile.FastGFile(filenames[i], 'r').read() def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
上面注意, 必须以 binary 方式读入图片, 而不能以 string 方式.
think the source file read from .................... is a utf-8 encoding file, in this scenario it's JPEG not utf-8 format, so program down with error: . . 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte. . . . . v . image_data read(图片, 'r') 图片 ===========================> string -------------->----------------------------+ RIGHT !!! WRONG!!! <utf-8> | tf.train.BytesList( image_data ) image_data read(图片, 'rb') | 图片 ===========================> bytes -------------->----------------------------+ RIGHT !!! RIGHT!!! <binary> 这里错不在第二步, 错在第一步不能以 'r' 模式读取图片文件, 因为 'r' 模式隐含的意思是 'r-utf8', 必须使用 'rb' 模式读取图片文件. tf.train.BytesList(xxx) 可以读取 bytes 文件, 也可以读取 string 文件
2 使用 tfrecord 和 slim 来处理的任务
官方 slim 所在位置及文件夹组成
/home/yiddi/wellknown_proj_sourcecode/models/research/slim: -rw-r--r-- 1 yiddi yiddi 14K 7月 31 18:33 BUILD drwxr-xr-x 2 yiddi yiddi 4.0K 7月 31 18:33 >>datasets<< drwxr-xr-x 2 yiddi yiddi 4.0K 7月 31 18:33 >>deployment<< -rw-r--r-- 1 yiddi yiddi 2.3K 7月 31 18:33 download_and_convert_data.py -rw-r--r-- 1 yiddi yiddi 6.6K 7月 31 18:33 eval_image_classifier.py -rw-r--r-- 1 yiddi yiddi 4.6K 7月 31 18:33 export_inference_graph.py -rw-r--r-- 1 yiddi yiddi 1.4K 7月 31 18:33 export_inference_graph_test.py -rw-r--r-- 1 yiddi yiddi 0 7月 31 18:33 __init__.py drwxr-xr-x 4 yiddi yiddi 4.0K 7月 31 18:33 >>nets<< drwxr-xr-x 2 yiddi yiddi 4.0K 7月 31 18:33 >>preprocessing<< -rw-r--r-- 1 yiddi yiddi 26K 7月 31 18:33 README.md drwxr-xr-x 2 yiddi yiddi 4.0K 7月 31 18:33 >>scripts<< -rw-r--r-- 1 yiddi yiddi 916 7月 31 18:33 setup.py -rw-r--r-- 1 yiddi yiddi 46K 7月 31 18:33 slim_walkthrough.ipynb -rw-r--r-- 1 yiddi yiddi 21K 7月 31 18:33 -> train_image_classifier.py <- -rw-r--r-- 1 yiddi yiddi 0 7月 31 18:33 WORKSPACE
其中被 >><< wrap 的都是文件夹, -> <- wrap 的就是 slim 代码的入口, 需要调用这个 python 文件来运行 slim
2.1 datasets 文件夹介绍
datasets 里面存放了下载 dataset 的 python 代码, 其中一个叫做 dataset_factory.py 是一个重要文件, 他是 train_image_classifier.py 获取dataset的入口, 如果你想使用自己的 dataset 做训练, 就必须要在这个 dataset_factory.py 中注册自己的数据集.
2.1.1 dataset_factory.py 介绍
dataset_factory.py, 需要做的修改如下 ==================================== from datasets import cifar10, flowers, imagenet, mnist, myimages -------- #^ datasets_map = { #| 'cifar10': cifar10, #| 'flowers': flowers, #| 'imagenet': imagenet, #这个自己加的 'mnist': mnist, 'myimages': myimages, #<- 这一行就是我们自己加的 -------------------- }
2.1.2 download_xxx.py xxx.py 介绍
除了 dataset_factory.py 其他文件都是两两成对的:
- 下载数据集转换成 tfrecord
- 将 tfrecord 读入内存
/home/yiddi/wellknown_proj_sourcecode/models/research/slim/datasets: download_and_convert_cifar10.py - 下载数据集转换成 tfrecord cifar10.py - 将 tfrecord 读入内存 download_and_convert_flowers.py - 下载数据集转换成 tfrecord flowers.py - 将 tfrecord 读入内存 download_and_convert_imagenet.sh - 下载数据集转换成 tfrecord imagenet.py - 将 tfrecord 读入内存 download_and_convert_mnist.py - 下载数据集转换成 tfrecord mnist.py - 将 tfrecord 读入内存 download_imagenet.sh - 下载数据集转换成 tfrecord build_imagenet_data.py - 将 tfrecord 读入内存 imagenet_2012_validation_synset_labels.txt imagenet_lsvrc_2015_synsets.txt imagenet_metadata.txt __init__.py preprocess_imagenet_validation_data.py process_bounding_boxes.py dataset_factory.py dataset_utils.py
download_xxxxx.py
: 声明一些重要参数, data_url, 等
其中 myimages 就是我们需要参考其他下载数据集源代码的源文件结构自己写的 myimages.py 的名字 基本都要提供如下参数值: # The URL where the CIFAR data can be downloaded. _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' # The number of training files. _NUM_TRAIN_FILES = 5 # The height and width of each image. _IMAGE_SIZE = 32 # The names of the classes. _CLASS_NAMES = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck', ] 最终每个下载程序都会在指定文件夹下生成 tfrecord 文件.
download_xxx.py
: dataset ===> tfrecord, 代码实例
features=tf.train.Features( feature={ #可自己定义 如果是string/image => bytes_feature #------------- : ------------------------ 'image/encoded': bytes_feature(image_data), 'iamge/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), }))
xxx.py
: tfrecord ===> 内存, 代码实例
keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), }
xxx.py
: 返回 dataset
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
2.2 编写 bash 文件执行 slim 程序
(require 'ob-async)
#!/bin/zsh python /home/yiddi/wellknown_proj_sourcecode/models/research/slim/train_image_classifier.py \ --train_dir= \ # 模型保存的位置 --dataset_name= \ # 我们在dataset/中编写的用于datasset-tfrecord->内存的.py文件 --dataset_split_name= \ # train or test --dataset_dir= \ # 图片存放的位置 --batch_size= \ # 如果GPU显存不够, 这里应该设小, 默认设置为32 --max_number_of_steps= \ # 默认一直训练, 可以声明最大循环次数 epoch --model_name=inception_v3 \ # 使用哪个模型
以上这些参数,都是以
tf.app.flags.DEFINE_string
tf.app.flags.DEFINE_integer
tf.app.flags.DEFINE_boolean
三种形式声明在 train_image_classifier.py
文件中, 可以在里面查看详细说明和使用守则.