• TensorFlow 笔记 ——1 —— GraphKeys,等待补充


    https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/GraphKeys?hl=zh_cn

    Class GraphKeys

    用于graph collections的标准名称

    别名:

    • Class tf.compat.v1.GraphKeys

    标准库使用各种众所周知的名称来收集和检索与图关联的值。例如,如果指定为none,tf.Optimizer子类默认优化在tf.GraphKeys.TRAINABLE_VARIABLES下收集的变量,但是也可以传递显式的变量列表。

    所以这个类下面的变量就是各种collections的keys值,通过这些值为String变量的名称,能够通过get_collection()方法来得到对应的值

    tf.get_collection

    Graph.get_collection()的封装,使用默认graph

    tf.get_collection(
        key,
        scope=None
    )

    参数:

    • key: collection的key值。例如,GraphKeys类包含许多集合的标准名称。
    • scope:(Optional):如果提供了此选项,则对结果列表进行筛选,以只包含名称属性使用re.match匹配的项。如果提供了scope值,并且choice或re.match表示没有特殊标记的作用域通过前缀筛选,则不返回没有name属性的项。

    返回:

    具有给定name的集合中的值列表,或者如果没有向该集合中添加值,则为空列表。该列表按收集值的顺序包含值。

    定义了如下的标准keys:

    • GLOBAL_VARIABLES:变量对象的默认集合,跨分布式环境共享(模型变量(model variable)是其中的子集)。更多细节请参见tf.compat.v1.global_variables。通常,所有TRAINABLE_VARIABLES变量都在MODEL_VARIABLES中,而所有MODEL_VARIABLES变量都在GLOBAL_VARIABLES中。
    • LOCAL_VARIABLES:位于每台机器局部的变量对象的子集。通常用于临时变量,如计数器。注意:使用tf.contrib.framework.local_variable将其添加到此集合中。
    • MODEL_VARIABLES:模型中用于推理(前向传播)的变量对象的子集。注意:使用tf.contrib.framework.model_variable将其添加到这个集合中。
    • TRAINABLE_VARIABLES:将由优化器训练的变量对象的子集。更多细节请参见tf.compat.v1.trainable_variables。
    • SUMMARIES:在graph中创建的summary张量对象。更多细节见tf.compat.v1.summary.merge_all。
    • QUEUE_RUNNERS:用于为计算生成输入的QueueRunner对象。参见tf.compat.v1.train.start_queue_runners了解更多细节。
    • MOVING_AVERAGE_VARIABLES:保持移动平均值的可变对象的子集。更多细节请参见tf.compat.v1.moving_average_variables。
    • REGULARIZATION_LOSSES:在graph构造过程中收集的正则化损失。

    还定义了以下标准keys,但是它们的集合不会像其他的那样自动填充,就是你可以手动使用add_to_collection()手动填充:

    • WEIGHTS
    • BIASES
    • ACTIVATIONS

    tf.add_to_collection

    Graph.add_to_collection()的封装,其使用了默认graph

    别名:

    • tf.compat.v1.add_to_collection
    tf.add_to_collection(
        name,
        value
    )

    根据给定的name值将value值存储在对应的collection中

    参数:

    • name:集合collection的键key。例如,GraphKeys类中包含的许多集合的标准名称。
    • value:要添加到该collection的值

    tf.add_to_collections

    Graph.add_to_collections()的封装,其使用了默认graph

    别名:

    • tf.compat.v1.add_to_collections
    tf.add_to_collections(
        names,
        value
    )

    将值存储在由name给出的集合中。

    注意,collections不是集合,因此可以多次向collection添加值。此函数确保忽略名称中的重复项,但它不会检查names中任何collections中value的预先存在的成员关系。

    names可以是任何可迭代的,但如果names是字符串,则将其视为单个集合名称。

    参数:

    • names:集合collection的键keys。例如,GraphKeys类中包含的许多集合的标准名称。
    • value:要添加到该collection的值

    其他的标准名称如下:

    • ACTIVATIONS = 'activations'
    • ASSET_FILEPATHS = 'asset_filepaths'
    • BIASES = 'biases'
    • CONCATENATED_VARIABLES = 'concatenated_variables'
    • COND_CONTEXT = 'cond_context'
    • EVAL_STEP = 'eval_step'
    • GLOBAL_STEP = 'global_step'
    • GLOBAL_VARIABLES = 'variables'
    • INIT_OP = 'init_op'
    • LOCAL_INIT_OP = 'local_init_op'
    • LOCAL_RESOURCES = 'local_resources'
    • LOCAL_VARIABLES = 'local_variables'
    • LOSSES = 'losses'
    • METRIC_VARIABLES = 'metric_variables'
    • MODEL_VARIABLES = 'model_variables'
    • MOVING_AVERAGE_VARIABLES = 'moving_average_variables'
    • QUEUE_RUNNERS = 'queue_runners'
    • READY_FOR_LOCAL_INIT_OP = 'ready_for_local_init_op'
    • READY_OP = 'ready_op'
    • REGULARIZATION_LOSSES = 'regularization_losses'
    • RESOURCES = 'resources'
    • SAVEABLE_OBJECTS = 'saveable_objects'
    • SAVERS = 'savers'
    • SUMMARIES = 'summaries'
    • SUMMARY_OP = 'summary_op'
    • TABLE_INITIALIZERS = 'table_initializer'
    • TRAINABLE_RESOURCE_VARIABLES = 'trainable_resource_variables'
    • TRAINABLE_VARIABLES = 'trainable_variables'
    • TRAIN_OP = 'train_op'
    • UPDATE_OPS = 'update_ops'
    • VARIABLES = 'variables'
    • WEIGHTS = 'weights'
    • WHILE_CONTEXT = 'while_context'

    下面举例说明几个常用的:

    1. tf.GraphKeys.GLOBAL_VARIABLES

    1) tf.global_variables

    别名:

    • tf.compat.v1.global_variables
    tf.global_variables(scope=None)

    全局变量是分布式环境中跨机器共享的变量。Variable()构造函数或get_variable()函数自动向Graph集合GraphKeys.GLOBAL_VARIABLES中添加新变量。这个方便的函数tf.global_variables()返回该集合的内容。

    局部变量是全局变量的替代品。看到tf.compat.v1.local_variables

    参数:

    • scope : (optional)一个字符串。如果提供了,则对结果列表进行筛选,以使用re.match返回只包含名称属性与作用域匹配的项。如果提供了作用域,则不会返回没有name属性的项。match的选择意味着没有特殊标记的范围通过前缀进行筛选。

    返回:

    • 变量对象列表

    我们熟悉的tf.global_variables_initializer()就是初始化这个集合内的Variables。

    2) tf.local_variables

    别名:

    • tf.compat.v1.local_variables
    tf.local_variables(scope=None)

    局部变量——每个进程变量,通常不保存/恢复到checkpoint,用于临时或中间值。例如,它们可以用作度量计算的计数器,或者机器读取数据的epoch数。函数的作用是:将新变量自动添加到GraphKeys.LOCAL_VARIABLES中。这个方便的函数返回该集合的内容。

    局部变量的替代方法是全局变量。看到tf.compat.v1.global_variables

    参数:

    • scope(optional):一个字符串。如果提供了,则对结果列表进行筛选,以使用re.match返回只包含名称属性与作用域匹配的项。如果提供了作用域,则不会返回没有name属性的项。match的选择意味着没有特殊标记的范围通过前缀进行筛选。

    返回:

    • 局部变量对象列表

    例子:

    import tensorflow as tf
    sess = tf.Session()
    
    #这里没有指定collections参数的值,则collections=None,等价于 collection=[tf.GraphKeys.GLOBAL_VARIABLES]
    a = tf.get_variable("a", [3,112,112])
    b = tf.get_variable("b", [64])
    print("a is ", a)
    print("b is  ", b)

    返回:

    a is  <tf.Variable 'a:0' shape=(3, 112, 112) dtype=float32_ref>
    b is   <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>

    可见生成了两个变量a和b,名字分别为'a:0'和'b:0'

    print("tf.GraphKeys.GLOBAL_VARIABLES = ", tf.GraphKeys.GLOBAL_VARIABLES)
    global_variables_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    print("global_variables_list is ", global_variables_list)

    返回:

    tf.GraphKeys.GLOBAL_VARIABLES =  variables
    global_variables_list is  [<tf.Variable 'a:0' shape=(3, 112, 112) dtype=float32_ref>, <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>]

    可见tf.GraphKeys.GLOBAL_VARIABLES对应的字符串名称为"variables",且其对应的collections中果然有a和b两个变量

    使用自定义的collections

    c = tf.get_variable("c", [10], collections=["my_collections"])
    d = tf.get_variable("d", [20], collections=["my_collections"])
    
    my_variables_list = tf.get_collection("my_collections")
    print("my_variables_list is ", my_variables_list)

    返回:

    my_variables_list is  [<tf.Variable 'c:0' shape=(10,) dtype=float32_ref>, <tf.Variable 'd:0' shape=(20,) dtype=float32_ref>]

    tf.GraphKeys.REGULARIZATION_LOSSES

    在使用tf.get_variable()时如果使用的了参数regularizer指定使用的正则化函数,则将新创建的变量应用正则化后的结果将添加到tf.GraphKeys.REGULARIZATION_LOSSES集合中,可用于正则化。

    其实就是损失中添加了正则化项,输入即这些变量权重weight值

    接着上面的例子:

    weight_decay = 0.1
    l2_reg = tf.contrib.layers.l2_regularizer(weight_decay)
    tmp = tf.constant([0,1,2,3], dtype=tf.float32)
    k = tf.get_variable('k', regularizer=l2_reg, initializer=tmp)
    global_variables_list = tf.get_collection("variables")
    print("global_variables_list is ", global_variables_list)
    
    #regularizer定义会将k加入REGULARIZATION_LOSSES集合
    regular_variables_list = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    print("regular_variables_list is ", regular_variables_list)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(k))
        l2_loss = tf.add_n(regular_variables_list)#实现一个列表的元素的相加
        print("loss is ", sess.run(l2_loss))

    返回:

    global_variables_list is  [<tf.Variable 'a:0' shape=(3, 112, 112) dtype=float32_ref>, <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'k:0' shape=(4,) dtype=float32_ref>]
    regular_variables_list is  [<tf.Tensor 'k/Regularizer/l2_regularizer:0' shape=() dtype=float32>]
    [0. 1. 2. 3.]
    loss is  0.7

    L2正则化的操作等价于:

    tf.reduce_sum(a*a)*weight_decay/2 = 0.1*(0*0+1*1+2*2+3*3)/2=0.7

    如果有多个变量都定义了regularizations参数,则:

    weight_decay = 0.1
    l2_reg = tf.contrib.layers.l2_regularizer(weight_decay)
    tmp = tf.constant([0,1,2,3], dtype=tf.float32)
    tmp2 = tf.constant([1,2,3,4], dtype=tf.float32)
    
    k = tf.get_variable('k', regularizer=l2_reg, initializer=tmp)
    k2 = tf.get_variable('k2', regularizer=l2_reg, initializer=tmp2)
    
    global_variables_list = tf.get_collection("variables")
    print("global_variables_list is ", global_variables_list)
    
    #regularizer定义会将k加入REGULARIZATION_LOSSES集合
    regular_variables_list = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    print("regular_variables_list is ", regular_variables_list)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        l2_loss = tf.add_n(regular_variables_list) #实现一个列表的元素的相加
        print("loss is ", sess.run(l2_loss))

    返回:

    global_variables_list is  [<tf.Variable 'a:0' shape=(3, 112, 112) dtype=float32_ref>, <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>, <tf.Variable 'k:0' shape=(4,) dtype=float32_ref>, <tf.Variable 'k2:0' shape=(4,) dtype=float32_ref>]
    regular_variables_list is  [<tf.Tensor 'k/Regularizer/l2_regularizer:0' shape=() dtype=float32>, <tf.Tensor 'k2/Regularizer/l2_regularizer:0' shape=() dtype=float32>]
    loss is  2.2

    L2正则化的操作等价于:

    tf.reduce_sum(k*k)*weight_decay/2 = 0.1*(0*0+1*1+2*2+3*3)/2=0.7
    tf.reduce_sum(k2*k2)*weight_decay/2 = 0.1*(1*1+2*2+3*3+4*4)/2=1.5

     所以:

    l2_loss = tf.add_n([0.7, 1.5]) = 2.2

    tf.GraphKeys.UPDATE_OPS

    用来将一些在运行过程中需要更新,但是有不是随着梯度后向传播更新的参数添加到该collection中,然后用于更新参数

    比如AM_softmax实现中:

    def diam_softmax(prelogits, label, num_classes,
                        scale='auto', m=1.0, alpha=0.5, reuse=None):
        ''' Implementation of DIAM-Softmax, AM-Softmax with Dynamic Weight Imprinting (DWI), proposed in:
                Y. Shi and A. K. Jain. DocFace+: ID Document to Selfie Matching. arXiv:1809.05620, 2018.
            The weights in the DIAM-Softmax are dynamically updated using the mean features of training samples.
        '''
        num_features = prelogits.shape[1].value
        batch_size = tf.shape(prelogits)[0]
        with tf.variable_scope('AM-Softmax', reuse=reuse):
            weights = tf.get_variable('weights', shape=(num_classes, num_features),
                    initializer=slim.xavier_initializer(),
                    trainable=False,
                    dtype=tf.float32)
            _scale = tf.get_variable('_scale', shape=(),
                    regularizer=slim.l2_regularizer(1e-2),
                    initializer=tf.constant_initializer(0.0),
                    trainable=True,
                    dtype=tf.float32)
    
            # Normalizing the vecotors
            prelogits_normed = tf.nn.l2_normalize(prelogits, dim=1)
            weights_normed = tf.nn.l2_normalize(weights, dim=1)
    
            # Label and logits between batch and examplars
            label_mat_glob = tf.one_hot(label, num_classes, dtype=tf.float32)
            label_mask_pos_glob = tf.cast(label_mat_glob, tf.bool) #将0转成False,1转成True,只有对应的那个类是True,其他类对应为False
            label_mask_neg_glob = tf.logical_not(label_mask_pos_glob) #取反操作,其他类为True
    
            logits_glob = tf.matmul(prelogits_normed, tf.transpose(weights_normed))
            # logits_glob = -0.5 * euclidean_distance(prelogits_normed, tf.transpose(weights_normed))
            logits_pos_glob = tf.boolean_mask(logits_glob, label_mask_pos_glob)
            logits_neg_glob = tf.boolean_mask(logits_glob, label_mask_neg_glob)
    
            logits_pos = logits_pos_glob
            logits_neg = logits_neg_glob
    
            if scale == 'auto':
                # Automatic learned scale
                scale = tf.log(tf.exp(0.0) + tf.exp(_scale))
            else:
                # Assigned scale value
                assert type(scale) == float
    
            # Losses
            _logits_pos = tf.reshape(logits_pos, [batch_size, -1])
            _logits_neg = tf.reshape(logits_neg, [batch_size, -1])
    
            _logits_pos = _logits_pos * scale
            _logits_neg = _logits_neg * scale
            _logits_neg = tf.reduce_logsumexp(_logits_neg, axis=1)[:,None]
    
            loss_ = tf.nn.softplus(m + _logits_neg - _logits_pos)
            loss = tf.reduce_mean(loss_, name='diam_softmax')
    
            # Dynamic weight imprinting
            # We follow the CenterLoss to update the weights, which is equivalent to
            # imprinting the mean features
            # 如temp = [1,11,21,31,41,51,61], indice=[0,3,5]
            # tf.gather(temp, indice)返回[1,31,51]
            weights_batch = tf.gather(weights, label) #根据label指向的类去得到对应的weight
            diff_weights = weights_batch - prelogits_normed
            # x=[1, 1, 2, 4, 4, 4, 7, 8, 8]
            # y, idx, count = unique_with_counts(x)
            # y ==> [1, 2, 4, 7, 8]
            # idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
            # count ==> [2, 1, 3, 1, 2]
            unique_label, unique_idx, unique_count = tf.unique_with_counts(label)
            appear_times = tf.gather(unique_count, unique_idx)
            appear_times = tf.reshape(appear_times, [-1, 1])
            diff_weights = diff_weights / tf.cast(appear_times, tf.float32)
            diff_weights = alpha * diff_weights
            #将weights中特定位置label的数与diff_weights对应的值分别进行减法运算
            #没指定的位置值不变,来更新在mini_batch中出现的类的weight
            weights_update_op = tf.scatter_sub(weights, label, diff_weights)#将weights中特定位置label的数分别与diff_weights进行减法运算
            #tf.control_dependencies()此函数指定某些操作执行的依赖关系
            #即tf.assign()操作要在tf.control_dependencies([weights_update_op])指定的weights_update_op操作后才能执行
            with tf.control_dependencies([weights_update_op]): #weight权值通过weights_update_op操作更新后才执行下面的赋值操作
                # 之后sess.run(weights_update_op)后weights的值才变
                weights_update_op = tf.assign(weights, tf.nn.l2_normalize(weights,dim=1)) #将weights归一化后的值赋值给weights,返回的结果就是参数中归一化的weight
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, weights_update_op) #将元素weights_update_op添加到列表tf.GraphKeys.UPDATE_OPS中
            
            return loss

    这样后面通过:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    sess.run([update_ops])

    进行更新

    tf.GraphKeys.TRAINABLE_VARIABLES

    由优化器Optimizer训练的变量对象的子集

    使用tf.get_variable()和tf.Variable()声明的变量(即trainable=True,默认就为True)就会默认加入该子集中

  • 相关阅读:
    关于js判断鼠标移入元素的方向--解释
    angularJs的学习笔记(一):angularJs的filter是根据value属性值来过滤的
    虚拟机设置网络连接
    [转载]23个经典JDK设计模式
    Ubuntu 17.04 开启 TCP BBR 拥塞控制算法
    解决DIGITALOCEAN后台被墙的两个方法
    远程访问服务器上的MySQL数据库,发现root远程连接不上
    jsp获取properties配置文件中的属性值
    去除底部“自豪地采用 WordPress”版权信息----最后附最新版的删除方法!!
    改91云linux服务器一键测试脚本(去除上传测试文件代码)
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/13384518.html
Copyright © 2020-2023  润新知