有些数据是没有专门的数据集的,为了找到神经网络训练的数据,自然而然的想到了用爬虫的方法开始采集数据。一开始采用了网上的一个动态爬虫的代码,发现爬取的图片大多是重复的,有效图片很少。
动态爬虫:
from lxml import etree import requests import re import urllib import json import time import os local_path = '/home/path/' if not os.path.exists(local_path): os.makedirs(local_path) keyword = input('请输入想要搜索图片的关键字:') first_url = 'http://image.baidu.com/search/flip?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1530850407660_R&pv=&ic=0&nc=1&z=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&ie=utf-8&ctd=1530850407660%5E00_1651X792&word={}'.format(keyword) want_download = input('请输入想要下载图片的张数:') global page_num page_num = 1 global download_num download_num = 0 #这个函数用来获取图片格式 def get_format(pic_url): #url的末尾存着图片的格式,用split提取 #有些url末尾并不是常见图片格式,此时用jpg补全 t = pic_url.split('.') if t[-1].lower() != 'bmp' and t[-1].lower() != 'gif' and t[-1].lower() != 'jpg' and t[-1].lower() != 'png': pic_format = 'jpg' else: pic_format = t[-1] return pic_format #这个函数用来获取下一页的url def get_next_page(page_url): global page_num html = requests.get(page_url).text with open('html_info.txt', 'w', encoding='utf-8') as h: h.write(html) selector = etree.HTML(html) try: msg = selector.xpath('//a[@class="n"]/@href') print(msg[0]) next_page = 'http://image.baidu.com/' + msg[0] print('现在是第%d页' % (page_num + 1)) except Exception as e: print('已经没有下一页了') print(e) next_page = None page_num = page_num + 1 return next_page #这个函数用来下载并保存图片 def download_img(pic_urls): count = 1 global download_num for i in pic_urls: time.sleep(1) try: pic_format = get_format(i) pic = requests.get(i, timeout=15) #按照格式和名称保存图片 with open(local_path + 'page%d_%d.%s' % (page_num, count, pic_format), 'wb') as f: f.write(pic.content) #print('成功下载第%s张图片: %s' % (str(count), str(pic.url))) count = count + 1 download_num = download_num + 1 except Exception as e: #print('下载第%s张图片时失败: %s' % (str(count), str(pic.url))) print(e) count = count + 1 continue finally: if int(want_download) == download_num: return 0 #这个函数用来提取url中图片的url def get_pic_urls(web_url): html = requests.get(web_url).text #通过正则表达式寻找图片的地址, pic_urls = re.findall('"objURL":"(.*?)",', html, re.S) #返回图片地址,是一个list return pic_urls if __name__ == "__main__": while True: pic_urls = get_pic_urls(first_url) t = download_img(pic_urls) if t==0: break next_url = get_next_page(first_url) if next_url == None: print('已经没有更多图片') break pic_urls = get_pic_urls(next_url) t = download_img(pic_urls) if t== 0: break first_url = next_url #print('已经成功下载%d张图片' %download_num)
为了筛选出重复的图片又采用了哈希算法进行去重
1 # -*- coding: utf-8 -*- 2 3 import sys 4 reload(sys) 5 sys.setdefaultencoding('utf8') 6 7 """ 8 用dhash判断是否相同照片 9 基于渐变比较的hash 10 hash可以省略(本文省略) 11 By Guanpx 12 """ 13 import os 14 from PIL import Image 15 from os import listdir 16 17 18 def picPostfix(): # 相册后缀的集合 19 postFix = set() 20 postFix.update(['bmp', 'jpg', 'png', 'tiff', 'gif', 'pcx', 'tga', 'exif', 21 'fpx', 'svg', 'psd', 'cdr', 'pcd', 'dxf', 'ufo', 'eps', 'JPG', 'raw', 'jpeg']) 22 return postFix 23 24 25 def getDiff(width, high, image): # 将要裁剪成w*h的image照片 26 diff = [] 27 im = image.resize((width, high)) 28 imgray = im.convert('L') # 转换为灰度图片 便于处理 29 pixels = list(imgray.getdata()) # 得到像素数据 灰度0-255 30 31 for row in range(high): # 逐一与它左边的像素点进行比较 32 rowStart = row * width # 起始位置行号 33 for index in range(width - 1): 34 leftIndex = rowStart + index 35 rightIndex = leftIndex + 1 # 左右位置号 36 diff.append(pixels[leftIndex] > pixels[rightIndex]) 37 38 return diff # *得到差异值序列 这里可以转换为hash码* 39 40 41 def getHamming(diff=[], diff2=[]): # 暴力计算两点间汉明距离 42 hamming_distance = 0 43 for i in range(len(diff)): 44 if diff[i] != diff2[i]: 45 hamming_distance += 1 46 47 return hamming_distance 48 49 50 if __name__ == '__main__': 51 52 width = 32 53 high = 32 # 压缩后的大小 54 dirName = "/home/yourpath" # 相册路径 55 allDiff = [] 56 postFix = picPostfix() # 图片后缀的集合 57 58 dirList = os.listdir(dirName) 59 cnt = 0 60 for i in dirList: 61 cnt += 1 62 # print('文件处理的数量是', cnt) # 可以不打印 表示处理的文件计数 63 if str(i).split('.')[-1] in postFix: # 判断后缀是不是照片格式 64 try: 65 im = Image.open(r'%s/%s' % (dirName, unicode(str(i), "utf-8"))) 66 except OSError as err: 67 os.remove(r'%s/%s' % (dirName, unicode(str(i), "utf-8"))) 68 print('OS error : {}'.format(err)) 69 # continue 70 71 except IndexError as err: 72 os.remove(r'%s/%s' % (dirName, unicode(str(i), "utf-8"))) 73 print('OS error : {}'.format(err)) 74 print('Index Error: {}'.format(err)) 75 # continue 76 77 78 except IOError as err: 79 os.remove(r'%s/%s' % (dirName, unicode(str(i), "utf-8"))) # 删除图片 80 # print('OS error : {}'.format(err)) 81 print('IOError : {}'.format(err)) 82 # continue 83 84 # except: 85 # print ('Other error') 86 else: 87 diff = getDiff(width, high, im) 88 allDiff.append((str(i), diff)) 89 90 91 for i in range(len(allDiff)): 92 for j in range(i + 1, len(allDiff)): 93 if i != j: 94 ans = getHamming(allDiff[i][1], allDiff[j][1]) 95 if ans <= 5: # 判别的汉明距离,自己根据实际情况设置 96 print(allDiff[i][0], "and", allDiff[j][0], "maybe same photo...") 97 result = dirName + "/" + allDiff[j][0] 98 if os.path.exists(result): 99 os.remove(result)
用哈希算法筛选后又发现筛除的太多了,阈值不好控制。又尝试采用了静态爬虫的方法,发现结果还不错,重复的也不多,也就省了筛除的步骤。
静态爬虫:
1 # -*- coding: utf-8 -*- 2 import sys 3 reload(sys) 4 sys.setdefaultencoding('utf8') 5 import time 6 # 导入需要的库 7 import requests 8 # import os 9 import json 10 import time 11 12 # 爬取百度图片,解析页面的函数 13 def getManyPages(keyword, pages): 14 ''' 15 参数keyword:要下载的影像关键词 16 参数pages:需要下载的页面数 17 ''' 18 params = [] 19 20 for i in range(30, 30 * pages + 30, 30): 21 params.append({ 22 'tn': 'resultjson_com', 23 'ipn': 'rj', 24 'ct': 201326592, 25 'is': '', 26 'fp': 'result', 27 'queryWord': keyword, 28 'cl': 2, 29 'lm': -1, 30 'ie': 'utf-8', 31 'oe': 'utf-8', 32 'adpicid': '', 33 'st': -1, 34 'z': '', 35 'ic': 0, 36 'word': keyword, 37 's': '', 38 'se': '', 39 'tab': '', 40 'width': '', 41 'height': '', 42 'face': 0, 43 'istype': 2, 44 'qc': '', 45 'nc': 1, 46 'fr': '', 47 'pn': i, 48 'rn': 30, 49 'gsm': '1e', 50 '1488942260214': '' 51 }) 52 url = 'https://image.baidu.com/search/acjson' 53 urls = [] 54 for i in params: 55 try: 56 urls.append(requests.get(url, params=i).json().get('data')) 57 # except json.decoder.JSONDecodeError: 58 # print("解析出错") 59 60 except OSError as err: 61 print('OS error : {}'.format(err)) 62 63 except IndexError as err: 64 print('Index Error: {}'.format(err)) 65 66 except IOError as err: 67 print('IOError : {}'.format(err)) 68 except: 69 print('Other error') 70 return urls 71 72 73 # 下载图片并保存 74 def getImg(dataList, localPath): 75 ''' 76 参数datallist:下载图片的地址集 77 参数localPath:保存下载图片的路径 78 ''' 79 if not os.path.exists(localPath): # 判断是否存在保存路径,如果不存在就创建 80 os.mkdir(localPath) 81 x = 0 82 for list in dataList: 83 for i in list: 84 if i.get('thumbURL') != None: 85 # print('正在下载:%s' % i.get('thumbURL')) 86 ir = requests.get(i.get('thumbURL')) 87 open(localPath + '/' + '%d.jpg' % x, 'wb').write(ir.content) # 这里是新加的斜杠 88 x += 1 89 else: 90 print('图片链接不存在') 91 92 93 # 根据关键词来下载图片 94 if __name__ == '__main__': 95 import os 96 father_path = "/home/yourpath/" 97 t0 = time.time() 98 for init in os.listdir(father_path): 99 print('init is{}'.format(str(init))) 100 for name in os.listdir(init): 101 print('name is{}'.format(str(name))) 102 t1 = time.time() 103 if not os.listdir(os.path.join(father_path, init, name)): 104 dataList = getManyPages(name, 30) 105 getImg(dataList, os.path.join(father_path, init, name)) 106 t2 = time.time() 107 print('cost time is', t2 - t1) 108 t3 = time.time() 109 print('total time is', t3 - t0) 110 # t1 = time.time() 111 # dataList = getManyPages('keyword', page 112 _number) # 参数1:关键字,参数2:要下载的页数 113 # getImg(dataList, './file_path/') # 参数2:指定保存的路径 114 # t2 = time.time() 115 # print('cost time is', t2 - t1) 116 # 117 # parent_name = "/home/path" # 相册路径 118 # dirList = os.listdir(parent_name) # 所有文件夹的列表 119 # for one_file in dirList: # 其中的一个文件夹 120 # # son_list = os.listdir(one_file) 121 # son_list = os.path.join(parent_name, one_file) 122 # son_file = os.listdir(son_list) 123 # t1 = time.time()