• 【人工智能导论:模型与算法】MOOC 8.3 误差后向传播(BP) 例题 编程验证


    8.3 误差后向传播(BP)

    原理和推导过程,参考慕课。
    https://www.icourse163.org/course/ZJU-1003377027

    (2022.4.29更正:上面的计算结果W1-W4是错误的。详细说明:【人工智能导论:模型与算法】MOOC 8.3 误差后向传播(BP) 例题 【第三版】 - HBU_DAVID - 博客园 (cnblogs.com)


    输入值:x1, x2 = 0.5,0.3

    输出值:y1, y2 =0.23, -0.07

    激活函数:sigmoid

    损失函数:MSE

    初始权值:0.2 -0.4 0.5 0.6 0.1 -0.5 -0.3 0.8

    目标:通过反向传播优化权值


     

    反向传播1轮,检验PPT数值

    =====正向计算:h1, h2, o1 ,o2=====0.56 0.5 0.48 0.53

    =====损失函数:均方误差=====0.21

    =====反向传播:误差传给每个权值=====0.01 0.01 0.01 0.01 0.03 0.08 0.03 0.07

    =====更新前的权值=====0.2 -0.4 0.5 0.6 0.1 -0.5 -0.3 0.8

    =====更新后的权值=====0.19 -0.41 0.49 0.59 0.07 -0.58 -0.33 0.73

    import numpy as np
    
    
    def sigmoid(z):
        a = 1 / (1 + np.exp(-z))
        return a
    
    
    if __name__ == "__main__":
        w1 = 0.2
        w2 = -0.4
        w3 = 0.5
        w4 = 0.6
        w5 = 0.1
        w6 = -0.5
        w7 = -0.3
        w8 = 0.8
    
        x1 = 0.5
        x2 = 0.3
    
        y1 = 0.23
        y2 = -0.07
    
        print("=====输入值:x1, x2;真实输出值:y1, y2=====")
        print(x1, x2, y1, y2)
    
        in_h1 = w1 * x1 + w3 * x2
        out_h1 = sigmoid(in_h1)
        in_h2 = w2 * x1 + w4 * x2
        out_h2 = sigmoid(in_h2)
    
        in_o1 = w5 * out_h1 + w7 * out_h2
        out_o1 = sigmoid(in_o1)
        in_o2 = w6 * out_h1 + w8 * out_h2
        out_o2 = sigmoid(in_o2)
    
        print("=====正向计算:h1, h2, o1 ,o2=====")
        print(round(out_h1, 2), round(out_h2, 2), round(out_o1, 2), round(out_o2, 2))
    
        error = (1 / 2) * (out_o1 - y1)**2 + (1 / 2) * (out_o2 - y2)**2
    
        print("=====损失函数:均方误差=====")
        print(round(error, 2))
    
        # 反向传播
        d_o1 = out_o1 - y1
        d_o2 = out_o2 - y2
        # print(round(d_o1, 2), round(d_o2, 2))
    
        d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
        d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
        # print(round(d_w5, 2), round(d_w7, 2))
        d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
        d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2
        # print(round(d_w6, 2), round(d_w8, 2))
    
        d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
        d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
        # print(round(d_w1, 2), round(d_w3, 2))
    
        d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
        d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2
        # print(round(d_w2, 2), round(d_w4, 2))
        print("=====反向传播:误差传给每个权值=====")
        print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2), round(d_w7, 2),
              round(d_w8, 2))
    
        print("=====更新前的权值=====")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        w1 = w1 - d_w1
        w2 = w2 - d_w2
        w3 = w3 - d_w3
        w4 = w4 - d_w4
        w5 = w5 - d_w5
        w6 = w6 - d_w6
        w7 = w7 - d_w7
        w8 = w8 - d_w8
    
        print("=====更新后的权值=====")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    View Code

    增加到5轮,测试收敛

    =====第6轮=====

    正向计算:h1, h2, o1 ,o2

    0.55 0.48 0.44 0.43

    损失函数:均方误差

    0.15

    import numpy as np
    
    
    def sigmoid(z):
        a = 1 / (1 + np.exp(-z))
        return a
    
    
    def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8):
        in_h1 = w1 * x1 + w3 * x2
        out_h1 = sigmoid(in_h1)
        in_h2 = w2 * x1 + w4 * x2
        out_h2 = sigmoid(in_h2)
    
        in_o1 = w5 * out_h1 + w7 * out_h2
        out_o1 = sigmoid(in_o1)
        in_o2 = w6 * out_h1 + w8 * out_h2
        out_o2 = sigmoid(in_o2)
    
        print("正向计算:h1, h2, o1 ,o2")
        print(round(out_h1, 2), round(out_h2, 2), round(out_o1, 2), round(out_o2, 2))
    
        error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2
    
        print("损失函数:均方误差")
        print(round(error, 2))
    
        return out_o1, out_o2, out_h1, out_h2
    
    
    def back_propagate(out_o1, out_o2, out_h1, out_h2):
        # 反向传播
        d_o1 = out_o1 - y1
        d_o2 = out_o2 - y2
        # print(round(d_o1, 2), round(d_o2, 2))
    
        d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
        d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
        # print(round(d_w5, 2), round(d_w7, 2))
        d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
        d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2
        # print(round(d_w6, 2), round(d_w8, 2))
    
        d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
        d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
        # print(round(d_w1, 2), round(d_w3, 2))
    
        d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
        d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2
        # print(round(d_w2, 2), round(d_w4, 2))
        print("反向传播:误差传给每个权值")
        print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2),
              round(d_w7, 2), round(d_w8, 2))
    
        return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8
    
    
    if __name__ == "__main__":
        w1 = 0.2
        w2 = -0.4
        w3 = 0.5
        w4 = 0.6
        w5 = 0.1
        w6 = -0.5
        w7 = -0.3
        w8 = 0.8
        x1 = 0.5
        x2 = 0.3
        y1 = 0.23
        y2 = -0.07
        print("=====输入值:x1, x2;真实输出值:y1, y2=====")
        print(x1, x2, y1, y2)
        print("=====更新前的权值=====")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
    
        # 步长
        step = 1
    
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
    
        print("第1轮更新后的权值")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        print("=====第2轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
    
        print("=====第3轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
    
        print("=====第4轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
    
        print("=====第5轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
    
        print("=====第6轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        print("更新后的权值")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    View Code

    改变步长(1变为50),看收敛速度

    =====第6轮=====

    正向计算:o1 ,o2

    0.23 0.03

    损失函数:均方误差

    0.01

    import numpy as np
    
    
    def sigmoid(z):
        a = 1 / (1 + np.exp(-z))
        return a
    
    
    def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8):
        in_h1 = w1 * x1 + w3 * x2
        out_h1 = sigmoid(in_h1)
        in_h2 = w2 * x1 + w4 * x2
        out_h2 = sigmoid(in_h2)
    
        in_o1 = w5 * out_h1 + w7 * out_h2
        out_o1 = sigmoid(in_o1)
        in_o2 = w6 * out_h1 + w8 * out_h2
        out_o2 = sigmoid(in_o2)
    
        print("正向计算:o1 ,o2")
        print(round(out_o1, 2), round(out_o2, 2))
    
        error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2
    
        print("损失函数:均方误差")
        print(round(error, 2))
    
        return out_o1, out_o2, out_h1, out_h2
    
    
    def back_propagate(out_o1, out_o2, out_h1, out_h2):
        # 反向传播
        d_o1 = out_o1 - y1
        d_o2 = out_o2 - y2
        # print(round(d_o1, 2), round(d_o2, 2))
    
        d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
        d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
        # print(round(d_w5, 2), round(d_w7, 2))
        d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
        d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2
        # print(round(d_w6, 2), round(d_w8, 2))
    
        d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
        d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
        # print(round(d_w1, 2), round(d_w3, 2))
    
        d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
        d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2
        # print(round(d_w2, 2), round(d_w4, 2))
        print("反向传播:误差传给每个权值")
        print(round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2),
              round(d_w7, 2), round(d_w8, 2))
    
        return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8
    
    
    def update_w(w1, w2, w3, w4, w5, w6, w7, w8):
        # 步长
        step = 50
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
        return w1, w2, w3, w4, w5, w6, w7, w8
    
    
    if __name__ == "__main__":
        w1 = 0.2
        w2 = -0.4
        w3 = 0.5
        w4 = 0.6
        w5 = 0.1
        w6 = -0.5
        w7 = -0.3
        w8 = 0.8
        x1 = 0.5
        x2 = 0.3
        y1 = 0.23
        y2 = -0.07
        print("=====输入值:x1, x2;真实输出值:y1, y2=====")
        print(x1, x2, y1, y2)
        print("=====更新前的权值=====")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("第1轮更新后的权值")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        print("=====第2轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("=====第3轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("=====第4轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("=====第5轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
        w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("=====第6轮=====")
        out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
        print("更新后的权值")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    View Code

    扩展到N轮,步长=5,训练N=1000次,查看效果

    =====第999轮=====

    正向计算:o1 ,o2

    0.23038 0.00954

    损失函数:均方误差

    0.00316

     

    import numpy as np
    
    
    def sigmoid(z):
        a = 1 / (1 + np.exp(-z))
        return a
    
    
    def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8):
        in_h1 = w1 * x1 + w3 * x2
        out_h1 = sigmoid(in_h1)
        in_h2 = w2 * x1 + w4 * x2
        out_h2 = sigmoid(in_h2)
    
        in_o1 = w5 * out_h1 + w7 * out_h2
        out_o1 = sigmoid(in_o1)
        in_o2 = w6 * out_h1 + w8 * out_h2
        out_o2 = sigmoid(in_o2)
    
        print("正向计算:o1 ,o2")
        print(round(out_o1, 5), round(out_o2, 5))
    
        error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2
    
        print("损失函数:均方误差")
        print(round(error, 5))
    
        return out_o1, out_o2, out_h1, out_h2
    
    
    def back_propagate(out_o1, out_o2, out_h1, out_h2):
        # 反向传播
        d_o1 = out_o1 - y1
        d_o2 = out_o2 - y2
        # print(round(d_o1, 2), round(d_o2, 2))
    
        d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
        d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
        # print(round(d_w5, 2), round(d_w7, 2))
        d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
        d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2
        # print(round(d_w6, 2), round(d_w8, 2))
    
        d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
        d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
        # print(round(d_w1, 2), round(d_w3, 2))
    
        d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
        d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2
        # print(round(d_w2, 2), round(d_w4, 2))
        print("反向传播:误差传给每个权值")
        print(round(d_w1, 5), round(d_w2, 5), round(d_w3, 5), round(d_w4, 5), round(d_w5, 5), round(d_w6, 5),
              round(d_w7, 5), round(d_w8, 5))
    
        return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8
    
    
    def update_w(w1, w2, w3, w4, w5, w6, w7, w8):
        # 步长
        step = 5
        w1 = w1 - step * d_w1
        w2 = w2 - step * d_w2
        w3 = w3 - step * d_w3
        w4 = w4 - step * d_w4
        w5 = w5 - step * d_w5
        w6 = w6 - step * d_w6
        w7 = w7 - step * d_w7
        w8 = w8 - step * d_w8
        return w1, w2, w3, w4, w5, w6, w7, w8
    
    
    if __name__ == "__main__":
        w1, w2, w3, w4, w5, w6, w7, w8 = 0.2, -0.4, 0.5, 0.6, 0.1, -0.5, -0.3, 0.8
        x1, x2 = 0.5, 0.3
        y1, y2 = 0.23, -0.07
        print("=====输入值:x1, x2;真实输出值:y1, y2=====")
        print(x1, x2, y1, y2)
        print("=====更新前的权值=====")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    
        for i in range(1000):
            print("=====第" + str(i) + "轮=====")
            out_o1, out_o2, out_h1, out_h2 = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
            d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
            w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)
    
        print("更新后的权值")
        print(round(w1, 2), round(w2, 2), round(w3, 2), round(w4, 2), round(w5, 2), round(w6, 2), round(w7, 2),
              round(w8, 2))
    View Code

    修改输出值y2为正,收敛效果很好。

    原因是:sigmoid,输出值应在(0,1)区间,所以最开始的假设 y2=-0.07,在这个模型里,无法很好的拟合。


    优化后的源代码:

    import numpy as np
    import matplotlib.pyplot as plt


    def sigmoid(z):
    a = 1 / (1 + np.exp(-z))
    return a


    def forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8): # 正向传播
    in_h1 = w1 * x1 + w3 * x2
    out_h1 = sigmoid(in_h1)
    in_h2 = w2 * x1 + w4 * x2
    out_h2 = sigmoid(in_h2)

    in_o1 = w5 * out_h1 + w7 * out_h2
    out_o1 = sigmoid(in_o1)
    in_o2 = w6 * out_h1 + w8 * out_h2
    out_o2 = sigmoid(in_o2)

    error = (1 / 2) * (out_o1 - y1) ** 2 + (1 / 2) * (out_o2 - y2) ** 2

    return out_o1, out_o2, out_h1, out_h2, error


    def back_propagate(out_o1, out_o2, out_h1, out_h2): # 反向传播
    d_o1 = out_o1 - y1
    d_o2 = out_o2 - y2

    d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1
    d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2
    d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1
    d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2

    d_w1 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x1
    d_w3 = (d_w5 + d_w6) * out_h1 * (1 - out_h1) * x2
    d_w2 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x1
    d_w4 = (d_w7 + d_w8) * out_h2 * (1 - out_h2) * x2

    return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8


    def update_w(step,w1, w2, w3, w4, w5, w6, w7, w8): #梯度下降,更新权值
    w1 = w1 - step * d_w1
    w2 = w2 - step * d_w2
    w3 = w3 - step * d_w3
    w4 = w4 - step * d_w4
    w5 = w5 - step * d_w5
    w6 = w6 - step * d_w6
    w7 = w7 - step * d_w7
    w8 = w8 - step * d_w8
    return w1, w2, w3, w4, w5, w6, w7, w8


    if __name__ == "__main__":
    w1, w2, w3, w4, w5, w6, w7, w8 = 0.2, -0.4, 0.5, 0.6, 0.1, -0.5, -0.3, 0.8 # 可以给随机值,为配合PPT,给的指定值
    x1, x2 = 0.5, 0.3 # 输入值
    y1, y2 = 0.23, -0.07 # 正数可以准确收敛;负数不行。why? 因为用sigmoid输出,y1, y2 在 (0,1)范围内。
    N = 10 # 迭代次数
    step = 10 # 步长

    print("输入值:x1, x2;",x1, x2, "输出值:y1, y2:", y1, y2)
    eli = []
    lli = []
    for i in range(N):
    print("=====第" + str(i) + "轮=====")
    # 正向传播
    out_o1, out_o2, out_h1, out_h2, error = forward_propagate(x1, x2, y1, y2, w1, w2, w3, w4, w5, w6, w7, w8)
    print("正向传播:", round(out_o1, 5), round(out_o2, 5))
    print("损失函数:", round(error, 2))
    # 反向传播
    d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8 = back_propagate(out_o1, out_o2, out_h1, out_h2)
    # 梯度下降,更新权值
    w1, w2, w3, w4, w5, w6, w7, w8 = update_w(step,w1, w2, w3, w4, w5, w6, w7, w8)
    eli.append(i)
    lli.append(error)


    plt.plot(eli, lli)
    plt.ylabel('Loss')
    plt.xlabel('w')
    plt.show()

  • 相关阅读:
    Python网页信息采集:使用PhantomJS采集淘宝天猫商品内容
    让Scrapy的Spider更通用
    API例子:用Python驱动Firefox采集网页数据
    API例子:用Java/JavaScript下载内容提取器
    Python即时网络爬虫:API说明
    Python: xml转json
    git 更新本地代码
    数据库事务
    Python的线程、进程和协程
    Java基础语法
  • 原文地址:https://www.cnblogs.com/hbuwyg/p/16166814.html
Copyright © 2020-2023  润新知