• 神经网络模型的参数量和FLOPs的计算,torchstat,thop,循环神经网络RNN


    概念

    FLOPS和FLOPs的区别:

    • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
    • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

    卷积的FLOPs的计算过程可参考:CNN 模型所需的计算力flops是什么?怎么计算? - 知乎 (zhihu.com)

    在介绍torchstat包和thop包之前,先总结一下:

    • torchstat包可以统计卷积神经网络和全连接神经网络的参数和计算量。
    • thop包可以统计统计卷积神经网络、全连接神经网络以及循环神经网络的参数和计算量,程序示例等详见下文。

     

    torchstat包

    在实际操作中,我们可以调用torchstat包,帮助我们统计模型的parameters和FLOPs。如果不修改这个包里面的一些代码,那么这个包只适用于输入为3通道的图像的模型。

    程序示例一:统计卷积神经网络的参数和计算量:

    import torch
    import torch.nn as nn
    from torchstat import stat
    
    
    class Simple(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            return x
    
    
    model = Simple()
    stat(model, (3, 244, 244))   # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size
    

    运行结果:

    如果把torchstat包中的一行程序进行一点点改动,那么这个包可以用来统计全连接神经网络的参数量和计算量。当然手动计算全连接神经网络的参数量和计算量也很快 =_= 。进入torchstat源代码之后,如下图所示,注释掉圈红的地方,就可以用torchstat包统计全连接神经网络的参数量和计算量了。

    程序示例二,用torchstat统计全连接神经网络的参数量和计算量::

    import torch.nn as nn
    from torchstat import stat
    
    class Simple(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(10, 10)
    
        def forward(self, x):
            x = self.fc1(x)
            return x
    
    net = Simple()
    stat(net, (10,))
    

    运行结果:

    我本来想把torchstat源码改一下,使得它可以用来统计循环神经网络的参数量和计算量,但好像稍微有点麻烦,然后网上找到thop包,可以用来统计循环神经网络的参数量和计算量,所以改torchstat源码的动力就...没有了 =_=。

    thop包

    除了torchstat包之外,thop也可以统计神经网络模型的参数量。thop的安装可以参考:GitHub - Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model.

    程序示例三,用thop统计全连接神经网络的参数量和计算量:

    import torch
    import torch.nn as nn
    from thop import profile
    
    class Simple(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(10, 10)
    
        def forward(self, x):
            x = self.fc1(x)
            return x
    
    net = Simple()
    input = torch.randn(1, 10)  # batchsize=1, 输入向量长度为10
    macs, params = profile(net, inputs=(input, ))
    print(' FLOPs: ', macs*2)   # 一般来讲,FLOPs是macs的两倍
    print('params: ', params)
    

    运行结果:

    程序示例四,用thop统计循环神经网络的参数量和计算量

    import torch
    import torch.nn as nn
    from thop import profile
    
    class Simple(nn.Module):
        def __init__(self):
            super().__init__()
            self.lstm = nn.LSTM(input_size=100, hidden_size=100, num_layers=1)
    
        def forward(self, x):
            x = self.lstm(x)
            return x
    
    net = Simple()
    input = torch.randn(1, 10, 100)  # batchsize=1,序列长度为10,序列中每个时间步的向量长度为100
    macs, params = profile(net, inputs=(input, ))
    print(' FLOPs: ', macs*2)   # 一般来讲,FLOPs是macs的两倍
    print('params: ', params)
    

    运行结果:

  • 相关阅读:
    va_start和va_end使用详解
    Visual Assist X设置
    google 快捷键
    /bin/sh^M: bad interpreter: No such file or directory 异常
    动态链接库的学习(一)
    sprintf函数的用法详解
    错误:在 C99 模式之外使用‘for’循环初始化声明
    VC6.0在win7下显示行号的插件
    错误: 程序中有游离的‘\302’ ‘\240’等
    Linux Shell编程笔记一:相关命令
  • 原文地址:https://www.cnblogs.com/picassooo/p/16343737.html
Copyright © 2020-2023  润新知