• 神经网络学习笔记2


    预处理:将各个像素值除以255,进行了简单的正则化。

    批处理:可以减轻数据总线的负荷,相对于数据读入,可以将更多的时间用在计算上。批处理一次性计算大型数组比分开逐步计算各个小型数组要快得多。

    考虑打包输入多张图像情形,使用predict()一次性打包处理100张图像
    x形状:100*784

    for i in range(0,len(x),batch_size):
        x_batch=x[i:i+batch_size] # 通过x[i:i+batch_size]从输入数据中抽出批数据
        y_batch=predict(network,x_batch)
        p=np.argmax(y_batch,axis=1) #  通过argmax()获取值最大的元素索引
        accuracy_cnt+=np.sum(p==t[i:i+batch_size])

    range():指定为range(start,end),生成一个由start到end-1之间的整数构成的列表

    使用for语句逐一取出保存在x中的图像数据通过predict()函数进行分类。predict函数以numPy数组的形式输出各个标签对应的概率,取出概率列表中最大的值的索引作为预测结果。可以使用np.argmax(x)函数取出数组中最大值的索引。np.argmax(x)将获取被赋给参数x的数组中最大值元素的索引。最后,比较神经网络所预测的答案和正确解的标签,将回答正确的概率作为识别精度。

    通过x[i:i+batch_size]从输入数据中抽出批数据:会取出从i到第i+batch_n个之间的数据,取出样例如x[0:100],x[100:200]...从头开始以100为单位将数据提取为批数据
    然后通过argmax()函数获取值最大的元素的索引
    axis=1:指定在100*10数组中,沿着第一位方向找到值最大的元素索引

    使用批处理可以实现高速高效的运算

    完整实现代码如下:

    # coding: utf-8
    import sys, os
    sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
    import numpy as np
    import pickle
    from dataset.mnist import load_mnist
    from common.functions import sigmoid, softmax
    
    def get_data():
        (x_train,t_train),(x_test,t_test) = load_mnist(normalize=True,flatten=True,one_hot_label=False)
        return x_test,t_test
    
    def init_network():
        with open("sample_weight.pkl",'rb') as f:
            network =pickle.load(f)
        return network
    
    def predict(network,x):
        W1,W2,W3=network["W1"],network["W2"],network["W3"]
        b1,b2,b3=network["b1"],network["b2"],network["b3"]
        a1=np.dot(x,W1)+b1
        z1=sigmoid(a1)
        a2=np.dot(z1,W2)+b2
        z2=sigmoid(a2)
        a3=np.dot(z2,W3)+b3
        y=softmax(a3)
        return y
    
    x,t=get_data()
    network=init_network()
    batch_size=100
    accuracy_cnt=0
    
    for i in range(0,len(x),batch_size):
        x_batch=x[i:i+batch_size] # 通过x[i:i+batch_size]从输入数据中抽出批数据
        y_batch=predict(network,x_batch)
        p=np.argmax(y_batch,axis=1) #  通过argmax()获取值最大的元素索引
        accuracy_cnt+=np.sum(p==t[i:i+batch_size])
    
    print("Accuracy:"+str(float(accuracy_cnt)/len(x)))

    mini-batch 使用mini-batch进行学习


    梯度:梯度指示的方向是各个点函数值减少最多的方向

  • 相关阅读:
    Linux下命令行安装weblogic10.3.6
    11g新特性:Health Monitor Checks
    Oracle/PLSQL: ORA-06550
    DBMS_NETWORK_ACL_ADMIN
    【RDA】使用RDA(Remote Diagnostic Agent)工具对数据库进行健康检查
    ORA-39242 错误
    Yii2 中常用的增删改查操作总结
    PHP递归函数return返回null的问题
    PHP中生成随机字符串,数字+大小写字母随机组合
    使用layer.msg 时间设置不起作用
  • 原文地址:https://www.cnblogs.com/AKsnoopy/p/13467431.html
Copyright © 2020-2023  润新知