Pytorch 剪枝操作实现
首先需要版本为 1.4 以上,
目前很多模型都取得了十分好的结果, 但是还是参数太多, 占得权重太大, 所以我们的目标是得到一个稀疏的子系数矩阵.
这个例子是基于 LeNet 的 Pytorch 实现的例子, 我们从 CNN 的角度来剪枝, 其实在全连接层与 RNN 的剪枝应该是类似, 首先导入一些必要的模块
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
然后是 LeNet 的网络结构, 不知道为什么这里的网络结构是这样的, 算出来输入的图像是 26x26 的,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
# 第一个卷积层, 输出的向量维度是 6
self.conv2 = nn.Conv2d(6, 16, 3)
# 第二个卷积层, 输出的向量维度是 16
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
# 最后将二维向量变成一维
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# 2*2 的池化层
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# relu 激活函数层
x = x.view(-1, int(x.nelement() / x.shape[0]))
# 除以 batch_size 的大小, 将维度变成一
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
这时查看模型的参数:
module = model.conv1
print(list(module.named_parameters()))
此时参数包含矩阵的权值与偏置.
为了剪枝一个模型, 首先要在 torch.nn.utils.prune
中选择一种剪枝方法, 或者使用子类 BasePruningMethod
实现自己的剪枝方法, 然后确定模型以及需要减去的参数, 最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数. 在下面的例子中, 我们将要随机减去 conv1 层中的 30% 的权重参数, module 是函数的第一个参数, name 使用的是参数的字符串标识, amount 表示剪枝的百分比.
prune.random_unstructured(module, name="weight", amount=0.3)
剪枝行为将 weight 参数名称删除, 并将其替代为新的参数名称, weight_orig
, weight_orig存储未修剪的张量版本. 也就是说 weight_orig 是原来的权重,
上述的剪枝方法会产生一个 mask 矩阵, 叫做 weight_mask , 存储为一个 module buffer , 相当于一个 mask矩阵, 他的维度与 weight 的维度相同, 不同的是 mask 矩阵是一个 0/1 矩阵. 可以通过下面的函数查看 mask 矩阵:
print(list(module.named_buffers()))
剪枝之后的权重属性 weight 不再是权重的集合, 而是 mask 矩阵与原始矩阵的结合, 所以不再是模型的一个 parameter, 而是一个 attribute.
最后,使用 PyTorch 的forward_pre_hooks在每次正向传递之前应用修剪。具体来说,如我们在此处所做的那样,在剪枝模块部分,它将为与之相关的每个要修剪的参数获取一个forward_pre_hook。目前为止我们只修剪了名为weight的原始参数,因此将只存在一个 forward_pre_hook, 相当于没有一个剪枝参数就有一个 forward_pre_hook.
除了对 weight 剪枝, 还可以对 bias 剪枝, 下面是通过 L1 范式剪去三个单元
prune.l1_unstructured(module, name="bias", amount=3)
# Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units with the lowest L1-norm.
Iterative Pruning
相同的参数在一个模型中可以被多次剪枝, 相当于把多个剪枝核序列化成一个剪枝核, 新的 mask 矩阵与旧的 mask 矩阵的结合使用 PruningContainer
中的 compute_mask
方法. 比如在上面的 module 的 weight 中, 我们除了随机剪枝外还可以通过范式剪枝, 下面是个例子:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
# 这里的 n 表示剪枝的范式, dim = 0, 表示参数矩阵的维度, 这里卷积层的 dim= 0, 就是核的个数
print(module.weight)
剪完之后, 核的个数变成原来的一半. mask 矩阵也会自动叠加.
还可以通过下面的方法查看我们使用了哪些方法剪枝, hook 记录了某个 attribute 的剪枝方法:
for hook in module._forward_pre_hooks.values():
if hook._tensor_name == "weight": # select out the correct hook
break
print(list(hook)) # pruning history in the container
Serializing a pruned model
所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的 state_dict 中,因此可以根据需要轻松地序列化和保存.
我们可以通过下面的方法查看模型中的权重参数:
>> print(model.state_dict().keys())
>> odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
Remove pruning re-parametrization
注意, 这里的删除剪枝的意思并不是真正的删除, 还原到未剪枝的状态. 举个例子, 剪枝之后, 我们的参数 parameters 中的 weight 会变成, 'weight_orig', 而 weight 变成一个属性, 他是 'weight_orig' 与 mask 矩阵结合后的结果, 那么
prune.remove(module, 'weight')
之后会发生什么呢?
print(list(module.named_parameters()))
('weight', Parameter containing:
tensor([[[[-0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000]]],
.......
也就是说, weight 又变成了 parameters, 剪枝变成永久化.
Pruning multiple parameters in a model
多个参数, 多个网络结构的剪枝,
new_model = LeNet()
for name, module in new_model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# 将所有卷积层的权重减去 20%
# prune 40% of connections in all linear layers
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
# 将所有全连接层的权重减去 40%
print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
Global pruning
之前的剪枝我们都是针对每一层每一层的剪枝, 减去某一层权重的百分比, 对于全局剪枝就是将模型的参数看成一个整体, 减去一部分参数, 对于每一层减去的比例可能不同.
剪枝的方法可以通过下面的方法:
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
使用自定义的方法剪枝
要实现自己的修剪功能,可以通过将 BasePruningMethod 基类作为子类来扩展 nn.utils.prune
模块,就像其他所有修剪方法一样. 基类以及完成了下面的方法:
__call__
, apply_mask
, apply
, prune
, and remove
除了一些特殊的情况, 你不需要重写这些方法以实现新的剪枝方法. 你需要实现的是:
__init__
构造器compute_mask
如何根据剪枝策略的逻辑为给定张量计算 mask- 需要说明是全局剪枝, 还是结构剪枝, 或者是非结构剪枝, 这决定了在迭代剪枝是如何结合 mask 矩阵, 换句话说,当剪枝需要剪枝的参数时,当前的剪枝策略应作用于参数的未剪枝部分。指定
PRUNING_TYPE
将启用 PruningContainer 正确识别要修剪的参数的范围.
比如说, 当我们希望剪枝一个张量中除了某一参数外的所有其他参数的时候, 或者说这个张量已经被部分剪枝的时候, 我们就需要设置: PRUNING_TYPE='unstructured'
因为他只是单独作用与一层, 而不是一个单元或者通道(对应于'structured'), 也不是作用于整个参数(对应于'global')
class FooBarPruningMethod(prune.BasePruningMethod):
# 继承自基类 BasePruningMethod
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
# 类型为 unstructured 类型
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
# 定义了 mask 矩阵的构成方法, 每两个数字一个 0
return mask
然后给出一个调用的例子:
def foobar_unstructured(module, name):
"""Prunes tensor corresponding to parameter called `name` in `module`
by removing every other entry in the tensors.
Modifies module in place (and also return the modified module)
by:
1) adding a named buffer called `name+'_mask'` corresponding to the
binary mask applied to the parameter `name` by the pruning method.
The parameter `name` is replaced by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
`name+'_orig'`.
Args:
module (nn.Module): module containing the tensor to prune
name (string): parameter name within `module` on which pruning
will act.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input
module
Examples:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
以上就是Pytorch 剪枝的主要方法, 其实对于复杂的剪枝方法, 只要在 compute_mask
设置特殊的 mask 构成方法就可以了.