UP | HOME

TFrecord 使用技术

Table of Contents

1 如何生成 TFrecord

这里仍然给使用已经 pretrained 的模型来训练我们自己的任务, 也就是我们只训练最后一层.

google tensorflow 提供的 slim 可以用一种更加强大且自由的方式去定义自己的任务, 你需要编写下载数据集的代码, 转换为tfrecord的代码, 注册入 dataset 的代码, 以及 tfrecord读入内存的代码. 但幸运的是这些代码的模式都很固定, 且其文件夹下有实例可以参考.

…../models/research/slim/

我们主要使用, slim 文件夹下的 train_image_classifier.py,

  1. 首先做图片预处理, 把图片生成 .tfrecord 文件, 该文件类型底层使用 protobuffer — google 提供的 二进制 文件存储方式, 传输和运算效率非常高. 在进行模型训练的时候使用 .tfrecord 作为数据输入格式.
import math
import os
import random
import sys

import tensorflow as tf

# 验证集数量
_NUM_TEST = 500
# random seed
_RANDOM_SEED = 0
# 数据块数量
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "lec_8_2_data/train/"
# 生成的标签文件. 注意这里'生成'的意思, 数据图片是使用各自所在文件夹作为自己的
# 标签, '生成'的意思是把文件夹名字映射为数字.
LABELS_FILENAME = "lec8_2_produced_labels/labels.txt"


#定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
    output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id,
                                                          _NUM_SHARDS)
    return os.path.join(dataset_dir, output_filename)


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def image_to_tfexample(image_data, image_format, class_id):
    #abstract base class for protocol message.
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                #可自己定义      如果是string/image => bytes_feature
                #------------- : ------------------------
                'image/encoded': bytes_feature(image_data),
                'iamge/format': bytes_feature(image_format),
                'image/class/label': int64_feature(class_id),
            }))


# 把数据转为 tfrecord 格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
    assert split_name in ['train', 'test']
    #计算每个数据块有多少数据
    num_per_shard = int(len(filenames) / _NUM_SHARDS)
    with tf.Graph().as_default():
        with tf.Session() as sess:
            for shard_id in range(_NUM_SHARDS):
                #定义tfrecord文件的路径+名字
                output_filename = _get_dataset_filename(
                    dataset_dir, split_name, shard_id)
                with tf.python_io.TFRecordWriter(
                        output_filename) as tfrecord_writer:  # 固定套路
                    #每一个数据块的开始位置
                    start_ndx = shard_id * num_per_shard
                    #每一个数据块的最后位置
                    end_ndx = min((shard_id + 1) * num_per_shard,
                                  len(filenames))
                    for i in range(start_ndx, end_ndx):
                        try:  #如果遇到损坏的图片文件, 则直接跳过不做处理
                            sys.stdout.write(
                                '\r>> Convert image %d/%d shard %d' %
                                (i + 1, len(filenames), shard_id))
                            sys.stdout.flush()
                            #读取图片
                            image_data = tf.gfile.FastGFile(
                                filenames[i], 'rb').read()
                            #获得图片的类别名称
                            class_name = os.path.basename(
                                os.path.dirname(filenames[i]))
                            #找到类别名称对应的id
                            class_id = class_names_to_ids[class_name]
                            #生成tfrecord文件
                            example = image_to_tfexample(
                                image_data, b'jpg', class_id)
                            tfrecord_writer.write(example.SerializeToString())
                        except IOError as e:
                            print('Could not read:', filenames[i])
                            print('Error:', e)
                            print('Skip it\n')

    sys.stdout.write('\n')
    sys.stdout.flush()


# 判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        for shard_id in range(_NUM_SHARDS):
            #定义tfrecord文件的路径+名字
            output_filename = _get_dataset_filename(dataset_dir, split_name,
                                                    shard_id)
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def write_label_file(labels_to_class_names,
                     dataset_dir,
                     filename=LABELS_FILENAME):
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'w') as f:
        for label in labels_to_class_names:
            class_name = labels_to_class_names[label]
            f.writer('%d:%s\n' % (label, class_name))


#获取所有文件以及分类
def _get_dataset_filenames_and_classes(dataset_dir):
    #数据目录
    directories = []
    #分类名称
    class_names = []
    for filename in os.listdir(dataset_dir):
        #合并文件路径
        path = os.path.join(dataset_dir, filename)
        #判断该路径是否为目录
        if os.path.isdir(path):
            #加入数据目录
            directories.append(path)
            #加入类别名称, 文件夹名就是类型名
            class_names.append(filename)

    photo_filenames = []
    #循环每个分类的文件夹
    for directory in directories:
        for filename in os.listdir(directory):
            path = os.path.join(directory, filename)
            #把图片加入图片列表
            photo_filenames.append(path)

    return photo_filenames, class_names


