• 用Pytorch训练线性回归模型


    假定我们要拟合的线性方程是:(y=2x+1)

    (x):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

    (y):[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    
    '''生成输入输出'''
    x_values = [i for i in range(15)]
    x_train = np.array(x_values, dtype=np.float32)
    x_train = x_train.reshape(-1,1)
    
    y_values = [2*i+1 for i in x_values]
    y_train = np.array(y_values, dtype=np.float32)
    y_train = y_train.reshape(-1,1)
    
    '''定义模型'''
    class LinearRegressionModel(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(LinearRegressionModel,self).__init__()      #用nn.Module的init方法
            self.linear = nn.Linear(input_dim, output_dim)    #因为我们假设的函数是线性函数
            
        def forward(self, x):
            out = self.linear(x)
            return out
        
    ''''''
    input_dim = 1
    output_dim = 1
    model = LinearRegressionModel(input_dim, output_dim)
    criterion = nn.MSELoss()    #损失函数为均方差
    
    learning_rate = 0.01
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    
    '''训练网络'''
    epochs = 30
    for epoch in range(epochs):
        epoch += 1
        inputs = Variable(torch.from_numpy(x_train))
        labels = Variable(torch.from_numpy(y_train))
        #清空梯度参数    
        optimizer.zero_grad()    
        #获得输出
        outputs = model(inputs)
        #计算损失
        loss = criterion(outputs, labels)
        #反向传播
        loss.backward()
        #更新参数
        optimizer.step()
        
        print('epoch {}, loss {}'.format(epoch, loss.data[0]))
    

    输出如下

    epoch 1, loss 290.4517517089844
    epoch 2, loss 39.308494567871094
    epoch 3, loss 5.320824146270752
    epoch 4, loss 0.721196711063385
    epoch 5, loss 0.09870971739292145
    epoch 6, loss 0.01445594523102045
    epoch 7, loss 0.003041634801775217
    epoch 8, loss 0.0014851536834612489
    epoch 9, loss 0.0012628223048523068
    epoch 10, loss 0.0012211636640131474
    epoch 11, loss 0.0012040861183777452
    epoch 12, loss 0.0011904657585546374
    epoch 13, loss 0.001177445170469582
    epoch 14, loss 0.0011646103812381625
    epoch 15, loss 0.0011519324034452438
    epoch 16, loss 0.0011393941240385175
    epoch 17, loss 0.0011269855313003063
    epoch 18, loss 0.0011147174518555403
    epoch 19, loss 0.001102585345506668
    epoch 20, loss 0.001090570935048163
    epoch 21, loss 0.0010787042556330562
    epoch 22, loss 0.0010669684270396829
    epoch 23, loss 0.0010553498286753893
    epoch 24, loss 0.001043855445459485
    epoch 25, loss 0.0010324924951419234
    epoch 26, loss 0.0010212488705292344
    epoch 27, loss 0.0010101287625730038
    epoch 28, loss 0.000999127165414393
    epoch 29, loss 0.0009882354643195868
    epoch 30, loss 0.0009774940554052591
    #可以看出loss逐步缩小
    

    画图观察

    predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
    
    plt.clf()
    plt.plot(x_train, y_train, 'go', label="True Value", alpha=0.5)
    
    plt.plot(x_train, predicted, '--', label='Predictions',alpha=0.5)
    
    plt.legend(loc='best')
    plt.show()
    

    图如下:

  • 相关阅读:
    Soap1.1和Soap1.2的区别
    常用开源软件许可协议简介
    Web优化之Javascript Compressor
    Web优化之YaHoo Web优化的14条法则
    Installing Cygwin on Windows 7 And Configure SSH
    Different Between Cygwin And MinGw
    xml读取异常Invalid byte 1 of 1-byte UTF-8 sequence
    JAVA事务系列三:JTA事务
    JAVA事务系列二:JDBC事务
    帅才将才慧才
  • 原文地址:https://www.cnblogs.com/MartinLwx/p/10353706.html
Copyright © 2020-2023  润新知