tensorflow object detection API
创造一些精确的机器学习模型用于定位和识别一幅图像里的多元目标仍然是一个计算机视觉领域的核心挑战。tensorflow object detection API是一个开源的基于tensorflow的框架,使得创建,训练以及应用目标检测模型变得简单。在谷歌我们已经确定发现这个代码对我们的计算机视觉研究需要很有用,我们希望这个对你也会很有用。
1. 安装tensorflow以及下载object detection api
安装tensorflow:
对于CPU版本:pip install tensorflow
对于GPU版本:pip install tensorflow-gpu
升级tensorflow到最新版1.4.0:pip install --upgrade tensorflow-gpu
安装必须库:
sudo pip install pillow
sudo pip install lxml
sudo pip install jupyter
sudo pip install matplotlib
protobuf编译:在tensorflow/models/research/目录下
protoc object_detection/protos/*.proto --python_out=.
添加pythonpath,在tensorflow/models/research/目录下
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
测试安装:
python object_detection/builders/model_builder_test.py
下载object detection api:
git clone https://github.com/tensorflow/models.git
2.运行演示文件:object_detection_tutorial.ipynb
2.训练数据集准备
在model下新建文件夹dataset,将我使用的pascal voc格式数据集(VOC3000)转换为TFRecord格式,并存放在dataset文件夹下:
将create_pascal_tf_record.py文件复制到dataset文件夹下:
(1)修改第55行:YEARS = ['VOC2007', 'VOC2012','VOC3000', 'merged']
(2)修改第58行:def dict_to_tf_example(data,
改为def dict_to_tf_example(year,data,
(3)修改第84行:img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
改为img_path = os.path.join(year,image_subdirectory, data['filename'])
(4)修改第152行:years = ['VOC2007', 'VOC2012']
改为years = ['VOC2007', 'VOC2012','VOC3000']
(5)修改第163行:examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
'aeroplane_' + FLAGS.set + '.txt')
改为 examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
FLAGS.set + '.txt')
(6)修改第175行:tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
改为tf_example = dict_to_tf_example(year, data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances)
以上涉及到路径需要根据自己数据集调整。
运行以下命令,就可以得到用于训练和验证的tf_record文件:
python data/create_pascal3000_tf_record.py
--data_dir=/data/models/research/object_detection/dataset/VOCdevkit
--label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt
--year=VOC3000
--set=train
--output_path=/data/models/research/object_detection/dataset/pascal_train.record
python data/create_pascal3000_tf_record.py
--data_dir=/data/models/research/object_detection/dataset/VOCdevkit
--label_map_path=/data/models/research/object_detection/dataset/pascal_label_map.pbtxt
--year=VOC3000
--set=val
--output_path=/data/models/research/object_detection/dataset/pascal_val.record
3.解压SSDMobilenet模型(下载API的时候已经下载好了)
tar -xvf ssd_mobilenet_v1_coco_2017_11_08.tar.gz
得到如下文件:
将文件夹里面的model.ckpt.*的三个文件copy到dataset文件夹。
4.修改config文件。
将文件object_detection/samples/configs/ssd_mobilenet_v1_pets.config复制到dataset.
修改:
(1)num_classes修改为自己的类别数目,我的是10
(2)修改路径。(5处)
fine_tune_checkpoint: "/data/models/research/object_detection/dataset/model.ckpt"
input_path: "/data/models/research/object_detection/dataset/pascal_train.record"
label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt"
input_path: "/data/models/research/object_detection/dataset/pascal_val.record"
label_map_path: "/data/models/research/object_detection/dataset/pascal_label_map.pbtxt"
保存config文件,重命名为ssd_mobilenet_v1_pascal.config。我的dataset文件夹如图所示。
5.开始训练(这里我换用了另一个模型faster_rcnn_inception_resnet)
python train.py
--logtostderr
--train_dir=/home/amax/guo/models/object_detection/dataset/output
--pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
6.评估模型
在dataset文件夹下新建evaluation文件夹
python eval.py
--logtostderr
--checkpoint_dir=/home/amax/guo/models/object_detection/dataset/output
--pipeline_config_path=/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
--eval_dir=/home/amax/guo/models/object_detection/dataset/evaluation
报错:ImportError: No module named nets
解决办法:导入slim模块
import sys
sys.path.append('/data/models/research/slim')
7.查看结果
tensorboard --logdir=/home/amax/guo/models/object_detection/dataset
8.生成可以被调用的模型
python object_detection/export_inference_graph.py --input_type
image_tensor
--pipeline_config_path
/home/amax/guo/models/object_detection/dataset/faster_rcnn_inception_resnet/faster_rcnn_inception_resnet_v2_atrous_pets.config
--trained_checkpoint_prefix
/home/amax/guo/models/object_detection/dataset/output/model.ckpt-10000
--output_directory
/home/amax/guo/models/object_detection/dataset/savedModelcd
生成的模型如图所示:
9.调用生成的模型
修改object_detection_tutorial.py
PATH_TO_CKPT ='/home/amax/guo/models/object_detection/dataset/savedModel/frozen_inference_graph.pb'
PATH_TO_LABELS='/home/amax/guo/models/object_detection/dataset/pascal_label_map.pbtxt'
NUM_CLASSES = 10
结果如下: