• TensorFlow Serving


    TensorFlow Serving 可以快速部署 Tensorflow 模型,上线 gRPC 或 REST API。

    官方推荐 Docker 部署,也给了训练到部署的完整教程:Servers: TFX for TensorFlow Serving。本文只是遵照教程进行的练习,有助于了解 TensorFlow 训练到部署的整个过程。

    准备环境

    准备好 TensorFlow 环境,导入依赖:

    import sys
    
    # Confirm that we're using Python 3
    assert sys.version_info.major == 3, 'Oops, not running Python 3. Use Runtime > Change runtime type'
    
    import tensorflow as tf
    from tensorflow import keras
    
    # Helper libraries
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import subprocess
    
    print(f'TensorFlow version: {tf.__version__}')
    print(f'TensorFlow GPU support: {tf.test.is_built_with_gpu_support()}')
    
    physical_gpus = tf.config.list_physical_devices('GPU')
    print(physical_gpus)
    for gpu in physical_gpus:
      # memory growth must be set before GPUs have been initialized
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(physical_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    
    TensorFlow version: 2.4.1
    TensorFlow GPU support: True
    [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
    1 Physical GPUs, 1 Logical GPUs
    

    创建模型

    载入 Fashion MNIST 数据集:

    fashion_mnist = keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    # scale the values to 0.0 to 1.0
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    
    # reshape for feeding into the model
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
    test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
    
    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    print('
    train_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
    print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
    
    train_images.shape: (60000, 28, 28, 1), of float64
    test_images.shape: (10000, 28, 28, 1), of float64
    

    用最简单的 CNN 训练模型,

    model = keras.Sequential([
      keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3,
                          strides=2, activation='relu', name='Conv1'),
      keras.layers.Flatten(),
      keras.layers.Dense(10, name='Dense')
    ])
    model.summary()
    
    testing = False
    epochs = 5
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[keras.metrics.SparseCategoricalAccuracy()])
    model.fit(train_images, train_labels, epochs=epochs)
    
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print('
    Test accuracy: {}'.format(test_acc))
    
    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    Conv1 (Conv2D)               (None, 13, 13, 8)         80
    _________________________________________________________________
    flatten (Flatten)            (None, 1352)              0
    _________________________________________________________________
    Dense (Dense)                (None, 10)                13530
    =================================================================
    Total params: 13,610
    Trainable params: 13,610
    Non-trainable params: 0
    _________________________________________________________________
    Epoch 1/5
    1875/1875 [==============================] - 3s 722us/step - loss: 0.7387 - sparse_categorical_accuracy: 0.7449
    Epoch 2/5
    1875/1875 [==============================] - 1s 793us/step - loss: 0.4561 - sparse_categorical_accuracy: 0.8408
    Epoch 3/5
    1875/1875 [==============================] - 1s 720us/step - loss: 0.4097 - sparse_categorical_accuracy: 0.8566
    Epoch 4/5
    1875/1875 [==============================] - 1s 718us/step - loss: 0.3899 - sparse_categorical_accuracy: 0.8636
    Epoch 5/5
    1875/1875 [==============================] - 1s 719us/step - loss: 0.3673 - sparse_categorical_accuracy: 0.8701
    313/313 [==============================] - 0s 782us/step - loss: 0.3937 - sparse_categorical_accuracy: 0.8630
    
    Test accuracy: 0.8629999756813049
    

    保存模型

    将模型保存成 SavedModel 格式,路径里加上版本号,以便 TensorFlow Serving 时可选择模型版本。

    # Fetch the Keras session and save the model
    # The signature definition is defined by the input and output tensors,
    # and stored with the default serving key
    import tempfile
    
    MODEL_DIR = os.path.join(tempfile.gettempdir(), 'tfx')
    version = 1
    export_path = os.path.join(MODEL_DIR, str(version))
    print('export_path = {}
    '.format(export_path))
    
    tf.keras.models.save_model(
        model,
        export_path,
        overwrite=True,
        include_optimizer=True,
        save_format=None,
        signatures=None,
        options=None
    )
    
    print('
    Saved model:')
    !ls -l {export_path}
    
    export_path = /tmp/tfx/1
    
    INFO:tensorflow:Assets written to: /tmp/tfx/1/assets
    
    Saved model:
    total 88
    drwxr-xr-x 2 john john  4096 Apr 13 15:10 assets
    -rw-rw-r-- 1 john john 78169 Apr 13 15:12 saved_model.pb
    drwxr-xr-x 2 john john  4096 Apr 13 15:12 variables
    

    查看模型

    使用 saved_model_cli 工具查看模型的 MetaGraphDefs (the models) 和 SignatureDefs (the methods you can call),了解信息。

    !saved_model_cli show --dir '/tmp/tfx/1' --all
    
    2021-04-13 15:12:29.433576: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
    
    MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
    
    signature_def['__saved_model_init_op']:
      The given SavedModel SignatureDef contains the following input(s):
      The given SavedModel SignatureDef contains the following output(s):
        outputs['__saved_model_init_op'] tensor_info:
            dtype: DT_INVALID
            shape: unknown_rank
            name: NoOp
      Method name is:
    
    signature_def['serving_default']:
      The given SavedModel SignatureDef contains the following input(s):
        inputs['Conv1_input'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 28, 28, 1)
            name: serving_default_Conv1_input:0
      The given SavedModel SignatureDef contains the following output(s):
        outputs['Dense'] tensor_info:
            dtype: DT_FLOAT
            shape: (-1, 10)
            name: StatefulPartitionedCall:0
      Method name is: tensorflow/serving/predict
    
    Defined Functions:
      Function Name: '__call__'
        Option #1
          Callable with:
            Argument #1
              Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input')
            Argument #2
              DType: bool
              Value: False
            Argument #3
              DType: NoneType
              Value: None
        Option #2
          Callable with:
            Argument #1
              inputs: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='inputs')
            Argument #2
              DType: bool
              Value: False
            Argument #3
              DType: NoneType
              Value: None
        Option #3
          Callable with:
            Argument #1
              inputs: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='inputs')
            Argument #2
              DType: bool
              Value: True
            Argument #3
              DType: NoneType
              Value: None
        Option #4
          Callable with:
            Argument #1
              Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input')
            Argument #2
              DType: bool
              Value: True
            Argument #3
              DType: NoneType
              Value: None
      ...
    

    部署模型

    安装 Serving

    echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && 
    curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
    
    sudo apt update
    sudo apt install tensorflow-model-server
    

    开启 Serving

    开启 TensorFlow Serving ,提供 REST API :

    • rest_api_port: REST 请求端口。
    • model_name: REST 请求 URL ,自定义的名称。
    • model_base_path: 模型所在目录。
    nohup tensorflow_model_server 
      --rest_api_port=8501 
      --model_name=fashion_model 
      --model_base_path="/tmp/tfx" >server.log 2>&1 &
    
    $ tail server.log
    To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
    2021-04-13 15:12:10.706648: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
    2021-04-13 15:12:10.726722: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
    2021-04-13 15:12:10.756506: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /tmp/tfx/1
    2021-04-13 15:12:10.759935: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 110653 microseconds.
    2021-04-13 15:12:10.760277: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /tmp/tfx/1/assets.extra/tf_serving_warmup_requests
    2021-04-13 15:12:10.760486: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: fashion_model version: 1}
    2021-04-13 15:12:10.763938: I tensorflow_serving/model_servers/server.cc:371] Running gRPC ModelServer at 0.0.0.0:8500 ...
    [evhttp_server.cc : 238] NET_LOG: Entering the event loop ...
    2021-04-13 15:12:10.765308: I tensorflow_serving/model_servers/server.cc:391] Exporting HTTP/REST API at:localhost:8501 ...
    

    访问服务

    随机显示一张测试图:

    def show(idx, title):
      plt.figure()
      plt.imshow(test_images[idx].reshape(28,28))
      plt.axis('off')
      plt.title('
    
    {}'.format(title), fontdict={'size': 16})
    
    import random
    rando = random.randint(0,len(test_images)-1)
    show(rando, 'An Example Image: {}'.format(class_names[test_labels[rando]]))
    

    创建 JSON 对象,给到三张要预测的图:

    import json
    data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:3].tolist()})
    print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))
    
    Data: {"signature_name": "serving_default", "instances": ...  [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]}
    

    REST 请求

    最新模型版本进行预测:

    !pip install -q requests
    
    import requests
    headers = {"content-type": "application/json"}
    json_response = requests.post('http://localhost:8501/v1/models/fashion_model:predict', data=data, headers=headers)
    predictions = json.loads(json_response.text)['predictions']
    
    show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
      class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))
    

    指定模型版本进行预测:

    headers = {"content-type": "application/json"}
    json_response = requests.post('http://localhost:8501/v1/models/fashion_model/versions/1:predict', data=data, headers=headers)
    predictions = json.loads(json_response.text)['predictions']
    
    for i in range(0,3):
      show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
        class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))
    

    GoCoding 个人实践的经验分享,可关注公众号!

  • 相关阅读:
    使用IDEA新建Maven项目没有完整的项目结构(src文件夹等等)
    MyBatis:SQL语句中的foreach标签的详细介绍
    嵌入式tomcat例子
    springboot项目创建(myeclipse2017)
    使用javafxpackager将java项目打包成exe
    Spring Boot异常
    myeclipse设置新建菜单file-new选项
    myeclilpse打开文件所在位置的图标消失后的找回方法
    mybatis使用接口方式报错
    SSH中的Dao类继承HibernateDaoSupport后出现异常
  • 原文地址:https://www.cnblogs.com/gocodinginmyway/p/14664043.html
Copyright © 2020-2023  润新知