• Pytoch 抽取中间层特征方法


    定义一个特征提取的类:

    参考pytorch论坛:How to extract features of an image from a trained model

    from torchvision.models import resnet18
    import torch.nn as nn
    myresnet=resnet18(pretrained=True)
    print (myresnet)
    
    class FeatureExtractor(nn.Module):
        def __init__(self, submodule, extracted_layers):
            super(FeatureExtractor, self).__init__()
            self.submodule = submodule
            self.extracted_layers = extracted_layers
    
        def forward(self, x):
            outputs = []
            for name, module in self.submodule._modules.items():
                if name is "fc": x = x.view(x.size(0), -1)
                x = module(x)  # last layer output put into current layer input
                print(name)
                if name in self.extracted_layers:
                    outputs.append(x)
            return outputs
    
    exact_list=["conv1","layer1","avgpool"]
    myexactor=FeatureExtractor(myresnet,exact_list).cuda()
    
    x = Variable(torch.rand(5, 3, 224, 224), requires_grad=True).cuda()
    
    y=myexactor(x)    # 5x64x112x112  5x64x56x56  5x512x1x1
    print (myexactor)
    
    print(type(y))
    print(type(y[0]))
    for i in range(len(y)):
        print y[i].data.cpu().numpy().size
        print y[i].data.cpu().numpy().shape
    
    
    # <type 'list'>
    # <class 'torch.autograd.variable.Variable'>
    # 4014080
    # (5, 64, 112, 112)
    # 1003520
    # (5, 64, 56, 56)
    # 2560
    # (5, 512, 1, 1)
    #特征输出可视化
    import matplotlib.pyplot as plt
    for i in range(64):
        ax = plt.subplot(8, 8, i + 1)
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')
        plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
        plt.show()
  • 相关阅读:
    C#:反射
    静态和非静态类
    数据的存入取出(注册机方式)
    退出unity运行
    网络流基础
    欧拉回路
    博弈论问题
    洛谷P5304 [GXOI/GZOI2019] 旅行者
    [ZJOI2006]物流运输
    POJ3278 Catch that cow
  • 原文地址:https://www.cnblogs.com/ranjiewen/p/9242223.html
Copyright © 2020-2023  润新知