• Saliency map实现


    import PIL, torch, torchvision
    import matplotlib.pyplot as plt
    import sys
    import pandas as pd
    
    # 标准化
    def normalize(image):
         return (image - image.min()) / (image.max() - image.min())
    
    
    def show_saliency_map(img_path, model, size=100, cmap=plt.cm.hot):
    #     evaluate模式
         model.eval()
         
    #     图像变换
         aug1 = torchvision.transforms.Compose(
             [torchvision.transforms.Resize((size, size)),
              torchvision.transforms.ToTensor()])
         aug2 = torchvision.transforms.Resize((size, size))
         aug3 = torchvision.transforms.ToPILImage()
    
    #     读取一张图片
         img = PIL.Image.open(img_path)
         img = img.convert("RGB")
    #     变换
         timg = aug1(img).view(1, 3, size, size)
    #     梯度
         timg.requires_grad = True
    
    #     正向传播得到output
         output = model(timg)
    #     获取预测概率最大的index
         timg_class = output.argmax(dim=1).item()
    
    #     1000类dict
         pd_data = pd.read_csv('./1000class_dict.csv')
         
         pd_data_en = pd_data.iloc[:, 3]
         class_index_en = pd_data_en.to_dict()
         
         pd_data_zh = pd_data.iloc[:, 2]
         class_index_zh = pd_data_zh.to_dict()
         
         print(class_index_zh[timg_class],class_index_en[timg_class])
    
    #     找到output的对应fc输出单元
         s = output[0, timg_class]
    #     反向传播求此单元梯度
         s.backward()
    
        with torch.no_grad():
    #         得到了梯度
             grad = timg.grad.data[0]
    #         对梯度图处理,取绝对值,求像素通道最大值
             graph = torch.max(torch.abs(grad), dim=0)[0]  # [0]是max_value  [1]是max_index
             lambd = 0.1
    #         paper中的方法
             saliency_map_gray = (graph - lambd * (torch.norm(timg, 2) ** 2).item()).numpy()
             
    #         直接梯度求绝对值
             saliency_map_rgb = timg.grad.abs().cpu()
    #         将每个通道归一化
             saliency_map_rgb = torch.stack([normalize(item) for item in saliency_map_rgb])
    
        fig, ax = plt.subplots(1, 3)
         raw_img = aug2(img)
         ax[0].imshow(raw_img)
         ax[0].set_title(class_index_en[timg_class])
         
         rgb_saliency = aug3(saliency_map_rgb.view(3, size, size))
         ax[1].imshow(rgb_saliency)
         ax[1].set_title('RGB map')
         ax[2].imshow(saliency_map_gray, cmap=cmap)
         ax[2].set_title('gray map')
         plt.show()
    
    img = './panda.png'
    model = torchvision.models.resnet18(pretrained=True)
    show_saliency_map(img, model, size = 224)

    参考:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps , https://arxiv.org/abs/1312.6034

  • 相关阅读:
    ansible 通过跳板机发布服务
    etcd API 神坑
    golang 条件编译
    服务治理
    golang web服务器处理前端HTTP请求跨域的方法
    rebar使用
    Apache/Tomcat/JBOSS/Jetty/Nginx/WebLogic/WebSphere之间的区别区别
    运维专家写给年轻运维的6条人生忠告
    谷歌浏览器书签索引—知识的海洋都在里面
    关于认识、格局、多维度发展的感触
  • 原文地址:https://www.cnblogs.com/mydrizzle/p/13977924.html
Copyright © 2020-2023  润新知