• pythoncaffe 2 使用python进行caffe的训练集与预测集的划分


    #!/usr/bin/env python
    # coding: utf-8
    #copyRight by heibanke 
    #如需转载请注明出处
    #<<用Python做深度学习2-caffe>>
    #http://study.163.com/course/courseMain.htm?courseId=1003491001
    
    import os
    import cv2
    import numpy as np
    import pdb
    
    def write_img_list(data, filename):
        with open(filename, 'w') as f:
            for i in xrange(len(data)):
                f.write(data[i][0]+' '+str(data[i][1])+'\n')
    
    
    image_size = 28
    s='ABCDEFGHIJ'
    
    filedir='/media/ye/WindowsFile2/001_ubuntu_workspace/caffe_learning/caffe_learn_code/5/notMNIST_small/'
    
    
    # 1. read file   獲取總文件夾中的中所有的文件序列
    filedir2 = os.listdir(filedir)
    
    #創建一個二維數組,用與保存圖像全路徑以及字符標識符
    datasets=[]
    #創建數組用於存儲原始數據
    data=[]
    #對子文件夾中所有文件進行遍歷
    for subdir in filedir2:
        #通過字符串相加判斷全路徑是否爲文件夾
        if os.path.isdir(filedir+subdir):
            #如果爲文件夾,則獲取其中每一個文件夾的名字
            files=os.listdir(filedir+subdir)
            #創建一個多維度數組,數組爲三維度,第一個維度爲圖像個數,第二以及第三維度爲圖像的長寬
            dataset = np.ndarray(shape=(len(files), image_size, image_size),
                             dtype=np.float32)
            
            num_image = 0
            #在每一個文件夾中獲取對應的文件名
            for file in files:
               #判斷文件名字是否爲png
                if file[-3:]=='png':
                    #讀入每一張圖片
                    tmp=cv2.imread(filedir+subdir+'/'+file,cv2.IMREAD_GRAYSCALE)
                    #判断图像大小是否符合要求,不符合则跳过
                    try:
                        if tmp.shape==(image_size,image_size):
                            # 使用rfind 查找字符串,並將字符串的名字與實際文件名字進行綁定,寫入數據庫
                            datasets.append((filedir+subdir+'/'+file, s.rfind(subdir)))
                            data.append(tmp)
                            num_image+=1
                        else:
                            print subdir,file,tmp.shape
                    except:
                        print subdir,file,tmp
                else:
                    print file
    
    #随机化数据序列 ,將圖像標識符與圖像路徑順序打亂
    np.random.shuffle(datasets)
    #計算均值
    print np.mean(np.array(data))
    
    TRAIN_NUM = 4*len(datasets)/5
     #訓練集使用0~4/5的數據
    write_img_list(datasets[0:TRAIN_NUM], 'train00.imglist')
    #驗證集合使用後面1/5的數據
    write_img_list(datasets[TRAIN_NUM:], 'test00.imglist')
  • 相关阅读:
    Tomcat8
    spring-framework-3.0.2RELEASE之后为啥没有依赖包了?
    foxmail6.5 不能收取电子邮件,反复提示输入密码?
    mysql中select distinct的用法
    mysql 批量更新
    java中数组与List相互转换的方法
    mysql 蠕虫复制
    鼠标聚焦到Text输入框时,按回车键刷新页面原因及解决方法
    com.sun.jdi.InvocationException occurred invoking method.
    linux 让一个程序开机自启动并把一个程序加为服务
  • 原文地址:https://www.cnblogs.com/codeAndlearn/p/16173547.html
Copyright © 2020-2023  润新知