• 手动实现前馈神经网络解决 多分类 任务


    1 导入实验需要的包

    import torch
    import numpy as np
    import random
    from IPython import  display
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader,TensorDataset
    from torchvision import transforms,datasets
    from torch import nn

    2 加载数据集

    mnist_train = datasets.MNIST(root = './Datasets/MNIST/',train = True,download = True,transform =transforms.ToTensor())
    mnist_test = datasets.MNIST(root ='./Datasets/MNIST/',train = False,download = True,transform = transforms.ToTensor())
    
    batch_size = 256
    train_iter = DataLoader( 
        dataset = mnist_train,
        shuffle = True,
        batch_size = batch_size,
        num_workers = 0
    )
    test_iter = DataLoader(
        dataset  = mnist_test,
        shuffle  =False,
        batch_size = batch_size,
        num_workers = 0
    )

    3 初始化参数

    num_input ,num_hiddens ,num_output = 784,256,10
    W1 =  torch.tensor(np.random.normal(0,0.01,size = (num_hiddens,num_input)),dtype = torch.float32)
    b1 = torch.zeros(1,dtype = torch.float32)
    
    W2 =  torch.tensor(np.random.normal(0,0.01,size = (num_output,num_hiddens)),dtype = torch.float32)
    b2 = torch.zeros(1,dtype = torch.float32)
    
    params = [W1 ,b1,W2,b2]
    for param in params:
        param.requires_grad_(requires_grad = True)

    4 定义激活函数

    def ReLU(X):
        return torch.max(X,other = torch.tensor(0.0))

    5 定义网络模型

    def net(x):
        x = x.view(-1,num_input)
        H1 = ReLU(torch.matmul(x,W1.t())+b1)
        H2 = torch.matmul(H1,W2.t()+b2)
        return H2

    6 定义损失函数和优化算法

    #定义多分类交叉熵损失函数  
    loss = torch.nn.CrossEntropyLoss()  
    def SGD(params,lr):
        for param in params:
            param.data -= param.grad/batch_size

    7 定义评价函数

    def evaluate_loss(data_iter,net):
            acc_sum,loss_sum,n= 0,0,0
            for x,y in data_iter:
                y_pred = net(x)
                l = loss(y_pred,y)
                loss_sum += l.item()
                acc_sum += (y_pred.argmax(dim =1)==y).sum().item()
                n += y.shape[0]
            return acc_sum/n,loss_sum/n
    # def evaluate_loss():
    #         n = mnist_test.data.shape[0]
    #         x = torch.tensor(mnist_test.data,dtype = torch.float32)
    #         y  = torch.tensor(mnist_test.targets,dtype = torch.float32)
    #         y_pred = net(x)
    #         acc_sum = (y_pred.argmax(dim = 1)==y).sum().item()
    #         loss_sum = loss(y_pred,mnist_test.targets).item()
    #         return acc_sum/n,loss_sum/n

    8 定义训练函数

    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr):
        train_ls ,test_ls = [],[]
        for epoch in range(num_epochs):
            train_l_sum, train_acc_num,n = 0.0,0.0,0
            for x ,y in train_iter:
                y_pred = net(x)
                l = loss(y_pred,y)
                if params is not None and params[0].grad is not None:
                    for param in params:
                        param.grad.data.zero_()
                l.backward()
                SGD(params,lr)
                train_l_sum += l.item()
                train_acc_num += (y_pred.argmax(dim = 1)==y).sum().item()
                n +=y.shape[0]
            train_ls.append(train_l_sum/n)  
            test_acc,test_l = evaluate_loss(test_iter,net)  
            test_ls.append(test_l)
            print('epoch %d, train_loss %.6f,test_loss %f,train_acc %.6f,test_acc %.6f'%(epoch+1, train_ls[epoch],test_ls[epoch],train_acc_num/n,test_acc))  
        return train_ls,test_ls        

    9 训练

    lr = 0.01  
    num_epochs = 50  
    train_loss,test_loss = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr)   

    10 可视化

    x = np.linspace(0,len(train_loss),len(train_loss))  
    plt.plot(x,train_loss,label="train_loss",linewidth=1.5)  
    plt.plot(x,test_loss,label="test_loss",linewidth=1.5)  
    plt.xlabel("epoch")  
    plt.ylabel("loss")  
    plt.legend()  
    plt.show()  
  • 相关阅读:
    jwt
    mybatis的回顾
    swagger
    MySQl总结
    Java异常
    常用Dos命令
    C++初级项目——机房预约系统
    C++中将数字型字符串转变为int类型的方法
    C++中int *a; int &a; int & *a; int * &a
    #define_CRT_SECURE_NO_WARNINGS的用法
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15970091.html
Copyright © 2020-2023  润新知