1 # -*- coding: utf-8 -*- 2 from pathlib import Path #从pathlib中导入Path 3 import os 4 import fileinput 5 import random 6 root_path='/home/tay/Videos/trash/垃圾分类项目/total/' 7 train = open('./trash_train.txt','a') 8 test = open('./trash_test.txt','a') 9 pwd = os.getcwd() +'/'# the val data path 训练集的路径 10 11 12 def gen_txt(): 13 i =0 14 for file in os.listdir(root_path): 15 print('file is{}'.format(str(file))) 16 for init in os.listdir(os.path.join(root_path, file)): #子文件夹 17 print('init is{}'.format(str(init))) 18 i += 1 19 pathDir = os.listdir(os.path.join(root_path, file, init)) # 20 print('pathDir is', pathDir) 21 file_num = len(pathDir) 22 rate = 0.2 23 pick_num = int(file_num * rate) 24 sample = random.sample(pathDir, pick_num) #随机选取20%的pathDir字符串 25 print('sample is', sample) 26 for pick_name in sample: 27 test.write(root_path.split('total/')[-1] +file + '/' + init +'/' + pick_name + ' ' + str(i) + ' ') 28 # for name in pathDir: #文件夹中的图片名 29 # print('name is{}'.format(str(name))) 30 # if test 31 # total.write(root_path.split('total/')[-1] +file + '/' + init +'/' + name + ' ' + str(i) + ' ' ) 32 same = [x for x in pathDir if x in sample] #列表中相同的内容 33 diff = [y for y in (sample + pathDir) if y not in same] #列表中不同的内容 34 print('different', diff) 35 print('same', same) 36 for train_name in diff: 37 train.write(root_path.split('total/')[-1] +file + '/' + init +'/' + train_name + ' ' + str(i) + ' ') 38 gen_txt()
采用了random.sample函数来随机选取特定数量的文件名作为测试集,通过比较两个列表中不同的元素来获取训练集的文件名。
总体上就是在进行字符串操作。