比如你在mnist的prototxt中定义图输入是单通道的,也就是channel=1,然后如果直接调用classify.py脚本来测试的话,是会报错,错误跟一下类似。
Source param shape is 128 3 32 32; target param shape is 128 1 32 32.
意思就是网络要求输入是1 channel,而你读入的数据是3 channels。
即使你再调用这个脚本之前,已经把图转换成灰度图了,也是不行。
那是因为caffe.io.load_image读入数据的时候,总是会把数据转成3 channels。
所以,我们需要换一种方式读入数据。
具体做法
- 找到classify.py中
inputs = [caffe.io.load_image(im_f) for im_f in glob.glob(args.input_file + '/*.' + args.ext)]
- 替换成
tmp = []
for _ in inputs:
img = skimage.img_as_float(skimage.io.imread(_)).astype(np.float32)
if len(img.shape) == 2:
# 设置channel为1
img = img.reshape(img.shape[0], img.shape[1], 1)
tmp.append(img)
inputs = tmp
这里是修改的测试集是从一个目录读入的,如果测试集是单独的一张图,修改方式也类似。