if __name__ == '__main__':
    # 判断tfrecord文件是否存在, 如果存在就不用预处理数据集图片, 直接跳过预处理
    # 阶段.
    if _dataset_exists(DATASET_DIR):
        print('tfrecord文件已存在')
    else:
        #获得所有图片及分类
        photo_filenames, class_names = _get_dataset_filenames_and_classes(
            DATASET_DIR)
        #把分类转为字典格式, 类似于{'house':0, 'flower':1, 'plane':2}
        class_names_to_ids = dict(zip(class_names, range(len(class_names))))

        #把数据切分为训练集和测试集
        random.seed(_RANDOM_SEED)
        random.shuffle(photo_filenames)  # shuffle 会把list中的数据打乱
        training_filenames = photo_filenames[_NUM_TEST:]
        testing_filenames = photo_filenames[:_NUM_TEST]

        #数据转换
        _convert_dataset('train', training_filenames, class_names_to_ids,
                         DATASET_DIR)

        _convert_dataset('test', testing_filenames, class_names_to_ids,
                         DATASET_DIR)

如果数据集比较小, 只需要存放到一个 tfrecord 即可, 但是当你数据量较大,比如500个G, 这时候可以做数据集切分.

数据块开始与最后位置的示意图

每个点一个图片, 当我们切分的时候, 每一块 shard 的开始位置就是 shard_id * num_per_shard,
这个公式对所有 shard 都有效, 但是每一块 shard 的末尾位置就需要考虑整个数据集的图片数量:

min((shard_id+1) * num_per_shard, len(filename))

                                                          /--- 1200
..........................................................

+---------+----------+---------+----------+---------+----------+
|     0   |     1    |    2    |    3     |    4    |     5    |
+---------+----------+---------+----------+---------+----------+
 \  300  /                                                      \-- 1500



注意你读取文件的方式, 好几次错误都处在这里了

# RIGHT
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
# WRONG
image_data = tf.gfile.FastGFile(filenames[i], 'r').read()

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

上面注意, 必须以 binary 方式读入图片, 而不能以 string 方式.

                       think the source file read from
  .................... is a utf-8 encoding file, in this scenario it's JPEG not utf-8 format, so program down with error:
  .                  . 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte.
  .                  .
  .                  .
  v                  .              image_data
         read(图片, 'r')
 图片  ===========================> string    -------------->----------------------------+ RIGHT !!!
              WRONG!!!              <utf-8>                                              |

                                                                   tf.train.BytesList( image_data )
                                    image_data
         read(图片, 'rb')                                                                |
 图片  ===========================> bytes     -------------->----------------------------+ RIGHT !!!
              RIGHT!!!              <binary>


这里错不在第二步, 错在第一步不能以 'r' 模式读取图片文件, 因为 'r' 模式隐含的意思是 'r-utf8', 必须使用
'rb' 模式读取图片文件.

tf.train.BytesList(xxx) 可以读取 bytes 文件, 也可以读取 string 文件

2 使用 tfrecord 和 slim 来处理的任务

官方 slim 所在位置及文件夹组成

/home/yiddi/wellknown_proj_sourcecode/models/research/slim:

-rw-r--r--  1 yiddi yiddi  14K 7月  31 18:33 BUILD
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>datasets<<
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>deployment<<
-rw-r--r--  1 yiddi yiddi 2.3K 7月  31 18:33 download_and_convert_data.py
-rw-r--r--  1 yiddi yiddi 6.6K 7月  31 18:33 eval_image_classifier.py
-rw-r--r--  1 yiddi yiddi 4.6K 7月  31 18:33 export_inference_graph.py
-rw-r--r--  1 yiddi yiddi 1.4K 7月  31 18:33 export_inference_graph_test.py
-rw-r--r--  1 yiddi yiddi    0 7月  31 18:33 __init__.py
drwxr-xr-x  4 yiddi yiddi 4.0K 7月  31 18:33 >>nets<<
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>preprocessing<<
-rw-r--r--  1 yiddi yiddi  26K 7月  31 18:33 README.md
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>scripts<<
-rw-r--r--  1 yiddi yiddi  916 7月  31 18:33 setup.py
-rw-r--r--  1 yiddi yiddi  46K 7月  31 18:33 slim_walkthrough.ipynb
-rw-r--r--  1 yiddi yiddi  21K 7月  31 18:33 -> train_image_classifier.py <-
-rw-r--r--  1 yiddi yiddi    0 7月  31 18:33 WORKSPACE

