• fashion_mnist多分类训练,两种模型的保存与加载


    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
    from tensorflow.python.keras.models import Sequential,Model
    from tensorflow.python.keras.layers import Dense,Flatten,Input
    import tensorflow as tf
    from tensorflow.python.keras.losses import sparse_categorical_crossentropy
    from tensorflow.python import keras
    import os
    import numpy as np
    
    class SingleNN(object):
    
        #建立神经网络模型
        model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28,28)),
            keras.layers.Dense(128,activation=tf.nn.relu),
            keras.layers.Dense(10,activation=tf.nn.softmax)
        ])
    
        def __init__(self):
            (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
            #归一化
            self.x_train = self.x_train/255.0
            self.x_test = self.x_test/255.0
    
        def singlenn_compile(self):
            '''
            编译模型优化器、损失、准确率
            :return:
            '''
            SingleNN.model.compile(
                optimizer=keras.optimizers.SGD(lr=0.01),
                loss=keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy']
            )
    
        def singlenn_fit(self):
            """
            进行fit训练
            :return: 
            """
            SingleNN.model.fit(self.x_train,self.y_train,epochs=5)
    
        def single_evalute(self):
            '''
            模型评估
            :return: 
            '''
            test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
            print(test_loss,test_acc)
    
        def single_predict(self):
            '''
            预测结果
            :return: 
            '''
            # if os.path.exists("./ckpt/checkpoink"):
            #     SingleNN.model.load_weights("./ckpt/SingleNN")
    
            if os.path.exists("./ckpt/SingleNN.h5"):
                SingleNN.model.load_weights("./ckpt/SingleNN.h5")
    
            predictions = SingleNN.model.predict(self.x_test)
    
            return predictions
    
    if __name__ == '__main__':
        snn = SingleNN()
        # snn.singlenn_compile()
        # snn.singlenn_fit()
        # snn.single_evalute()
        # # SingleNN.model.save_weights("./ckpt/SingleNN")
        # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
        predictions = snn.single_predict()
        print(predictions)
        result = np.argmax(predictions,axis=1)
        print(result)
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    webStrom 注释模板添加
    匹配正则 url 端口 域名
    监测数据类型封装方法
    base64图片展示(后端给base64数据,前端展示图片)
    倒计时
    机密16位
    mvc与mvvm的区别
    flex表格的使用
    flex中tab页面的实现
    flex中下拉框的实现
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12250596.html
Copyright © 2020-2023  润新知