• keras使用多GPU并行训练模型 | keras multi gpu training


    本文首发于个人博客https://kezunlin.me/post/95370db7/,欢迎阅读最新内容!

    keras multi gpu training

    Guide

    multi_gpu_model

    import tensorflow as tf
    from keras.applications import Xception
    from keras.utils import multi_gpu_model
    import numpy as np
    
    G = 8 
    batch_size_per_gpu = 32
    batch_size = batch_size_per_gpu * G
    
    num_samples = 1000
    height = 224
    width = 224
    num_classes = 1000
    
    # Instantiate the base model (or "template" model).
    # We recommend doing this with under a CPU device scope,
    # so that the model's weights are hosted on CPU memory.
    # Otherwise they may end up hosted on a GPU, which would
    # complicate weight sharing.
    with tf.device('/cpu:0'):
        model = Xception(weights=None,
                         input_shape=(height, width, 3),
                         classes=num_classes)
    
    # Replicates the model on 8 GPUs.
    # This assumes that your machine has 8 available GPUs.
    parallel_model = multi_gpu_model(model, gpus=G)
    parallel_model.compile(loss='categorical_crossentropy',
                           optimizer='rmsprop')
    
    # Generate dummy data.
    x = np.random.random((num_samples, height, width, 3))
    y = np.random.random((num_samples, num_classes))
    
    # This `fit` call will be distributed on 8 GPUs.
    # Since the batch size is 256, each GPU will process 32 samples.
    parallel_model.fit(x, y, epochs=20, batch_size=batch_size)
    
    # Save model via the template model (which shares the same weights):
    model.save('my_model.h5')
    

    results

    results from Multi-GPU training with Keras, Python, and deep learning on Onepanel.io
    To validate this, we trained MiniGoogLeNet on the CIFAR-10 dataset with 4 V100 GPU.

    Using a single GPU we were able to obtain 63 second epochs with a total training time of 74m10s.
    However, by using multi-GPU training with Keras and Python we decreased training time to 16 second epochs with a total training time of 19m3s.
    4x times speedup!

    Reference

    History

    • 20190910:: created.

    Copyright

  • 相关阅读:
    Celery
    MongoDB-简介
    人工智障
    Flask-session,WTForms,POOL,Websocket通讯原理 -握手,加密解密过程
    web-socket
    flask基础2
    flask的基础1
    项目部署
    nginx简单学习
    redis的安装与配置
  • 原文地址:https://www.cnblogs.com/kezunlin/p/11961533.html
Copyright © 2020-2023  润新知