• 年龄_性别识别


    参考开源项目:年龄_性别识别

    1.识别效果如下图

    2.keras模型转pb模型,方便模型的迁移和rknn平台的使用,代码1如下:

    from keras.models import load_model
    import tensorflow as tf
    import os
    import os.path as osp
    from keras import backend as K
    from wide_resnet import WideResNet
    import tensorflow as tf
    from tensorflow.python.framework import graph_io
    
    print(tf.__version__)
    import keras as ks
    print(ks.__version__)
    import platform
    
    print (platform.python_version())
    
    def freeze_graph(graph, session, output_node_names, model_name):
        with graph.as_default():
            graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
            graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
            graph_io.write_graph(graphdef_frozen, "tmp", os.path.basename(model_name) + ".pb", as_text=False)
            print("done")
    
    
    
    
    def pb_transfer():
        weight_file = "E:\python_project\age-gender-estimation-master\pretrained_models\weights.28-3.73.hdf5"
    
        output_fld ='./'
        output_graph_name = 'age-gender.pb'
        tf.keras.backend.set_learning_phase(0)
    
        img_size = 64
        model = WideResNet(img_size, depth=16, k=8)()
        model.load_weights(weight_file)
        for out in model.outputs:
            print(out.op.name)
    
        session = tf.keras.backend.get_session()
        freeze_graph(session.graph, session, [out.op.name for out in model.outputs], weight_file)
    
    
    
    if __name__ == '__main__':
        pb_transfer()
    View Code

    代码2如下:

    # coding=utf-8
    
    from keras.models import load_model
    import tensorflow as tf
    import os
    import os.path as osp
    from keras import backend as K
    #路径参数
    weight_file_path = "E:\python_project\age-gender-estimation-master\pretrained_models\weights.28-3.73.hdf5"
    output_graph_name = 'ttt.pb'
    #转换函数
    def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
        if osp.exists(output_dir) == False:
            os.mkdir(output_dir)
        out_nodes = []
        for i in range(len(h5_model.outputs)):
            out_nodes.append(out_prefix + str(i + 1))
            tf.identity(h5_model.output[i],out_prefix + str(i + 1))
        sess = K.get_session()
        from tensorflow.python.framework import graph_util,graph_io
        init_graph = sess.graph.as_graph_def()
        main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
        graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
        if log_tensorboard:
            from tensorflow.python.tools import import_pb_to_tensorboard
            import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
    #输出路径
    output_dir = "./"
    #加载模型
    #h5_model = load_model(weight_file_path)
    from keras.models import load_model
    import tensorflow as tf
    import os
    import os.path as osp
    from keras import backend as K
    from wide_resnet import WideResNet
    import tensorflow as tf
    weight_file = "E:\python_project\age-gender-estimation-master\pretrained_models\weights.28-3.73.hdf5"
    
    output_fld ='./'
    tf.keras.backend.set_learning_phase(0)
    img_size = 64
    model = WideResNet(img_size, depth=16, k=8)()
    model.load_weights(weight_file)
    h5_to_pb(model,output_dir = output_dir,model_name = output_graph_name)
    print('model saved')
    View Code

    3.推理代码如下:

    import tensorflow as tf
    from tensorflow.python.platform import gfile
    import os
    import cv2
    import numpy as np
    import time
    
    from keras.layers import Input, Activation, add, Dense, Flatten, Dropout
    
    #facenet_model_checkpoint ="E:\python_project\age-gender-estimation-master\tmp\weights.28-3.73.hdf5.pb"
    facenet_model_checkpoint ="E:\python_project\age-gender-estimation-master\ttt.pb"
    
    
    def load_model(model, input_map=None):
        # Check if the model is a model directory (containing a metagraph and a checkpoint file)
        #  or if it is a protobuf file with a frozen graph
        model_exp = os.path.expanduser(model)
        if (os.path.isfile(model_exp)):
            print('Model filename: %s' % model_exp)
            with gfile.FastGFile(model_exp,'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, input_map=input_map, name='')
    
    
    
    def main():
        img_size = 64
        with tf.Graph().as_default():
            with tf.Session() as sess:
                print("load model:" + facenet_model_checkpoint)
                load_model(facenet_model_checkpoint)
                print("load over.")
    
    
                images_placeholder = tf.get_default_graph().get_tensor_by_name("input_1:0")
                gender = tf.get_default_graph().get_tensor_by_name("pred_gender/Softmax:0")
                age = tf.get_default_graph().get_tensor_by_name("pred_age/Softmax:0")
                while True:
                    img = cv2.imread("E:\python_project\age-gender-estimation-master\0036.jpg")
                    faces = cv2.resize(img, (img_size, img_size))
                    faces = faces[np.newaxis, :, :, :]
                    start_time = time.time()
                    feed_dict = {images_placeholder: faces}
                    results = sess.run([gender,age], feed_dict=feed_dict)
                    predicted_genders = results[0]
                   # print(predicted_genders)
    
    
    
                    ages = np.arange(0, 101).reshape(101, 1)
                    predicted_ages = results[1].dot(ages).flatten()
                    print("spend_time is", time.time() - start_time)
                    print(int(predicted_ages[0]))
                    if predicted_genders[0][0] < 0.5:
                        print("m")
                    else:
                        print("f")
    
    if __name__ == '__main__':
        main()
    View Code

    4.推理时间在tx2上为:70ms

  • 相关阅读:
    Java数据持久层
    一张图解决ThreadLocal
    类加载器及其加载原理
    手写LRU缓存淘汰算法
    使用归并排序思想解决逆序对数量问题
    Same Origin Policy 浏览器同源策略详解
    如何估算线程池的线程数?
    分布式锁为什么要选择Zookeeper而不是Redis?
    SpringBoot的SpringMVC使用FastJson依赖时LocalDateTime全局配置序列化格式
    数据库中的枚举值如何存储
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/13140763.html
Copyright © 2020-2023  润新知