UP | HOME

tensorflow 实用技巧汇编

Table of Contents

1 saver must follow after variables

The tf.train.Saver must be created after the variables that you want to restore (or save). Additionally it must be created in the same graph as those variables.


+--------------------------------+
|    Variable()                  |
|    tf.train.Saver([var1, var2])|
|    with tf.Session() as sess:  |
|         sess.save()            |
|                                |
+--------------------------------+


my_test_model-1000.index my_test_model-1000.meta my_test_model-1000.data-00000-of-00001 checkpoint

2 ckpt

ckpt 存储的是快照, 你指定快照时间我给你照相, 是拍照那一刻Variable的值;

     +-----> saver.restore(sess, 'dir/my_net.ckpt')
     |                                -----------
     ^
     +-----> saver.restore(sess, tf.train.latest_checkpoint('dir'))
     |
     ^
     +-<---- a saver obj <== tf.train.import_meta_graph(             )
                                                             ^
                  saver.restore(sess, 'net/my_net.ckpt')     |
                            ^                                +--------------------------------------------|
                            |                                                                             |
                            |                                                                             |
                            |                         | my_net.ckpt.index                                 |
                                                        -----------                                       |
                                                      | my_net.ckpt.meta     => architechture of graphs --+
                                              include   -----------
TensorBoard    ---<-----   ckpt => checkpoint --------| my_net.ckpt.data     => variable values
                                                        -----------
                                                      | checkpoint           => record of latest checkpoint
                            ^
                            |
                            |
                            |
                          saver.save(sess, save_path)
                               - replace sess with some variables not all.
                               - ~write_meta_graph~ or not
                               - max number of ~max_to_keep~
                               - ~keep_checkpoint_every_n_hours~
                               - ~global_step~ save at how many global interaion steps

3 summary

summary 存储的是视频, 是这一段时间内所有的Variable的值.


graph building
==============
tf.summary.scalar           \
          .histogram         |  store values of
          .image             |> variables in       --+
                             |  continuous steps     |
                             |                       |
tf.summary.merged_all()     /                        |
                                                     +---> events.out.tfevents.1533011911.yiddiEdge
                                                     |                    ||
graph computing                                      |                    ||
===============                                      |                    ||
tf.summary.FileWriter('dir')  > store view of graph -+                    ||
                                                                          \/

                                                                      TensorBoard

4 download pretrained model, load to current session

  1. 下载解压并获取.pb文件
  2. 把.pb文件加载进当前会话
    1. gfile.FastGFile(.pb_file, 'rb')
    2. 获取graph定义句柄
    3. 从.pb文件读取内容
    4. 定义句柄解此内容作为graph定义
    5. tf导入该graph定义
  3. 可视化该模型

               /         URL of tgz --->  file name of tgz -+
               |               |                            |--> absolute path of tgz
               |               |          local dir of tgz -+        |
               |               v                   |                 v
               |         download tgz   <----------|-----------------+
    download   |               |                   |
    to         |               |                   |
    local     <|               v                   |
               |          extract tgz   <----------+
               |               |
               |               |                                                                                                                                   predict label(one from
               |               v                                                  /-------------------------------------------------------------------------------------+  1000) of new image
               |    +--- classify_image_graph_def.pb                                                                                                                    |
               |    |    imagenet_2012_challenge_label_map_proto.pbtxt            int to uid ->+         parse function                                                 |
               |    |    imagenet_synset_to_human_label_map.txt           uid to discription ->+------->-----------------> dict {int : discription}                     |
               \    |    inception-2015-12-05.tgz                                                                                 |             |                       |
                    |                                                                                                             ^             v                       |
               /    +---------------------+  with tf.Session() as sess:                                                           |             |                       |
               |                          |  .........................                                       +---------->---------+             |                       |
               |                          |                          .                                       |                                  |                       |
               |                          v                          .                                                                          |                       |
               |         gfile.FastGFile(___, 'rb').read()           .                                  [ x,x,x,x,x ]                           |                       |
               |                                         |           .                                       ^                                  |                       |
               |                                         |           .                                       |   argsort[-5:][::-1]             |                       |
    load       |                                         |           .                                       |                                  v                       |
    to        <|                                         v           .                                   ******** (1000,)                       |                       |
    session    |         tf.GraphDef().ParsetFromString(___)         .                                       ^                                  |                       |
               |         -------------                               .                                       |   squeeze                        |                       |
               |               |                                     .                                       |                                  -------|                |
               |               |                                     .                                   ******** (1000,1) output                      |                |
               |               +---------------+                     .                                                                                 |                |
               |                               v                     .      get_tensor_by_name               ^                                         v                |
               \         tf.import_meta_graph(___, name='') ===>  graph  -----------------------> tensor->  ---                   /  new img         label       \      |
                                                                     .       get active function            |||                  /   +----------+                 \     |
                                                             .........       of last layer                  |||                  |   |   ...   .|                 |     |
                                                             .                                              |||          |<----- |   |  ....   .|                 |     |
                                                             .                                              |||          |       \   |  ........|                 /     |
                                                             .                                              |||          |        \  +----------+ ,  ----------  /      |
               /         tf.summary.FileWriter(log_dir, sess.graph)                                         --- <---------                                              |
               |                                 |                                                 inception-v3                            ^                            |
    visualize <|                                 |                                                                                         |                            |
               |                                 v                                                                                         | image get function         |
               \         tensorboard --log_dir=_____                                                                                       |                            |
                                                                                                                                           |-------- DIR                /
    
    

