UP | HOME

Tensorflow 模型保存与加载

Table of Contents

1 Tensorflow model save and load

<<包导入>>
# <<导入 projector: for embeddings 可视化>>


<<数据准备>>
# numpy构造(with/without noise)
# 数据集位置
# 数据集导入内存(one_hot or not)
# 截取部分数据集

<<图参数>>
# 批次大小, batch_size
# 批次数量, n_batch
# dropout 保留率, keep_prob
# + <<for rnn >>n_inputs: dimension of each item(a vector) of sequence
# + <<for rnn >>max_time: the max length of all sequences(maybe the size of Dataset), max_time also means the iteration time of rnn layer
# + <<for rnn >>lstm_size: inside of each rnn layer, how many lstm units
# + <<for rnn >>n_classes: units of ful nn layer

<<工具函数与工具声明>>
# 对某些 Variable 进行 OP 并 summary
# <<def Variable: for embeddings 可视化>> as untrainable Variable, stack front 3000 img, give name 'embeddings'
# <<file IO: for embeddings 可视化>> read in one_hot labels, argmax get true labels, write to file in one-label-one-line format
# W, b 初始化工具函数

<<图构造>>
# 一神: NN layers, name_scope for TB, 参数 summary
#   1. Placeholders
#      1.1 x: dataset placeholder,
#      + <<def OP: for img process, CNN[-1, height, width, channels], RNN[-1, max_time, n_inputs] >> reshape x  ------+
#      1.2 y: labelset placeholder,                                                                                   |
#      1.3 keep_prob: dropout, keep rate of certain layer's nodes                                                     |
#   2. Layers & Variables                                                                                             |
#      2.0 名称空间设置                                                                                               |
#      2.1 第一层权重 W,                  声明 summary tf.summary.scalar/image/histogram node                         |
#      2.2 第一层偏置 b,                  声明 summary tf.summary.scalar/image/histogram node                         |
#      2.3 第一层输出(active_fn(logits)), 声明 summary tf.summary.scalar/image/histogram node                         |
#      + <<conv2d layer: for CNN>> 只接受 [batch_size, height, width, channels] 格式 <--------------------------------+
#      + <<max_pool layer: for CNN>>                                                                                  |
#      + <<BasicLSTMCell: for RNN>>                                                                                   |
#      + <<dynamic_rnn(units, inputs): for RNN>>                                                                      |
#                               ^                                                                                     |
#                               +-------------------------------------------------------------------------------------+

# 两函:
#   1. err_fn:
#      1.1 名称空间设置
#      1.2 err fn(单点错误), 声明 summary, tf.summary.scalar/image/histogram node
#   2. loss_fn:
#      2.1 名称空间设置
#      2.2 loss fn(整体错误), 声明 summary, tf.summary.scalar/image/histogram node

# 三器:
#   1. 初始化器
#   2. 优化器
#      2.1 名称空间设置
#   3. 保存器

<<图构造善后>>
# 准确率
#   1. correct_prediction
#      1.1 名称空间设置
#   2. accuracy
#      2.1 名称空间设置
# 合并 summary
# + <<for embeddings 可视化>>配置 embeddings 可视化参数

<<图计算>>
# 运行初始化器
# summary Writer for TB
# for epoch_num: <<
#          1. for batch_num:
#                 1.1 x_y_of_next_batch;
#                 1.2 运行 优化器计算 and summary计算
#          2. 运行准确率计算
# 运行保存器
# matplot绘图
 1: import tensorflow as tf
 2: from tensorflow.examples.tutorials.mnist import input_data
 3: 
 4: # 载入数据
 5: mnist = input_data.read_data_sets("MNIST", one_hot=True)
 6: 
 7: # 设置批次大小
 8: batch_size = 100
 9: # 计算共有多少批次