其中被 >><< wrap 的都是文件夹, -> <- wrap 的就是 slim 代码的入口, 需要调用这个 python 文件来运行 slim

2.1 datasets 文件夹介绍

datasets 里面存放了下载 dataset 的 python 代码, 其中一个叫做 dataset_factory.py 是一个重要文件, 他是 train_image_classifier.py 获取dataset的入口, 如果你想使用自己的 dataset 做训练, 就必须要在这个 dataset_factory.py 中注册自己的数据集.

2.1.1 dataset_factory.py 介绍

dataset_factory.py, 需要做的修改如下
====================================

from datasets import cifar10, flowers, imagenet, mnist, myimages
                                                        --------
                                                         #^
datasets_map = {                                         #|
    'cifar10': cifar10,                                  #|
    'flowers': flowers,                                  #|
    'imagenet': imagenet,                                #这个自己加的
    'mnist': mnist,
    'myimages': myimages, #<- 这一行就是我们自己加的
    --------------------
}


2.1.2 download_xxx.py xxx.py 介绍

除了 dataset_factory.py 其他文件都是两两成对的:

  • 下载数据集转换成 tfrecord
  • 将 tfrecord 读入内存
/home/yiddi/wellknown_proj_sourcecode/models/research/slim/datasets:

download_and_convert_cifar10.py    - 下载数据集转换成 tfrecord
cifar10.py                         - 将 tfrecord 读入内存

download_and_convert_flowers.py    - 下载数据集转换成 tfrecord
flowers.py                         - 将 tfrecord 读入内存

download_and_convert_imagenet.sh   - 下载数据集转换成 tfrecord
imagenet.py                        - 将 tfrecord 读入内存

download_and_convert_mnist.py      - 下载数据集转换成 tfrecord
mnist.py                           - 将 tfrecord 读入内存

download_imagenet.sh               - 下载数据集转换成 tfrecord
build_imagenet_data.py             - 将 tfrecord 读入内存

imagenet_2012_validation_synset_labels.txt
imagenet_lsvrc_2015_synsets.txt
imagenet_metadata.txt

__init__.py
preprocess_imagenet_validation_data.py
process_bounding_boxes.py
dataset_factory.py
dataset_utils.py

download_xxxxx.py : 声明一些重要参数, data_url, 等

其中 myimages 就是我们需要参考其他下载数据集源代码的源文件结构自己写的 myimages.py 的名字
基本都要提供如下参数值:


# The URL where the CIFAR data can be downloaded.
_DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

# The number of training files.
_NUM_TRAIN_FILES = 5

# The height and width of each image.
_IMAGE_SIZE = 32

# The names of the classes.
_CLASS_NAMES = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck',
]

最终每个下载程序都会在指定文件夹下生成 tfrecord 文件.

download_xxx.py : dataset ===> tfrecord, 代码实例

features=tf.train.Features(
    feature={
        #可自己定义      如果是string/image => bytes_feature
        #------------- : ------------------------
        'image/encoded': bytes_feature(image_data),
        'iamge/format': bytes_feature(image_format),
        'image/class/label': int64_feature(class_id),
    }))

xxx.py : tfrecord ===> 内存, 代码实例

keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
    'image/class/label': tf.FixedLenFeature(
        [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}

xxx.py : 返回 dataset

return slim.dataset.Dataset(
    data_sources=file_pattern,
    reader=reader,
    decoder=decoder,
    num_samples=SPLITS_TO_SIZES[split_name],
    items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
    num_classes=_NUM_CLASSES,
    labels_to_names=labels_to_names)

2.2 编写 bash 文件执行 slim 程序

(require 'ob-async)
#!/bin/zsh
python /home/yiddi/wellknown_proj_sourcecode/models/research/slim/train_image_classifier.py \
       --train_dir= \  # 模型保存的位置
       --dataset_name= \  # 我们在dataset/中编写的用于datasset-tfrecord->内存的.py文件
       --dataset_split_name= \ # train or test
       --dataset_dir= \ # 图片存放的位置
       --batch_size= \  # 如果GPU显存不够, 这里应该设小, 默认设置为32
       --max_number_of_steps= \ # 默认一直训练, 可以声明最大循环次数 epoch
       --model_name=inception_v3 \ # 使用哪个模型

以上这些参数,都是以

  • tf.app.flags.DEFINE_string
  • tf.app.flags.DEFINE_integer
  • tf.app.flags.DEFINE_boolean

三种形式声明在 train_image_classifier.py 文件中, 可以在里面查看详细说明和使用守则.