简单处理API
读取图像:
image.imdecode(open('../img/cat1.jpg', 'rb').read())
图像类型转换:
img.astype('float32')
图像增强流程
具体增强方式教程有很详细的示意,不再赘述
辅助函数,用于将增强函数应用于单张图片:
def apply_aug_list(img, augs): for f in augs: img = f(img) return img
对于训练图片我们随机水平翻转和剪裁。对于测试图片仅仅就是中心剪裁。我们假设剪裁成28×28×3用于输入网络:
train_augs = [ image.HorizontalFlipAug(.5), image.RandomCropAug((28,28)) ] test_augs = [ image.CenterCropAug((28,28)) ]
使用如下闭包来增强:
def get_transform(augs): def transform(data, label): # data: sample x height x width x channel # label: sample data = data.astype('float32') if augs is not None: # apply to each sample one-by-one and then stack data = nd.stack(*[ apply_aug_list(d, augs) for d in data]) data = nd.transpose(data, (0,3,1,2)) return data, label.astype('float32') return transform
基本逻辑就是这样。