• LEP+低秩+神经网络去噪


    from __future__ import print_function
    import matplotlib
    import matplotlib.pyplot as plt
    %matplotlib inline
    import scipy.misc
    import os
    import numpy as np
    
    from models.resnet import ResNet
    from models.unet import UNet
    from models.skip import skip
    import torch
    import torch.optim
    
    from utils.inpainting_utils import *
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark =True
    dtype = torch.cuda.FloatTensor
    
    PLOT = True
    imsize = -1
    dim_div_by = 64
    NET_TYPE = 'skip_depth6'
    
    iteation_LEP = '/home/hxj/桌面/PG/test/iteation+LEP/'
    LEP = '/home/hxj/桌面/PG/test/LEP-only/'
    ORI = '/home/hxj/gluon-tutorials/GAN/MultiPIE/YaleB_test_crop_gray/'
    img_name = 'yaleB38_P00A-130E+20.png'
    real_face_name='data/face/reSVD10.png'
    
    pad = 'reflection' # 'zero'
    OPT_OVER = 'net'
    OPTIMIZER = 'adam'
    INPUT = 'noise'
    input_depth = 32
    #input_depth = 4
    num_iter = 600
    param_noise = False
    figsize = 5 
    reg_noise_std = 0.03
    LR = 0.01
    mse = torch.nn.MSELoss().type(dtype)
    #i = 0
    def closure():
        #global i
        
        if param_noise:
            for n in [x for x in net.parameters() if len(x.size()) == 4]:
                n = n + n.detach().clone().normal_() * n.std() / 50
        
        net_input = net_input_saved
        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)
            
            
        out = net(net_input)
       
        #total_loss = mse(out * mask_var, img_var * mask_var)
        #total_loss = mse(out, img_var)
        total_loss = mse(out,itLEP_var) + mse(out,ORI_var)*0.1+ mse(out,LEP_var)*0.2 + mse(out,RF_var)*0.5
        total_loss.backward()
            
        print ('Iteration %s     Loss %f' % (img_name, total_loss.item()), '
    ', end='')
        #if  PLOT and i % show_every == 0:
            #out_np = torch_to_np(out)
            #img_save =(np.clip(out_np, 0, 1))[0]
            #scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/'+str(i)+'_'+img_name)
            #plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
            #plt.imshow(img_save)
            #plt.axis('off')
            #plt.savefig('result/'+str(i)+'_'+img_name,dpi=128*128)
            #plt.show()
             
           
        #i += 1
    
        return total_loss
    RF_pil, RF_np = get_image(real_face_name, imsize)
    RF_var = np_to_torch(RF_np).type(dtype)
    
    files = os.listdir(iteation_LEP)
    for img_name in files:
        itLEP_pil, itLEP_np = get_image(iteation_LEP+img_name, imsize)
        LEP_pil, LEP_np = get_image(LEP+img_name, imsize)
        ORI_pil, ORI_np = get_image(ORI+img_name, imsize)
        
        itLEP_var = np_to_torch(itLEP_np).type(dtype)
        LEP_var = np_to_torch(LEP_np).type(dtype)
        ORI_var = np_to_torch(ORI_np).type(dtype)
        
        net = skip(input_depth, itLEP_np.shape[0], 
               num_channels_down = [128] * 5,
               num_channels_up =   [128] * 5,
               num_channels_skip =    [128] * 5,
               filter_size_up = 3, filter_size_down = 3,
               upsample_mode='nearest', filter_skip_size=1,
               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
        
        net_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
        # net_input[0,0,:] = itLEP_var
        # net_input[0,1,:] = LEP_var
        # net_input[0,2,:] = ORI_var
        # net_input[0,3,:] = RF_var
        #net_input = np_to_torch(RF_np).type(dtype)
        
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        p = get_params(OPT_OVER, net, net_input)
        optimize(OPTIMIZER, p, closure, LR, num_iter)
    
        
        out_np = torch_to_np(net(net_input))
        img_save =(np.clip(out_np, 0, 1))[0]
        scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/noise_input/0.01/'+img_name)
        
  • 相关阅读:
    SSAS的维度表之间的关系只能有一个不能有多个
    SqlServer 在创建数据库时候指定的初始数据库大小是不能被收缩的
    SQL Server数据库的三种恢复模式:简单恢复模式、完整恢复模式和大容量日志恢复模式(转载)
    HttpHandler和ashx要实现IRequiresSessionState接口才能访问Session信息(转载)
    Jquery Ajax调用aspx页面方法 (转载)
    linux查找目录下的所有文件中是否含有某个字符串 <zhuan>
    ubuntu下使用sdk manager 安装sdk 其他版本
    Ubuntu更新命令 <转>
    sudo:must be setuid root 解决方法 <转>
    Ubuntu 查看磁盘空间大小命令<转>
  • 原文地址:https://www.cnblogs.com/hxjbc/p/10817751.html
Copyright © 2020-2023  润新知