将训练好的语义分割模型保存下来,重新加载之后
通过这一个操作得到标签;
output = self.model(image)
这里的output即为标签内容,通过重新编码的函数来获得彩色图像.
1 def decode_segmap(label_mask, dataset, plot=False): 2 """Decode segmentation class labels into a color image 3 解码标签,得到彩色的图像 4 Args: 5 label_mask (np.ndarray): an (M,N) array of integer values denoting 6 the class label at each spatial location. 7 plot (bool, optional): whether to show the resulting color image 8 in a figure. 9 Returns: 10 (np.ndarray, optional): the resulting decoded color image. 11 """ 12 if dataset == 'pascal' or dataset == 'coco': 13 n_classes = 21 14 label_colours = get_pascal_labels() 15 elif dataset == 'cityscapes': 16 n_classes = 19 17 label_colours = get_cityscapes_labels() 18 else: 19 raise NotImplementedError 20 21 r = label_mask.copy() 22 g = label_mask.copy() 23 b = label_mask.copy() 24 for ll in range(0, n_classes): 25 r[label_mask == ll] = label_colours[ll, 0] 26 g[label_mask == ll] = label_colours[ll, 1] 27 b[label_mask == ll] = label_colours[ll, 2] 28 rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 29 rgb[:, :, 0] = r / 255.0 30 rgb[:, :, 1] = g / 255.0 31 rgb[:, :, 2] = b / 255.0 32 if plot: 33 plt.imshow(rgb) 34 plt.show() 35 else: 36 return rgb
绘图的主函数在下面:
1 if __name__ == '__main__': 2 from dataloaders.utils import decode_segmap 3 from torch.utils.data import DataLoader 4 import matplotlib.pyplot as plt 5 import argparse 6 7 parser = argparse.ArgumentParser() 8 args = parser.parse_args() 9 args.base_size = 256 10 args.crop_size = 256 11 12 voc_train = VOCSegmentation(args, split='train') 13 14 dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 15 16 17 18 for ii, sample in enumerate(dataloader): 19 for jj in range(sample["image"].size()[0]): 20 img = sample['image'].numpy() 21 gt = sample['label'].numpy() 22 tmp = np.array(gt[jj]).astype(np.uint8) 23 segmap = decode_segmap(tmp, dataset='pascal') 24 img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 25 img_tmp *= (0.229, 0.224, 0.225) 26 img_tmp += (0.485, 0.456, 0.406) 27 img_tmp *= 255.0 28 img_tmp = img_tmp.astype(np.uint8) 29 plt.figure() 30 plt.title('display') 31 plt.subplot(211) 32 plt.imshow(img_tmp) 33 plt.subplot(212) 34 plt.imshow(segmap) 35 36 if ii == 1: 37 break 38 39 plt.show(block=True)