• 【深度学习】BP算法分类iris数据集


    这里写图片描述
    Network:

    package test2;
    
    import java.util.Random;
    
    public class Network {
    
        private double input[]; // 输入层
        private double hidden[]; // 隐藏层
        private double output[]; // 输出层
        private double target[]; // 期望输出向量
        private double i_h_weight[][]; // 输入层-隐藏层权值
        private double h_o_weight[][]; // 隐藏层-输出层权值
        private double i_h_weightUpdate[][]; // 输入层权值更新
        private double h_o_weightUpdate[][]; // 输出层权值更新
        private double outputError[];// 输出层误差
        private double hiddenError[];// 隐藏层误差
        private double outputErrorSum;// 输出误差和
        private double hiddenErrorSum;// 隐藏误差和
        // private double i_threshold[]; // 输入层阈值
        // private double o_threshold[]; // 输出层阈值
        private double rate = 0.25;
        private double momentum = 0.3;
    
        private Random random;
        /**
         * 初始化
         * @param inputSize
         * @param hiddenSize
         * @param outputSize
         */
        public void init(int inputSize, int hiddenSize, int outputSize) {
            input = new double[inputSize];
            hidden = new double[hiddenSize];
            output = new double[outputSize];
            target = new double[outputSize];
    
            i_h_weight = new double[inputSize][hiddenSize];
            h_o_weight = new double[hiddenSize][outputSize];
            i_h_weightUpdate = new double[inputSize][hiddenSize];
            h_o_weightUpdate = new double[hiddenSize][outputSize];
    
            outputError = new double[outputSize];
            hiddenError = new double[hiddenSize];
    
            rate = 0.2;
            momentum = 0.3;
    
            random = new Random();
            randomWeights(i_h_weight);
            randomWeights(h_o_weight);
    
        }
        /**
         * 随机权值
         * @param matrix
         */
        private void randomWeights(double[][] matrix) {
            for (int i = 0; i < matrix.length; i++)
                for (int j = 0; j < matrix[i].length; j++) {
                    double real = random.nextDouble();
                    matrix[i][j] = real > 0.5 ? real : -real;
    
                }
        }
        /**
         * 训练
         * @param trainData
         * @param target
         */
        public void train(double[] trainData, double[] target) {
            loadInput(trainData);
            loadTarget(target);
            forward();
            calculateError();
            adjustWeight();
        }
        /**
         * 测试
         * @param inData
         * @return
         */
        public double[] test(double[] inData) {
            if (inData.length != input.length) {
                throw new IllegalArgumentException("长度不匹配.");
            }
            input = inData;
            forward();
            return getNetworkOutput();
        }
        /**
         * 网络输出
         * @return
         */
        private double[] getNetworkOutput() {
            int len = output.length;
            double[] temp = new double[len];
            for (int i = 0; i != len; i++)
                temp[i] = output[i];
            return temp;
        }
        /**
         * 载入期望数据
         * @param target
         */
        private void loadTarget(double target[]) {
            if (this.target.length != target.length) {
                throw new IllegalArgumentException("长度不匹配.");
            }
            this.target = target;
        }
        /**
         * 载入输入数据
         * @param input
         */
        private void loadInput(double input[]) {
            if (this.input.length != input.length) {
                throw new IllegalArgumentException("长度不匹配.");
            }
            this.input = input;
        }
        /**
         * 前向传播
         * @param layer0
         * @param layer1
         * @param weight
         */
        private void forward(double[] layer0, double[] layer1, double[][] weight) {
            for (int j = 0; j < layer1.length; j++) {
                double sum = 0;
                for (int i = 0; i < layer0.length; i++)
                    sum += weight[i][j] * layer0[i];
                layer1[j] = sigmoid(sum);
            }
        }
        /**
         * 前向传播
         */
        public void forward() {
            forward(input, hidden, i_h_weight);
            forward(hidden, output, h_o_weight);
        }
        /**
         * 输出层误差
         */
        private void outputError() {
            double errSum = 0;
            for (int i = 0; i < outputError.length; i++) {
                double o = output[i];
                outputError[i] = o * (1d - o) * (target[i] - o);// 误差函数
                errSum += Math.abs(outputError[i]);
            }
            outputErrorSum = errSum;
        }
        /**
         * 隐含层误差
         */
        private void hiddenError() {
            double errSum = 0;
            for (int i = 0; i < hiddenError.length; i++) {
                double o = hidden[i];
                double sum = 0;
                for (int j = 0; j < outputError.length; j++)
                    sum += h_o_weight[i][j] * outputError[j];
                hiddenError[i] = o * (1d - o) * sum;
                errSum += Math.abs(hiddenError[i]);
            }
            hiddenErrorSum = errSum;
        }
        /**
         * 计算误差
         */
        private void calculateError() {
            outputError();
            hiddenError();
        }
        /**
         * 调整权值
         * @param error
         * @param layer
         * @param weight
         * @param prevWeight
         */
        private void adjustWeight(double[] error, double[] layer, double[][] weight, double[][] prevWeight) {
            // layer[0] = 1;
            for (int i = 0; i < error.length; i++) {
                for (int j = 0; j < layer.length; j++) {
                    double newVal = momentum * prevWeight[j][i] + rate * error[i] * layer[j];
                    weight[j][i] += newVal;
                    prevWeight[j][i] = newVal;
                }
            }
        }
        /**
         * 调整权值
         */
        private void adjustWeight() {
            adjustWeight(hiddenError, input, i_h_weight, i_h_weightUpdate);// 15,15,(4,15),(4,15)
            adjustWeight(outputError, hidden, h_o_weight, h_o_weightUpdate);
        }
    
        /**
         * 激活函数,输出区间(0,1),关于(0,0.5)中心对称
         * 
         * @param x
         * @return
         */
        public double sigmoid(double x) {
            return 1 / (1 + Math.exp(-x));
        }
    
        /**
         * 激活函数,输出区间(-1,1),关于(0,0)中心对称
         * 
         * @param x
         * @return
         */
        public double tanh(double x) {
            return (1 - Math.exp(-2 * x)) / (1 + Math.exp(-2 * x));
        }
    
    }
    

    Mian:

    package test2;
    
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    import java.util.Random;
    
    import test.BP;
    
    public class Main {
    
        public static void main(String[] args) throws IOException {
            System.out.println("->读取样本数据");
            ReadData rd = new ReadData();
            List<double[]> data = rd.loadData("data/iris.txt", 0, 3, ",");
            System.out.println("->读取完成");
            System.out.println("->初始化神经网络");
            int ipt = 4;
            int opt = 3;
            int hid = (int) (Math.sqrt(ipt + opt) + 10);
            Network bp = new Network();
            bp.init(ipt, hid, opt);
            System.out.println("->初始化完成");
            int maxLearn = 10000;
            System.out.println("->最大学习次数:" + maxLearn);
            System.out.println("->开始训练");
            double start = System.currentTimeMillis();
            for (int j = 0; j < maxLearn; j++) {
                for (int i = 0; i < data.size(); i++) {
                    double[] target = new double[] { 0, 0, 0 };
                    if (i < 50)
                        target[0] = 1;
                    else if (i < 100)
                        target[1] = 1;
                    else if (i < 150)
                        target[2] = 1;
                    bp.train(data.get(i), target);
                }
            }
            double end = System.currentTimeMillis();
            System.out.println("->训练完成,用时:" + (end - start) + "ms");
    
            System.out.println("-------------");
            List<double[]> testData = rd.loadData("data/test.txt", 0, 3, ",");
            int correct = 0;
            int error = 0;
            for (int i = 0; i < testData.size(); i++) {
                double[] result = bp.test(testData.get(i));
                // System.out.println("-------------");
                // System.out.println("->网络输出:"+Arrays.toString(result));
                // System.out.println("->分类结果:"+classify(result));
                if (classify(result).equals(rd.getColumn("data/test.txt", 4, ",").get(i))) {
                    // System.out.println("->分类结果:√");
                    correct++;
                } else {
                    // System.out.println("->分类结果:×");
                    error++;
                }
            }
            System.out.println("->测试数据:" + (correct + error) + "条," + "正确 " + correct + "条");
            System.out.println("->正确率:" + (float) correct / (correct + error));
        }
    
        private static String classify(double[] result) {
            String[] category = { "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
            String resStr = "";
            double max = -Integer.MIN_VALUE;
            int idx = -1;
            for (int i = 0; i != result.length; i++) {
                if (result[i] > max) {
                    max = result[i];
                    idx = i;
                }
            }
            switch (idx) {
            case 0:
                resStr = category[0];
                break;
            case 1:
                resStr = category[1];
                break;
            case 2:
                resStr = category[2];
            default:
                break;
            }
            return resStr;
        }
    
    }
    

    结果:
    这里写图片描述

  • 相关阅读:
    转载--gulp入门
    grunt之easy demo
    CentOS下vm虚拟机桥接联网
    Webstorm & PhpStorm
    2.使用Package Control组件安装
    virtual方法和abstract方法
    sql server 2008 跨服务器查询
    .NET开源项目常用记录
    vs2010 安装MVC 3.0
    所有运行命令指令大全
  • 原文地址:https://www.cnblogs.com/cnsec/p/13286785.html
Copyright © 2020-2023  润新知