最近在做比赛的时候,遇到了一个最好结果,但是之后无论怎样都复现不出来最好结果了。猜测是不是跟Pytorch中的随机种子有关。
训练过程
在训练过程中,若相同的数据数据集,相同的训练集、测试集划分方式,相同的权重初始化,但是每次训练结果不同,可能有以下几个原因:
- Dropout的存在
- Pytorch、Python、Numpy中的随机种子没有固定
- 数据预处理、增强方式采用了概率,若没有设置固定的随机种子,结果可能不同。例如常用数据增强库albumentations就采用了Python的随机产生器。
- 训练数据集被随机打乱了顺序
- 向上采样和插值函数/类的向后是不确定的(Pytorch的问题)
另外,在Pytorch官方文档中说明了在Pytorch的不同提交、不同版本和不同平台上,不能保证完全可重现的结果。此外,即使使用相同的种子,因为存在不同的CPU和GPU,结果也不能重现。
但是对于一个特定的平台和PyTorch发行版上对您的特定问题进行确定性的计算,需要采取几个步骤。在can’t reproduce results even set all random seeds说明了两种解决方式:
can’t reproduce results even set all random seeds#7068 (comment1)建议采用下面方式解决:
- 在运行任何程序之前写入下面代码(可以放在主代码的开头)
1 | torch.manual_seed(seed) |
在Pytorch的
DataLoader
函数中填入为不同的work
设置初始化函数,确保您的dataloader在每次调用时都以相同的顺序加载样本(随机种子固定时)。如果进行裁剪或其他预处理步骤,请确保它们是确定性的。1
2
3def (worker_id):
np.random.seed(int(seed)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, worker_init_fn=_init_fn)向上采样和插值函数/类的向后是不确定的(请参见此处)。这意味着,如果你在训练图中使用这样的模块,无论你做什么,都永远不会得到确定性的结果。
torch.nn.ConvTranspose2d
函数是不确定的,除非你使用torch.backends.cudnn.deterministic = True
(原文中说you can try to make the operation deterministic ... by setting torch.backends.cudnn.deterministic = True
,所以这样做是否能够得到正确结果也是待定的)。
在can’t reproduce results even set all random seeds#7068 (comment2)建议采用下面方式解决:在运行任何程序之前写入下面代码(可以放在主代码的开头)
1 | def seed_torch(seed=1029): |
测试过程
相同的权重,相同的测试数据集,结果不同,可能有以下几个原因:
- 未设定
eval()
模式,因为模型中的Dropout和Batchnorm存在,导致结果不固定 - Pytorch、Python、Numpy中的随机种子没有固定,可能运行时依赖的一些第三方库有随机性
- 数据预处理方式中含有概率
- 向上采样和插值函数/类的向后是不确定的(Pytorch的问题)
代码随机种子的设定
有的时候,不同的随机种子对应的神经网络结果不同,我们并不想固定随机种子,使其能够搜索最优结果。但是又想能够根据复现最优结果,所以我们需要每次运行代码都根据当前时间设定不同的随机种子,并将随机种子保存下来。
可以使用下面代码产生随机种子,用于固定Pytorch、Python、Numpy中的随机种子,你可以将这个值保存到特定的文件中,用于之后使用。
1 | seed = int(time.time() * 256) |
Python默认随机种子
首先,确定Python随机模块所在位置
1 | import random |
例如,我这里的路径为/home/zdkit/miniconda3/lib/python3.7/random.py
。打开该文件,可以看到默认生成的随机种子采用如下方式:
1 | if a is None: |
也就是当没有给定随机种子的话,则此方法尝试使用OS提供的默认随机生成器,如果没有,则使用当前时间作为种子值。那么_urandom
是什么呢?我们继续进行探索。
1 | from os import urandom as _urandom |
上述代码得到了长度为2500的一串数字,这里我只给出了部分数字。这个随机数是由系统基于硬件中断给出的,硬件中断是非常随机的(它包括硬盘读取的中断、由用户键入的按键、移动鼠标等),所以已经很接近随机了。
参考
can’t reproduce results even set all random seeds
can’t reproduce results even set all random seeds#7068 (comment1)
can’t reproduce results even set all random seeds#7068 (comment2)