• pytorch学习笔记——训练时显存逐渐增加,几个epoch后outofmemory


    问题起因:笔者想把别人的torch的代码复制到笔者的代码框架下,从而引起的显存爆炸问题

    该bug在困扰了笔者三天的情况下,和学长一同解决了该bug,故在此记录这次艰辛的debug之路。

    尝试思路1:检查是否存在保留loss的情况下是否使用了 item() 取值,经检查,并没有

    尝试思路2:按照网上的说法,添加两行下面的代码:

    torch.backends.cudnn.enabled = True
    
    torch.backends.cudnn.benchmark = True

    实测发现并没有用。

    尝试思路3:及时删除临时变量和清空显存的cache,例如每次训练一个batch就清除模型的输入输出。

    del inputs,loss
    gc.collect()
    torch.cuda.empty_cache()

    这样确实使得模型能够多训练几个epoch,但依旧没有解决显存持续增长的问题,而且由于频繁使用torch.cuda.empty_cache(),导致模型一个epoch的训练时长翻了3倍多

    尝试思路4:重新核对原模型代码,打印模型中所有parameters和register_buffer的require_grad,终于发现是因为模型中的某个register_buffer在训练过程中,它的require_grad本应该为False,然而迁移到我代码上的实际训练过程中变成了True,而这个buffer的占用数据空间也不大,可能是因为变为True之后,导致在显存中一直被保留,从而最终导致显存溢出。再将那个buffer在forward函数里的操作放在torch.no_grad()上下文中,问题解决!

    总结:如果训练时显存占用持续增加,需要谨慎的检查forward函数中的操作,尤其是在编写复杂代码的时候,更需要更细致的检查

  • 相关阅读:
    C#文件读写常用类介绍
    C#实现注销、重启和关机代码
    Mybatis学习---基础知识考核
    Linux操作系统各版本ISO镜像下载
    Java学习---JDK的安装和配置
    Java学习---基础知识学习
    Java学习---常见的模式
    Java实例---黑白五子棋[单机版]
    Java实例---简单的超市管理系统
    Java实例---简单的个人管理系统
  • 原文地址:https://www.cnblogs.com/ISGuXing/p/16079734.html
Copyright © 2020-2023  润新知