• tflearn 在每一个epoch完毕保存模型


    关键代码:
    tflearn.DNN(net, checkpoint_path='model_resnet_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.)
    snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
    我的demo:
    def get_model(width, height, classes=40):
        # TODO, modify model
        network = input_data(shape=[None, width, height, 3])  # if RGB, 224,224,3
        # Residual blocks  
        # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
        n = 2
        net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
        net = tflearn.residual_block(net, n, 16)
        net = tflearn.residual_block(net, 1, 32, downsample=True)
        net = tflearn.residual_block(net, n-1, 32)
        net = tflearn.residual_block(net, 1, 64, downsample=True)
        net = tflearn.residual_block(net, n-1, 64)
        net = tflearn.batch_normalization(net)
        net = tflearn.activation(net, 'relu')
        net = tflearn.global_avg_pool(net)
        # Regression  
        net = tflearn.fully_connected(net, classes, activation='softmax')
        #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
        mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
        net = tflearn.regression(net, optimizer=mom,
                                 loss='categorical_crossentropy')
        # Training  
        model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                            max_checkpoints=10, tensorboard_verbose=0,
                            clip_gradients=0.)
        return model
    
    
    
    def  main():
        trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
        testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
        #trainX = trainX.reshape([-1, width, height, 1])
        #testX = testX.reshape([-1, width, height, 1])
        print("sample data:")
        print(trainX[0])
        print(trainY[0])
        print(testX[-1])
        print(testY[-1])
    
        model = get_model(width, height, classes=3755)
    
        filename = 'tflearn_resnet/model.tflearn'
        # try to load model and resume training
        try:
            #model.load(filename)
            model.load("model_resnet_cifar10-195804")
            print("Model loaded OK. Resume training!")
        except:
            pass
    
        early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
        try:      
            model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                      snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                      show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
        except StopIteration as e:
            print("OK, stop iterate!Good!")
    
        model.save(filename)
    
        del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
        filename = 'tflearn_resnet/model-infer.tflearn'
        model.save(filename)
    
  • 相关阅读:
    上云,你真的只差一本葵花宝典
    Linux Kernel 4.11首个候选版本开放下载
    Windows 10 host where Credential Guard or Device Guard is enabled fails when running Workstation (2146361)
    .NET技术+25台服务器怎样支撑世界第54大网站
    Azure 订阅和服务限制、配额和约束
    python再议装饰器
    python的上下文管理器-1
    python的上下文管理器
    python小知识点
    python做简易记事本
  • 原文地址:https://www.cnblogs.com/bonelee/p/9006243.html
Copyright © 2020-2023  润新知