Tensorflow inception-v3 迁移学习
Table of Contents
1 下载google图像识别网络inception-v3并查看结构
import os import tarfile import requests import tensorflow as tf # 下载解压并获取.pb文件 # get URL of tgz file inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' # get local dir path to store tgz file inception_pretrain_model_dir = "inception_model" if not os.path.exists(inception_pretrain_model_dir): os.makedirs(inception_pretrain_model_dir) # get tgz file name from URL filename = inception_pretrain_model_url.split('/')[-1] # make the absolute path of tgz file filepath = os.path.join(inception_pretrain_model_dir, filename) # downloading tgz file as certain and in certain absolute path as we defined if not os.path.exists(filepath): print("download: ", filename) r = requests.get(inception_pretrain_model_url, stream=True) with open(filepath, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: f.write(chunk) print("finish: ", filename) # extract zip or tar file tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir) # dir prepared for summary, after loading .pb log_dir = 'inception_log' if not os.path.exists(log_dir): os.makedirs(log_dir) # get path of .pb file inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb') # 把.pb模型加载进当前会话 with tf.Session() as sess: with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f: # 打開.pb模型的文件 graph_def = tf.GraphDef() # 獲取圖定義對象 graph_def.ParseFromString(f.read()) # 圖定義對象從FastGFile文件中讀取定義 tf.import_graph_def(graph_def, name='') # 在當前 session 下引入圖定義 # save the structure of graph writer = tf.summary.FileWriter(log_dir, sess.graph) # 总结當前圖定義,用于可视化 writer.close()
注意,這裏探討的是如何下載已經訓練好的模型 .pb, 以及如何通過tensorboard對其進行可視化. 想要可視化就一定需要 summary, 然後 tensorboard 讀取 summary. 關鍵的問題是我之前一直以爲必須保存點什麼才能可視化, 但其實這條語句已經保存了整張圖了,即使你 summary 一些變量的值, 這個graph 也會直接保存在 summary file 中, 並且被 tensorboard 加載
download =====> inception.tgz =====> extractall =====> classify_image_graph_def.pb | | | v tf.gfile.FastGFile open and read .pb | | | v GraphDef parse the content after read | | | v tf.import(GraphDef) | | | v summary into summary file under log_dir | | | v Tensorboard
上面的程序执行完, 就说明图对象已经被 summary 到本地文件. 运行下面的命令即可.
tensorboard --logdir=/home/yiddi/git_repos/on_ml_tensorflow/inception_log