UP | HOME

Tensorflow inception-v3 迁移学习

Table of Contents

1 下载google图像识别网络inception-v3并查看结构

import os
import tarfile

import requests
import tensorflow as tf

# 下载解压并获取.pb文件
# get URL of tgz file
inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'

# get local dir path to store tgz file
inception_pretrain_model_dir = "inception_model"
if not os.path.exists(inception_pretrain_model_dir):
    os.makedirs(inception_pretrain_model_dir)

# get tgz file name from URL
filename = inception_pretrain_model_url.split('/')[-1]

# make the absolute path of tgz file
filepath = os.path.join(inception_pretrain_model_dir, filename)

# downloading tgz file as certain and in certain absolute path as we defined
if not os.path.exists(filepath):
    print("download: ", filename)
    r = requests.get(inception_pretrain_model_url, stream=True)
    with open(filepath, 'wb') as f:
        for chunk in r.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)

print("finish: ", filename)

# extract zip or tar file
tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)

# dir prepared for summary, after loading .pb
log_dir = 'inception_log'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# get path of .pb file
inception_graph_def_file = os.path.join(inception_pretrain_model_dir,
                                        'classify_image_graph_def.pb')

# 把.pb模型加载进当前会话
with tf.Session() as sess:
    with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:  # 打開.pb模型的文件
        graph_def = tf.GraphDef()                                  # 獲取圖定義對象
        graph_def.ParseFromString(f.read())                        # 圖定義對象從FastGFile文件中讀取定義
        tf.import_graph_def(graph_def, name='')                    # 在當前 session 下引入圖定義
    # save the structure of graph
    writer = tf.summary.FileWriter(log_dir, sess.graph)            # 总结當前圖定義,用于可视化
    writer.close()

注意,這裏探討的是如何下載已經訓練好的模型 .pb, 以及如何通過tensorboard對其進行可視化. 想要可視化就一定需要 summary, 然後 tensorboard 讀取 summary. 關鍵的問題是我之前一直以爲必須保存點什麼才能可視化, 但其實這條語句已經保存了整張圖了,即使你 summary 一些變量的值, 這個graph 也會直接保存在 summary file 中, 並且被 tensorboard 加載


download  =====> inception.tgz =====> extractall =====> classify_image_graph_def.pb
                                                                   |
                                                                   |
                                                                   |
                                                                   v
                                                         tf.gfile.FastGFile open and read .pb
                                                                   |
                                                                   |
                                                                   |
                                                                   v
                                                         GraphDef parse the content after read
                                                                   |
                                                                   |
                                                                   |
                                                                   v
                                                         tf.import(GraphDef)
                                                                   |
                                                                   |
                                                                   |
                                                                   v
                                                summary into summary file under log_dir
                                                                   |
                                                                   |
                                                                   |
                                                                   v
                                                             Tensorboard

上面的程序执行完, 就说明图对象已经被 summary 到本地文件. 运行下面的命令即可.

tensorboard --logdir=/home/yiddi/git_repos/on_ml_tensorflow/inception_log