• 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)


    本篇博客代码来自于《动手学深度学习》pytorch版,也是代码较多,解释较少的一篇。不过好多方法在我以前的博客都有提,所以这次没提。还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂(前提是python语法大概了解),这是我不加很多解释的重要原因。

    K折交叉验证实现

    def get_k_fold_data(k, i, X, y):
        # 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据
        assert k > 1
        fold_size = X.shape[0] // k  # 双斜杠表示除完后再向下取整
        X_train, y_train = None, None
        for j in range(k):
            idx = slice(j * fold_size, (j + 1) * fold_size)  #slice(start,end,step)切片函数
            X_part, y_part = X[idx, :], y[idx]
            if j == i:
                X_valid, y_valid = X_part, y_part
            elif X_train is None:
                X_train, y_train = X_part, y_part
            else:
                X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接
                y_train = torch.cat((y_train, y_part), dim=0)
        return X_train, y_train, X_valid, y_valid
    
    def k_fold(k, X_train, y_train, num_epochs,learning_rate, weight_decay, batch_size):
        train_l_sum, valid_l_sum = 0, 0
        for i in range(k):
            data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据
            net = get_net(X_train.shape[1])  #get_net在这是一个基本的线性回归模型,方法实现见附录1
            train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                       weight_decay, batch_size)  #train方法见后面附录2
            train_l_sum += train_ls[-1]
            valid_l_sum += valid_ls[-1]
            if i == 0:
                d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
                             range(1, num_epochs + 1), valid_ls,
                             ['train', 'valid'])   #画图,且是对y求对数了,x未变。方法实现见附录3
            print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))
        return train_l_sum / k, valid_l_sum / k
    

     *args:表示接受任意长度的参数,然后存放入一个元组中;如def fun(*args) print(args),‘fruit','animal','human'作为参数传进去,输出(‘fruit','animal','human')

    **kwargs:表示接受任意长的参数,然后存放入一个字典中;如

    def fun(**kwargs):   
        for key, value in kwargs.items():
            print("%s:%s" % (key,value)
    

    fun(a=1,b=2,c=3)会输出 a=1 b=2 c=3

    附录1

    loss = torch.nn.MSELoss()
    
    def get_net(feature_num):
        net = nn.Linear(feature_num, 1)
        for param in net.parameters():
            nn.init.normal_(param, mean=0, std=0.01) 
        return net

    附录2

    def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
        train_ls, test_ls = [], []
        dataset = torch.utils.data.TensorDataset(train_features, train_labels)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #TensorDataset和DataLoader的使用请查看我以前的博客
        
        #这里使用了Adam优化算法
        optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
        net = net.float()
        for epoch in range(num_epochs):
            for X, y in train_iter:
                l = loss(net(X.float()), y.float())
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
            train_ls.append(log_rmse(net, train_features, train_labels))
            if test_labels is not None:
                test_ls.append(log_rmse(net, test_features, test_labels))
        return train_ls, test_ls
    

     附录3

    def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):
        set_figsize(figsize)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.semilogy(x_vals, y_vals)
        if x2_vals and y2_vals:
            plt.semilogy(x2_vals, y2_vals, linestyle=':')
            plt.legend(legend)

     注:由于最近有其他任务,所以此博客写的匆忙,等我有时间后会丰富,也可能加详细解释。

  • 相关阅读:
    Javascript 跨域知识详细介绍
    jquery对象和DOM对象的相互转换详解
    jQuery 第四章 实例方法 DOM操作之data方法
    jQuery 第四章 实例方法 DOM操作_基于jQuery对象增删改查相关方法
    jQuery 第二章 实例方法 DOM操作取赋值相关方法
    jQuery 第三章 CSS操作
    jQuery 第二章 实例方法 DOM操作选择元素相关方法
    jQuery 第一章 $()选择器
    javaScript之实战 页面筛选功能
    javaScript 二分查找
  • 原文地址:https://www.cnblogs.com/JadenFK3326/p/12164519.html
Copyright © 2020-2023  润新知