• Tensorflow选择性初始化图中的变量


    import tensorflow as tf
    
    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    
        print [str(i.name) for i in not_initialized_vars] # only for testing
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    上述代码是用于初始化剩余未被初始化的变量的函数

    需要注意的是,我们一般采用tf.global_variables_initializer()作为初始化op会覆盖原来通过saver.restore()方式加载的变量状态,因此,不可采用此方法。

    另外,如果采用sess.run(tf.global_variables_initializer())在 saver.restore()之前,是不起作用的,原因未知,restore函数似乎能屏蔽掉global_variables_initializer()

    的初始化效果。

    选择性加载变量时可以采用scope进行隔离,提取出name:var这样的键值对的字典作为saver的加载根据。如下代码:

    # stage_merged.py
    # transform from single frame into multi-frame enhanced single raw
    from __future__ import division
    import os, time, scipy.io
    import tensorflow as tf
    import numpy as np
    import rawpy
    import glob
    from model_sid_latest import network_stages_merged, network_my_unet, network_enhance_raw
    import platform
    from PIL import Image
    
    if platform.system() == 'Windows':
        data_dir = 'D:/data/Sony/dataset/bbf-raw-selected/'
    elif platform.system() == 'Linux':
        data_dir = './dataset/bbf-raw-selected/'
    else:
        print('platform not supported!')
        assert False
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "6"
    checkpoint_dir = './model_stage_merged/'
    result_dir = './out_stage_merged/'
    log_dir = './log_stage_merged/'
    learning_rate = 1e-4
    epoch_bound = 20000
    save_model_every_n_epoch = 10
    
    if platform.system() == 'Windows':
        output_every_n_steps = 1
    else:
        output_every_n_steps = 100
    
    if platform.system() == 'Windows':
        ckpt_enhance_raw = 'D:/model/enhance_raw/'
        ckpt_raw2rgb = 'D:/model/raw2rgb-c1/'
    else:
        ckpt_enhance_raw = './model/enhance_raw/'
        ckpt_raw2rgb = './model/raw2rgb-c1/'
    
    # BBF100-2
    bbf_w = 4032
    bbf_h = 3024
    
    patch_w = 512
    patch_h = 512
    
    max_level = 1023
    black_level = 64
    
    patch_w = 512
    patch_h = 512
    
    # set up dataset
    input_files = glob.glob(data_dir + '/*.dng')
    input_files.sort()
    
    
    def preprocess(raw, bl, wl):
        im = raw.raw_image_visible.astype(np.float32)
        im = np.maximum(im - bl, 0)
        return im / (wl - bl)
    
    
    def pack_raw_bbf(path):
        raw = rawpy.imread(path)
        bl = 64
        wl = 1023
        im = preprocess(raw, bl, wl)
        im = np.expand_dims(im, axis=2)
        H = im.shape[0]
        W = im.shape[1]
        if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
            out = np.concatenate((im[0:H:2, 0:W:2, :],
                                  im[0:H:2, 1:W:2, :],
                                  im[1:H:2, 1:W:2, :],
                                  im[1:H:2, 0:W:2, :]), axis=2)
        elif raw.raw_pattern[0,0] == 2: # BGGR
            out = np.concatenate((im[1:H:2, 1:W:2, :],
                                  im[0:H:2, 1:W:2, :],
                                  im[0:H:2, 0:W:2, :],
                                  im[1:H:2, 0:W:2, :]), axis=2)
        elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
            out = np.concatenate((im[0:H:2, 1:W:2, :],
                                  im[0:H:2, 0:W:2, :],
                                  im[1:H:2, 0:W:2, :],
                                  im[1:H:2, 1:W:2, :]), axis=2)
        elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
            out = np.concatenate((im[1:H:2, 0:W:2, :],
                                  im[0:H:2, 0:W:2, :],
                                  im[0:H:2, 1:W:2, :],
                                  im[1:H:2, 1:W:2, :]), axis=2)
        else:
            assert False
        wb = np.array(raw.camera_whitebalance)
        wb[3] = wb[1]
        wb = wb / wb[1]
        out = np.minimum(out * wb, 1.0)
    
        h_, w_ = im.shape[0]//2, im.shape[1]//2
        out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
        out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
        del out
        return out_16bit_
    
    
    tf.reset_default_graph()
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input')
    
    with tf.variable_scope('enhance_raw', reuse=tf.AUTO_REUSE):
        enhanced_raw = network_enhance_raw(in_im, patch_h, patch_w)
    with tf.variable_scope('raw2rgb', reuse=tf.AUTO_REUSE):
        gt_im = network_my_unet(enhanced_raw, patch_h, patch_w)
    with tf.variable_scope('stage_merged', reuse=tf.AUTO_REUSE):
        out_im = network_stages_merged(in_im, patch_h, patch_w)
    
    gt_im_cut = tf.minimum(tf.maximum(gt_im, 0.0), 1.0)
    out_im_cut = tf.minimum(tf.maximum(out_im, 0.0), 1.0)
    ssim_loss = 1 - tf.image.ssim_multiscale(gt_im_cut[0], out_im_cut[0], 1.0)
    l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(gt_im_cut - out_im_cut), axis=-1))
    l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(gt_im_cut - out_im_cut), axis=-1))
    G_loss = ssim_loss
    # G_loss = l1_loss + l2_loss
    
    tf.summary.scalar('G_loss', G_loss)
    tf.summary.scalar('L1 Loss', l1_loss)
    tf.summary.scalar('L2 Loss', l2_loss)
    
    ########## LOADING MODELS #############
    scope_ = 'enhance_raw'
    enhance_raw_var_list = tf.global_variables(scope_)
    enhance_raw_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in enhance_raw_var_list]
    enhance_raw_map = dict()
    for i in range(len(enhance_raw_var_names)):
        enhance_raw_map[enhance_raw_var_names[i]] = enhance_raw_var_list[i]
    
    saver_enhance_raw = tf.train.Saver(var_list=enhance_raw_map)
    ckpt = tf.train.get_checkpoint_state(ckpt_enhance_raw)
    if ckpt:
        saver_enhance_raw.restore(sess, ckpt.model_checkpoint_path)
        print('loaded enhance_raw model: ' + ckpt.model_checkpoint_path)
    else:
        print('Error: failed to load enhance_raw model!')
    #----------------------------------------
    scope_ = 'raw2rgb'
    raw2rgb_var_list = tf.global_variables(scope_)
    raw2rgb_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in raw2rgb_var_list]
    raw2rgb_map = dict()
    for i in range(len(raw2rgb_var_names)):
        raw2rgb_map[raw2rgb_var_names[i]] = raw2rgb_var_list[i]
    
    saver_raw2rgb = tf.train.Saver(var_list=raw2rgb_map)
    ckpt = tf.train.get_checkpoint_state(ckpt_raw2rgb)
    if ckpt:
        saver_raw2rgb.restore(sess, ckpt.model_checkpoint_path)
        print('loaded raw2rgb model: ' + ckpt.model_checkpoint_path)
    else:
        print('Error: failed to load raw2rgb model!')
        assert False
    #----------------------------------------
    
    
    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        bool_inits = sess.run([tf.is_variable_initialized(var) for var in global_vars])
        uninit_vars = [v for (v, b) in zip(global_vars, bool_inits) if not b]
        for v in uninit_vars:
            print(str(v.name))
        if len(uninit_vars):
            sess.run(tf.variables_initializer(uninit_vars))
    
    t_vars = tf.trainable_variables(scope='stage_merged')
    lr = tf.placeholder(tf.float32)
    G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=t_vars)
    
    saver = tf.train.Saver(var_list=t_vars)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('loaded ' + ckpt.model_checkpoint_path)
    else:
        sess.run(tf.variables_initializer(var_list=t_vars))
        initialize_uninitialized(sess)
    #######################################
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)
    
    input_images = [None] * len(input_files)
    g_loss = np.zeros([500, 1])
    
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter(log_dir, sess.graph)
    
    steps = 0
    st = time.time()
    
    for epoch in range(0, epoch_bound):
        for ind in np.random.permutation(len(input_images)):
            steps += 1
            if input_images[ind] is None:
                input_images[ind] = np.expand_dims(pack_raw_bbf(input_files[ind]), axis=0)
    
            # random cropping
            xx = np.random.randint(0, bbf_w // 2 - patch_w)
            yy = np.random.randint(0, bbf_h // 2 - patch_h)
            input_patch = np.float32(input_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (
                        max_level - black_level)
    
            # random flipping
            if np.random.randint(2, size=1)[0] == 1:  # random flip
                input_patch = np.flip(input_patch, axis=1)
            if np.random.randint(2, size=1)[0] == 1:
                input_patch = np.flip(input_patch, axis=0)
            if np.random.randint(2, size=1)[0] == 1:  # random transpose
                input_patch = np.transpose(input_patch, (0, 2, 1, 3))
    
            summary, _, G_current, output, gt_im_ = sess.run(
                [merged, G_opt, G_loss, out_im_cut, gt_im_cut],
                feed_dict={
                    in_im: input_patch,
                    lr: learning_rate})
            g_loss[steps % len(g_loss)] = G_current
    
            if steps % output_every_n_steps == 0:
                loss_ = np.mean(g_loss[np.where(g_loss)])
                cost_ = (time.time() - st) / output_every_n_steps
                st = time.time()
                print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
                writer.add_summary(summary, global_step=steps)
                temp = np.concatenate(
                    (input_patch[0, :, :, :3],
                     gt_im_[0, 0:patch_h*2:2, 0:patch_w*2:2, :3],
                     output[0, 0:patch_h*2:2, 0:patch_w*2:2, :3]), axis=1)
                scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255) 
                    .save(result_dir + '/%d_%d.jpg' % (epoch, steps))
    
            # clean up the memory if necessary
            if platform.system() == 'Windows':
                input_images[ind] = None
    
        if epoch % save_model_every_n_epoch == 0:
            saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
            print('model saved.')
  • 相关阅读:
    C++多态
    C++和C#实现剪切板数据交互
    通过CLR API实现C++调用C#代码交互
    COM方式实现C++调用C#代码的一些总结
    输入LPCWSTR类型字符串
    取得COM对象的UUID并以string输出
    springmvc xml文件配置中使用系统环境变量
    SpringMVC,SpringBoot上传文件简洁代码
    c语言实行泛型hashmap
    java使用nio(Paths,Files)遍历文件目录,转成java.io.File
  • 原文地址:https://www.cnblogs.com/thisisajoke/p/10407059.html
Copyright © 2020-2023  润新知