• 简单粗暴的tensorflow模型导出


    # TensorFlow 模型导出 
    mport tensorflow as tf
    from zh.model.utils import MNISTLoader
    
    num_epochs = 1
    batch_size = 50
    learning_rate = 0.001
    
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(100, activation=tf.nn.relu),
        tf.keras.layers.Dense(10),
        tf.keras.layers.Softmax()
    ])
    
    data_loader = MNISTLoader()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )
    model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)
    tf.saved_model.save(model, "saved/1")   #tf.saved_model.save保存模型
    
    # 测试
    import tensorflow as tf
    from zh.model.utils import MNISTLoader
    batch_size = 50
    model = tf.saved_model.load("saved/1")  #tf.saved_model.load读取模型
    data_loader = MNISTLoader()
    sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    num_batches = int(data_loader.num_test_data // batch_size)
    for batch_index in range(num_batches):
        start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
        y_pred = model(data_loader.test_data[start_index: end_index])
        sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
    print("test accuracy: %f" % sparse_categorical_accuracy.result())

    # 基础Model需转换为图模式,才可以进行保存模型
    class MLP(tf.keras.Model):
        def __init__(self):
            super().__init__()
            self.flatten = tf.keras.layers.Flatten()
            self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
            self.dense2 = tf.keras.layers.Dense(units=10)
    
        @tf.function
        def call(self, inputs):         # [batch_size, 28, 28, 1]
            x = self.flatten(inputs)    # [batch_size, 784]
            x = self.dense1(x)          # [batch_size, 100]
            x = self.dense2(x)          # [batch_size, 10]
            output = tf.nn.softmax(x)
            return output
    
    model = MLP()
    ...
    
    y_pred = model.call(data_loader.test_data[start_index: end_index])  #测试需显示调用call方法
  • 相关阅读:
    Django基础
    MySQL(索引)
    MySQL(进阶部分)
    MySQL(Python+ORM)
    JavaScript的对象
    abc
    Let's Encrypt,免费好用的 HTTPS 证书
    Java调试那点事
    Memcache mutex 设计模式
    从 Nginx 默认不压缩 HTTP/1.0 说起
  • 原文地址:https://www.cnblogs.com/wuyuan2011woaini/p/15907795.html
Copyright © 2020-2023  润新知