• pytorch读取一张图像进行分类预测需要注意的问题(opencv、PIL)


    读取图像一般是两个库:opencv和PIL

    1、使用opencv读取图像

    import cv2
    image=cv2.imread("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
    print(image.shape)

    (490, 410, 3)

    2、使用PIL读取图像

    import PIL
    image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
    print(image.shape)

    这里会报错:

    AttributeError                            Traceback (most recent call last)
    
    <ipython-input-30-807ec7af434b> in <module>()
          1 import PIL
          2 image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
    ----> 3 print(image.shape)
    
    AttributeError: 'JpegImageFile' object has no attribute 'shape'

    我们要输出要这么做:

    import numpy as np
    print(np.array(image).shape)

    (490, 410, 3)

    需要注意的是:

    使用opencv读取图像之后是BGR格式的,使用PIL读取图像之后是RGB格式的。

    3、opencv格式的和PIL格式的之间的转换

    这里参考:https://www.cnblogs.com/enumx/p/12359850.html

    (1)opencv格式转换为PIL格式

    import cv2
    from PIL import Image
    import numpy
     
    img = cv2.imread("plane.jpg")
    cv2.imshow("OpenCV",img)
    image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
    image.show()
    cv2.waitKey()

    (2)PIL格式转换为opencv格式

    import cv2
    from PIL import Image
    import numpy
     
    image = Image.open("plane.jpg")
    image.show()
    img = cv2.cvtColor(numpy.asarray(image),cv2.COLOR_RGB2BGR)
    cv2.imshow("OpenCV",img)
    cv2.waitKey()

    4、使用pytorch读取一张图片并进行分类预测

    需要注意两个问题:

    • 输入要转换为:[1,channel,H,W]
    • 对输入的图像进行数据增强时要求是PIL.Image格式的
    import torchvision
    import sys
    import torch
    import torch.nn as nn
    from PIL import Image
    sys.path.append("/content/drive/My Drive/colab notebooks")
    import glob
    import numpy as np
    import torchvision.transforms as transforms
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model=torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features,4,bias=False)
    model.to(device)
    model.eval()
    save_path="/content/drive/My Drive/colab notebooks/checkpoint/resnet18_best_v2.t7" 
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model'])
    print("当前模型准确率为:",checkpoint["epoch_acc"])
    images_path="/content/drive/My Drive/colab notebooks/data/dataset/test/four"
    transform = transforms.Compose([transforms.Resize((224,224))])
    def predict():
      true_labels=[]
      output_labels=[]
      for image in glob.glob(images_path+"/*.png"):
        print(image)
        true_labels.append(0)
        #image=Image.open(image)
        #image=image.resize((224,224))
        image=cv2.imread(image)
        image=cv2.resize(image,(224,224))
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
        #print(np.array(image).shape)
        tensor=torch.from_numpy(np.asarray(image)).permute(2,0,1).float()/255.0
        tensor=tensor.reshape((1,3,224,224))
        tensor=tensor.to(device)
        #print(tensor.shape)
        output=model(tensor)
        print(output)
        _, pred = torch.max(output.data,1)
        output_labels.append(pred.item())
      return true_labels,output_labels
    
    true_labels,output_labels=predict()
    print("正确的标签是:")
    print(true_labels)
    print("预测的标签是:")
    print(output_labels)
  • 相关阅读:
    maven Spring MVC项目
    NET 解析HTML代码——NSoup
    Masstransit开发基于消息传递的分布式应用
    iOS项目生成通用Windows应用
    测试框架mochajs详解
    9宫格拼图
    spring 整合redis
    Linux下SSH Session复制
    File Templates for web.xml & web-fragment.xml (Servlet 2.3, 2.4, 2.5 + 3.0)
    极度简约 最小 Linux 发行版 Tiny Core Linux 7.1 发布
  • 原文地址:https://www.cnblogs.com/xiximayou/p/13390166.html
Copyright © 2020-2023  润新知