• 梯度累加实现 “显存扩大"


    参考:PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎
    pytorch会在每一次backward()后进行梯度计算,但是梯度不会自动归零,如果不进行手动归零的话,梯度会不断累加

    1.1 传统的训练一个 batch 的流程如下:

    for i, (images, target) in enumerate(train_loader):
        # 1. input output
        images = images.cuda(non_blocking=True)
        target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, target)
        
        # 2. backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    • 获取loss: 输入图像和标签,通过infer计算得到预测值,计算损失函数
    • optimizer.zero_grad()清空过往梯度
    • loss.backward()反向传播,计算当前梯度
    • optimizer.step()根据梯度更新网络参数

    即进来一个batch的数据,就计算一次梯度,更新一次网络

    1.2 使用梯度累加

    for i, (images, target) in enumerate(train_loader):
        # 1. input output
        images = images.cuda(non_blocking=True)
        target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
        outputs = model(imgaes)
        loss = criterion(outputs, target)
    
        # 2.1 loss regularization
        loss = loss / accumulation_steps # loss每次都会更新,因此每次都除以steps再加到原来的梯度上面去
        
        # 2.2 backward propagation
        loss.backward()
    
        # 3. update parameters of net
        if ((i+1)%accumulation)==0:
            # optimizer the net
            optimizer.step()
            optimizer.zero_grad() # reset grdient
    
    • 获取loss: 输入图像和标签,通过infer计算得到预测值,计算损失函数
    • loss.backward()反向传播,计算当前梯度
    • 多次循环步骤 1-2, 不清空梯度,使梯度累加在已有梯度上
    • 梯度累加一定次数后,先optimizer.step()根据累积的梯度更新网络参数,然后optimizer.zero_grad()清空过往梯度,为下一波梯度累加做准备

    总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
    一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。
    BN的估算是在forward阶段就已经完成的,并不冲突,只是accumulation_steps=8和真实的batchsize放大八倍相比,效果自然是差一些,毕竟八倍Batchsize的BN估算出来的均值和方差肯定更精准一些。

  • 相关阅读:
    Android之Activity启动过程
    Android之Application进阶
    Android之Context进阶
    Thread之ThreadLocal
    Android 系统服务与Binder应用服务
    Android Binder
    Android SystemServer
    Android系统服务与服务注册
    Android Binder进阶扁一
    小米商城-题头3
  • 原文地址:https://www.cnblogs.com/qiulinzhang/p/11169236.html
Copyright © 2020-2023  润新知