• Chainer的初步学习


      人们都说Chainer是一块非常灵活you要用的框架,今天接着项目里面的应用,初步接触一下,涨涨姿势,直接上源码吧,看着好理解。其实跟Tensorflow等其他框架都是一个套路,个人感觉更简洁了。

     1 """
     2     测试使用
     3 """
     4 import pickle
     5 import time
     6 import numpy as np
     7 import matplotlib.pyplot as plt
     8 from chainer import Chain, Variable, optimizers, serializers
     9 import chainer.functions as F
    10 import chainer.links as L
    11 
    12 # 创建Chainer Variables变]量
    13 a = Variable(np.array([3], dtype=np.float32))
    14 b = Variable(np.array([4], dtype=np.float32))
    15 c = a**2 +b**2
    16 
    17 # 5通过data属性检查之前定义的变量
    18 print('a.data:{0}, b.data{1}, c.data{2}'.format(a.data, b.data, c.data))
    19 
    20 # 使用backward()方法,对变量c进行反向传播.对c进行求导
    21 c.backward()
    22 # 通过在变量中存储的grad属性,检查其导数
    23 print('dc/da = {0}, dc/db={1}, dc/dc={2}'.format(a.grad, b.grad, c.grad))
    24 
    25 # 在chainer中做线性回归
    26 x = 30*np.random.rand(1000).astype(np.float32)
    27 y = 7*x + 10
    28 y += 10*np.random.randn(1000).astype(np.float32)
    29 
    30 plt.scatter(x, y)
    31 plt.xlabel('x')
    32 plt.ylabel('y')
    33 plt.show()
    34 
    35 
    36 # 使用chainer做线性回归
    37 
    38 # 从一个变量到另一个变量建立一个线性连接
    39 linear_function = L.Linear(1, 1)
    40 # 设置x和y作为chainer变量,以确保能够变形到特定形态
    41 x_var = Variable(x.reshape(1000, -1))
    42 y_var = Variable(y.reshape(1000, -1))
    43 # 建立优化器
    44 optimizer = optimizers.MomentumSGD(lr=0.001)
    45 optimizer.setup(linear_function)
    46 
    47 
    48 # 定义一个前向传播函数,数据作为输入,线性函数作为输出
    49 def linear_forward(data):
    50     return linear_function(data)
    51 
    52 
    53 # 定义一个训练函数,给定输入数据,目标数据,迭代数
    54 def linear_train(train_data, train_traget, n_epochs=200):
    55     for _ in range(n_epochs):
    56         # 得到前向传播结果
    57         output = linear_forward(train_data)
    58         # 计算训练目标数据和实际标数据的损失
    59         loss = F.mean_squared_error(train_traget, output)
    60         # 在更新之前将梯度取零,线性函数和梯度有非常密切的关系
    61         # linear_function.zerograds()
    62         linear_function.cleargrads()
    63         # 计算并更新所有梯度
    64         loss.backward()
    65         # 优化器更新
    66         optimizer.update()
    67 
    68 
    69 # 绘制训练结果
    70 plt.scatter(x, y, alpha=0.5)
    71 for i in range(150):
    72     # 训练
    73     linear_train(x_var, y_var, n_epochs=5)
    74     # 预测值
    75     y_pred = linear_forward(x_var).data
    76     plt.plot(x, y_pred, color=plt.cm.cool(i / 150.), alpha=0.4, lw=3)
    77 
    78 slope = linear_function.W.data[0, 0]        # linear_function是之前定义的连接,线性连接有两个参数W和b,此种形式可以获取训练后参数的值,slope是斜率的意思
    79 intercept = linear_function.b.data[0]       # intercept是截距的意思
    80 plt.title("Final Line: {0:.3}x + {1:.3}".format(slope, intercept))
    81 plt.xlabel('x')
    82 plt.ylabel('y')
    83 plt.show()
  • 相关阅读:
    SilverLight入门实例(一)
    应聘成功了,要去沪江网上班啦!
    C#中(int)、int.Parse()、int.TryParse、Convert.ToInt32数据转换注意事项
    DataTable和DataSet什么区别
    SQL 通配符
    可以把 SQL 分为两个部分:数据操作语言 (DML) 和 数据定义语言 (DDL)
    发现jquery库的关键字冲突,造成了隐形BUG!(附代码)
    《转载》微软PostScirpt打印机驱动程序原理
    在应聘工作中,不知不觉的完成了一个比较困难的小项目
    在最新的Eclipse 3.6 上配置 Java ME 的开发环境!
  • 原文地址:https://www.cnblogs.com/demo-deng/p/9713471.html
Copyright © 2020-2023  润新知