• photo2cartoon 中 face_seg.py 关于TensorFlow版本的问题修改


    import os
    import cv2
    import numpy as np
    import tensorflow as tf
    from tensorflow.python.platform import gfile
    
    
    curPath = os.path.abspath(os.path.dirname(__file__))
    
    
    class FaceSeg:
        def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
            #config = tf.ConfigProto()
            config = tf.compat.v1.ConfigProto()
            config.gpu_options.allow_growth = True
            self._graph = tf.Graph()
            #self._graph = tf.compat.v1.GraphDef()
    
            #self._sess = tf.Session(config=config, graph=self._graph)
            self._sess = tf.compat.v1.Session(config=config, graph=self._graph)  
    
            self.pb_file_path = model_path
            self._restore_from_pb()
            self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
            self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
    
        def _restore_from_pb(self):
            with self._sess.as_default():
                with self._graph.as_default():
                    with gfile.FastGFile(self.pb_file_path, 'rb') as f:
                        #graph_def = tf.GraphDef()
                        graph_def = tf.compat.v1.GraphDef()
                        graph_def.ParseFromString(f.read())
                        tf.import_graph_def(graph_def, name='')
    
        def input_transform(self, image):
            image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
            image_input = (image / 255.)[np.newaxis, :, :, :]
            return image_input
    
        def output_transform(self, output, shape):
            output = cv2.resize(output, (shape[1], shape[0]))
            image_output = (output * 255).astype(np.uint8)
            return image_output
    
        def get_mask(self, image):
            image_input = self.input_transform(image)
            output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
            return self.output_transform(output, shape=image.shape[:2])
  • 相关阅读:
    Android 模拟系统事件(三)
    全民Scheme(2):来自星星的你
    Java经典23种设计模式之行为型模式(三)
    libmysqld,嵌入式MySQLserver库
    闲云控制台(一)控制台命令解析框架
    怎样改动android系统字体大小
    [多校2015.02.1006 高斯消元] hdu 5305 Friends
    换工作经历和心得
    安卓实训第七天---多线程下载实现(进度条)
    校园双选会,你都懂么
  • 原文地址:https://www.cnblogs.com/hxjbc/p/12836519.html
Copyright © 2020-2023  润新知