• pytorch训练过程中Loss的保存与读取、绘制Loss图


    在训练神经网络的过程中往往要定时记录Loss的值,以便查看训练过程和方便调参。一般可以借助tensorboard等工具实时地可视化Loss情况,也可以手写实时绘制Loss的函数。基于自己的需要,我要将每次训练之后的Loss保存到文件夹中之后再统一整理,因此这里总结两种保存loss到文件的方法以及读取Loss并绘图的方法。

    一、采用torch.save(tensor, 'file_name')方法:

    for epoch in range(config.NUM_EPOCH)
        #...中间略
        loss = criterion(outputs,ground_truth)  # 计算损失值
        running_loss = loss.item()  # loss累加
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  # 反向传播后参数更新
        if i % 200 == 199:
            Loss.append(running_loss)
            print('Epoch '+str(epoch)+' : '+str(i//200)+' , LOSS ='+str(running_loss))
    Loss0 = torch.tensor(Loss)
    torch.save(Loss0,'/home/wangshuyu/MobileNet_v1/loss2/epoch_{}'.format(epoch))        

    将每个epoch中的Loss存在一个list中,最后转成tensor类型存到文件中。

    二、采用np.save('file_name', np_array)方法

    for epoch in range(config.NUM_EPOCH)
        #...中间略
        loss = criterion(outputs,ground_truth)  # 计算损失值
        running_loss = loss.item()  # loss累加
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  # 反向传播后参数更新
        if i % 200 == 199:
            Loss.append(running_loss)
            print('Epoch '+str(epoch)+' : '+str(i//200)+' , LOSS ='+str(running_loss))
    Loss0 = np.array(Loss)
    np.save('/home/wangshuyu/MobileNet_v1/loss2/epoch_{}'.format(epoch),Loss0) 

    np.save默认保存为.npy格式

    另外,也可以使用np.savez方法将每个epoch的Loss数据压缩保存在同一个文件中(.npz文件),详情可以参考博客:numpy数据存储(save、savetxt、savez)的区别

    三、读取并绘制Loss曲线

    import matplotlib.pyplot as plt
    import torch
    import numpy as np
    
    def plot_loss(n):
        y = []
        for i in range(0,n):
            enc = np.load('D:MobileNet_v1plan1-AddsingleLayerlossepoch_{}.npy'.format(i))
            # enc = torch.load('D:MobileNet_v1plan1-AddsingleLayerlossepoch_{}'.format(i))
            tempy = list(enc)
            y += tempy
        x = range(0,len(y))
        plt.plot(x, y, '.-')
        plt_title = 'BATCH_SIZE = 32; LEARNING_RATE:0.001'
        plt.title(plt_title)
        plt.xlabel('per 200 times')
        plt.ylabel('LOSS')
        # plt.savefig(file_name)
        plt.show()
    
    if __name__ == "__main__":
        plot_loss(20)

    得到曲线如下:

  • 相关阅读:
    XOR linked list
    Doubly Linked List&DLL(双向链表)
    circular linked list&CLL(循环链表)
    Linked list(单链表)
    malloc() & free()
    Mysql 1045
    DoublePointer (**)
    java.sql.SQLException: Operation not allowed after ResultSet closed
    Array
    java-方法
  • 原文地址:https://www.cnblogs.com/nekoneko-15/p/13691338.html
Copyright © 2020-2023  润新知