• 梯度下降代码


    https://blog.csdn.net/panghaomingme/article/details/79384922

     1 #!usr/bin/python3
     2 # coding:utf-8
     3  
     4 # BGD 批梯度下降代码实现
     5 # SGD 随机梯度下降代码实现
     6 import numpy as np
     7  
     8 import random
     9  
    10  
    11 def batchGradientDescent(x, y, theta, alpha, m, maxInteration):
    12     x_train = x.transpose()
    13     for i in range(0, maxInteration):
    14         hypothesis = np.dot(x, theta)
    15         # 损失函数
    16         loss = hypothesis - y
    17         # 下降梯度
    18         gradient = np.dot(x_train, loss) / m
    19         # 求导之后得到theta
    20         theta = theta - alpha * gradient
    21     return theta
    22  
    23  
    24 def stochasticGradientDescent(x, y, theta, alpha, m, maxInteration):
    25     data = []
    26     for i in range(4):
    27         data.append(i)
    28     x_train = x.transpose()
    29     for i in range(0, maxInteration):
    30         hypothesis = np.dot(x, theta)
    31         # 损失函数
    32         loss = hypothesis - y
    33         # 选取一个随机数
    34         index = random.sample(data, 1)
    35         index1 = index[0]
    36         # 下降梯度
    37         gradient = loss[index1] * x[index1]
    38         # 求导之后得到theta
    39         theta = theta - alpha * gradient
    40     return theta
    41  
    42  
    43 def main():
    44     trainData = np.array([[1, 4, 2], [2, 5, 3], [5, 1, 6], [4, 2, 8]])
    45     trainLabel = np.array([19, 26, 19, 20])
    46     print(trainData)
    47     print(trainLabel)
    48     m, n = np.shape(trainData)
    49     theta = np.ones(n)
    50     print(theta.shape)
    51     maxInteration = 500
    52     alpha = 0.01
    53     theta1 = batchGradientDescent(trainData, trainLabel, theta, alpha, m, maxInteration)
    54     print(theta1)
    55     theta2 = stochasticGradientDescent(trainData, trainLabel, theta, alpha, m, maxInteration)
    56     print(theta2)
    57     return
    58  
    59  
    60 if __name__ == "__main__":
    61     main()
  • 相关阅读:
    $.extend 的相关用法
    boxsizing
    用localStorage来存储数据的一些经验
    让input光标一直在最右边
    函数声明和函数表达式的区别
    css动画和jq动画的简单区分
    apply与call简单用法以及判断数组的坑
    replace的运用
    onscroll事件没有响应的原因以及vue.js中添加onscroll事件监听的方法
    解决移动端touch事件(touchstart/touchend) 的穿透问题
  • 原文地址:https://www.cnblogs.com/zhangbojiangfeng/p/9474298.html
Copyright © 2020-2023  润新知