• Theano编写分类神经网络


    Theano编写分类神经网络

    1.导入模块并创建数据

    #!/usr/bin/env python2
    # -*- coding: utf-8 -*-
    """
    theano de classify
    """
    
    import numpy as np
    import theano.tensor as T
    import theano
    
    #计算分类准确率
    def compute_accuracy(y_target, y_predict):
        correct_prediction = np.equal(y_predict, y_target)
        accuracy = np.sum(correct_prediction)/len(correct_prediction)
        return accuracy
    
    
    #训练数据的个数
    N = 400
    #训练数据的特征数
    feats = 784
    
    #生成随机数
    D = (np.random.randn(N,feats), np.random.randint(size = N, low = 0, high =2))
    
    print(D)

    2.建立模型

    #构建神经网络
    #定义x y容器, 相当于tensorflow中的placeholder
    x = T.dmatrix("x")
    y = T.dvector("y")
    
    #初始化weights和bias, weights的数量和features一样
    
    W = theano.shared(np.random.randn(feats), name =  'w')
    b = theano.shared(0., name='b')
    
    #定义激活函数(sigmoid), 加入l1正则化
    p_1 = T.nnet.sigmoid(T.dot(x,W) + b)
    #sigmoid值大于0.5为true
    prediction = p_1 > 0.5
    #定义交叉熵损失函数
    xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) 
    
    #加入l2正则化,减少过拟合
    cost = xent.mean() + 0.01 * (W**2).sum()
    #定义梯度迭代的gW, gb,用于更新参数
    gW, gb = T.grad(cost, [W, b])

    3.激活模型

    #激活神经网络
    learning_rate = 0.1
    train = theano.function(
            inputs = [x, y],
            outputs = [prediction, xent.mean()],
            updates = ((W, W - learning_rate* gW), (b, b - learning_rate * gb))
            )
    
    predict = theano.function(inputs = [x], outputs = prediction)

    4.训练模型

    #训练模型
    for i in range(500):
        pred, err = train(D[0], D[1])
        if i % 50 ==0:
            print('cost', err)
            print('accuracy', compute_accuracy(D[1], predict(D[0])))
    
    print("target values for D:")
    print(D[1])
    print("prediction on D:")
    print(predict(D[0]))
  • 相关阅读:
    Python 3基础教程1-环境安装和运行环境
    Python+Selenium框架设计篇之-什么是POM
    CS/BS架构的特点
    HTTP 方法:GET 对比 POST
    apt的通讯信道是如何发现的?
    安全运维到运营
    僵尸网络是如何控制你的电脑的?
    了解资产,减少网络攻击面的威胁
    ISC安全大会感受
    mysql 在忘记密码不能登录情况下的设置方法
  • 原文地址:https://www.cnblogs.com/xmeo/p/7241275.html
Copyright © 2020-2023  润新知