1.png
2.png
在pytorch中使用torchvision的vutils函数实现对多张图片的拼接。具体操作就是将上面的两张图片,1.png和2.png的多张图片进行拼接形成一张图片,拼接后的效果如下图。
给出具体代码:
import matplotlib.pyplot as plt from PIL import Image import numpy as np import torch import torchvision.utils as vutils im1=Image.open("1.png").convert("RGB") im1 = im1.resize((1000, 1000)).rotate(-90) im2=Image.open("2.png").convert("RGB") im2 = im2.resize((1000, 1000)).rotate(-90) # 1000, 1000, 3 => 3, 1000, 1000 images = [np.moveaxis(np.array(im1), 2, 0), np.moveaxis(np.array(im2), 2, 0)]*8 images_tensor = vutils.make_grid(torch.tensor(images)/255.0, nrow=4, padding=0, normalize=True) print(images_tensor.shape) # 3, 1000, 1000 => 1000, 1000, 3 plt.imshow(images_tensor.numpy().transpose((1,2,0))) plt.show() vutils.save_image(images_tensor, "3.png") vutils.save_image(images_tensor, "3_back.png", nrow=2, padding=0, normalize=True) vutils.save_image(torch.tensor(images)/255.0, "4.png", nrow=8, padding=0, normalize=True)
=============================================
需要注意的地方:
- 1. 使用PIL读入的图片要转为RGB模式,然后要将图片对象转为numpy数组形式,在上面例子中转为数组后的单张图片维度为(1000,1000,3)。
- 2. 使用vutils.make_grid函数对图片进行拼接时,每张图片的数据类型都为torch.tensor,并且单张图片的格式应为(channel数,长,宽),上面例子中则是(3,1000,1000)。这样将16张图片拼接为每行4张图片的大图后,大图的维度为(3,4000,4000)。
- vutils.make_grid函数和vutils.save_image函数接受的pytorch.tensor的类型均为float,如果不能保证数据大小在0和1之间则需要设置正则项normalize=True 。
-------------------------------------------------------------