• Hopfield神经网络实现污染字体的识别


    这个网络的内部使用的是hebb学习规则

    贴上两段代码:

      

    package geym.nn.hopfiled;
    
    import java.util.Arrays;
    
    import org.neuroph.core.data.DataSet;
    import org.neuroph.core.data.DataSetRow;
    import org.neuroph.nnet.Hopfield;
    import org.neuroph.nnet.comp.neuron.InputOutputNeuron;
    import org.neuroph.nnet.learning.HopfieldLearning;
    import org.neuroph.util.NeuronProperties;
    import org.neuroph.util.TransferFunctionType;
    
    /**
     * 识别0 1 2 使用hopfield 全连接结构
     * @author Administrator
     *
     */
    public class HopfieldSample2 {
    
        public static double[] format(double[] data){
            for(int i=0;i<data.length;i++){
                if(data[i]==0)data[i]=-1;
            }
            return data;
        }
        
        public static void main(String args[]) {
            NeuronProperties neuronProperties = new NeuronProperties();
            neuronProperties.setProperty("neuronType", InputOutputNeuron.class);
            neuronProperties.setProperty("bias", new Double(0.0D));
            neuronProperties.setProperty("transferFunction", TransferFunctionType.STEP);
            neuronProperties.setProperty("transferFunction.yHigh", new Double(1.0D));
            neuronProperties.setProperty("transferFunction.yLow", new Double(-1.0D));
    
            // create training set (H and T letter in 3x3 grid)
            DataSet trainingSet = new DataSet(30);
            trainingSet.addRow(new DataSetRow(format(new double[] { 
                    0,1,1,1,1,0,
                    1,0,0,0,0,1,
                    1,0,0,0,0,1,
                    1,0,0,0,0,1,
                    0,1,1,1,1,0}))); //0
            
            trainingSet.addRow(new DataSetRow(format(new double[] { 
                    0,0,0,0,0,0,
                    1,0,0,0,0,0,
                    1,1,1,1,1,1,
                    0,0,0,0,0,0,
                    0,0,0,0,0,0}))); //1
            
            trainingSet.addRow(new DataSetRow(format(new double[] { 
                    1,0,0,0,0,0,
                    1,0,0,1,1,1,
                    1,0,0,1,0,1,
                    1,0,0,1,0,1,
                    0,1,1,0,0,1}))); //2
            
            
    
            // create hopfield network
            Hopfield myHopfield = new Hopfield(30, neuronProperties);
            myHopfield.setLearningRule(new StandHopfieldLearning());
            // learn the training set
            myHopfield.learn(trainingSet);
    
            // test hopfield network
            System.out.println("Testing network");
    
            // add one more 'incomplete' H pattern for testing - it will be
            // recognized as H
            // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1
            // });
            // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1
            // });
            DataSetRow h = new DataSetRow(format(new double[] { 
                    1,0,0,0,0,0,
                    1,0,0,1,1,1,
                    1,0,0,1,0,1,
                    1,0,0,1,0,0,
                    0,1,1,0,0,1})); // 2 bad
            trainingSet.addRow(h); 
    
    
            myHopfield.setInput(h.getInput());
    
            double[] networkOutput = null;
            double[] preNetworkOutput = null;
            while (true) {
                myHopfield.calculate();
                networkOutput = myHopfield.getOutput();
                if (preNetworkOutput == null) {
                    preNetworkOutput = networkOutput;
                    continue;
                }
                if (Arrays.equals(networkOutput, preNetworkOutput)) {
                    break;
                }
                preNetworkOutput = networkOutput;
            }
    
            System.out.print("Input: " + Arrays.toString(h.getInput()));
            System.out.println(" Output: " + Arrays.toString(networkOutput));
        
            System.out.println(Arrays.equals(format(new double[] { 
                    1,0,0,0,0,0,
                    1,0,0,1,1,1,
                    1,0,0,1,0,1,
                    1,0,0,1,0,1,
                    0,1,1,0,0,1}), networkOutput));
        }
    
    }

    下面就是StandHopfieldLearning类的实现,里面标红的地方就是hebb学习规则,权重为输入和输出的乘积:

      

    package com.cgjr.com.hopfield;
    
    import org.neuroph.core.Connection;
    import org.neuroph.core.Layer;
    import org.neuroph.core.Neuron;
    import org.neuroph.core.data.DataSet;
    import org.neuroph.core.data.DataSetRow;
    import org.neuroph.core.learning.LearningRule;
    
    /**
     * Learning algorithm for the Hopfield neural network.
     * 
     * @author Zoran Sevarac <sevarac@gmail.com>
     */
    public class StandHopfieldLearning extends LearningRule {
        
        /**
         * The class fingerprint that is set to indicate serialization
         * compatibility with a previous version of the class.
         */    
        private static final long serialVersionUID = 1L;
    
        /**
         * Creates new HopfieldLearning
         */
        public StandHopfieldLearning() {
            super();
        }
    
    
        /**
         * Calculates weights for the hopfield net to learn the specified training
         * set
         * 
         * @param trainingSet
         *            training set to learn
         */
        public void learn(DataSet trainingSet) {
            int M = trainingSet.size();
            int N = neuralNetwork.getLayerAt(0).getNeuronsCount();
            Layer hopfieldLayer = neuralNetwork.getLayerAt(0);
    
            for (int i = 0; i < N; i++) {
                for (int j = 0; j < N; j++) {
                    if (j == i)
                        continue;
                    Neuron ni = hopfieldLayer.getNeuronAt(i);
                    Neuron nj = hopfieldLayer.getNeuronAt(j);
                    Connection cij = nj.getConnectionFrom(ni);
                    Connection cji = ni.getConnectionFrom(nj);
                    
                    double wij=0;
                    for(int k = 0;k < M;k++){
                        DataSetRow row=trainingSet.getRowAt(k);
                        double[] inputs=row.getInput();
                        wij+=inputs[i]*inputs[j];//Hebb学习规则
                    }
                    cij.getWeight().setValue(wij);
                    cji.getWeight().setValue(wij);
                }// j
            } // i
    
        }
    
    }
  • 相关阅读:
    龟兔赛跑(多线程练习题)
    进程和线程详解
    toString()方法详解
    使用IDEA的Git插件上传项目教程
    js运算符单竖杠“|”的用法和作用及js数据处理
    vue项目axios请求接口,后端代理请求接口404,问题出现在哪?
    jQuery的ajax的post请求json格式无法上传空数组
    es6 学习小记 扩展运算符 三个点(...)
    select2插件使用小记2
    js中多维数组转一维
  • 原文地址:https://www.cnblogs.com/beigongfengchen/p/5562032.html
Copyright © 2020-2023  润新知