UP | HOME

学习 Tensorflow nets 模块

Table of Contents

1 如何改装 nets 下的 alexnet.py 适配自己的任务

上讲说明了 slim 模块下有如下文件夹, 并且重点讨论了如何改装 datasets 文件夹. 本节重点探讨如何改装 nets 文件夹下的NN模型代码 alexnet.py .

/home/yiddi/wellknown_proj_sourcecode/models/research/slim:

-rw-r--r--  1 yiddi yiddi  14K 7月  31 18:33 BUILD
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>datasets<<
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>deployment<<
-rw-r--r--  1 yiddi yiddi 2.3K 7月  31 18:33 download_and_convert_data.py
-rw-r--r--  1 yiddi yiddi 6.6K 7月  31 18:33 eval_image_classifier.py
-rw-r--r--  1 yiddi yiddi 4.6K 7月  31 18:33 export_inference_graph.py
-rw-r--r--  1 yiddi yiddi 1.4K 7月  31 18:33 export_inference_graph_test.py
-rw-r--r--  1 yiddi yiddi    0 7月  31 18:33 __init__.py
drwxr-xr-x  4 yiddi yiddi 4.0K 7月  31 18:33 >>nets<<--------
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>preprocessing<<
-rw-r--r--  1 yiddi yiddi  26K 7月  31 18:33 README.md
drwxr-xr-x  2 yiddi yiddi 4.0K 7月  31 18:33 >>scripts<<
-rw-r--r--  1 yiddi yiddi  916 7月  31 18:33 setup.py
-rw-r--r--  1 yiddi yiddi  46K 7月  31 18:33 slim_walkthrough.ipynb
-rw-r--r--  1 yiddi yiddi  21K 7月  31 18:33 -> train_image_classifier.py <-
-rw-r--r--  1 yiddi yiddi    0 7月  31 18:33 WORKSPACE

1.0.1 nets_factory.py 介绍

datasets/datasets_factory.py 相同的是在 nets 文件夹下也有一个 factory: nets_factory.py, 里面可以利用 networks_map 通过模型名字调用对应的NN.

networks_map = {'alexnet_v2': alexnet.alexnet_v2,
                'cifarnet': cifarnet.cifarnet,
                'overfeat': overfeat.overfeat,
                'vgg_a': vgg.vgg_a,
                'vgg_16': vgg.vgg_16,
                'vgg_19': vgg.vgg_19,
                'inception_v1': inception.inception_v1,
                'inception_v2': inception.inception_v2,
                'inception_v3': inception.inception_v3,
                'inception_v4': inception.inception_v4,
                'inception_resnet_v2': inception.inception_resnet_v2,
                'lenet': lenet.lenet,
                'resnet_v1_50': resnet_v1.resnet_v1_50,
                'resnet_v1_101': resnet_v1.resnet_v1_101,
                'resnet_v1_152': resnet_v1.resnet_v1_152,
                'resnet_v1_200': resnet_v1.resnet_v1_200,
                'resnet_v2_50': resnet_v2.resnet_v2_50,
                'resnet_v2_101': resnet_v2.resnet_v2_101,
                'resnet_v2_152': resnet_v2.resnet_v2_152,
                'resnet_v2_200': resnet_v2.resnet_v2_200,
                'mobilenet_v1': mobilenet_v1.mobilenet_v1,
                'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075,
                'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
                'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
                'mobilenet_v2': mobilenet_v2.mobilenet,
                'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140,
                'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035,
                'nasnet_cifar': nasnet.build_nasnet_cifar,
                'nasnet_mobile': nasnet.build_nasnet_mobile,
                'nasnet_large': nasnet.build_nasnet_large,
                'pnasnet_large': pnasnet.build_pnasnet_large,
                'pnasnet_mobile': pnasnet.build_pnasnet_mobile,
               }

1.0.2 alexnet.py 介绍

关于 arg_scope

arg_scope 是一种可以被下层 arg_scope 集成的, 自动给其下指定类型的NN定义参数并赋初值的一种机制, 他会被 nets_factory.py 调用.

语法为:

with slim.arg_scope([适用类范围],
                    NN 参数1 = 值
                    NN 参数2 = 值
                    NN 参数3 = 值):
# 定义网络参数
def alexnet_v2_arg_scope(weight_decay=0.0005):

    # 其下定义的[conv2d, 和 fully_connected] 都默认使用:
    # 1. relu 作为激活函数
    # 2. biases 都初始化为 0.1
    # 3. 采用 L2-regularizer 并且其系数为 weight_decay=0.0005
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        activation_fn=tf.nn.relu,
                        biases_initializer=tf.constant_initializer(0.1),
                        weights_regularizer=slim.l2_regularizer(weight_decay)):

        # 其下定义的[conv2d] 都默认使用:
        # 1. padding 方式为 "SAME"
        with slim.arg_scope([slim.conv2d], padding='SAME'):

            # 其下定义的[max_pool2d] 都默认使用:
            # 1. padding 方式为 "VALID"
            with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
                return arg_sc
# 定义网络结构
def alexnet_v2(inputs,
               num_classes=1000,
               is_training=True,
               dropout_keep_prob=0.5,
               spatial_squeeze=True,
               scope='alexnet_v2',
               global_pool=False):

构造 alexnet_v2 NN 需要指明的参数(如上):

  • inputs: a tensor of size [batch_size, height, width, channels]. 关于图片的 height, width 默认的值都是 224, 所以想直接使用 alexnet_v2 应该在图片预处理的时候都转换成 224*224 大小.
  • num_classes: the number of predicted classes. If 0 or None, the logits layer is omitted and the input features to the logits layer are returned instead. is_training: whether or not the model is being trained.
  • dropout_keep_prob: the probability that activations are kept in the dropout layers during training.
  • spatial_squeeze: whether or not should squeeze the spatial dimensions of the logits. Useful to remove unnecessary dimensions for classification.
  • scope: Optional scope for the variables.
  • global_pool: Optional boolean flag. If True, the input to the classification layer is avgpooled to size 1x1, for any input size. (This is not part of the original AlexNet.)