5 hub

screenshot_2018-07-31_16-04-18.png

                      xxx.pb     model          \
                      yyy.pbtxt  int-uid         | --------+
                      zzz.txt    uid-discription/          |
                                                           | model.tgz
                                                           |
                                             /+--+    /+--+    /+--+    /+--+
+------------>--------------------------    +-+-+|   +-+-+|   +-+-+|   +-+-+|
|                               ...<....    |   |/   |   |/   |   |/   |   |/  .....
|                               .           +---+    +---+    +---+    +---+
|                               .      inception    resnet  imagenet   a3c
|                               .      _v3
|                  as transform .                               |
|                               .                               |  model_dir, the dir of downloaded tgz of pretrained model
|                               v                               |
|                               .                               |         ----------------------| output xxx.pb
|                               .                               v        /                      |                       1. open .pb
|                               .                        +--------------+                       v                       2. graph_def parse graph
|                          your own    image_dir         |              |     output_graph                                 from .pb
|            +------       Data set  ---------------->   |  retrain.py  | < --------------- a dir used to store ----+
|            |                  .                        |              |                   the new NN model        |   3. import graph_def
|            |                  .                        +--------------+                   suit for your task,     |      as it was defined
|            |               to .                          ^        ^   +--------------+    "xxx.pb" file           |      by us
|            |                  .                          |        |                  v                            |
|            |                  .          bottleneck_dir  |        | output_labels                                 |
|                          code vector  -------------------+        +---------------- a dir uesd to                 |
v                          files        1 vector 1 file                               store all parsed              |
|        train_data             .       same architechture with                       labels from                   |
|    ... |- cars                .       img dir                                       image_dir                     |                          test images
|    .        |- car1.jpng      .                                                     "yyy.txt" file,               |                         /
|    .        |- car2.jpng      v                                                     like                          v              session   /
|    .        |- car3.jpng      .                                                     "cars                                      +----------/+
|    .        |- car4.jpng      .                                                     animals                      *   graph     |         / |
|    .        |- ...            .                                                     flower"                     / \            |        |  |  predict
|    .                          .                                                                                *   *      ---> |run(_ , _ )| --------> test images' labels
|    ... |- animal              .                                                                                |   | \         |   /       |
|    .        |- animal1.jpng   .                                                                                *   *  *        |  /        |
|    .        |- animal2.jpng   .                                                                                                +-/---------+
|    .        |- animal3.jpng   .                                                                                                 /
|    .        |- animal4.jpng   .                                                                                                /
|    .        |- ...            .                                                                                               /
|    .                          .                                                                                              get tensor by name
|    ... |- flower              .
|    .        |- ...            .
|    .                          .
|    .                          .
|  as label                as data
|    |                          |
|    +------------+-------------+
|                 |
+------------<----+
 train the model

6 tf.stop_gradients(loss, embed)

tf.gradients(loss, embed) computes the partial derivative of the tensor loss with respect to the tensor embed. TensorFlow computes this partial derivative by backpropagation, so it is expected behavior that evaluating the result of tf.gradients(…) performs backpropagation. However, evaluating that tensor does not perform any variable updates, because the expression does not include any assignment operations.

tf.stop_gradient() is an operation that acts as the identity function in the forward direction, but stops the accumulated gradient from flowing through that operator in the backward direction. It does not prevent backpropagation altogether, but instead prevents an individual tensor from contributing to the gradients that are computed for an expression. The documentation for the operation has more details about the operation, and when to use it.

注意, stop_gradient(tensor) 不会阻止反向传播求梯度, 他只会让括号内的tensor不对求梯度产生贡献.