• Sklearn,TensorFlow,keras模型保存与读取


    一、sklearn模型保存与读取 
    1、保存

    1 from sklearn.externals import joblib
    2 from sklearn import svm
    3 X = [[0, 0], [1, 1]]
    4 y = [0, 1]
    5 clf = svm.SVC()
    6 clf.fit(X, y)  
    7 joblib.dump(clf, "train_model.m")

    2、读取

    1 clf = joblib.load("train_model.m")
    2 clf.predit([0,0]) #此处test_X为特征集

    二、TensorFlow模型保存与读取(该方式tensorflow只能保存变量而不是保存整个网络,所以在提取模型时,我们还需要重新第一网络结构。) 
    1、保存

     1 import tensorflow as tf  
     2 import numpy as np  
     3 
     4 W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')  
     5 b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')  
     6  
     7 init = tf.initialize_all_variables()  
     8 saver = tf.train.Saver()  
     9 with tf.Session() as sess:  
    10          sess.run(init)  
    11          save_path = saver.save(sess,"save/model.ckpt")  

    2、加载

    1 import tensorflow as tf  
    2 import numpy as np  
    3  
    4 W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')  
    5 b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')  
    6  
    7 saver = tf.train.Saver()  
    8 with tf.Session() as sess:  
    9       saver.restore(sess,"save/model.ckpt")  

    三、TensorFlow模型保存与读取(该方式tensorflow保存整个网络) 
    1、保存

     1 import tensorflow as tf
     2 
     3 # First, you design your mathematical operations
     4 # We are the default graph scope
     5 
     6 # Let's design a variable
     7 v1 = tf.Variable(1. , name="v1")
     8 v2 = tf.Variable(2. , name="v2")
     9 # Let's design an operation
    10 a = tf.add(v1, v2)
    11 
    12 # Let's create a Saver object
    13 # By default, the Saver handles every Variables related to the default graph
    14 all_saver = tf.train.Saver() 
    15 # But you can precise which vars you want to save under which name
    16 v2_saver = tf.train.Saver({"v2": v2}) 
    17 
    18 # By default the Session handles the default graph and all its included variables
    19 with tf.Session() as sess:
    20   # Init v and v2   
    21   sess.run(tf.global_variables_initializer())
    22   # Now v1 holds the value 1.0 and v2 holds the value 2.0
    23   # We can now save all those values
    24   all_saver.save(sess, 'data.chkp')
    25   # or saves only v2
    26   v2_saver.save(sess, 'data-v2.chkp')
    27 模型的权重是保存在 .chkp 文件中,模型的图是保存在 .chkp.meta 文件中。

    2、加载

     1 import tensorflow as tf
     2 
     3 # Let's laod a previous meta graph in the current graph in use: usually the default graph
     4 # This actions returns a Saver
     5 saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')
     6 
     7 # We can now access the default graph where all our metadata has been loaded
     8 graph = tf.get_default_graph()
     9 
    10 # Finally we can retrieve tensors, operations, etc.
    11 global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')
    12 train_op = graph.get_operation_by_name('loss/train_op')
    13 hyperparameters = tf.get_collection('hyperparameters')
    14 
    15 恢复权重
    16 
    17 请记住,在实际的环境中,真实的权重只能存在于一个会话中。也就是说,restore 这个操作必须在一个会话中启动,然后将数据权重导入到图中。理解恢复操作的最好方法是将它简单的看做是一种数据初始化操作。
    18 with tf.Session() as sess:
    19     # To initialize values with saved data
    20     saver.restore(sess, 'results/model.ckpt-1000-00000-of-00001')
    21     print(sess.run(global_step_tensor)) # returns 1000

    四、keras模型保存和加载

    1 model.save('my_model.h5')  
    2 model = load_model('my_model.h5') 
  • 相关阅读:
    如果你正在找工作,也许这七个方法会帮到你
    WebSocket 浅析
    关系数据库涉及中的范式与反范式
    MySQL字段类型与合理的选择字段类型
    ER图,数据建模与数据字典
    详解慢查询
    MySQL的最佳索引攻略
    后端技术演进
    MySQL主从复制(BinaryLog)
    MySQL读写分离
  • 原文地址:https://www.cnblogs.com/tectal/p/9053205.html
Copyright © 2020-2023  润新知