https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
1 定义模型
跟一般模型定义并无区别,需要torch_model.eval()或者torch_model.train(False)将模型转换为推理模型(一般dropout、batchnorm在推理和训练模式中有区别)。
2 导出模型(torch.onnx.export())
(1)export会运行模型,所以需要提供一个输入x。注意这里的x并非模型预测时的输入。
(2)输入x值任意,但是大小、类型必须正确。
(3)如何不指定动态轴(dynamic_axes),模型输入x的各个维度上大小将固定,[batch_size, 1, 224, 224]中batch_size可以是变量。
# Input to the model x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) # torch_out = torch_model(x) #这个是不使用onnx runtime进行模型预测推理的结果,参考比较用的 # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes 'output' : {0 : 'batch_size'}})
3、加载、检测模型(onnx.load(),onnx.checker.check_model())
(1)加载模型后会生成一个onnx.ModelProto结构,其会绑定一个ML model。
import onnx onnx_model = onnx.load("super_resolution.onnx") onnx.checker.check_model(onnx_model)
4 运行、输入、输出模型(run())
(1)推理模型的输入为一个字典结构:{'输入名称': to_numpy(x)}
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
(2)为了用ONNX Runtime运行模型,需要创建一个推理session,需要输入配置参数,下面是default config。
ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
(3)输出一个list,其包含了ONNX Runtime计算的模型结果。
import onnxruntime ort_session = onnxruntime.InferenceSession("super_resolution.onnx") def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # compute ONNX Runtime output prediction ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) print("Exported model has been tested with ONNXRuntime, and the result looks good!")
5 一个完整的案例:
模型:First, let’s create a SuperResolution model in PyTorch. This model uses the efficient sub-pixel convolution layer described in “Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network” - Shi et al for increasing the resolution of an image by an upscale factor. The model expects the Y component of the YCbCr of an image as an input, and outputs the upscaled Y component in super resolution.
import io import numpy as np from torch import nn import torch.utils.model_zoo as model_zoo import torch.onnx import torch.nn.init as init from PIL import Image import torchvision.transforms as transforms import onnxruntime import onnx #前一半:导出模型 class SuperResolutionNet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SuperResolutionNet, self).__init__() self.relu = nn.ReLU(inplace=inplace) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv4.weight) # Create the super-resolution model by using the above model definition. torch_model = SuperResolutionNet(upscale_factor=3) # Load pretrained model weights model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' batch_size = 1 # just a random number # Initialize model with the pretrained weights map_location = lambda storage, loc: storage if torch.cuda.is_available(): map_location = None torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) # set the model to inference mode torch_model.eval() # Input to the model x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) torch_out = torch_model(x) # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names=['input'], # the model's input names output_names=['output'], # the model's output names dynamic_axes={'input': {0: 'batch_size'}, # variable length axes 'output': {0: 'batch_size'}}) # 后一半:导入模型 onnx_model = onnx.load("super_resolution.onnx") onnx.checker.check_model(onnx_model) ort_session = onnxruntime.InferenceSession("super_resolution.onnx") def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() img = Image.open("cat.jpg") resize = transforms.Resize([224, 224]) img = resize(img) img_ycbcr = img.convert('YCbCr') img_y, img_cb, img_cr = img_ycbcr.split() to_tensor = transforms.ToTensor() img_y = to_tensor(img_y) img_y.unsqueeze_(0) # 输入 ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)} # 运行 ort_outs = ort_session.run(None, ort_inputs) # 输出处理 img_out_y = ort_outs[0] img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L') # get the output image follow post-processing step from PyTorch implementation # 合成输出图片 final_img = Image.merge( "YCbCr", [ img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC), ]).convert("RGB") # Save the image, we will compare this with the output image from mobile device final_img.save("cat_superres_with_ort.jpg")