• Pytorch01_通用结构


    通用结构

    首先:导入相关库

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import torch
    from torch import nn
    from torch.autograd import Variable
     
    plt.rcParams['font.sans-serif'] = ['SimHei'] 
    plt.rcParams['axes.unicode_minus'] = False

    step1:建立数据集,Train and Test

    step2:建立模型

    class BP(nn.Module):
        def __init__(self, input_size=576, output_size=288):
            super(BP, self).__init__()
            self.bp = nn.Sequential(
                nn.Linear(input_size, 288),
                nn.Linear(288, 144),
                nn.Linear(144, 72),
                nn.Linear(72, output_size),
            )
    
        def forward(self, x):
            result = self.bp(x)
            return result

    另外一种写法

    def create_net():
        net = nn.Sequential()
        net.add_module("linear1",nn.Linear(15,20))
        net.add_module("relu1",nn.ReLU())
        net.add_module("linear2",nn.Linear(20,15))
        net.add_module("relu2",nn.ReLU())
        net.add_module("linear3",nn.Linear(15,1))
        net.add_module("sigmoid",nn.Sigmoid())
        return net
    
    net = create_net()
    print(net)

    效果都一样

    Sequential(
    (linear1): Linear(in_features=15, out_features=20, bias=True)
    (relu1): ReLU()
    (linear2): Linear(in_features=20, out_features=15, bias=True)
    (relu2): ReLU()
    (linear3): Linear(in_features=15, out_features=1, bias=True)
    (sigmoid): Sigmoid()
    )

    step3:设置误差计算方法(均方差)及模型优化方法

    model = BP()
    criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) # 1e-2

    step4:训练

    for e in range(3000):
        var_x = Variable(Train_x)
        var_y = Variable(Train_y)
        # 前向传播
        var_x = torch.tensor(var_x, dtype=torch.float32)
        var_y = torch.tensor(var_y, dtype=torch.float32)
        out = model(var_x)
        loss = criterion(out, var_y)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (e + 1) % 10 == 0:  # 每 10 次输出结果
            print('Epoch: {}, Loss: {:.5f}'.format(e + 1, loss.item()))

    step5:使用测试集进行预测

    model = model.eval()  # 转换成测试模式
    Test_x = Variable(Test_x)
    Test_x = torch.tensor(Test_x, dtype=torch.float32)
    pred_test = model(Test_x)  # 测试集的预测结果
    # 改变输出的格式
    pred_test = pred_test.view(-1).data.numpy()

    常用损失函数总结

    nn.CrossEntropyLoss()  # 交叉熵,用于分类问题
    nn.MSELoss()  # 均方误差,拟合 回归问题

    常用优化器总结

    torch.optim.SGD   # 带动量SGD优化算法
    torch.optim.ASGD  # 表示随机平均梯度下降
    torch.optim.Adagrad  # 是自适应的为各个参数分配不同的学习率
    torch.optim.Adadelta  # Adadelta是Adagrad的改进。Adadelta分母中采用距离当前时间点比较近的累计项,这可以避免在训练后期,学习率过小。
    torch.optim.RMSprop   # 也是对Adagrad的一种改进。RMSprop采用均方根作为分母,可缓解Adagrad学习率下降较快的问题
    torch.optim.Adam(AMSGrad)  # Adam是一种自适应学习率的优化方法,Adam利用梯度的一阶矩估计和二阶矩估计动态的调整学习率
    我喜欢一致,可是世界并不一致
  • 相关阅读:
    关于打开MTK_SDCARD_SWAP 宏后MTK目前升级方案和 关于打开MTK_SHARED_SDCARD宏后MTK目前升级方案
    报表填报时,如何实现多个单元格绑定一个字段?
    双4G LTE
    报表移动端如何进行移动设备绑定与撤销
    各种卡的一些信息积累
    广佛肇城轨年内通车 佛山西站预计2017年中通车
    Web报表页面如何传递中文参数
    根据条件控制参数控件是否显示(可用)
    如何对报表的参数控件赋值
    Jquery前端分页插件pagination同步加载和异步加载
  • 原文地址:https://www.cnblogs.com/Haozi-D17/p/14866244.html
Copyright © 2020-2023  润新知