• Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks


    VGG16

    run/vgg16/vgg16_prune_demo.py运行:

     python ./run/vgg16/vgg16_prune_demo.py --config ./run/vgg16/prune.json

    报错:

    Traceback (most recent call last):
      File "./run/vgg16/vgg16_prune_demo.py", line 16, in <module>
        from logger import logger
      File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 67, in <module>
        logger = Logger()
      File "/Users/user/pytorch/gate-decorator-pruning/logger.py", line 42, in __init__
        json.dump(cfg, fp)
      File "/anaconda3/envs/deeplearning/lib/python3.7/json/__init__.py", line 179, in dump
        for chunk in iterable:
      File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 438, in _iterencode
        o = _default(o)
      File "/anaconda3/envs/deeplearning/lib/python3.7/json/encoder.py", line 179, in default
        raise TypeError(f'Object of type {o.__class__.__name__} '
    TypeError: Object of type Config is not JSON serializable

    原因是无法序列化某些对象格式,因为我们这里使用了自定义的dotdict

    解决办法:

    将logger.py中的json.dump()改为:

                with open(self.cfgfile, 'w') as fp:
                    json.dump(cfg, fp, cls=dotdict)

    显式指定使用自定义序列化方法dotdict

    再出错:

    AssertionError: Torch not compiled with CUDA enabled

    将prune.json中的cuda:true改为false

    报错:

    FileNotFoundError: [Errno 2] No such file or directory: './logs/vgg16_cifar10/ckp.160.torch'

    这是因为我没有按照顺序运行,没有先运行:

    CUDA_VISIBLE_DEVICES=0 python main.py --config ./run/vgg16/baseline.json

    该命令会生成一个ckp.160.torch文件

    所以我使用pytorch给的预训练文件,将vgg16_prune_demo.py中的:

    def get_pack():
        set_seeds()
        pack = recover_pack()
    
        model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
        pack.net.module.load_state_dict(model_dict)

    改成:

    def get_pack():
        set_seeds()
        pack = recover_pack()
        pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)

    然后查看此时的网络结果:

    pack, GBNs = get_pack()
    for name, child in pack.net.named_children():
        print(name)
        print(child)
    
    print(GBNs)

    后面运行出错:

    Traceback (most recent call last):
      File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
        run()
      File "./run/vgg16/vgg16_prune_demo.py", line 112, in run
        pack, GBNs = get_pack()
      File "./run/vgg16/vgg16_prune_demo.py", line 29, in get_pack
        pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16-397923af.pth'), strict=False)
      File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
        self.__class__.__name__, "
    	".join(error_msgs)))
    RuntimeError: Error(s) in loading state_dict for VGG:
        size mismatch for features.7.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).

    这是因为应该使用的结构是vgg16_bn的结构,否则就没有bn层,改模型为https://download.pytorch.org/models/vgg16_bn-6c64b313.pth

    又报错:

    Traceback (most recent call last):
      File "./run/vgg16/vgg16_prune_demo.py", line 137, in <module>
        run()
      File "./run/vgg16/vgg16_prune_demo.py", line 114, in run
        cloned, _ = clone_model(pack.net)
      File "./run/vgg16/vgg16_prune_demo.py", line 54, in clone_model
        gbns = GatedBatchNorm2d.transform(model.module)
      File "/anaconda3/envs/deeplearning/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
        type(self).__name__, name))
    AttributeError: 'VGG' object has no attribute 'module'

    model.module改成model即可,因为我没有使用

        if cfg.base.multi_gpus: #设置了multi_gpus为False
            model = torch.nn.DataParallel(model)

    仅仅根据代码说说原理

    感觉看了所有的代码后其工作原理是这样的,拿vgg16_prune_demo.py的prune()函数举例子:

    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)

    1)准备好了Tick-Tock

    # 先所有数据迭代cfg.gbn.tock_epoch次
        prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)

    其实就相当于在原有模型上进行微调cfg.gbn.tock_epoch次

    2)然后就循环进行Tick操作:

    def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
        LOGS = []
        flops_save_points = set([30, 20, 10])
        iter_idx = 0
    
        pack.tick_trainset = pack.train_loader
        prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
        # 先所有数据迭代cfg.gbn.tock_epoch次
        prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
        while True:
            left_filter = prune_agent.total_filters - prune_agent.pruned_filters
            num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
            info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
            flops, params = eval_prune(pack)
            info.update({ #查看这次剪枝后的结果
                'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
                'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
            })
            LOGS.append(info)
            print('Iter: %d,	 FLOPS: %s,	 Param: %s,	 Left: %d,	 Pruned Ratio: %.2f %%,	 Train Loss: %.4f,	 Test Acc: %.2f' % 
                (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
            
            iter_idx += 1
            if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
                print('Tocking:')
                prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    
            flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
            for point in [i for i in list(flops_save_points)]:
                if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                    torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                    flops_save_points.remove(point)
    
            if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
                break

    Tick操作就是在计算分数,决定剪去BN层的哪些channels

    3)开始进行Tick-Tock前的网络结构就是将BN层换成了GBN层:

    def get_pack():
        set_seeds()
        pack = recover_pack()
    
        #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
        pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
        #pack.net.module.load_state_dict(model_dict)
    
        
        GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
        for gbn in GBNs:
            gbn.extract_from_bn()
            
    #     for name, child in pack.net.named_children():
    #         print(name)
    #         print(child)
            
        pack.optimizer = optim.SGD(
            pack.net.parameters() ,
            lr=2e-3,
            momentum=cfg.train.momentum,
            weight_decay=cfg.train.weight_decay,
            nesterov=cfg.train.nesterov
        )
    
        return pack, GBNs

    GatedBatchNorm2d.transform(pack.net) 中的extract_from_bn()函数在bn层加入g参数,同时将其bias、weight参数进行更改,并freeze weight参数,这样训练时只有g参数会优化:

        def extract_from_bn(self):
            # freeze bn weight
            with torch.no_grad():
                self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10))
                self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
                self.bn.weight.set_(torch.ones_like(self.bn.weight))
                self.bn.weight.requires_grad = False

    如论文中:

    在这个基础上进行Tock操作其实就是在bn层加入g参数,并freeze weight参数的基础上使用整个训练数据集训练模型

    4)然后进行prune操作:

    info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数

    其实就是进行Tick操作+prune操作

    首先Tick操作是:

        def tick(self, lr, test):
            ''' Do Prune '''
            self.freeze_conv()
            info = self.recover(lr, test)
            self.restore_conv()
            return info

    会freeze住卷积层的参数,所以tick训练时只会训练GBN层的g参数和全连接层的参数

    接下来的就是剪枝prune操作:

    然后接下来就是根据这个Tick训练的g计算每个bn层中filter的分数,一开始bn_mask(查看prune/universal.py文件中的类GatedBatchNorm2d定义)这个值全是1,即表示所有的filter都要,这样子self.score*self.bn_mask就能得到所有的filter的分数,然后再根据分数进行排序等操作来计算阈值分数值threshold,然后再根据阈值等信息得到一个self.mask的值,用这个值去更新self.bn_mask = mask * g.bn_mask,这样每个GBN层中的bn_mask值中为0就表示对应的filter是被删除的,1则表示该对应的filter留下

    所以剪枝操作其实就是根据bn_mask的结果去剪枝,因为GatedBatchNorm2d类的forward操作中有:

        def forward(self, x): 
            x = self.bn(x) * self.g
    
            self.area[0] = x.shape[-1] * x.shape[-2]
    
            if self.bn_mask is not None:
                return x * self.bn_mask
            return x

    因此在训练的时候,前向操作经过GBN层得到的结果就是x * self.bn_maskbn_mask为0对应的x的channels的值就会全为0,就相当于剪掉了这个filter

    5)接下来就是根据上面的剪枝结果去对应地将卷积层和全连接层中的channels数和GBN层对应起来:

        _ = Conv2dObserver.transform(pack.net.module)
        pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)

    主要就是将它们分别封装成Conv2dObserver和FinalLinearObserver

    Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了:

        def _forward_hook(self, m, _in, _out):
            x = _in[0]
            self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
    
        def _backward_hook(self, grad):
            self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
            new_grad = torch.ones_like(grad)
            return new_grad
    
        def forward(self, x):
            output = self.conv(x)
            noise = torch.zeros_like(output).normal_()
            output = output + noise
            if self.training:
                output.register_hook(self._backward_hook)
            return output

    FinalLinearObserver也是同样的概念

    6)然后就是observe和melt_all操作:

        Meltable.observe(pack, 0.001)
        Meltable.melt_all(pack.net)

    observe感觉就是在将那些没有被换成GBN层的bn层的weight添加一个极小值(1e-3)、将relu层改成LeakyReLU并freeze bn层的参数,然后再进行训练,训练完之后再恢复原状(这里一直不太明白目的是啥)

    突然明白这里是干嘛了,这里其实就是训练一遍,来计算Conv2dObserver和FinalLinearObserver中in_mask和out_mask的结果,然后用于melt_all

    melt_all其实就是将所有的GBN、Conv2dObserver和FinalLinearObserver根据得到的in_mask和out_mask以及GBN中的self.bn_mask来恢复网络,删去不要的filter,只将对应的filter的参数赋值到新的网络结构中,调用的是这几个类中的melt()函数

    7)最后再使用这个新的网络结构进行微调:

        _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

    要自己将微调后的模型保存下来

    1》仅保存模型:

    torch.save(pack.net.module.state_dict(), os.path.join(saving_path, '30_finetune_state.pth'))

    用.module是因为使用了:

    model = torch.nn.DataParallel(model)

    如果没有使用可以删掉

    2》保存模型和网络结构:

    torch.save(pack.net.module, os.path.join(saving_path, '30_finetune.pth'))

    整个代码是:

    import os
    import sys
    
    _r = os.getcwd().split('/')
    _p = '/'.join(_r[:_r.index('gate-decorator-pruning')+1])
    print('Change dir from %s to %s' % (os.getcwd(), _p))
    os.chdir(_p)
    sys.path.append(_p)
    
    import torch
    import torch.nn as nn
    import numpy as np
    import torch.optim as optim
    
    from config import cfg
    from logger import logger
    from main import set_seeds, recover_pack, adjust_learning_rate, _step_lr, _sgdr
    from models import get_model
    from utils import dotdict
    
    from prune.universal import Meltable, GatedBatchNorm2d, Conv2dObserver, IterRecoverFramework, FinalLinearObserver
    from prune.utils import analyse_model, finetune
    
    def get_pack():
        set_seeds()
        pack = recover_pack()
    
        #model_dict = torch.load('./logs/vgg16_cifar10/ckp.160.torch', map_location='cpu' if not cfg.base.cuda else 'cuda')
        pack.net.load_state_dict(torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), strict=False)
        #pack.net.module.load_state_dict(model_dict)
    
        
        GBNs = GatedBatchNorm2d.transform(pack.net) #这样操作之后BN层就变成了GBN层了,同时freeze该bn层的weight,不训练
        for gbn in GBNs:
            gbn.extract_from_bn()
            
    #     for name, child in pack.net.named_children():
    #         print(name)
    #         print(child)
            
        pack.optimizer = optim.SGD(
            pack.net.parameters() ,
            lr=2e-3,
            momentum=cfg.train.momentum,
            weight_decay=cfg.train.weight_decay,
            nesterov=cfg.train.nesterov
        )
    
        return pack, GBNs
    # get_pack()
    
    def clone_model(net):
        model = get_model()
        gbns = GatedBatchNorm2d.transform(model)
        model.load_state_dict(net.state_dict())
        return model, gbns
    
    
    def eval_prune(pack):
        cloned, _ = clone_model(pack.net)
        _ = Conv2dObserver.transform(cloned.module) #根据prune后的bn更改conv2d层
        cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) #根据prune后的bn更改全连接层
        cloned_pack = dotdict(pack.copy())
        cloned_pack.net = cloned
        Meltable.observe(cloned_pack, 0.001)
        Meltable.melt_all(cloned_pack.net) #根据此时的g恢复所有的参数
    #     flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
        flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32))
        del cloned
        del cloned_pack
        
        return flops, params
    
    
    def prune(pack, GBNs, BASE_FLOPS, BASE_PARAM):
        LOGS = []
        flops_save_points = set([30, 20, 10])
        iter_idx = 0
    
        pack.tick_trainset = pack.train_loader
        prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
        # 先所有数据迭代cfg.gbn.tock_epoch次
        prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
        while True:
            left_filter = prune_agent.total_filters - prune_agent.pruned_filters
            num_to_prune = int(left_filter * cfg.gbn.p) # 用来确定阈值
            info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #tick一次并计算分数
            flops, params = eval_prune(pack)
            info.update({ #查看这次剪枝后的结果
                'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
                'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
            })
            LOGS.append(info)
            print('Iter: %d,	 FLOPS: %s,	 Param: %s,	 Left: %d,	 Pruned Ratio: %.2f %%,	 Train Loss: %.4f,	 Test Acc: %.2f' % 
                (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
            
            iter_idx += 1
            if iter_idx % cfg.gbn.T == 0: #T=10,即10次Tick后来tock_epoch=10次Tock
                print('Tocking:')
                prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    
            flops_ratio = flops/BASE_FLOPS * 100 #减少到原来的多少
            for point in [i for i in list(flops_save_points)]:
                if flops_ratio <= point:#比如现在flops_ratio小于30%但是大于20%,就会存下现在的状态,并删掉对应的30 point
                    torch.save(pack.net.module.state_dict(), './logs/vgg16_cifar10/gbn_%s.ckp' % str(point))
                    flops_save_points.remove(point)
    
            if len(flops_save_points) == 0:#当为0的时候,该Tick-Tock就结束了
                break
    
    
    def run():
        pack, GBNs = get_pack()
    
        cloned, _ = clone_model(pack.net)
    #     BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32).cuda()) #计算一开始预训练好的模型的Flops和内存
        BASE_FLOPS, BASE_PARAM = analyse_model(cloned, torch.randn(1, 3, 32, 32))
        print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
        print('%.3f M' % (BASE_PARAM / 1e6))
        del cloned
    
        prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) # 进行Tick-Tock操作
    
        _ = Conv2dObserver.transform(pack.net.module)
        pack.net.module.classifier = FinalLinearObserver(pack.net.module.classifier)
        Meltable.observe(pack, 0.001)
        Meltable.melt_all(pack.net)
    
        pack.optimizer = optim.SGD(
            pack.net.parameters(),
            lr=1,
            momentum=cfg.train.momentum,
            weight_decay=cfg.train.weight_decay,
            nesterov=cfg.train.nesterov
        )
    
        _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)
    
    run()
    View Code

    感觉比较重要的代码是prune/universal.py:

    这里有转换bn层、conv层和FinalLinear层的类,还有Tick-Tock操作的类:

    """
     * Copyright (C) 2019 Zhonghui You
     * If you are using this code in your research, please cite the paper:
     * Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019.
    """
    
    import torch
    import torch.nn as nn
    
    import numpy as np
    import uuid
    
    OBSERVE_TIMES = 5
    FINISH_SIGNAL = 'finish'
    
    class Meltable(nn.Module):
        def __init__(self):
            super(Meltable, self).__init__()
    
        @classmethod
        def melt_all(cls, net):
            def _melt(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _melt(modules[k]._modules)
                    if isinstance(modules[k], Meltable):
                        modules[k] = modules[k].melt() #根据此时的g恢复所有的参数
    
            _melt(net._modules)
    
        @classmethod
        def observe(cls, pack, lr):
            tmp = pack.train_loader
            if pack.tick_trainset is not None:
                pack.train_loader = pack.tick_trainset
    
            for m in pack.net.modules():
                if isinstance(m, nn.BatchNorm2d): #这个用来干嘛的??
                    m.weight.data.abs_().add_(1e-3)
    
            def replace_relu(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        replace_relu(modules[k]._modules)
                    if isinstance(modules[k], nn.ReLU):
                        modules[k] = nn.LeakyReLU(inplace=True)
            replace_relu(pack.net._modules)
    
            count = 0
            def _freeze_bn(curr_iter, total_iter):
                for m in pack.net.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                nonlocal count
                count += 1
                if count == OBSERVE_TIMES:
                    return FINISH_SIGNAL
            info = pack.trainer.train(pack, iter_hook=_freeze_bn, update=False, mute=True)
    
            def recover_relu(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        recover_relu(modules[k]._modules)
                    if isinstance(modules[k], nn.LeakyReLU):
                        modules[k] = nn.ReLU(inplace=True)
            recover_relu(pack.net._modules)
    
            for m in pack.net.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.weight.data.abs_().add_(-1e-3)
    
            pack.train_loader = tmp
    
    
    class GatedBatchNorm2d(Meltable):
        def __init__(self, bn, minimal_ratio = 0.1):
            super(GatedBatchNorm2d, self).__init__()
            assert isinstance(bn, nn.BatchNorm2d)
            self.bn = bn
            self.group_id = uuid.uuid1()
    
            self.channel_size = bn.weight.shape[0]
            self.minimal_filter = max(1, int(self.channel_size * minimal_ratio))
            self.device = bn.weight.device
            self._hook = None
    
            self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True)
            self.register_buffer('area', torch.zeros(1).to(self.device)) #记录此时输入的图像大小
            self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device))
            #这个值要么0要么1,根据得到的分数来得到self.masks,然后设置g.bn_mask.set_(mask * g.bn_mask),这个才是决定channels留下来与否的值
            self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device)) 
            
            self.extract_from_bn()
    
        def set_groupid(self, new_id):
            self.group_id = new_id
    
        def extra_repr(self):
            return '%d -> %d | ID: %s' % (self.channel_size, int(self.bn_mask.sum()), self.group_id)
    
        def extract_from_bn(self):
            # freeze bn weight
            with torch.no_grad():
                self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10))
                self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
                self.bn.weight.set_(torch.ones_like(self.bn.weight))
                self.bn.weight.requires_grad = False
    
        def reset_score(self):
            self.score.zero_()
    
        def cal_score(self, grad):
            # used for hook
            self.score += (grad * self.g).abs()
    
        def start_collecting_scores(self):
            if self._hook is not None:
                self._hook.remove()
    
            self._hook = self.g.register_hook(self.cal_score)
    
        def stop_collecting_scores(self):
            if self._hook is not None:
                self._hook.remove()
                self._hook = None
        
        def get_score(self, eta=0.0): #eta表示什么?,如果为0,其实就是score = self.score * self.bn_mask
            # use self.bn_mask.sum() to calculate the number of input channel. eta should had been normed
            flops_reg = eta * int(self.area[0]) * self.bn_mask.sum()
            return ((self.score - flops_reg) * self.bn_mask).view(-1)
    
        def forward(self, x): # train时就会调用这个函数,输出与self.g相乘
            x = self.bn(x) * self.g
    
            self.area[0] = x.shape[-1] * x.shape[-2]
    
            if self.bn_mask is not None:
                return x * self.bn_mask #只留下bn_mask中值不为0的channels对应的x的值
            return x
    
        def melt(self): #训练完了,恢复参数的函数
            with torch.no_grad():
                mask = self.bn_mask.view(-1)
                replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device)
                replacer.running_var.set_(self.bn.running_var[mask != 0])
                replacer.running_mean.set_(self.bn.running_mean[mask != 0])
                replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0])
                replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0])
            return replacer
    
        @classmethod
        def transform(cls, net, minimal_ratio=0.1):
            r = []
            def _inject(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _inject(modules[k]._modules)
                    if isinstance(modules[k], nn.BatchNorm2d):
                        modules[k] = GatedBatchNorm2d(modules[k], minimal_ratio) #将原来的BN层改成GBN层
                        r.append(modules[k])
            _inject(net._modules)
            return r
    
    
    class FinalLinearObserver(Meltable):
        ''' assert was in the last layer. only input was masked '''
        def __init__(self, linear):
            super(FinalLinearObserver, self).__init__()
            assert isinstance(linear, nn.Linear)
            self.linear = linear
            self.in_mask = torch.zeros(linear.weight.shape[1]).to('cpu')
            self.f_hook = linear.register_forward_hook(self._forward_hook)
        
        def extra_repr(self):
            return '(%d, %d) -> (%d, %d)' % (
                int(self.linear.weight.shape[1]),
                int(self.linear.weight.shape[0]),
                int((self.in_mask != 0).sum()),
                int(self.linear.weight.shape[0]))
    
        def _forward_hook(self, m, _in, _out):
            x = _in[0]
            self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1)
    
        def forward(self, x):
            return self.linear(x)
    
        def melt(self):
            with torch.no_grad():
                replacer = nn.Linear(int((self.in_mask != 0).sum()), self.linear.weight.shape[0]).to(self.linear.weight.device)
                replacer.weight.set_(self.linear.weight[:, self.in_mask != 0])
                replacer.bias.set_(self.linear.bias)
            return replacer
    
    
    class Conv2dObserver(Meltable):
        def __init__(self, conv):
            super(Conv2dObserver, self).__init__()
            assert isinstance(conv, nn.Conv2d)
            self.conv = conv
            self.in_mask = torch.zeros(conv.in_channels).to('cpu')
            self.out_mask = torch.zeros(conv.out_channels).to('cpu')
            self.f_hook = conv.register_forward_hook(self._forward_hook)
    
        def extra_repr(self):
            return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum()))
        
        def _forward_hook(self, m, _in, _out):
            x = _in[0]
            self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
    
        def _backward_hook(self, grad):
            self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
            new_grad = torch.ones_like(grad)
            return new_grad
    
        def forward(self, x):
            output = self.conv(x)
            noise = torch.zeros_like(output).normal_()
            output = output + noise
            if self.training:
                output.register_hook(self._backward_hook)
            return output
    
        def melt(self):
            if self.conv.groups == 1:
                groups = 1
            elif self.conv.groups == self.conv.out_channels:
                groups = int((self.out_mask != 0).sum())
            else:
                assert False
    
            replacer = nn.Conv2d(
                in_channels = int((self.in_mask != 0).sum()),
                out_channels = int((self.out_mask != 0).sum()),
                kernel_size = self.conv.kernel_size,
                stride = self.conv.stride,
                padding = self.conv.padding,
                dilation = self.conv.dilation,
                groups = groups,
                bias = (self.conv.bias is not None)
            ).to(self.conv.weight.device)
    
            with torch.no_grad():
                if self.conv.groups == 1:
                    replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
                else:
                    replacer.weight.set_(self.conv.weight[self.out_mask != 0])
                if self.conv.bias is not None:
                    replacer.bias.set_(self.conv.bias[self.out_mask != 0])
            return replacer
        
        @classmethod
        def transform(cls, net):
            r = []
            def _inject(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _inject(modules[k]._modules)
                    if isinstance(modules[k], nn.Conv2d):
                        modules[k] = Conv2dObserver(modules[k])
                        r.append(modules[k])
            _inject(net._modules)
            return r
    
    # -------------------------------------------------------------------------------------------------------
    
    def get_gate_sparse_loss(masks, sparse_lambda):
        def _loss_hook(data, label, logits):
            loss = 0.0
            for gbn in masks:
                if isinstance(gbn, GatedBatchNorm2d):
                    loss += gbn.g.abs().sum()
            return sparse_lambda * loss
    
        return _loss_hook
    
    class IterRecoverFramework():
        def __init__(self, pack, masks, sparse_lambda=1e-5, flops_eta=0.0, minium_filter=10):
            self.pack = pack
            self.masks = masks
            self.sparse_loss_hook = get_gate_sparse_loss(masks, sparse_lambda) #计算tock的损失的后半部分
            self.logs = []
            # minium_filter would be delete
            # self.minium_filter = minium_filter
            self.sparse_lambda = sparse_lambda
            self.flops_eta = flops_eta
            self.eta_scale_factor = 1.0
    
            self.total_filters = sum([m.bn.weight.shape[0] for m in masks])
            self.pruned_filters = 0
    
        def recover(self, lr, test):
            for gbn in self.masks:
                if isinstance(gbn, GatedBatchNorm2d):
                    gbn.reset_score()
                    gbn.start_collecting_scores()
    
            for g in self.pack.optimizer.param_groups:
                g['lr'] = lr
    
            tmp = self.pack.train_loader
            self.pack.train_loader = self.pack.tick_trainset #使用训练集的子集
            info = self.pack.trainer.train(self.pack) #执行Tick,只更新gate φ和最后的线性层参数
            self.pack.train_loader = tmp
    
            if test:
                info.update(self.pack.trainer.test(self.pack))
    
            info.update({'LR': lr})
    
            for gbn in self.masks:
                if isinstance(gbn, GatedBatchNorm2d):
                    gbn.stop_collecting_scores()
            
            return info
    
        def get_threshold(self, status, num):
            '''
                input score list from layers, and the number of filter to prune
            '''
            total_filters, left_filters = 0, 0
            filtered_score_list = []
            
            for group_id, v in status.items():
                total_filters += len(v['score']) * v['count'] #count>1,说明分group了,对应channels有着相同的分数
                left_filters += int((v['score'] != 0).sum()) * v['count']
    
                sorted_score = np.sort(v['score'])[:-v['minimal']] #-v['minimal']之后的分数是不要的,按比例抛弃
                filtered_score = sorted_score[sorted_score != 0]
                for i in range(v['count']):
                    filtered_score_list.append(filtered_score)#因为相同组的channels分数是相同的,所以append v['count']次
    
            scores = np.concatenate(filtered_score_list) #将所有GBN中的channels的分数串联在一起
            threshold = np.sort(scores)[num] #然后再排序,取num索引的值作为阈值
            to_prune = int((scores <= threshold).sum()) #分数小于这个阈值的channels也prune
    
            info = {'left': left_filters, 'to_prune': to_prune, 'total_pruned_ratio': (total_filters - left_filters + to_prune) / total_filters}
            return threshold, info
    
        #这里的作用就是根据训练得到的分数和阈值计算出mask,用于更改之前的g.bn_mask,这个才是决定channels的值石佛耦留下来的值
        # 因为在BGN的forward中输出的x为x * self.bn_mask
        def set_mask(self, status, threshold): #这里的self.masks是GBNs
            for group_id, v in status.items():
                hard_threshold = float(np.sort(v['score'])[-v['minimal']]) #根据按比例得到的v['minimal'] = max(1, int(self.channel_size * minimal_ratio)),得到该位置的分数,说明小于这个分数的channels是一定要prune的
                hard_mask = v['score'] >= hard_threshold #留下的channels更多
                soft_mask = v['score'] > threshold #留下的channels少
                v['mask'] = (hard_mask + soft_mask)
    
            with torch.no_grad():
                for g in self.masks:
                    if g.group_id in status:
                        mask = torch.from_numpy(status[g.group_id]['mask'].astype('float32')).to(g.device).view(1, -1, 1, 1)
                        g.bn_mask.set_(mask * g.bn_mask)
    
        def freeze_conv(self): #Tick训练时不更新conv的参数,所以freeze它们
            self._status = {}
            for m in self.pack.net.modules():
                if isinstance(m, nn.Conv2d):
                    for p in m.parameters():
                        self._status[id(p)] = p.requires_grad
                        p.requires_grad = False
    
        def restore_conv(self):
            for m in self.pack.net.modules():
                if isinstance(m, nn.Conv2d):
                    for p in m.parameters():
                        p.requires_grad = self._status[id(p)]
    
        def tock(self, lr_min=0.001, lr_max=0.01, tock_epoch = 20, mute=False, acc_step=1): #损失有一个额外的sparse loss,所以loss_hook有值,训练改变所有的参数,只是没有计算score,参数g的作用就是用于计算score
            logs = []
            epoch = 0
            T = tock_epoch
            def iter_hook(curr_iter, total_iter):
                total = T * total_iter
                half = total / 2
                itered = epoch * total_iter + curr_iter
                if itered < half:
                    _iter = epoch * total_iter + curr_iter
                    _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max
                else:
                    _iter = (epoch - T/2) * total_iter + curr_iter
                    _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min
                
                for g in self.pack.optimizer.param_groups:
                    g['lr'] = max(_lr, 0)
                    # g['momentum'] = 0.9
            
            for i in range(T):# 迭代T次
                info = self.pack.trainer.train(self.pack, loss_hook = self.sparse_loss_hook, iter_hook = iter_hook, acc_step=acc_step)
                info.update(self.pack.trainer.test(self.pack))
                info.update({'LR': self.pack.optimizer.param_groups[0]['lr']})
                epoch += 1
                if not mute:
                    # print('Tock - %d,	 Test Loss: %.4f,	 Test Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['acc@1'], info['LR']))
                    print('Tock - %d,	 Test Loss: %.4f,	 Test age_correct Acc: %.2f, Test gender_correct Acc: %.2f, Final LR: %.5f' % (i, info['test_loss'], info['age_correct'], info['gender_correct'], info['LR']))
                logs.append(info)
            return logs
    
        def tick(self, lr, test):
            ''' Do Prune '''
            self.freeze_conv()
            info = self.recover(lr, test)
            self.restore_conv()
            return info
    
        def prune(self, num, tick=False, lr=0.01, test=True):
            info = {}
            if tick:
                info = self.tick(lr, test)
    
                area = []
                for g in self.masks:
                    area.append(int(g.area[0]))
                self.eta_scale_factor = min(area)
    
            status = {}
            for g in self.masks:
                if g.group_id in status:
                    # assert the gbn in same group has the same channel size
                    status[g.group_id]['score'] += g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy()
                    status[g.group_id]['count'] += 1
                else:
                    status[g.group_id] = {
                        'score': g.get_score(self.flops_eta / self.eta_scale_factor).cpu().data.numpy(),
                        'minimal': g.minimal_filter,
                        'count': 1,
                        'mask': None
                    }
            
            threshold, r = self.get_threshold(status, num)
            info.update(r)
            threshold = float(threshold)
            self.set_mask(status, threshold)
            if test:
                info.update({'after_prune_test_age_acc': self.pack.trainer.test(self.pack)['age_correct']})
                info.update({'after_prune_test_gender_acc': self.pack.trainer.test(self.pack)['gender_correct']})
            self.logs.append(info)
            self.pruned_filters = self.total_filters - info['left']
            info['total'] = self.total_filters
            return info
    View Code

    Resnet

    感觉Resnet和VGG的主要差别在与Resnet有侧枝,所以需要对BN层分组:

    1)resnet-56/resnet56_prune.ipynb

    一步步向下运行:

    GBNs = GatedBatchNorm2d.transform(pack.net)#GatedBatchNorm2d初始化时调用的self.extract_from_bn() 是用于一开始最外层的bn层的
    print(GBNs) #extra_repr(self)函数设置了返回的额外内容,此时的ID是各不相同的
    for gbn in GBNs: #这个再一层层地初始化
        gbn.extract_from_bn() #bn层的权重训练时不变,只有g变

    返回:

    [GatedBatchNorm2d(
      16 -> 16 | ID: 94d38830-2d48-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 94d5b330-2d48-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 94d5bb32-2d48-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), 
    ...
    import uuid
    
    def bottleneck_set_group(net): #将3个layers的bn分成3个组
        layers = [
            net.module.layer1,
            net.module.layer2,
            net.module.layer3
        ]
        for m in layers:
            masks = []
            if m == net.module.layer1: #将layer1这个分组之前的一个bn层添加进来
                masks.append(pack.net.module.bn1)
            for mm in m.modules():
                if mm.__class__.__name__ == 'BasicBlock':
                    if len(mm.shortcut._modules) > 0: #说明shortcut是resnet两个channels不同的layer层的过渡操作
                        masks.append(mm.shortcut._modules['1']) #这里面也有一个bn层
                    masks.append(mm.bn2) #bn1不加吗?
    
            group_id = uuid.uuid1()
            for mk in masks: #masks中的每个值都是一个
                mk.set_groupid(group_id) #这个是GatedBatchNorm2d中设置的函数,仅将bn2的id更改成新的
            print(masks)
    
    bottleneck_set_group(pack.net)

    返回:

    [GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      16 -> 16 | ID: 0e2ca1f2-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )]
    [GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      32 -> 32 | ID: 0e2cc826-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )]
    [GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ), GatedBatchNorm2d(
      64 -> 64 | ID: 0e2d0728-2d4a-11ea-ba2e-00e04c6841ff
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )]
    View Code

    可见分成3组,每组的GBN的ID是想用的

    此时的网络结构是:

    for name, child in pack.net.named_children():
        print(name)
        print(child)

    通过将同一个group的id设置为同一个id来说明它们是同一个group的,这样同一个组的BN的bn_mask值是相同的,即它们的channels数也是相同的,这样就能连起来了

    然后克隆了一个一样的模型来计算模型的FLOPs和参数大小:

    def clone_model(net):
        model = get_model()
        gbns = GatedBatchNorm2d.transform(model.module)
        model.load_state_dict(net.state_dict())
        return model, gbns
    
    cloned, _ = clone_model(pack.net)
    #BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32).cuda())
    BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32))
    print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
    print('%.3f M' % (BASE_PARAM / 1e6))

    for name, child in cloned.named_children():
        print(name)
        print(child)
        
    del cloned

    除了id不同其他一致,而且del cloned删除了该模型,说明该模型后面没用

    返回:

    127.932 MFLOPS
    0.856 M

    上面的clone_model(net)操作其实是为了下面的函数eval_prune(pack)服务的:

    def eval_prune(pack):
        cloned, _ = clone_model(pack.net)
        _ = Conv2dObserver.transform(cloned.module)
        cloned.module.linear = FinalLinearObserver(cloned.module.linear)
        cloned_pack = dotdict(pack.copy())
        cloned_pack.net = cloned
        Meltable.observe(cloned_pack, 0.001)
        Meltable.melt_all(cloned_pack.net)
        flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
        del cloned
        del cloned_pack
        
        return flops, params

    其实就是用于计算此时的模型的flops, params,用来与原始模型比较,查看此时压缩了多少

    接下来就是测试此时的模型,看看效果:

    pack.trainer.test(pack)

    返回:

    {'test_loss': 0.31250936849207817, 'acc@1': 92.92919303797468}

    然后就是tick-tock操作:

    pack.tick_trainset = pack.train_loader
    prune_agent = IterRecoverFramework(pack, GBNs, sparse_lambda = cfg.gbn.sparse_lambda, flops_eta = cfg.gbn.flops_eta, minium_filter = 3)
    
    LOGS = []
    flops_save_points = set([40, 38, 35, 32, 30]) #当压缩到原来模型的40%、38%...时保存模型
    
    iter_idx = 0
    # 先进行一个tock,训练tock_epoch次,查看
    prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    while True:
        left_filter = prune_agent.total_filters - prune_agent.pruned_filters
        num_to_prune = int(left_filter * cfg.gbn.p)
        info = prune_agent.prune(num_to_prune, tick=True, lr=cfg.gbn.lr_min) #进行tick操作
        flops, params = eval_prune(pack) #计算此时的模型
        info.update({
            'flops': '[%.2f%%] %.3f MFLOPS' % (flops/BASE_FLOPS * 100, flops / 1e6),
            'param': '[%.2f%%] %.3f M' % (params/BASE_PARAM * 100, params / 1e6)
        })
        LOGS.append(info)
        print('Iter: %d,	 FLOPS: %s,	 Param: %s,	 Left: %d,	 Pruned Ratio: %.2f %%,	 Train Loss: %.4f,	 Test Acc: %.2f' % 
              (iter_idx, info['flops'], info['param'], info['left'], info['total_pruned_ratio'] * 100, info['train_loss'], info['after_prune_test_acc']))
        
        iter_idx += 1
        if iter_idx % cfg.gbn.T == 0: #即每gbn.T=10次tick操作后进行gbn.tock_epoch次tock操作
            print('Tocking:')
            prune_agent.tock(lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, tock_epoch=cfg.gbn.tock_epoch)
    
        flops_ratio = flops/BASE_FLOPS * 100 #计算此时的模型占原来模型的百分比
        for point in [i for i in list(flops_save_points)]: #当压缩到原来模型的40%、38%...时保存模型
            if flops_ratio <= point:
                torch.save(pack.net.module.state_dict(), './logs/resnet56_cifar10_ticktock/%s.ckp' % str(point))
                flops_save_points.remove(point)
    
        if len(flops_save_points) == 0: #当压缩到30%时就停止压缩
            break

    其finetune操作在resnet-56/finetune.ipynb

    核心就是先将模型根据bn_mask值剪枝,不仅剪BN层,还要见Conv2d层和FinalLinear层,即:

    GBNs = GatedBatchNorm2d.transform(pack.net)
    for gbn in GBNs:
        gbn.extract_from_bn()
    
    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.linear = FinalLinearObserver(pack.net.module.linear)
    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net) # 剪枝

    然后进行finetune:

    _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)

    这里有一个特别的写法,记录一下

    保证路径为主路径的方法:python保证路径为主路径的方法

  • 相关阅读:
    1343. Fairy Tale
    Codeforces Beta Round #97 (Div. 1)
    URAL1091. Tmutarakan Exams(容斥)
    1141. RSA Attack(RSA)
    hdu4003Find Metal Mineral(树形DP)
    hdu2196 Computer待续
    KMP
    莫比乌斯反演
    配对堆
    bzoj3224Treap
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/12148709.html
Copyright © 2020-2023  润新知