• pytorch note


    1.模型保存与加载

    1.1

    #a、保存 推荐仅仅保存模型的state_dict
    torch.save(model.state_dict(), MODELPATH) # .pt  .pth
    #b、加载
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    #Pytorch保存的模型后缀一般是.pt或者.pth
    #必须在加载模型后调用model.eval函数来将dropout及批归一化层设置为预测模式。如果不这么做结果出错。
    

    1.2 a、保存临时模型用于预测或再训练

    torch.save({
     'epoch': epoch,
     'model_state_dict': model.state_dict(),
     'optimizer_state_dict': optimizer.state_dict(),
     'loss': loss, ... },
     PATH)
    

    当保存一个临时模型用于预测或再训练时,需要保存比state_dict更多的参数。包括优化器的state_dict,迭代次数epoch,最后一层迭代的loss及其他任何需要的参数。
     当保存多个组件时,将多个组件以字典的形式组织,然后用torch.savee()来序列化该字典。在Pytorch中常用.tar文件后缀表示这种模型。
    b、加载

    model = TheModelClass(*args, **kwargs) 
    optimizer = TheOptimizerClass(*args, **kwargs)
     checkpoint = torch.load(PATH)
     model.load_state_dict(checkpoint['model_state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     epoch = checkpoint['epoch']
     loss = checkpoint['loss']
     model.eval() #预测 # - or - model.train() #再训练
    
    

    e.g.

    save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer' : optimizer.state_dict(),
                }, checkpoint=args.checkpoint)
    

    2.maxpool 并记录最大位置 F.max_pool2d

        input_tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14,15,16]).reshape(4,4).float()
        input_tensor = input_tensor.unsqueeze(0)
        pool1, pool1_mask = F.max_pool2d(input_tensor, kernel_size=2, stride=2, return_indices=True)
        print(pool1)
        print(pool1_mask)
    

    输出如下:
    tensor([[[ 6., 8.],
    [14., 16.]]])
    tensor([[[ 5, 7],
    [13, 15]]])

  • 相关阅读:
    二叉树
    队列
    python3使用pdfminer3k解析pdf文件
    得到手机版新闻解析
    python连接redis并插入url
    Python使用requirements.txt安装类库
    (1366, "Incorrect string value: '\xF3\xB0\x84\xBC</...' for column 'content' at row 1")
    mysql中Incorrect string value乱码问题解决方案
    mysql命令
    requests ip代理池单ip和多ip设置方式
  • 原文地址:https://www.cnblogs.com/yanghailin/p/11607080.html
Copyright © 2020-2023  润新知