• Keras学习笔记三:一个图像去噪训练并离线测试的例子,基于mnist


    训练模型需要的数据文件有:

    MNIST_data文件夹下的mnist_train、mnist_test、noisy_train、noisy_test。train文件夹下60000个图片,test下10000个图片

    noisy_train、noisy_test下的图片加了椒盐噪声与原图序号对应

    离线测试需要的数据文件有:

    MNIST_data文件夹下的my_model.hdf5、my_test。my_test文件夹下要有一层嵌套文件夹并放测试图片

    数据集准备参考:

    https://www.cnblogs.com/dzzy/p/10824072.html

    训练:

    import os
    import glob
    from PIL import Image
    import numpy as np
    from warnings import simplefilter
    simplefilter(action='ignore', category=FutureWarning)
    import matplotlib.pyplot as plt
    from keras.models import Model
    from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
    from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
    from keras.callbacks import TensorBoard , ModelCheckpoint
    print("_________________________keras start_____________________________")
    
    
    base_dir = 'MNIST_data' #基准目录
    #os._exit(0)
    Datapath = os.path.join(base_dir,'mnist_train/*.png') #train目录
    x_train = np.zeros((60000, 28, 28))
    x_train = np.reshape(x_train, (60000, 28, 28, 1))
    i = 0
    for imageFile in glob.glob(Datapath ):
        # 打开图像并转化为数字矩阵
        img = np.array(Image.open(imageFile))
        img = np.reshape(img, (1, 28, 28, 1))
        img = img.astype('float32') / 255.
        x_train[i] = img
        i += 1
    
    Datapath = os.path.join(base_dir,'mnist_test/*.png') #test目录
    x_test = np.zeros((10000, 28, 28))
    x_test = np.reshape(x_test, (10000, 28, 28, 1))
    i = 0
    for imageFile in glob.glob(Datapath ):
        # 打开图像并转化为数字矩阵
        img = np.array(Image.open(imageFile))
        img = np.reshape(img, (1, 28, 28, 1))
        img = img.astype('float32') / 255.
        x_test[i] = img
        i += 1
    
    print( x_train.shape)
    print( x_test.shape)
    
    Datapath = os.path.join(base_dir,'noisy_train/*.png') #test目录
    x_train_noisy = np.zeros(x_train.shape)
    i = 0
    for imageFile in glob.glob(Datapath ):
        # 打开图像并转化为数字矩阵
        img = np.array(Image.open(imageFile))
        img = np.reshape(img, (1, 28, 28, 1))
        img = img.astype('float32') / 255.
        x_train_noisy[i] = img
        i += 1
    
    Datapath = os.path.join(base_dir,'noisy_test/*.png') #test目录
    x_test_noisy = np.zeros(x_test.shape)
    i = 0
    for imageFile in glob.glob(Datapath ):
        # 打开图像并转化为数字矩阵
        img = np.array(Image.open(imageFile))
        img = np.reshape(img, (1, 28, 28, 1))
        img = img.astype('float32') / 255.
        x_test_noisy[i] = img
        i += 1
    
    print( x_train_noisy.shape)
    print( x_test_noisy.shape)
    
    '''
    plt.figure(figsize=(20, 4))
    plt.subplot(4, 4, 1)
    plt.imshow(x_train[0].reshape(28, 28))
    plt.subplot(4, 4, 2)
    plt.imshow(x_train_noisy[0].reshape(28, 28))
    plt.subplot(4, 4, 3)
    plt.imshow(x_train[1].reshape(28, 28))
    plt.subplot(4, 4, 4)
    plt.imshow(x_train_noisy[1].reshape(28, 28))
    plt.show()
    #os._exit(0)
    '''
    
    """
    搭建模型
    """
    input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format
    
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) #relu激活函数
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    encoded = MaxPooling2D((2, 2), padding='same')(x)
    
    # at this point the representation is (7, 7, 32)
    
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
    
    autoencoder = Model(input_img, decoded)
    autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
    
    file_path="MNIST_data/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
    tensorboard = TensorBoard(log_dir='/tmp/tb', histogram_freq=0, write_graph=False)
    checkpoint = ModelCheckpoint(filepath=file_path,verbose=1,monitor='val_loss', save_weights_only=False,mode='auto' ,save_best_only=True,period=1)
    autoencoder.fit(x_train_noisy, x_train,
                    epochs=100,
                    batch_size=128,
                    shuffle=True,
                    validation_data=(x_test_noisy, x_test),
                    callbacks=[checkpoint,tensorboard])
                    
    #展示结果                
    n = 10
    plt.figure(figsize=(20, 4))
    for i in range(n):
    #noisy data
        ax = plt.subplot(3, n, i+1)
        plt.imshow(x_test_noisy[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    #predict
        ax = plt.subplot(3, n, i+1+n)
        decoded_img = autoencoder.predict(x_test_noisy)
        plt.imshow(decoded_img[i].reshape(28, 28))
        plt.gray()
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
    #original
        ax = plt.subplot(3, n, i+1+2*n)
        plt.imshow(x_test[i].reshape(28, 28))
        plt.gray()
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
    plt.show()

    测试:(同  https://www.cnblogs.com/dzzy/p/11387645.html

    import os
    import numpy as np
    from warnings import simplefilter
    simplefilter(action='ignore', category=FutureWarning)
    import matplotlib.pyplot as plt
    from keras.models import Model,Sequential,load_model
    from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
    from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
    from keras.callbacks import TensorBoard , ModelCheckpoint
    print("_________________________keras start_____________________________")
    pic_num = 3
    base_dir = 'MNIST_data' #基准目录
    train_dir = os.path.join(base_dir,'my_test') #train目录
    validation_dir="".join(train_dir)
    test_datagen = ImageDataGenerator(rescale= 1./255)
    validation_generator  = test_datagen.flow_from_directory(validation_dir,
                                                        target_size = (28,28),
                                                        color_mode = "grayscale",
                                                        batch_size = pic_num,
                                                        class_mode =  "categorical")#利用test_datagen.flow_from_directory(图像地址,目标size,批量数目,标签分类情况)
    for x_train,batch_labels in validation_generator:
        break
    x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
    y_train = x_train
    
    # create model
    model = load_model('MNIST_data/my_model.hdf5')
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    print("Created model and loaded weights from file")
    
    # estimate accuracy on whole dataset using loaded weights
    y_train=model.predict(x_train)
    # 评价训练出的网络
    #loss, accuracy = model.evaluate(x_train, y_train)
    #print('test loss: ', loss)
    #print('test accuracy: ', accuracy)
    
    n = pic_num
    for i in range(n):
        ax = plt.subplot(2, n, i+1)
        plt.imshow(x_train[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax = plt.subplot(2, n, i+1+n)
        plt.imshow(y_train[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()
  • 相关阅读:
    javascript简繁转换函数
    在嵌套的repeater中加ItemDataBound事件
    asp.net url重写方法和步骤
    打开,另存为,属性,打印"等14个JS代码
    php中global的用法
    ini_get
    PHP学习笔记
    PHP isset()与empty()的使用区别详解
    PHP符号说明
    html禁止清除input文本输入缓存
  • 原文地址:https://www.cnblogs.com/dzzy/p/11393420.html
Copyright © 2020-2023  润新知