• pytorch 6 batch_train 批训练


    import torch
    import torch.utils.data as Data
    
    torch.manual_seed(1)    # reproducible
    
    # BATCH_SIZE = 5  
    BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
    
    x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
    
    torch_dataset = Data.TensorDataset(x, y)
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=False,              # 设置不随机打乱数据 random shuffle for training
        num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
    )
    
    
    def show_batch():
        for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
            for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
                # train your data...
                print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                      batch_x.numpy(), '| batch y: ', batch_y.numpy())
    
    
    if __name__ == '__main__':
        show_batch()
    

    BATCH_SIZE = 8 , 所有数据利用三次

    Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
    Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
    

    END

  • 相关阅读:
    C++primer习题3.13
    Indesign技巧
    《转载》虚函数在对象中的内存布局
    C++new失败后如何处理
    sizeof的用法
    转载 C++中虚继承防止二义性
    字符串反转
    回文写法
    C++术语
    QT+VS2008
  • 原文地址:https://www.cnblogs.com/yangzhaonan/p/10439839.html
Copyright © 2020-2023  润新知