• 机器学习: Python with Recurrent Neural Network


    之前我们介绍了Recurrent neural network (RNN) 的原理:

    http://blog.csdn.net/matrix_space/article/details/53374040

    http://blog.csdn.net/matrix_space/article/details/53376870  

    这里,我们构建一个简单的RNN网络,激励函数我们用sigmoid 函数,利用这个网络,我们来测试二进制数的运算。网络重复模块的表达式是:

    ht=σ(Whht1+WiXt)

    ot=σ(Woht)

    e=12(ytot)2

    import copy, numpy as np
    
    np.random.seed(0)
    
    # compute sigmoid nonlinearity
    # 定义sigmoid 函数
    def sigmoid (x):
        output = 1 / (1+np.exp(-x))
        return output
    
    # convert output to sigmoid function to its derivative
    # 定义sigmoid 函数的导数
    def sigmoid_output_to_derivative(output):
        return  output*(1-output)
    
    # training dataset generation
    # 生成训练集
    int2binary = {}
    binary_dim = 8
    max_number = pow (2, binary_dim)
    binary = np.unpackbits(np.array([range(max_number)], dtype=np.uint8).T, axis=1)
    # np.unpackbits 是将一个uint8的数组元素都转换成0-1二进制形式,这里max_number 是256, 
    # binary 里面一共存了0-255 共 256 个二进制数
    
    
    for i in range(max_number):
        int2binary[i]=binary[i]
    
    # input parameters
    alpha = 0.1
    input_dim = 2
    hidden_dim = 16
    output_dim = 1
    
    # weight 的初始化
    synapse_0 = 2 * np.random.random((input_dim, hidden_dim))-1
    synapse_1 = 2 * np.random.random((hidden_dim, output_dim))-1
    synapse_h = 2 * np.random.random((hidden_dim, hidden_dim)) -1
    
    synapse_0_update = np.zeros_like(synapse_0)
    synapse_1_update = np.zeros_like(synapse_1)
    synapse_h_update = np.zeros_like(synapse_h)
    
    for j in range(10000):
        # 生成一个0-128之间的随机数
        # 获取这个数的二进制序列
        a_int = np.random.randint(max_number/2)
        a = int2binary [a_int]   
    
        b_int = np.random.randint(max_number/2)
        b = int2binary [b_int]
    
        c_int = a_int + b_int
        c = int2binary [c_int]
    
        d = np.zeros_like(c)
    
        overallError = 0
    
        layer_2_deltas = list ()
        layer_1_values = list ()
        layer_1_values.append(np.zeros(hidden_dim))
    
        # moving along the positions in the binary encoding
        for position in range(binary_dim):
    
            # generate input and output
            X = np.array([[a[binary_dim-position-1], b[binary_dim-position-1]]])
            y = np.array([[c[binary_dim-position-1]]]).T
    
         # 计算重复模块的隐含层的输入和输出
            layer_1 = sigmoid(np.dot(X, synapse_0) + np.dot(layer_1_values[-1], synapse_h))
    
            layer_2 = sigmoid(np.dot(layer_1, synapse_1))
    
         # BP
            layer_2_error = y-layer_2
            layer_2_deltas.append((layer_2_error)*sigmoid_output_to_derivative(layer_2))
            overallError += np.abs(layer_2_error[0])
    
            d[binary_dim-position-1] = np.round(layer_2[0][0])
    
            layer_1_values.append(copy.deepcopy(layer_1))
    
        future_layer_1_delta = np.zeros(hidden_dim)
    
        for position in range (binary_dim):
    
            X = np.array([[a[position], b[position]]])
            layer_1 = layer_1_values [-position-1]
            pre_layer_1 = layer_1_values[-position-2]
    
            layer_2_delta = layer_2_deltas[-position-1]
    
            layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(
                synapse_1.T)) * sigmoid_output_to_derivative(layer_1)
    
            # weight update
            synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)
            synapse_h_update += np.atleast_2d(pre_layer_1).T.dot(layer_1_delta)
            synapse_0_update += X.T.dot(layer_1_delta)
    
            future_layer_1_delta = layer_1_delta
    
        synapse_0 += synapse_0_update * alpha
        synapse_1 += synapse_1_update * alpha
        synapse_h += synapse_h_update * alpha
    
        synapse_0_update *= 0
        synapse_1_update *= 0
        synapse_h_update *= 0
    
        # print out progress
        if (j % 500 == 0):
            print ("Error: ", str(overallError))
            print ("Pred:", str(d))
            print ("True:", str(c))
            out = 0
            for index, x in enumerate(reversed(d)):
                out += x*pow(2, index)
            print (str(a_int) + "+" + str(b_int) + "=" + str(out))
            print ("---------------")
    

    运行结果:

    ('Error: ', '[ 3.45638663]')
    ('Pred:', '[0 0 0 0 0 0 0 1]')
    ('True:', '[0 1 0 0 0 1 0 1]')
    9+60=1
    ---------------
    ('Error: ', '[ 4.02253884]')
    ('Pred:', '[0 1 1 0 1 0 1 1]')
    ('True:', '[1 0 0 0 0 0 0 1]')
    112+17=107
    ---------------
    ('Error: ', '[ 3.63389116]')
    ('Pred:', '[1 1 1 1 1 1 1 1]')
    ('True:', '[0 0 1 1 1 1 1 1]')
    28+35=255
    ---------------
    ('Error: ', '[ 3.99234598]')
    ('Pred:', '[1 1 0 1 1 0 1 0]')
    ('True:', '[1 0 1 1 0 0 1 1]')
    78+101=218
    ---------------
    ('Error: ', '[ 3.91366595]')
    ('Pred:', '[0 1 0 0 1 0 0 0]')
    ('True:', '[1 0 1 0 0 0 0 0]')
    116+44=72
    ---------------
    ('Error: ', '[ 3.65154804]')
    ('Pred:', '[1 1 0 1 1 0 1 0]')
    ('True:', '[1 1 0 1 1 1 1 0]')
    122+100=218
    ---------------
    ('Error: ', '[ 3.72191702]')
    ('Pred:', '[1 1 0 1 1 1 1 1]')
    ('True:', '[0 1 0 0 1 1 0 1]')
    4+73=223
    ---------------
    ('Error: ', '[ 3.35048888]')
    ('Pred:', '[1 0 0 1 1 0 0 1]')
    ('True:', '[1 0 0 1 0 0 0 1]')
    76+69=153
    ---------------
    ('Error: ', '[ 3.5852713]')
    ('Pred:', '[0 0 0 0 1 0 0 0]')
    ('True:', '[0 1 0 1 0 0 1 0]')
    71+11=8
    ---------------
    ('Error: ', '[ 2.43239777]')
    ('Pred:', '[0 1 1 0 1 0 1 1]')
    ('True:', '[0 1 1 0 1 0 1 1]')
    72+35=107
    ---------------
    ('Error: ', '[ 2.53352328]')
    ('Pred:', '[1 0 1 0 0 0 1 0]')
    ('True:', '[1 1 0 0 0 0 1 0]')
    81+113=162
    ---------------
    ('Error: ', '[ 1.87382863]')
    ('Pred:', '[0 1 1 0 0 0 1 0]')
    ('True:', '[0 1 1 0 0 0 1 0]')
    21+77=98
    ---------------
    ('Error: ', '[ 0.57691441]')
    ('Pred:', '[0 1 0 1 0 0 0 1]')
    ('True:', '[0 1 0 1 0 0 0 1]')
    81+0=81
    ---------------
    ('Error: ', '[ 0.75100965]')
    ('Pred:', '[0 0 1 1 1 1 0 0]')
    ('True:', '[0 0 1 1 1 1 0 0]')
    49+11=60
    ---------------
    ('Error: ', '[ 1.42589952]')
    ('Pred:', '[1 0 0 0 0 0 0 1]')
    ('True:', '[1 0 0 0 0 0 0 1]')
    4+125=129
    ---------------
    ('Error: ', '[ 0.6594703]')
    ('Pred:', '[0 1 1 0 1 1 0 0]')
    ('True:', '[0 1 1 0 1 1 0 0]')
    80+28=108
    ---------------
    ('Error: ', '[ 0.47477457]')
    ('Pred:', '[0 0 1 1 1 0 0 0]')
    ('True:', '[0 0 1 1 1 0 0 0]')
    39+17=56
    ---------------
    ('Error: ', '[ 0.7200904]')
    ('Pred:', '[1 0 1 0 1 0 0 0]')
    ('True:', '[1 0 1 0 1 0 0 0]')
    123+45=168
    ---------------
    ('Error: ', '[ 0.21595037]')
    ('Pred:', '[0 0 0 0 1 1 1 0]')
    ('True:', '[0 0 0 0 1 1 1 0]')
    11+3=14
    ---------------
    ('Error: ', '[ 0.52112049]')
    ('Pred:', '[1 0 1 0 1 0 1 1]')
    ('True:', '[1 0 1 0 1 0 1 1]')
    71+100=171
    ---------------

    参考来源:

    https://github.com/llSourcell/recurrent_neural_net_demo

  • 相关阅读:
    sqlserver如何查询一个表的主键都是哪些表的外键
    sql server nullif的使用技巧,除数为零的处理技巧
    如何解决数据库中,数字+null=null
    sql server truncate table 删除表数据限制条件
    eclipse需要的环境变量就两个,一个是java_home指向JDK。另一个是Tomcat,自己去preference-sever下new一个
    解释Eclipse下Tomcat项目部署路径问题(.metadata.pluginsorg.eclipse.wst.server.core mp0wtpwebapps)
    mysql登录退出命令
    代码svn下载到本地后,关于数据库问题
    MySQL配置文件详解
    mysql查看存储过程show procedure status;
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9412437.html
Copyright © 2020-2023  润新知