• pytorch使用说明2


    网络保存与加载

    1.保存

    torch.manual_seed(1)    # reproducible
    
    # 假数据
    x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
    y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
    
    def save():
        # 建网络
        net1 = torch.nn.Sequential(
            torch.nn.Linear(1, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 1)
        )
        optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
        loss_func = torch.nn.MSELoss()
    
        # 训练
        for t in range(100):
            prediction = net1(x)
            loss = loss_func(prediction, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    torch.save(net1, 'net.pkl')  # 保存整个网络
    torch.save(net.state_dict(), 'net_params.pkl') # 只保存网络中的参数(速度快,占内存少)
    
    

    2.加载网络

    def restore_net():
        # restore entire net1 to net2
        net2 = torch.load('net.pkl')
        prediction = net2(x)
        
    
    # 只提取网络参数
    def restore_params():
        # 新建 net3
        net3 = torch.nn.Sequential(
            torch.nn.Linear(1, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 1)
        )
    
        # 将保存的参数复制到 net3
        net3.load_state_dict(torch.load('net_params.pkl'))
        prediction = net3(x)
       
    
    # 保存 net1 (1. 整个网络, 2. 只有参数)
    save()
    # 提取整个网络
    restore_net()
    # 提取网络参数, 复制到新网络
    restore_params()
    
    
    
    

    3.批训练

    DataLoader是torch给你用来包装你的数据的工具。所以要将自己的(numpy array或其他)数据形式转换成Tensor, 然后再放进这个包装器中。使用DataLoader的好处就是帮你有效地迭代数据。

    import torch
    import torch.utils.data as Data
    torch.manual_seed(1)    # reproducible
    
    BATCH_SIZE = 5      # 批训练的数据个数
    
    x = torch.linspace(1, 10, 10)       # x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # y data (torch tensor)
    
    # 先转换成 torch 能识别的 Dataset
    torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
    
    # 把 dataset 放入 DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               # 要不要打乱数据 (打乱比较好)
        num_workers=2,              # 多线程来读数据
    )
    
    for epoch in range(3):   # 训练所有!整套!数据 3 次
        for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
            # 假设这里就是你训练的地方...
    
            # 打出来一些数据
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
    
    """
    Epoch:  0 | Step:  0 | batch x:  [ 6.  7.  2.  3.  1.] | batch y:  [  5.   4.   9.   8.  10.]
    Epoch:  0 | Step:  1 | batch x:  [  9.  10.   4.   8.   5.] | batch y:  [ 2.  1.  7.  3.  6.]
    Epoch:  1 | Step:  0 | batch x:  [  3.   4.   2.   9.  10.] | batch y:  [ 8.  7.  9.  2.  1.]
    Epoch:  1 | Step:  1 | batch x:  [ 1.  7.  8.  5.  6.] | batch y:  [ 10.   4.   3.   6.   5.]
    Epoch:  2 | Step:  0 | batch x:  [ 3.  9.  2.  6.  7.] | batch y:  [ 8.  2.  9.  5.  4.]
    Epoch:  2 | Step:  1 | batch x:  [ 10.   4.   8.   1.   5.] | batch y:  [  1.   7.   3.  10.   6.]
    """
    

    当数据最后不足batch时,就会返回这个epoch中剩下的数据。

  • 相关阅读:
    AB压力测试(Windows)
    Ensure You Are Not Adding To Global Scope in JavaScript(转)
    使用jasmine来对js进行单元测试
    HTML5安全:CORS(跨域资源共享)简介(转)
    asp.net+jquery Jsonp使用方法(转)
    在ios上时间无法parse返回 "Invalid Date"(转)
    用document.domain完美解决Ajax跨子域 (转)
    IE10、IE11 User-Agent 导致的 ASP.Net 网站无法写入Cookie 问题
    NodeJs:module.filename、__filename、__dirname、process.cwd()和require.main.filename 解惑(转)
    关于反射的一些总结(转)
  • 原文地址:https://www.cnblogs.com/o-v-o/p/10946146.html
Copyright © 2020-2023  润新知