• 3、pytorch实现最基础的MLP网络


    %matplotlib inline
    import numpy as np
    import torch
    from torch import nn
    import matplotlib.pyplot as plt
    
    d = 1
    n = 200
    X = torch.rand(n,d)  #200*1, batch * feature_dim
    #y = 3*torch.sin(X) + 5* torch.cos(X**2)
    y = 4 * torch.sin(np.pi * X) * torch.cos(6*np.pi*X**2)
    
    #注意这里hid_dim 设置是超参数(如果太小,效果就不好),使用tanh还是relu效果也不同,优化器自选
    hid_dim_1 = 128
    hid_dim_2 = 32
    d_out = 1
    
    model = nn.Sequential(nn.Linear(d,hid_dim_1),
                         nn.Tanh(),
                         nn.Linear(hid_dim_1, hid_dim_2),
                         nn.Tanh(),
                         nn.Linear(hid_dim_2, d_out)
                         )
    loss_func = nn.MSELoss()
    optim = torch.optim.SGD(model.parameters(), 0.05)
    
    epochs = 6000
    print("epoch	 loss	")
    for i in range(epochs):
        y_hat = model(X)
        loss = loss_func(y_hat, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        if((i+1)%100 == 0):
            print("{}	 {:.5f}".format(i+1,loss.item()))
    
    #这个地方容易出错,测试时不要用原来的x,因为原来的x不是从小到达排序,导致x在连线时会混乱,所以要用np.linspace重新来构造
    test_x  = torch.tensor(np.linspace(0,1,50), dtype = torch.float32).reshape(-1,1)
    final_y = model(test_x)
    plt.scatter(X,y)
    plt.plot(test_x.detach(),final_y.detach(),"r")  #不使用detach会报错
    print("over")
    epoch	 loss	
    100	 3.84844
    200	 3.83552
    300	 3.78960
    400	 3.64596
    500	 3.43755
    600	 3.17153
    700	 2.59001
    800	 2.21228
    900	 1.87939
    1000	 1.55716
    1100	 1.41315
    1200	 1.26750
    1300	 1.05869
    1400	 0.91269
    1500	 0.81320
    1600	 0.74047
    1700	 0.67874
    1800	 0.61939
    1900	 0.56204
    2000	 0.51335
    2100	 0.47797
    2200	 0.45317
    2300	 0.43151
    2400	 0.40505
    2500	 0.37628
    2600	 0.34879
    2700	 0.32457
    2800	 0.30431
    2900	 0.28866
    3000	 0.30260
    3100	 0.26200
    3200	 0.30286
    3300	 0.25229
    3400	 0.21422
    3500	 0.22737
    3600	 0.22905
    3700	 0.19909
    3800	 0.24601
    3900	 0.17733
    4000	 0.22905
    4100	 0.15704
    4200	 0.21570
    4300	 0.14141
    4400	 0.14657
    4500	 0.14609
    4600	 0.11998
    4700	 0.12598
    4800	 0.10871
    4900	 0.08616
    5000	 0.18319
    5100	 0.08111
    5200	 0.08213
    5300	 0.11087
    5400	 0.06879
    5500	 0.07235
    5600	 0.11281
    5700	 0.06817
    5800	 0.08423
    5900	 0.06886
    6000	 0.06301

  • 相关阅读:
    delphi算法
    delphi 弹出选择目录窗口
    delphi 导出xml文件
    play 源码分析
    oracle指令
    delphi 环境问题
    如何启动redis
    关于整理和工作小结
    如何强制关闭服务
    delphi之事件
  • 原文地址:https://www.cnblogs.com/qiezi-online/p/13949296.html
Copyright © 2020-2023  润新知