10: n_batch = mnist.train.num_examples // batch_size
11: 
12: # 定义两个 placeholder
13: x = tf.placeholder(tf.float32, [None, 784])
14: y = tf.placeholder(tf.float32, [None, 10])
15: 
16: # 创建简单神经网络(无隐藏层)
17: W = tf.Variable(tf.zeros([784, 10]))
18: b = tf.Variable(tf.zeros([10]))
19: prediction = tf.nn.softmax(tf.matmul(x, W) + b)
20: 
21: # 二函,二器
22: init = tf.global_variables_initializer()
23: optimizer = tf.train.GradientDescentOptimizer(0.2)
24: loss = tf.reduce_mean(tf.square(y-prediction))
25: train = optimizer.minimize(loss)
26: 
27: # 预测对错存在一个向量中
28: correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1))
29: # 计算准确率
30: accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
31: 
32: 
33: saver = tf.train.Saver()
34: 
35: 
36: # 图计算
37: with tf.Session() as sess:
38:     sess.run(init)
39:     # 采取训练一轮就测试一轮的方式
40:     for epoch in range(21):
41:         # 训练模型
42:         acc_train = 0
43:         for batch in range(n_batch):
44:             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
45:             _, acc_train = sess.run([train, accuracy], feed_dict={x:batch_xs, y:batch_ys})
46: 
47:         # 测试模型
48:         # 测试集必须使用已经训练完毕的模型
49:         acc_test = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
50:         print("Iter " + str(epoch) + " ,Train:" + str(acc_train) + " ,Test:" + str(acc_test))
51: 
52:     # 保存模型
53:     # 注意代码缩进, 他很明显是训练完成后的代码, 保存的是 session
54:     saver.save(sess, 'net/my_net.ckpt')

上面的代码会在原本为空的 net/ 文件夹下产生如下四个文件:

-rw-r--r--  1 yiddi yiddi   79 7月  31 03:09 checkpoint
-rw-r--r--  1 yiddi yiddi  31K 7月  31 03:09 my_net.ckpt.data-00000-of-00001
-rw-r--r--  1 yiddi yiddi  159 7月  31 03:09 my_net.ckpt.index
-rw-r--r--  1 yiddi yiddi  16K 7月  31 03:09 my_net.ckpt.meta
 1: import tensorflow as tf
 2: from tensorflow.examples.tutorials.mnist import input_data
 3: 
 4: # 载入数据
 5: mnist = input_data.read_data_sets("MNIST", one_hot=True) (one_hot)
 6: 
 7: # 设置批次大小
 8: batch_size = 100 (batch_size)
 9: # 计算共有多少批次
10: n_batch = mnist.train.num_examples // batch_size (floor division)
11: 
12: # 定义两个 placeholder
13: x = tf.placeholder(tf.float32, [None, 784])
14: y = tf.placeholder(tf.float32, [None, 10])
15: 
16: # 创建简单神经网络(无隐藏层)
17: W = tf.Variable(tf.zeros([784, 10]))
18: b = tf.Variable(tf.zeros([10]))
19: prediction = tf.nn.softmax(tf.matmul(x, W) + b)
20: 
21: # 二函,二器
22: init = tf.global_variables_initializer()
23: optimizer = tf.train.GradientDescentOptimizer(0.2)
24: loss = tf.reduce_mean(tf.square(y-prediction))
25: train = optimizer.minimize(loss)
26: 
27: # 预测对错存在一个向量中
28: correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1)) (count correct prediction)
29: # 计算准确率
30: accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
31: 
32: 
33: saver = tf.train.Saver()
34: 
35: ####################################################################
36: # 在此之前, 也就是图构建过程与之前的程序完全一样
37: ####################################################################
38: 
39: # 图计算
40: with tf.Session() as sess:
41:     sess.run(init)
42: 
43:     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))
44:     saver.restore(sess, 'net/my_net.ckpt')
45:     print(sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}))

0.098 tensorflow:Restoring parameters from net/my_net.ckpt 0.9137