• Unet 项目部分代码学习


    github地址:https://github.com/orobix/retina-unet

    主程序:

    ###################################################
    #
    #   Script to:
    #   - Load the images and extract the patches
    #   - Define the neural network
    #   - define the training
    #
    ##################################################
    
    
    import numpy as np
    import configparser as ConfigParser
    
    from keras.models import Model
    from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
    from keras.optimizers import Adam
    from keras.callbacks import ModelCheckpoint, LearningRateScheduler
    from keras import backend as K
    from keras.utils.vis_utils import plot_model as plot
    from keras.optimizers import SGD
    
    import sys
    sys.path.insert(0, './lib/')
    from help_functions import *
    
    #function to obtain data for training/testing (validation)
    from extract_patches import get_data_training
    
    
    
    #Define the neural network
    def get_unet(n_ch,patch_height,patch_width):
        inputs = Input(shape=(n_ch,patch_height,patch_width))
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(inputs)
        conv1 = Dropout(0.2)(conv1)
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv1)
        pool1 = MaxPooling2D((2, 2))(conv1)
        #
        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool1)
        conv2 = Dropout(0.2)(conv2)
        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv2)
        pool2 = MaxPooling2D((2, 2))(conv2)
        #
        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(pool2)
        conv3 = Dropout(0.2)(conv3)
        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv3)
    
        up1 = UpSampling2D(size=(2, 2))(conv3)
        up1 = concatenate([conv2,up1],axis=1)
        conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(up1)
        conv4 = Dropout(0.2)(conv4)
        conv4 = Conv2D(64, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv4)
        #
        up2 = UpSampling2D(size=(2, 2))(conv4)
        up2 = concatenate([conv1,up2], axis=1)
        conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(up2)
        conv5 = Dropout(0.2)(conv5)
        conv5 = Conv2D(32, (3, 3), activation='relu', padding='same',data_format='channels_first')(conv5)
        #
        conv6 = Conv2D(2, (1, 1), activation='relu',padding='same',data_format='channels_first')(conv5)
        conv6 = core.Reshape((2,patch_height*patch_width))(conv6)
        conv6 = core.Permute((2,1))(conv6)
        ############
        conv7 = core.Activation('softmax')(conv6)
    
        model = Model(inputs=inputs, outputs=conv7)
    
        # sgd = SGD(lr=0.01, decay=1e-6, momentum=0.3, nesterov=False)
        model.compile(optimizer='sgd', loss='categorical_crossentropy',metrics=['accuracy'])
    
        return model
    
    #Define the neural network gnet
    #you need change function call "get_unet" to "get_gnet" in line 166 before use this network
    def get_gnet(n_ch,patch_height,patch_width):
        inputs = Input((n_ch, patch_height, patch_width))
        conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(inputs)
        conv1 = Dropout(0.2)(conv1)
        conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
        up1 = UpSampling2D(size=(2, 2))(conv1)
        #
        conv2 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(up1)
        conv2 = Dropout(0.2)(conv2)
        conv2 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(conv2)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)
        #
        conv3 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(pool1)
        conv3 = Dropout(0.2)(conv3)
        conv3 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv3)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv3)
        #
        conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool2)
        conv4 = Dropout(0.2)(conv4)
        conv4 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv4)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv4)
        #
        conv5 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool3)
        conv5 = Dropout(0.2)(conv5)
        conv5 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv5)
        #
        up2 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
        conv6 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up2)
        conv6 = Dropout(0.2)(conv6)
        conv6 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv6)
        #
        up3 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
        conv7 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up3)
        conv7 = Dropout(0.2)(conv7)
        conv7 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv7)
        #
        up4 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
        conv8 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(up4)
        conv8 = Dropout(0.2)(conv8)
        conv8 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(conv8)
        #
        pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)
        conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(pool4)
        conv9 = Dropout(0.2)(conv9)
        conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)
        #
        conv10 = Convolution2D(2, 1, 1, activation='relu', border_mode='same')(conv9)
        conv10 = core.Reshape((2,patch_height*patch_width))(conv10)
        conv10 = core.Permute((2,1))(conv10)
        ############
        conv10 = core.Activation('softmax')(conv10)
    
        model = Model(input=inputs, output=conv10)
    
        # sgd = SGD(lr=0.01, decay=1e-6, momentum=0.3, nesterov=False)
        model.compile(optimizer='sgd', loss='categorical_crossentropy',metrics=['accuracy'])
    
        return model
    
    #========= Load settings from Config file
    config = ConfigParser.RawConfigParser()
    config.read('configuration.txt')
    #patch to the datasets
    path_data = config.get('data paths', 'path_local')
    #Experiment name
    name_experiment = config.get('experiment name', 'name')
    #training settings
    N_epochs = int(config.get('training settings', 'N_epochs'))
    batch_size = int(config.get('training settings', 'batch_size'))
    
    
    
    #============ Load the data and divided in patches
    patches_imgs_train, patches_masks_train = get_data_training(
        DRIVE_train_imgs_original = path_data + config.get('data paths', 'train_imgs_original'),
        DRIVE_train_groudTruth = path_data + config.get('data paths', 'train_groundTruth'),  #masks
        patch_height = int(config.get('data attributes', 'patch_height')),
        patch_width = int(config.get('data attributes', 'patch_width')),
        N_subimgs = int(config.get('training settings', 'N_subimgs')),
        inside_FOV = config.getboolean('training settings', 'inside_FOV') #select the patches only inside the FOV  (default == True)
    )
    
    
    #========= Save a sample of what you're feeding to the neural network ==========
    N_sample = min(patches_imgs_train.shape[0],40)#这里规定,要显示的图片最多不超过40张
    visualize(group_images(patches_imgs_train[0:N_sample,:,:,:],5),'./'+name_experiment+'/'+"sample_input_imgs")#.show()
    visualize(group_images(patches_masks_train[0:N_sample,:,:,:],5),'./'+name_experiment+'/'+"sample_input_masks")#.show()
    #显示的结果会在下面贴出来
    
    
    #=========== Construct and save the model arcitecture =====
    n_ch = patches_imgs_train.shape[1]#得到每个patch的通道数
    patch_height = patches_imgs_train.shape[2]#得到每个patch的高
    patch_width = patches_imgs_train.shape[3]#得到每个patch的宽
    model = get_unet(n_ch, patch_height, patch_width)  #the U-net model
    print ("Check: final output of the network:")
    print (model.output_shape)
    plot(model, to_file='./'+name_experiment+'/'+name_experiment + '_model.png')   #check how the model looks like
    json_string = model.to_json()#model.to_json:返回代表模型的JSON字符串,仅包含网络结构,不包含权值。可以从JSON字符串中重构原模型:
    open('./'+name_experiment+'/'+name_experiment +'_architecture.json', 'w').write(json_string)
    
    
    
    #============  Training ==================================
    checkpointer = ModelCheckpoint(filepath='./'+name_experiment+'/'+name_experiment +'_best_weights.h5', verbose=1, monitor='val_loss', mode='auto', save_best_only=True) #save at each epoch if the validation decreased
    
    
    # def step_decay(epoch):
    #     lrate = 0.01 #the initial learning rate (by default in keras)
    #     if epoch==100:
    #         return 0.005
    #     else:
    #         return lrate
    #
    # lrate_drop = LearningRateScheduler(step_decay)
    
    patches_masks_train = masks_Unet(patches_masks_train)  #reduce memory consumption
    model.fit(patches_imgs_train, patches_masks_train, nb_epoch=N_epochs, batch_size=batch_size, verbose=2, shuffle=True, validation_split=0.1, callbacks=[checkpointer])
    
    
    #========== Save and test the last model ===================
    model.save_weights('./'+name_experiment+'/'+name_experiment +'_last_weights.h5', overwrite=True)
    #test the model
    # score = model.evaluate(patches_imgs_test, masks_Unet(patches_masks_test), verbose=0)
    # print('Test score:', score[0])
    # print('Test accuracy:', score[1])
    

    实验结果显示:上中下分别为原图-groundTruth-预测图

  • 相关阅读:
    使用JAXB来实现Java合xml之间的转换
    Kinect 骨骼映射---Let me dance for U!
    WebService学习笔记系列(四)
    PHP中出现BOM字符ufeff,PHP去掉诡异的BOM ufeff
    PHP中出现BOM字符ufeff,PHP去掉诡异的BOM ufeff
    WebService学习笔记系列(三)
    把QQ聊天记录插入数据库中
    如何使用SAP CRM Marketing Survey创建一个市场问卷调查
    为什么ABAP整型的1转成string之后,后面会多个空格
    如何用Chrome自带的截屏功能截取超过一个屏幕的网页
  • 原文地址:https://www.cnblogs.com/fourmi/p/8993631.html
Copyright © 2020-2023  润新知