• PyTorch 剪枝


    pytorch 实现剪枝的思路是 生成一个掩码,然后同时保存 原参数、mask、新参数,如下图

    pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

    局部剪枝 是对 模型内 的部分模块 的 部分参数 进行剪枝,全局剪枝是对  整个模型进行剪枝;

    本文旨在记录 pytorch 剪枝模块的用法,首先让我们构建一个模型

    import torch
    from torch import nn
    import torch.nn.utils.prune as prune
    import torch.nn.functional as F
    
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
            # 1 input image channel, 6 output channels, 3x3 square conv kernel
            self.conv1 = nn.Conv2d(1, 6, 3)
            self.conv2 = nn.Conv2d(6, 16, 3)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, int(x.nelement() / x.shape[0]))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    model = LeNet().to(device=device)

    下面对 这个模型进行剪枝

    局部剪枝

    以修剪 第一层卷积  模块 为例

    module = model.conv1
    print(list(module.named_parameters()))
    print(list(module.buffers()))
    
    # 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
    # random_unstructured 是一种裁剪技术,随机非结构化裁剪
    prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
    print(list(module.named_parameters()))
    
    # 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
    print(list(module.named_buffers()))
    
    # 新的参数保存为模块 的weight属性
    print(module.weight)
    # print(module.bias)
    
    print(module._forward_pre_hooks)
    # OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>)])

    named_parameters() 内 存储的对象 除非手动删除,否则在剪枝过程中对其无影响

    迭代剪枝

    迭代剪枝 是 对 同一模块 进行 多种剪枝,执行逻辑是 顺序执行各剪枝操作

    在之前  随机非结构化剪枝 的基础上进行  L1 L2 非结构化剪枝

    ## 增加一个修剪,看看变化
    # l1范数修剪bias中3个最小条目
    prune.l1_unstructured(module, name="bias", amount=3)
    print(module.bias)
    print(module._forward_pre_hooks)
    # OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>),
    #              (1, <torch.nn.utils.prune.L1Unstructured object at 0x000002695DE5CEB8>)])
    
    print(list(module.named_parameters()))
    print(list(module.named_buffers()))
    
    
    ### 迭代修剪
    # 一个模块中的同一参数可以被多次修剪,多次修剪会顺序执行
    # 如在之前的基础上,对 weight 参数继续修剪
    # l2 结构化裁剪,n=2代表l2,dim=0代表在weight的第0轴进行结构化裁剪
    prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
    
    # 查看 weight 参数的 剪枝 操作
    for hook in module._forward_pre_hooks.values():
        if hook._tensor_name == "weight":  # select out the correct hook
            break
    
    print(list(hook))
    # [<torch.nn.utils.prune.RandomUnstructured object at 0x0000020AE2A6EC18>,
    # <torch.nn.utils.prune.LnStructured object at 0x0000020AA872DE80>]
    
    print(module.state_dict().keys())
    # odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])

    修剪模型中的多个参数

    ### 修剪模型中的多个参数
    new_model = LeNet()
    for name, module in new_model.named_modules():
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
        # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
    
    print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

    全局剪枝

    以上研究通常被称为“局部”修剪方法,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐一修剪模型中的张量的做法。

    但是,一种常见且可能更强大的技术是通过删除整个模型中最低的 20%的连接,

    而不是删除每一层中最低的 20%的连接来修剪模型。

    这很可能导致每个层的修剪百分比不同。

    让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作

    model = LeNet()
    
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )
    
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.2,
    )
    # 检查每个修剪参数的稀疏性,该稀疏性不等于每层中的 20%。 但是,全局稀疏度将(大约)为 20%

    自定义剪枝

    见  参考资料3

    训练中剪枝实例

    见参考资料1

    参考资料:

    https://blog.csdn.net/qq_40268672/article/details/108631518  pytorch剪枝实战     训练时剪枝,类似 dropout 

    https://blog.csdn.net/ssunshining/article/details/125121066  PyTorch--模型剪枝案例

    https://www.w3cschool.cn/pytorch/pytorch-rnmi3bti.html  PyTorch 修剪教程

    https://www.bilibili.com/video/BV147411W7am?spm_id_from=333.337.search-card.all.click&vd_source=f0fc90583fffcc40abb645ed9d20da32  神经网络剪枝 Neural Network Pruning   自定义的剪枝

    https://github.com/mepeichun/Efficient-Neural-Network-Bilibili/tree/master/2-Pruning  上面视频的 代码  已下载

  • 相关阅读:
    读取 classes下的配置文件
    java中Class.getResource用法(用于配置文件的读取)
    windows 中 到底是用的哪个java.exe??? 删除了PATH变量的Java设置还是可以运行java.exe windows/system32
    mysql中null与“空值”的坑
    innodb架构理解
    mysql5.7性能提升一百倍调优宝典
    servlet 3.0笔记之servlet的动态注册
    前端性能优化建议
    了解CSRF攻击原理和预防
    vue的热更新配置
  • 原文地址:https://www.cnblogs.com/yanshw/p/16592678.html
Copyright © 2020-2023  润新知