VGG16
run/vgg16/vgg16_prune_demo.py运行:
python ./run/vgg16/vgg16_prune_demo.py --config ./run/vgg16/prune.json
报错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 16, in <module>
from logger import logger
File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 67, in <module>
logger = Logger()
File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 42, in __init__
json.dump(cfg, fp)
File "/anaconda3/envs/deeplearning/lib/python3.7/json/__init__.py", line 179, in dump
for chunk in iterable:
File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 438, in _iterencode
o = _default(o)
File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 179, in default
raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type Config is not JSON serializable
原因是无法序列化某些对象格式,因为我们这里使用了自定义的dotdict
解决办法:
将logger.py中的json.dump()改为:
with open(self.cfgfile, 'w') as fp:
json.dump(cfg, fp, cls=dotdict)
显式指定使用自定义序列化方法dotdict
再出错:
AssertionError: Torch not compiled with CUDA enabled
将prune.json中的cuda:true改为false
报错:
FileNotFoundError: [Errno 2] No such file or directory: './logs/vgg16_cifar10/ckp.160.torch'
这是因为我没有按照顺序运行,没有先运行:
CUDA_VISIBLE_DEVICES=0 python main.py --config ./run/vgg16/baseline.json
该命令会生成一个ckp.160.torch文件
所以我使用pytorch给的预训练文件,将vgg16_prune_demo.py中的:
def get_pack():
set_seeds()
pack = recover_pack()
model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
pack.net.module.load_state_dict(model_dict)
改成:
def get_pack():
set_seeds()
pack = recover_pack()
pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
然后查看此时的网络结果:
pack, GBNs = get_pack()
for name, child in pack.net.named_children():
print(name)
print(child)
print(GBNs)
后面运行出错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
run()
File "./run/vgg16/vgg16_prune_demo.py", line 112, in run
pack, GBNs = get_pack()
File "./run/vgg16/vgg16_prune_demo.py", line 29, in get_pack
pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.__class__.__name__, "
".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VGG:
size mismatch for features.7.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
这是因为应该使用的结构是vgg16_bn的结构,否则就没有bn层,改模型为https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
又报错:
Traceback (most recent call last):
File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
run()
File "./run/vgg16/vgg16_prune_demo.py", line 114, in run
cloned, _ = clone_model(pack.net)
File "./run/vgg16/vgg16_prune_demo.py", line 54, in clone_model
gbns = GatedBatchNorm2d.transform(model.module)
File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
type(self).__name__, name))
AttributeError: 'VGG' object has no attribute 'module'
把model.module改成model即可,因为我没有使用
if cfg.base.multi_gpus: #设置了multi_gpus为False model = torch.nn.DataParallel(model)
仅仅根据代码说说原理
感觉看了所有的代码后其工作原理是这样的,拿vgg16_prune_demo.py的prune()函数举例子:
prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
1)准备好了Tick-Tock
# 先所有数据迭代cfg.gbn.tock_epoch次
prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
其实就相当于在原有模型上进行微调cfg.gbn.tock_epoch次
2)然后就循环进行Tick操作:
def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM): LOGS = [] flops_save_points = set([30, 20, 10]) iter_idx = 0 pack.tick_trainset = pack.train_loader prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3) # 先所有数据迭代cfg.gbn.tock_epoch次 prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) while True: left_filter = prune_agent.total_filters - prune_agent.pruned_filters num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值 info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数 flops, params = eval_prune(pack) info.update({ #查看这次剪枝后的结果 'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6), 'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6) }) LOGS.append(info) print('Iter: %d, FLOPS: %s, Param: %s, Left: %d, Pruned Ratio: %.2f %%, Train Loss: %.4f, Test Acc: %.2f' % (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc'])) iter_idx += 1 if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock print('Tocking:') prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少 for point in [i for i in list(flops_save_points)]: if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point)) flops_save_points.remove(point) if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了 break
Tick操作就是在计算分数,决定剪去BN层的哪些channels
3)开始进行Tick-Tock前的网络结构就是将BN层换成了GBN层:
def get_pack(): set_seeds() pack = recover_pack() #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda') pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False) #pack.net.module.load_state_dict(model_dict) GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练 for gbn in GBNs: gbn.extract_from_bn() # for name, child in pack.net.named_children(): # print(name) # print(child) pack.optimizer = optim.SGD( pack.net.parameters() , lr=2e-3, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) return pack, GBNs
GatedBatchNorm2d.transform(pack.net) 中的extract_from_bn()函数在bn层加入g参数,同时将其bias、weight参数进行更改,并freeze weight参数,这样训练时只有g参数会优化:
def extract_from_bn(self): # freeze bn weight with torch.no_grad(): self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1)) self.bn.weight.set_(torch.ones_like(self.bn.weight)) self.bn.weight.requires_grad = False
如论文中:
在这个基础上进行Tock操作其实就是在bn层加入g参数,并freeze weight参数的基础上使用整个训练数据集训练模型
4)然后进行prune操作:
info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
其实就是进行Tick操作+prune操作
首先Tick操作是:
def tick(self, lr, test): ''' Do Prune ''' self.freeze_conv() info = self.recover(lr, test) self.restore_conv() return info
会freeze住卷积层的参数,所以tick训练时只会训练GBN层的g参数和全连接层的参数
接下来的就是剪枝prune操作:
然后接下来就是根据这个Tick训练的g计算每个bn层中filter的分数,一开始bn_mask(查看prune/universal.py文件中的类GatedBatchNorm2d定义)这个值全是1,即表示所有的filter都要,这样子self.score*self.bn_mask就能得到所有的filter的分数,然后再根据分数进行排序等操作来计算阈值分数值threshold,然后再根据阈值等信息得到一个self.mask的值,用这个值去更新self.bn_mask = mask * g.bn_mask,这样每个GBN层中的bn_mask值中为0就表示对应的filter是被删除的,1则表示该对应的filter留下
所以剪枝操作其实就是根据bn_mask的结果去剪枝,因为GatedBatchNorm2d类的forward操作中有:
def forward(self, x): x = self.bn(x) * self.g self.area[0] = x.shape[-1] * x.shape[-2] if self.bn_mask is not None: return x * self.bn_mask return x
因此在训练的时候,前向操作经过GBN层得到的结果就是x * self.bn_mask,bn_mask为0对应的x的channels的值就会全为0,就相当于剪掉了这个filter
5)接下来就是根据上面的剪枝结果去对应地将卷积层和全连接层中的channels数和GBN层对应起来:
_ = Conv2dObserver.transform(pack.net.module)
pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)
主要就是将它们分别封装成Conv2dObserver和FinalLinearObserver
Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了:
def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) def _backward_hook(self, grad): self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) new_grad = torch.ones_like(grad) return new_grad def forward(self, x): output = self.conv(x) noise = torch.zeros_like(output).normal_() output = output + noise if self.training: output.register_hook(self._backward_hook) return output
FinalLinearObserver也是同样的概念
6)然后就是observe和melt_all操作:
Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net)
observe感觉就是在将那些没有被换成GBN层的bn层的weight添加一个极小值(1e-3)、将relu层改成LeakyReLU并freeze bn层的参数,然后再进行训练,训练完之后再恢复原状(这里一直不太明白目的是啥)
突然明白这里是干嘛了,这里其实就是训练一遍,来计算Conv2dObserver和FinalLinearObserver中in_mask和out_mask的结果,然后用于melt_all
melt_all其实就是将所有的GBN、Conv2dObserver和FinalLinearObserver根据得到的in_mask和out_mask以及GBN中的self.bn_mask来恢复网络,删去不要的filter,只将对应的filter的参数赋值到新的网络结构中,调用的是这几个类中的melt()函数
7)最后再使用这个新的网络结构进行微调:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)
要自己将微调后的模型保存下来
1》仅保存模型:
torch.save(pack.net.module.state_dict(), os.path.join(saving_path, '30_finetune_state.pth'))
用.module是因为使用了:
model = torch.nn.DataParallel(model)
如果没有使用可以删掉
2》保存模型和网络结构:
torch.save(pack.net.module, os.path.join(saving_path, '30_finetune.pth'))
整个代码是:
import os import sys _r = os.getcwd().split('/') _p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1]) print('Change dir from %s to %s' % (os.getcwd(), _p)) os.chdir(_p) sys.path.append(_p) import torch import torch.nn as nn import numpy as np import torch.optim as optim from config import cfg from logger import logger from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr from models import get_model from utils import dotdict from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver from prune.utils import analyse_model, finetune def get_pack(): set_seeds() pack = recover_pack() #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda') pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False) #pack.net.module.load_state_dict(model_dict) GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练 for gbn in GBNs: gbn.extract_from_bn() # for name, child in pack.net.named_children(): # print(name) # print(child) pack.optimizer = optim.SGD( pack.net.parameters() , lr=2e-3, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) return pack, GBNs # get_pack() def clone_model(net): model = get_model() gbns = GatedBatchNorm2d.transform(model) model.load_state_dict(net.state_dict()) return model, gbns def eval_prune(pack): cloned, _ = clone_model(pack.net) _ = Conv2dObserver.transform(cloned.module) #根据prune后的bn更改conv2d层 cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) #根据prune后的bn更改全连接层 cloned_pack = dotdict(pack.copy()) cloned_pack.net = cloned Meltable.observe(cloned_pack, 0.001) Meltable.melt_all(cloned_pack.net) #根据此时的g恢复所有的参数 # flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda()) flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32)) del cloned del cloned_pack return flops, params def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM): LOGS = [] flops_save_points = set([30, 20, 10]) iter_idx = 0 pack.tick_trainset = pack.train_loader prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3) # 先所有数据迭代cfg.gbn.tock_epoch次 prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) while True: left_filter = prune_agent.total_filters - prune_agent.pruned_filters num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值 info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数 flops, params = eval_prune(pack) info.update({ #查看这次剪枝后的结果 'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6), 'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6) }) LOGS.append(info) print('Iter: %d, FLOPS: %s, Param: %s, Left: %d, Pruned Ratio: %.2f %%, Train Loss: %.4f, Test Acc: %.2f' % (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc'])) iter_idx += 1 if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock print('Tocking:') prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少 for point in [i for i in list(flops_save_points)]: if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point)) flops_save_points.remove(point) if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了 break def run(): pack, GBNs = get_pack() cloned, _ = clone_model(pack.net) # BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32).cuda()) #计算一开始预训练好的模型的Flops和内存 BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32)) print('%.3f MFLOPS' % (BASE_FLOPS / 1e6)) print('%.3f M' % (BASE_PARAM / 1e6)) del cloned prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) # 进行Tick-Tock操作 _ = Conv2dObserver.transform(pack.net.module) pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier) Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net) pack.optimizer = optim.SGD( pack.net.parameters(), lr=1, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov ) _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch) run()
感觉比较重要的代码是prune/universal.py:
这里有转换bn层、conv层和FinalLinear层的类,还有Tick-Tock操作的类:
""" * Copyright (C) 2019 Zhonghui You * If you are using this code in your research, please cite the paper: * Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019. """ import torch import torch.nn as nn import numpy as np import uuid OBSERVE_TIMES = 5 FINISH_SIGNAL = 'finish' class Meltable(nn.Module): def __init__(self): super(Meltable, self).__init__() @classmethod def melt_all(cls, net): def _melt(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _melt(modules[k]._modules) if isinstance(modules[k], Meltable): modules[k] = modules[k].melt() #根据此时的g恢复所有的参数 _melt(net._modules) @classmethod def observe(cls, pack, lr): tmp = pack.train_loader if pack.tick_trainset is not None: pack.train_loader = pack.tick_trainset for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): #这个用来干嘛的?? m.weight.data.abs_().add_(1e-3) def replace_relu(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: replace_relu(modules[k]._modules) if isinstance(modules[k], nn.ReLU): modules[k] = nn.LeakyReLU(inplace=True) replace_relu(pack.net._modules) count = 0 def _freeze_bn(curr_iter, total_iter): for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() nonlocal count count += 1 if count == OBSERVE_TIMES: return FINISH_SIGNAL info = pack.trainer.train(pack, iter_hook=_freeze_bn, update=False, mute=True) def recover_relu(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: recover_relu(modules[k]._modules) if isinstance(modules[k], nn.LeakyReLU): modules[k] = nn.ReLU(inplace=True) recover_relu(pack.net._modules) for m in pack.net.modules(): if isinstance(m, nn.BatchNorm2d): m.weight.data.abs_().add_(-1e-3) pack.train_loader = tmp class GatedBatchNorm2d(Meltable): def __init__(self, bn, minimal_ratio = 0.1): super(GatedBatchNorm2d, self).__init__() assert isinstance(bn, nn.BatchNorm2d) self.bn = bn self.group_id = uuid.uuid1() self.channel_size = bn.weight.shape[0] self.minimal_filter = max(1, int(self.channel_size * minimal_ratio)) self.device = bn.weight.device self._hook = None self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True) self.register_buffer('area', torch.zeros(1).to(self.device)) #记录此时输入的图像大小 self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device)) #这个值要么0要么1,根据得到的分数来得到self.masks,然后设置g.bn_mask.set_(mask * g.bn_mask),这个才是决定channels留下来与否的值 self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device)) self.extract_from_bn() def set_groupid(self, new_id): self.group_id = new_id def extra_repr(self): return '%d -> %d | ID: %s' % (self.channel_size, int(self.bn_mask.sum()), self.group_id) def extract_from_bn(self): # freeze bn weight with torch.no_grad(): self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1)) self.bn.weight.set_(torch.ones_like(self.bn.weight)) self.bn.weight.requires_grad = False def reset_score(self): self.score.zero_() def cal_score(self, grad): # used for hook self.score += (grad * self.g).abs() def start_collecting_scores(self): if self._hook is not None: self._hook.remove() self._hook = self.g.register_hook(self.cal_score) def stop_collecting_scores(self): if self._hook is not None: self._hook.remove() self._hook = None def get_score(self, eta=0.0): #eta表示什么?,如果为0,其实就是score = self.score * self.bn_mask # use self.bn_mask.sum() to calculate the number of input channel. eta should had been normed flops_reg = eta * int(self.area[0]) * self.bn_mask.sum() return ((self.score - flops_reg) * self.bn_mask).view(-1) def forward(self, x): # train时就会调用这个函数,输出与self.g相乘 x = self.bn(x) * self.g self.area[0] = x.shape[-1] * x.shape[-2] if self.bn_mask is not None: return x * self.bn_mask #只留下bn_mask中值不为0的channels对应的x的值 return x def melt(self): #训练完了,恢复参数的函数 with torch.no_grad(): mask = self.bn_mask.view(-1) replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device) replacer.running_var.set_(self.bn.running_var[mask != 0]) replacer.running_mean.set_(self.bn.running_mean[mask != 0]) replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0]) replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0]) return replacer @classmethod def transform(cls, net, minimal_ratio=0.1): r = [] def _inject(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _inject(modules[k]._modules) if isinstance(modules[k], nn.BatchNorm2d): modules[k] = GatedBatchNorm2d(modules[k], minimal_ratio) #将原来的BN层改成GBN层 r.append(modules[k]) _inject(net._modules) return r class FinalLinearObserver(Meltable): ''' assert was in the last layer. only input was masked ''' def __init__(self, linear): super(FinalLinearObserver, self).__init__() assert isinstance(linear, nn.Linear) self.linear = linear self.in_mask = torch.zeros(linear.weight.shape[1]).to('cpu') self.f_hook = linear.register_forward_hook(self._forward_hook) def extra_repr(self): return '(%d, %d) -> (%d, %d)' % ( int(self.linear.weight.shape[1]), int(self.linear.weight.shape[0]), int((self.in_mask != 0).sum()), int(self.linear.weight.shape[0])) def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1) def forward(self, x): return self.linear(x) def melt(self): with torch.no_grad(): replacer = nn.Linear(int((self.in_mask != 0).sum()), self.linear.weight.shape[0]).to(self.linear.weight.device) replacer.weight.set_(self.linear.weight[:, self.in_mask != 0]) replacer.bias.set_(self.linear.bias) return replacer class Conv2dObserver(Meltable): def __init__(self, conv): super(Conv2dObserver, self).__init__() assert isinstance(conv, nn.Conv2d) self.conv = conv self.in_mask = torch.zeros(conv.in_channels).to('cpu') self.out_mask = torch.zeros(conv.out_channels).to('cpu') self.f_hook = conv.register_forward_hook(self._forward_hook) def extra_repr(self): return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum())) def _forward_hook(self, m, _in, _out): x = _in[0] self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) def _backward_hook(self, grad): self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) new_grad = torch.ones_like(grad) return new_grad def forward(self, x): output = self.conv(x) noise = torch.zeros_like(output).normal_() output = output + noise if self.training: output.register_hook(self._backward_hook) return output def melt(self): if self.conv.groups == 1: groups = 1 elif self.conv.groups == self.conv.out_channels: groups = int((self.out_mask != 0).sum()) else: assert False replacer = nn.Conv2d( in_channels = int((self.in_mask != 0).sum()), out_channels = int((self.out_mask != 0).sum()), kernel_size = self.conv.kernel_size, stride = self.conv.stride, padding = self.conv.padding, dilation = self.conv.dilation, groups = groups, bias = (self.conv.bias is not None) ).to(self.conv.weight.device) with torch.no_grad(): if self.conv.groups == 1: replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0]) else: replacer.weight.set_(self.conv.weight[self.out_mask != 0]) if self.conv.bias is not None: replacer.bias.set_(self.conv.bias[self.out_mask != 0]) return replacer @classmethod def transform(cls, net): r = [] def _inject(modules): keys = modules.keys() for k in keys: if len(modules[k]._modules) > 0: _inject(modules[k]._modules) if isinstance(modules[k], nn.Conv2d): modules[k] = Conv2dObserver(modules[k]) r.append(modules[k]) _inject(net._modules) return r # ------------------------------------------------------------------------------------------------------- def get_gate_sparse_loss(masks, sparse_lambda): def _loss_hook(data, label, logits): loss = 0.0 for gbn in masks: if isinstance(gbn, GatedBatchNorm2d): loss += gbn.g.abs().sum() return sparse_lambda * loss return _loss_hook class IterRecoverFramework(): def __init__(self, pack, masks, sparse_lambda=1e-5, flops_eta=0.0, minium_filter=10): self.pack = pack self.masks = masks self.sparse_loss_hook = get_gate_sparse_loss(masks, sparse_lambda) #计算tock的损失的后半部分 self.logs = [] # minium_filter would be delete # self.minium_filter = minium_filter self.sparse_lambda = sparse_lambda self.flops_eta = flops_eta self.eta_scale_factor = 1.0 self.total_filters = sum([m.bn.weight.shape[0] for m in masks]) self.pruned_filters = 0 def recover(self, lr, test): for gbn in self.masks: if isinstance(gbn, GatedBatchNorm2d): gbn.reset_score() gbn.start_collecting_scores() for g in self.pack.optimizer.param_groups: g['lr'] = lr tmp = self.pack.train_loader self.pack.train_loader = self.pack.tick_trainset #使用训练集的子集 info = self.pack.trainer.train(self.pack) #执行Tick,只更新gate φ和最后的线性层参数 self.pack.train_loader = tmp if test: info.update(self.pack.trainer.test(self.pack)) info.update({'LR': lr}) for gbn in self.masks: if isinstance(gbn, GatedBatchNorm2d): gbn.stop_collecting_scores() return info def get_threshold(self, status, num): ''' input score list from layers, and the number of filter to prune ''' total_filters, left_filters = 0, 0 filtered_score_list = [] for group_id, v in status.items(): total_filters += len(v['score']) * v['count'] #count>1,说明分group了,对应channels有着相同的分数 left_filters += int((v['score'] != 0).sum()) * v['count'] sorted_score = np.sort(v['score'])[:-v['minimal']] #-v['minimal']之后的分数是不要的,按比例抛弃 filtered_score = sorted_score[sorted_score != 0] for i in range(v['count']): filtered_score_list.append(filtered_score)#因为相同组的channels分数是相同的,所以append v['count']次 scores = np.concatenate(filtered_score_list) #将所有GBN中的channels的分数串联在一起 threshold = np.sort(scores)[num] #然后再排序,取num索引的值作为阈值 to_prune = int((scores <= threshold).sum()) #分数小于这个阈值的channels也prune info = {'left': left_filters, 'to_prune': to_prune, 'total_pruned_ratio': (total_filters - left_filters + to_prune) / total_filters} return threshold, info #这里的作用就是根据训练得到的分数和阈值计算出mask,用于更改之前的g.bn_mask,这个才是决定channels的值石佛耦留下来的值 # 因为在BGN的forward中输出的x为x * self.bn_mask def set_mask(self, status, threshold): #这里的self.masks是GBNs for group_id, v in status.items(): hard_threshold = float(np.sort(v['score'])[-v['minimal']]) #根据按比例得到的v['minimal'] = max(1, int(self.channel_size * minimal_ratio)),得到该位置的分数,说明小于这个分数的channels是一定要prune的 hard_mask = v['score'] >= hard_threshold #留下的channels更多 soft_mask = v['score'] > threshold #留下的channels少 v['mask'] = (hard_mask + soft_mask) with torch.no_grad(): for g in self.masks: if g.group_id in status: mask = torch.from_numpy(status[g.group_id]['mask'].astype('float32')).to(g.device).view(1, -1, 1, 1) g.bn_mask.set_(mask * g.bn_mask) def freeze_conv(self): #Tick训练时不更新conv的参数,所以freeze它们 self._status = {} for m in self.pack.net.modules(): if isinstance(m, nn.Conv2d): for p in m.parameters(): self._status[id(p)] = p.requires_grad p.requires_grad = False def restore_conv(self): for m in self.pack.net.modules(): if isinstance(m, nn.Conv2d): for p in m.parameters(): p.requires_grad = self._status[id(p)] def tock(self, lr_min=0.001, lr_max=0.01, tock_epoch = 20, mute=False, acc_step=1): #损失有一个额外的sparse loss,所以loss_hook有值,训练改变所有的参数,只是没有计算score,参数g的作用就是用于计算score logs = [] epoch = 0 T = tock_epoch def iter_hook(curr_iter, total_iter): total = T * total_iter half = total / 2 itered = epoch * total_iter + curr_iter if itered < half: _iter = epoch * total_iter + curr_iter _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max else: _iter = (epoch - T/2) * total_iter + curr_iter _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min for g in self.pack.optimizer.param_groups: g['lr'] = max(_lr, 0) # g['momentum'] = 0.9 for i in range(T):# 迭代T次 info = self.pack.trainer.train(self.pack, loss_hook = self.sparse_loss_hook, iter_hook = iter_hook, acc_step=acc_step) info.update(self.pack.trainer.test(self.pack)) info.update({'LR': self.pack.optimizer.param_groups[0]['lr']}) epoch += 1 if not mute: # print('Tock - %d, Test Loss: %.4f, Test Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['acc@1'], info['LR'])) print('Tock - %d, Test Loss: %.4f, Test age_correct Acc: %.2f, Test gender_correct Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['age_correct'], info['gender_correct'], info['LR'])) logs.append(info) return logs def tick(self, lr, test): ''' Do Prune ''' self.freeze_conv() info = self.recover(lr, test) self.restore_conv() return info def prune(self, num, tick=False, lr=0.01, test=True): info = {} if tick: info = self.tick(lr, test) area = [] for g in self.masks: area.append(int(g.area[0])) self.eta_scale_factor = min(area) status = {} for g in self.masks: if g.group_id in status: # assert the gbn in same group has the same channel size status[g.group_id]['score'] += g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy() status[g.group_id]['count'] += 1 else: status[g.group_id] = { 'score': g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy(), 'minimal': g.minimal_filter, 'count': 1, 'mask': None } threshold, r = self.get_threshold(status, num) info.update(r) threshold = float(threshold) self.set_mask(status, threshold) if test: info.update({'after_prune_test_age_acc': self.pack.trainer.test(self.pack)['age_correct']}) info.update({'after_prune_test_gender_acc': self.pack.trainer.test(self.pack)['gender_correct']}) self.logs.append(info) self.pruned_filters = self.total_filters - info['left'] info['total'] = self.total_filters return info
Resnet
感觉Resnet和VGG的主要差别在与Resnet有侧枝,所以需要对BN层分组:
1)resnet-56/resnet56_prune.ipynb
一步步向下运行:
GBNs = GatedBatchNorm2d.transform(pack.net)#GatedBatchNorm2d初始化时调用的self.extract_from_bn() 是用于一开始最外层的bn层的
print(GBNs) #extra_repr(self)函数设置了返回的额外内容,此时的ID是各不相同的
for gbn in GBNs: #这个再一层层地初始化
gbn.extract_from_bn() #bn层的权重训练时不变,只有g变
返回:
[GatedBatchNorm2d(
16 -> 16 | ID: 94d38830-2d48-11ea-ba2e-00e04c6841ff
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
16 -> 16 | ID: 94d5b330-2d48-11ea-ba2e-00e04c6841ff
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
), GatedBatchNorm2d(
16 -> 16 | ID: 94d5bb32-2d48-11ea-ba2e-00e04c6841ff
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
),
...
import uuid
def bottleneck_set_group(net): #将3个layers的bn分成3个组
layers = [
net.module.layer1,
net.module.layer2,
net.module.layer3
]
for m in layers:
masks = []
if m == net.module.layer1: #将layer1这个分组之前的一个bn层添加进来
masks.append(pack.net.module.bn1)
for mm in m.modules():
if mm.__class__.__name__ == 'BasicBlock':
if len(mm.shortcut._modules) > 0: #说明shortcut是resnet两个channels不同的layer层的过渡操作
masks.append(mm.shortcut._modules['1']) #这里面也有一个bn层
masks.append(mm.bn2) #bn1不加吗?
group_id = uuid.uuid1()
for mk in masks: #masks中的每个值都是一个
mk.set_groupid(group_id) #这个是GatedBatchNorm2d中设置的函数,仅将bn2的id更改成新的
print(masks)
bottleneck_set_group(pack.net)
返回:
[GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )] [GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )] [GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), GatedBatchNorm2d( 64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )]
可见分成3组,每组的GBN的ID是想用的
此时的网络结构是:
for name, child in pack.net.named_children():
print(name)
print(child)
通过将同一个group的id设置为同一个id来说明它们是同一个group的,这样同一个组的BN的bn_mask值是相同的,即它们的channels数也是相同的,这样就能连起来了
然后克隆了一个一样的模型来计算模型的FLOPs和参数大小:
def clone_model(net):
model = get_model()
gbns = GatedBatchNorm2d.transform(model.module)
model.load_state_dict(net.state_dict())
return model, gbns
cloned, _ = clone_model(pack.net)
#BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32).cuda())
BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32))
print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
print('%.3f M' % (BASE_PARAM / 1e6))
for name, child in cloned.named_children():
print(name)
print(child)
del cloned
除了id不同其他一致,而且del cloned删除了该模型,说明该模型后面没用
返回:
127.932 MFLOPS
0.856 M
上面的clone_model(net)操作其实是为了下面的函数eval_prune(pack)服务的:
def eval_prune(pack): cloned, _ = clone_model(pack.net) _ = Conv2dObserver.transform(cloned.module) cloned.module.linear = FinalLinearObserver(cloned.module.linear) cloned_pack = dotdict(pack.copy()) cloned_pack.net = cloned Meltable.observe(cloned_pack, 0.001) Meltable.melt_all(cloned_pack.net) flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda()) del cloned del cloned_pack return flops, params
其实就是用于计算此时的模型的flops, params,用来与原始模型比较,查看此时压缩了多少
接下来就是测试此时的模型,看看效果:
pack.trainer.test(pack)
返回:
{'test_loss': 0.31250936849207817, 'acc@1': 92.92919303797468}
然后就是tick-tock操作:
pack.tick_trainset = pack.train_loader prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3) LOGS = [] flops_save_points = set([40, 38, 35, 32, 30]) #当压缩到原来模型的40%、38%...时保存模型 iter_idx = 0 # 先进行一个tock,训练tock_epoch次,查看 prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) while True: left_filter = prune_agent.total_filters - prune_agent.pruned_filters num_to_prune = int(left_filter * cfg.gbn.p) info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #进行tick操作 flops, params = eval_prune(pack) #计算此时的模型 info.update({ 'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6), 'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6) }) LOGS.append(info) print('Iter: %d, FLOPS: %s, Param: %s, Left: %d, Pruned Ratio: %.2f %%, Train Loss: %.4f, Test Acc: %.2f' % (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc'])) iter_idx += 1 if iter_idx % cfg.gbn.T == 0: #即每gbn.T=10次tick操作后进行gbn.tock_epoch次tock操作 print('Tocking:') prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch) flops_ratio = flops/BASE_FLOPS * 100 #计算此时的模型占原来模型的百分比 for point in [i for i in list(flops_save_points)]: #当压缩到原来模型的40%、38%...时保存模型 if flops_ratio <= point: torch.save(pack.net.module.state_dict(), './logs/resnet56_cifar10_ticktock/%s.ckp' % str(point)) flops_save_points.remove(point) if len(flops_save_points) == 0: #当压缩到30%时就停止压缩 break
其finetune操作在resnet-56/finetune.ipynb
核心就是先将模型根据bn_mask值剪枝,不仅剪BN层,还要见Conv2d层和FinalLinear层,即:
GBNs = GatedBatchNorm2d.transform(pack.net) for gbn in GBNs: gbn.extract_from_bn() _ = Conv2dObserver.transform(pack.net.module) pack.net.module.linear = FinalLinearObserver(pack.net.module.linear) Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net) # 剪枝
然后进行finetune:
_ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)
这里有一个特别的写法,记录一下
保证路径为主路径的方法:python保证路径为主路径的方法