图像的标签在一个json文件中。
%matplotlib inline import json import gluonbook as gb import mxnet as mx from mxnet import autograd, gluon, image, init, nd from mxnet.gluon import data as gdata, loss as gloss, utils as gutils import sys from time import time train_Pedestrian_url = [] train_Cyclist_url = [] train_Others_url = [] with open('instances.json',encoding='utf-8') as f: for _ in range(100000): if len(train_Pedestrian_url) + len(train_Cyclist_url) + len(train_Others_url) >= 300: break line = f.readline() js = json.loads(line) if js['attrs']['ignore']=='yes' or js['attrs']['occlusion']=='heavily_occluded' or js['attrs']['occlusion']=='invisible': continue if js['attrs']['type'] == 'Pedestrian': if len(train_Pedestrian_url) >=100: continue train_Pedestrian_url.append(js['thumbnail_path']) elif js['attrs']['type'] == 'Cyclist': if len(train_Cyclist_url) >=100: continue train_Cyclist_url.append(js['thumbnail_path']) elif js['attrs']['type'] == 'Others': if len(train_Others_url) >=100: continue train_Others_url.append(js['thumbnail_path']) # img = image.imread(url) f.close() print(train_Cyclist_url) print(len(train_Pedestrian_url),len(train_Cyclist_url),len(train_Others_url)) img = image.imread('/mnt/hdfs-data-4/data/'+train_Cyclist_url[0]) img.astype('float32') labels = nd.zeros(shape=(30000,)) labels[10000:20000] = 1 labels[20000:] = 2
数据整理就差不多了,然后就是建网络,跑模型了。