import tensorflow as tf def initialize_uninitialized(sess): global_vars = tf.global_variables() is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] print [str(i.name) for i in not_initialized_vars] # only for testing if len(not_initialized_vars): sess.run(tf.variables_initializer(not_initialized_vars))
上述代码是用于初始化剩余未被初始化的变量的函数
需要注意的是,我们一般采用tf.global_variables_initializer()作为初始化op会覆盖原来通过saver.restore()方式加载的变量状态,因此,不可采用此方法。
另外,如果采用sess.run(tf.global_variables_initializer())在 saver.restore()之前,是不起作用的,原因未知,restore函数似乎能屏蔽掉global_variables_initializer()
的初始化效果。
选择性加载变量时可以采用scope进行隔离,提取出name:var这样的键值对的字典作为saver的加载根据。如下代码:
# stage_merged.py # transform from single frame into multi-frame enhanced single raw from __future__ import division import os, time, scipy.io import tensorflow as tf import numpy as np import rawpy import glob from model_sid_latest import network_stages_merged, network_my_unet, network_enhance_raw import platform from PIL import Image if platform.system() == 'Windows': data_dir = 'D:/data/Sony/dataset/bbf-raw-selected/' elif platform.system() == 'Linux': data_dir = './dataset/bbf-raw-selected/' else: print('platform not supported!') assert False os.environ["CUDA_VISIBLE_DEVICES"] = "6" checkpoint_dir = './model_stage_merged/' result_dir = './out_stage_merged/' log_dir = './log_stage_merged/' learning_rate = 1e-4 epoch_bound = 20000 save_model_every_n_epoch = 10 if platform.system() == 'Windows': output_every_n_steps = 1 else: output_every_n_steps = 100 if platform.system() == 'Windows': ckpt_enhance_raw = 'D:/model/enhance_raw/' ckpt_raw2rgb = 'D:/model/raw2rgb-c1/' else: ckpt_enhance_raw = './model/enhance_raw/' ckpt_raw2rgb = './model/raw2rgb-c1/' # BBF100-2 bbf_w = 4032 bbf_h = 3024 patch_w = 512 patch_h = 512 max_level = 1023 black_level = 64 patch_w = 512 patch_h = 512 # set up dataset input_files = glob.glob(data_dir + '/*.dng') input_files.sort() def preprocess(raw, bl, wl): im = raw.raw_image_visible.astype(np.float32) im = np.maximum(im - bl, 0) return im / (wl - bl) def pack_raw_bbf(path): raw = rawpy.imread(path) bl = 64 wl = 1023 im = preprocess(raw, bl, wl) im = np.expand_dims(im, axis=2) H = im.shape[0] W = im.shape[1] if raw.raw_pattern[0, 0] == 0: # CFA=RGGB out = np.concatenate((im[0:H:2, 0:W:2, :], im[0:H:2, 1:W:2, :], im[1:H:2, 1:W:2, :], im[1:H:2, 0:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 2: # BGGR out = np.concatenate((im[1:H:2, 1:W:2, :], im[0:H:2, 1:W:2, :], im[0:H:2, 0:W:2, :], im[1:H:2, 0:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG out = np.concatenate((im[0:H:2, 1:W:2, :], im[0:H:2, 0:W:2, :], im[1:H:2, 0:W:2, :], im[1:H:2, 1:W:2, :]), axis=2) elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG out = np.concatenate((im[1:H:2, 0:W:2, :], im[0:H:2, 0:W:2, :], im[0:H:2, 1:W:2, :], im[1:H:2, 1:W:2, :]), axis=2) else: assert False wb = np.array(raw.camera_whitebalance) wb[3] = wb[1] wb = wb / wb[1] out = np.minimum(out * wb, 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2 out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16) out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl)) del out return out_16bit_ tf.reset_default_graph() gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input') with tf.variable_scope('enhance_raw', reuse=tf.AUTO_REUSE): enhanced_raw = network_enhance_raw(in_im, patch_h, patch_w) with tf.variable_scope('raw2rgb', reuse=tf.AUTO_REUSE): gt_im = network_my_unet(enhanced_raw, patch_h, patch_w) with tf.variable_scope('stage_merged', reuse=tf.AUTO_REUSE): out_im = network_stages_merged(in_im, patch_h, patch_w) gt_im_cut = tf.minimum(tf.maximum(gt_im, 0.0), 1.0) out_im_cut = tf.minimum(tf.maximum(out_im, 0.0), 1.0) ssim_loss = 1 - tf.image.ssim_multiscale(gt_im_cut[0], out_im_cut[0], 1.0) l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(gt_im_cut - out_im_cut), axis=-1)) l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(gt_im_cut - out_im_cut), axis=-1)) G_loss = ssim_loss # G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss) tf.summary.scalar('L1 Loss', l1_loss) tf.summary.scalar('L2 Loss', l2_loss) ########## LOADING MODELS ############# scope_ = 'enhance_raw' enhance_raw_var_list = tf.global_variables(scope_) enhance_raw_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in enhance_raw_var_list] enhance_raw_map = dict() for i in range(len(enhance_raw_var_names)): enhance_raw_map[enhance_raw_var_names[i]] = enhance_raw_var_list[i] saver_enhance_raw = tf.train.Saver(var_list=enhance_raw_map) ckpt = tf.train.get_checkpoint_state(ckpt_enhance_raw) if ckpt: saver_enhance_raw.restore(sess, ckpt.model_checkpoint_path) print('loaded enhance_raw model: ' + ckpt.model_checkpoint_path) else: print('Error: failed to load enhance_raw model!') #---------------------------------------- scope_ = 'raw2rgb' raw2rgb_var_list = tf.global_variables(scope_) raw2rgb_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in raw2rgb_var_list] raw2rgb_map = dict() for i in range(len(raw2rgb_var_names)): raw2rgb_map[raw2rgb_var_names[i]] = raw2rgb_var_list[i] saver_raw2rgb = tf.train.Saver(var_list=raw2rgb_map) ckpt = tf.train.get_checkpoint_state(ckpt_raw2rgb) if ckpt: saver_raw2rgb.restore(sess, ckpt.model_checkpoint_path) print('loaded raw2rgb model: ' + ckpt.model_checkpoint_path) else: print('Error: failed to load raw2rgb model!') assert False #---------------------------------------- def initialize_uninitialized(sess): global_vars = tf.global_variables() bool_inits = sess.run([tf.is_variable_initialized(var) for var in global_vars]) uninit_vars = [v for (v, b) in zip(global_vars, bool_inits) if not b] for v in uninit_vars: print(str(v.name)) if len(uninit_vars): sess.run(tf.variables_initializer(uninit_vars)) t_vars = tf.trainable_variables(scope='stage_merged') lr = tf.placeholder(tf.float32) G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=t_vars) saver = tf.train.Saver(var_list=t_vars) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print('loaded ' + ckpt.model_checkpoint_path) else: sess.run(tf.variables_initializer(var_list=t_vars)) initialize_uninitialized(sess) ####################################### if not os.path.isdir(result_dir): os.mkdir(result_dir) input_images = [None] * len(input_files) g_loss = np.zeros([500, 1]) merged = tf.summary.merge_all() writer = tf.summary.FileWriter(log_dir, sess.graph) steps = 0 st = time.time() for epoch in range(0, epoch_bound): for ind in np.random.permutation(len(input_images)): steps += 1 if input_images[ind] is None: input_images[ind] = np.expand_dims(pack_raw_bbf(input_files[ind]), axis=0) # random cropping xx = np.random.randint(0, bbf_w // 2 - patch_w) yy = np.random.randint(0, bbf_h // 2 - patch_h) input_patch = np.float32(input_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / ( max_level - black_level) # random flipping if np.random.randint(2, size=1)[0] == 1: # random flip input_patch = np.flip(input_patch, axis=1) if np.random.randint(2, size=1)[0] == 1: input_patch = np.flip(input_patch, axis=0) if np.random.randint(2, size=1)[0] == 1: # random transpose input_patch = np.transpose(input_patch, (0, 2, 1, 3)) summary, _, G_current, output, gt_im_ = sess.run( [merged, G_opt, G_loss, out_im_cut, gt_im_cut], feed_dict={ in_im: input_patch, lr: learning_rate}) g_loss[steps % len(g_loss)] = G_current if steps % output_every_n_steps == 0: loss_ = np.mean(g_loss[np.where(g_loss)]) cost_ = (time.time() - st) / output_every_n_steps st = time.time() print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_)) writer.add_summary(summary, global_step=steps) temp = np.concatenate( (input_patch[0, :, :, :3], gt_im_[0, 0:patch_h*2:2, 0:patch_w*2:2, :3], output[0, 0:patch_h*2:2, 0:patch_w*2:2, :3]), axis=1) scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255) .save(result_dir + '/%d_%d.jpg' % (epoch, steps)) # clean up the memory if necessary if platform.system() == 'Windows': input_images[ind] = None if epoch % save_model_every_n_epoch == 0: saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch) print('model saved.')