训练模型需要的数据文件有:
MNIST_data文件夹下的mnist_train、mnist_test、noisy_train、noisy_test。train文件夹下60000个图片,test下10000个图片
noisy_train、noisy_test下的图片加了椒盐噪声与原图序号对应
离线测试需要的数据文件有:
MNIST_data文件夹下的my_model.hdf5、my_test。my_test文件夹下要有一层嵌套文件夹并放测试图片
数据集准备参考:
https://www.cnblogs.com/dzzy/p/10824072.html
训练:
import os import glob from PIL import Image import numpy as np from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) import matplotlib.pyplot as plt from keras.models import Model from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img from keras.callbacks import TensorBoard , ModelCheckpoint print("_________________________keras start_____________________________") base_dir = 'MNIST_data' #基准目录 #os._exit(0) Datapath = os.path.join(base_dir,'mnist_train/*.png') #train目录 x_train = np.zeros((60000, 28, 28)) x_train = np.reshape(x_train, (60000, 28, 28, 1)) i = 0 for imageFile in glob.glob(Datapath ): # 打开图像并转化为数字矩阵 img = np.array(Image.open(imageFile)) img = np.reshape(img, (1, 28, 28, 1)) img = img.astype('float32') / 255. x_train[i] = img i += 1 Datapath = os.path.join(base_dir,'mnist_test/*.png') #test目录 x_test = np.zeros((10000, 28, 28)) x_test = np.reshape(x_test, (10000, 28, 28, 1)) i = 0 for imageFile in glob.glob(Datapath ): # 打开图像并转化为数字矩阵 img = np.array(Image.open(imageFile)) img = np.reshape(img, (1, 28, 28, 1)) img = img.astype('float32') / 255. x_test[i] = img i += 1 print( x_train.shape) print( x_test.shape) Datapath = os.path.join(base_dir,'noisy_train/*.png') #test目录 x_train_noisy = np.zeros(x_train.shape) i = 0 for imageFile in glob.glob(Datapath ): # 打开图像并转化为数字矩阵 img = np.array(Image.open(imageFile)) img = np.reshape(img, (1, 28, 28, 1)) img = img.astype('float32') / 255. x_train_noisy[i] = img i += 1 Datapath = os.path.join(base_dir,'noisy_test/*.png') #test目录 x_test_noisy = np.zeros(x_test.shape) i = 0 for imageFile in glob.glob(Datapath ): # 打开图像并转化为数字矩阵 img = np.array(Image.open(imageFile)) img = np.reshape(img, (1, 28, 28, 1)) img = img.astype('float32') / 255. x_test_noisy[i] = img i += 1 print( x_train_noisy.shape) print( x_test_noisy.shape) ''' plt.figure(figsize=(20, 4)) plt.subplot(4, 4, 1) plt.imshow(x_train[0].reshape(28, 28)) plt.subplot(4, 4, 2) plt.imshow(x_train_noisy[0].reshape(28, 28)) plt.subplot(4, 4, 3) plt.imshow(x_train[1].reshape(28, 28)) plt.subplot(4, 4, 4) plt.imshow(x_train_noisy[1].reshape(28, 28)) plt.show() #os._exit(0) ''' """ 搭建模型 """ input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) #relu激活函数 x = MaxPooling2D((2, 2), padding='same')(x) x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) encoded = MaxPooling2D((2, 2), padding='same')(x) # at this point the representation is (7, 7, 32) x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded) x = UpSampling2D((2, 2))(x) x = Conv2D(32, (3, 3), activation='relu', padding='same')(x) x = UpSampling2D((2, 2))(x) decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) autoencoder = Model(input_img, decoded) autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') file_path="MNIST_data/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5" tensorboard = TensorBoard(log_dir='/tmp/tb', histogram_freq=0, write_graph=False) checkpoint = ModelCheckpoint(filepath=file_path,verbose=1,monitor='val_loss', save_weights_only=False,mode='auto' ,save_best_only=True,period=1) autoencoder.fit(x_train_noisy, x_train, epochs=100, batch_size=128, shuffle=True, validation_data=(x_test_noisy, x_test), callbacks=[checkpoint,tensorboard]) #展示结果 n = 10 plt.figure(figsize=(20, 4)) for i in range(n): #noisy data ax = plt.subplot(3, n, i+1) plt.imshow(x_test_noisy[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) #predict ax = plt.subplot(3, n, i+1+n) decoded_img = autoencoder.predict(x_test_noisy) plt.imshow(decoded_img[i].reshape(28, 28)) plt.gray() ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) #original ax = plt.subplot(3, n, i+1+2*n) plt.imshow(x_test[i].reshape(28, 28)) plt.gray() ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) plt.show()
测试:(同 https://www.cnblogs.com/dzzy/p/11387645.html)
import os import numpy as np from warnings import simplefilter simplefilter(action='ignore', category=FutureWarning) import matplotlib.pyplot as plt from keras.models import Model,Sequential,load_model from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img from keras.callbacks import TensorBoard , ModelCheckpoint print("_________________________keras start_____________________________") pic_num = 3 base_dir = 'MNIST_data' #基准目录 train_dir = os.path.join(base_dir,'my_test') #train目录 validation_dir="".join(train_dir) test_datagen = ImageDataGenerator(rescale= 1./255) validation_generator = test_datagen.flow_from_directory(validation_dir, target_size = (28,28), color_mode = "grayscale", batch_size = pic_num, class_mode = "categorical")#利用test_datagen.flow_from_directory(图像地址,目标size,批量数目,标签分类情况) for x_train,batch_labels in validation_generator: break x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) y_train = x_train # create model model = load_model('MNIST_data/my_model.hdf5') model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) print("Created model and loaded weights from file") # estimate accuracy on whole dataset using loaded weights y_train=model.predict(x_train) # 评价训练出的网络 #loss, accuracy = model.evaluate(x_train, y_train) #print('test loss: ', loss) #print('test accuracy: ', accuracy) n = pic_num for i in range(n): ax = plt.subplot(2, n, i+1) plt.imshow(x_train[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax = plt.subplot(2, n, i+1+n) plt.imshow(y_train[i].reshape(28, 28)) plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show()