• 在多分类任务实验中手动实现实现dropout


    1 导入需要的包

    import torch
    import torch.nn as nn
    import numpy as np
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt

    2 下载MNIST数据集以及读取数据

    #下载MNIST手写数据集  
    mnist_train = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=True, download=True, transform=transforms.ToTensor())  
    mnist_test = torchvision.datasets.MNIST(root='../Datasets/MNIST', train=False,download=True, transform=transforms.ToTensor())  
    
    #读取数据  
    batch_size = 256 
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  

    3 初始化模型参数

    #初始化参数  
    num_inputs,num_hiddens,num_outputs =784, 256,10
    num_epochs=30
    lr = 0.001
    def init_param():
        W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
        b1 = torch.zeros(1, dtype=torch.float32)  
        W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
        b2 = torch.zeros(1, dtype=torch.float32)  
        params =[W1,b1,W2,b2]
        for param in params:  
            param.requires_grad_(requires_grad=True)  
        return W1,b1,W2,b2

    4 手动实现dropout

    设丢弃概率为$p$,那么有$p$ 的概率 $h_i$ 会被清 零,有$1−p$ 的概率 $h_i$ 会除以 $1−p$ 做拉伸。由此定义进行dropout操作的函数

    def dropout(X, drop_prob):
        X = X.float()
        assert 0 <= drop_prob <= 1
        keep_prob = 1 - drop_prob
        if keep_prob == 0:
            return torch.zeros_like(X)
        mask = (torch.rand(X.shape) < keep_prob).float()
        return mask * X / keep_prob

    5 定义模型

    def net(X, is_training=True):
        X = X.view(-1, num_inputs)
        H1 = (torch.matmul(X, W1.t()) + b1).relu()
        if is_training:
            H1 = dropout(H1, drop_prob1)
        return (torch.matmul(H1,W2.t()) + b2).relu()

    6 定义训练模型

    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):
        train_ls, test_ls = [], []
        for epoch in range(num_epochs):
            ls, count = 0, 0
            for X,y in train_iter:
                l=loss(net(X),y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                ls += l.item()
                count += y.shape[0]
            train_ls.append(ls)
            ls, count = 0, 0
            for X,y in test_iter:
                l=loss(net(X,is_training=False),y)
                ls += l.item()
                count += y.shape[0]
            test_ls.append(ls)
            if(epoch+1)%10==0:
                print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
        return train_ls,test_ls

    7 比较不同dropout的影响

    drop_probs = np.arange(0,1.1,0.1)
    Train_ls, Test_ls = [], []
    for drop_prob in drop_probs:
        drop_prob1 = drop_prob
        W1,b1,W2,b2 = init_param()
        loss = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
        train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   
        Train_ls.append(train_ls)
        Test_ls.append(test_ls)

    8 绘制不同dropout损失图

    x = np.linspace(0,len(train_ls),len(train_ls))
    plt.figure(figsize=(10,8))
    for i in range(0,len(drop_probs)):
        plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
        plt.xlabel('epoch')
        plt.ylabel('loss')
    # plt.legend()
    plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
    plt.title('train loss with dropout')
    plt.show()

    因上求缘,果上努力~~~~ 作者:每天卷学习,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15513219.html

  • 相关阅读:
    Chrome 中的彩蛋,一款小游戏,你知道吗?
    Json对象与Json字符串互转(4种转换方式)
    [PHP自动化-进阶]005.Snoopy采集框架介绍
    [PHP自动化-进阶]004.Snoopy VS CURL 模拟Discuz.net登陆
    [PHP自动化-进阶]003.CURL处理Https请求访问
    [PHP自动化-进阶]002.CURL模拟登录带有验证码的网站
    [PHP自动化-进阶]001.CURL模拟登录并采集数据
    [注]2015中国程序员生存报告,你苦你先看@^@
    [JavaWeb基础] 016.Struts2 国际化配置
    [工具推荐]_iOS音频批量转换
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15513219.html
Copyright © 2020-2023  润新知