• TensorFlow基础笔记(5) VGGnet_test


    参考

    http://blog.csdn.net/jsond/article/details/72667829

    资源

    1.相关的vgg模型下载网址

    http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat

    2.ImageNet 1000种分类以及排列

    https://github.com/sh1r0/caffe-Android-demo/blob/master/app/src/main/assets/synset_words.txt(如果下载单个txt格式不对的话就整包下载)

     

     

    这里以E网络为测试模型VGG19

    #coding=utf-8
    import numpy as np
    import scipy.misc
    import scipy.io as sio
    import tensorflow as tf
    import os
    
    
    ##卷积层
    def _conv_layer(input, weight, bias):
        conv = tf.nn.conv2d(input, tf.constant(weight), strides=(1, 1, 1, 1), padding='SAME')
        return tf.nn.bias_add(conv, bias)
    
    
    ##池化层
    def _pool_layer(input):
        return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
    
    
    ##全链接层
    def _fc_layer(input, weights, bias):
        shape = input.get_shape().as_list()
        dim = 1
        for d in shape[1:]:
            dim *= d
        x = tf.reshape(input, [-1, dim])
        fc = tf.nn.bias_add(tf.matmul(x, weights), bias)
        return fc
    
    
    ##softmax输出层
    def _softmax_preds(input):
        preds = tf.nn.softmax(input, name='prediction')
        return preds
    
    
    ##图片处里前减去均值
    def _preprocess(image, mean_pixel):
        return image - mean_pixel
    
    
    ##加均值  显示图片
    def _unprocess(image, mean_pixel):
        return image + mean_pixel
    
    
    ##读取图片 并压缩
    def _get_img(src, img_size=False):
        img = scipy.misc.imread(src, mode='RGB')
        if not (len(img.shape) == 3 and img.shape[2] == 3):
            img = np.dstack((img, img, img))
        if img_size != False:
            img = scipy.misc.imresize(img, img_size)
        return img.astype(np.float32)
    
    
    ##获取名列表
    def list_files(in_path):
        files = []
        for (dirpath, dirnames, filenames) in os.walk(in_path):
            # print("dirpath=%s, dirnames=%s, filenames=%s"%(dirpath, dirnames, filenames))
            files.extend(filenames)
            break
    
        return files
    
    
    ##获取文件路径列表dir+filename
    def _get_files(img_dir):
        files = list_files(img_dir)
        return [os.path.join(img_dir, x) for x in files]
    
    ##获得图片lable列表
    def _get_allClassificationName(file_path):
        f = open(file_path, 'r')
        lines = f.readlines()
        f.close()
        return lines
    
    ##构建cnn前向传播网络
    def net(data, input_image):
        layers = (
            'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
    
            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
    
            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
            'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
    
            'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2',
            'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
    
            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
            'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5',
    
            'fc6', 'relu6',
            'fc7', 'relu7',
            'fc8', 'softmax'
        )
    
        weights = data['layers'][0]
        net = {}
        current = input_image
        for i, name in enumerate(layers):
            kind = name[:4]
            if kind == 'conv':
                kernels, bias = weights[i][0][0][0][0]
                kernels = np.transpose(kernels, (1, 0, 2, 3))
                bias = bias.reshape(-1)
                current = _conv_layer(current, kernels, bias)
            elif kind == 'relu':
                current = tf.nn.relu(current)
            elif kind == 'pool':
                current = _pool_layer(current)
            elif kind == 'soft':
                current = _softmax_preds(current)
    
            kind2 = name[:2]
            if kind2 == 'fc':
                kernels1, bias1 = weights[i][0][0][0][0]
    
                kernels1 = kernels1.reshape(-1, kernels1.shape[-1])
                bias1 = bias1.reshape(-1)
                current = _fc_layer(current, kernels1, bias1)
    
            net[name] = current
        assert len(net) == len(layers)
        return net, mean_pixel, layers
    
    
    if __name__ == '__main__':
        imagenet_path = 'imagenet-vgg-verydeep-19.mat'
        image_dir = 'images/'
    
        data = sio.loadmat(imagenet_path)  ##加载ImageNet mat模型
        mean = data['normalization'][0][0][0]
        mean_pixel = np.mean(mean, axis=(0, 1))  ##获取图片像素均值
    
        lines = _get_allClassificationName('synset_words.txt')  ##加载ImageNet mat标签
        images = _get_files(image_dir)  ##获取图片路径列表
        with tf.Session() as sess:
            for i, imgPath in enumerate(images):
                image = _get_img(imgPath, (224, 224, 3));  ##加载图片并压缩到标准格式=>224 224
    
                image_pre = _preprocess(image, mean_pixel)
                # image_pre = image_pre.transpose((2, 0, 1))
                image_pre = np.expand_dims(image_pre, axis=0)
    
                image_preTensor = tf.convert_to_tensor(image_pre)
                image_preTensor = tf.to_float(image_preTensor)
    
                # Test pretrained model
                nets, mean_pixel, layers = net(data, image_preTensor)
    
                preds = nets['softmax']
    
                predsSortIndex = np.argsort(-preds[0].eval())
                print('
    #####%s#######' % imgPath)
                for i in range(3):   ##输出前3种分类
                    nIndex = predsSortIndex
                    classificationName = lines[nIndex[i]] ##分类名称
                    problity = preds[0][nIndex[i]]   ##某一类型概率
    
                    print('%d.ClassificationName=%s  Problity=%f' % ((i + 1), classificationName, problity.eval()))
            sess.close()

    分类结果

    #####images/airplay.jpg#######
    1.ClassificationName=n04228054 ski
      Problity=0.177715
    2.ClassificationName=n04286575 spotlight, spot
      Problity=0.108483
    3.ClassificationName=n04127249 safety pin
      Problity=0.026277
    
    #####images/bird.jpg#######
    1.ClassificationName=n01608432 kite
      Problity=0.096818
    2.ClassificationName=n01833805 hummingbird
      Problity=0.072687
    3.ClassificationName=n02231487 walking stick, walkingstick, stick insect
      Problity=0.069186
    
    #####images/cat1.jpg#######
    1.ClassificationName=n02123045 tabby, tabby cat
      Problity=0.232015
    2.ClassificationName=n02123159 tiger cat
      Problity=0.094694
    3.ClassificationName=n02124075 Egyptian cat
      Problity=0.030673
    
    #####images/cat2.jpg#######
    1.ClassificationName=n02123045 tabby, tabby cat
      Problity=0.333797
    2.ClassificationName=n02123159 tiger cat
      Problity=0.164726
    3.ClassificationName=n02124075 Egyptian cat
      Problity=0.057272
    
    #####images/cat3.jpg#######
    1.ClassificationName=n03887697 paper towel
      Problity=0.086723
    2.ClassificationName=n02111889 Samoyed, Samoyede
      Problity=0.055845
    3.ClassificationName=n03131574 crib, cot
      Problity=0.052640
    
    #####images/dog1.jpg#######
    1.ClassificationName=n02096585 Boston bull, Boston terrier
      Problity=0.429622
    2.ClassificationName=n02108089 boxer
      Problity=0.199422
    3.ClassificationName=n02093256 Staffordshire bullterrier, Staffordshire bull terrier
      Problity=0.093615
    
    #####images/dog2.jpg#######
    1.ClassificationName=n02085936 Maltese dog, Maltese terrier, Maltese
      Problity=0.172208
    2.ClassificationName=n03445777 golf ball
      Problity=0.139949
    3.ClassificationName=n02259212 leafhopper
      Problity=0.118109
    
    #####images/lena.jpg#######
    1.ClassificationName=n02869837 bonnet, poke bonnet
      Problity=0.130357
    2.ClassificationName=n04356056 sunglasses, dark glasses, shades
      Problity=0.066170
    3.ClassificationName=n04355933 sunglass
      Problity=0.043199
    
    #####images/sky.jpg#######
    1.ClassificationName=n03733281 maze, labyrinth
      Problity=0.711163
    2.ClassificationName=n03065424 coil, spiral, volute, whorl, helix
      Problity=0.181123
    3.ClassificationName=n04259630 sombrero
      Problity=0.010005
  • 相关阅读:
    POJ 1741
    POJ 3107
    权限管理
    用户和组
    软件包管理简介
    制作网线
    认识vim编辑器
    linux 进阶命令
    linux 目录&基础命令
    在raw_input()中使用中文提示,在CMD下中文乱码问题解决。。。
  • 原文地址:https://www.cnblogs.com/adong7639/p/7652635.html
Copyright © 2020-2023  润新知