• RBM


    获取数据,放到List中

    将数据集划分为训练集、验证集、测试集

    新建RBM对象,确定可见层、隐含层的大小

    训练RBM

    新建线程集

    public static void train(SGDBase sgd, List<SampleVector> samples, SGDTrainConfig config) {
        int xy_n = (int) samples.size();
        int nrModelReplica = config.getNbrModelReplica();

    //划分数据集
        HashMap<Integer, List<SampleVector>> list_map = new HashMap<Integer, List<SampleVector>>();
        for (int i = 0; i < nrModelReplica; i++) {
            list_map.put(i, new ArrayList<SampleVector>());
        }
        Random rand = new Random(System.currentTimeMillis());
        for (SampleVector v: samples) {
            int id = rand.nextInt(nrModelReplica);
            list_map.get(id).add(v);
        }
        //新建线程,并且给线程赋数据
        List<DeltaThread> threads = new ArrayList<DeltaThread>();
        List<LossThread> loss_threads = new ArrayList<LossThread>();
        for (int i = 0; i < nrModelReplica; i++) {
            threads.add(new DeltaThread(sgd, config, list_map.get(i)));
            loss_threads.add(new LossThread(sgd));
        }

        // start iteration
        for (int epoch = 1; epoch <= config.getMaxEpochs(); epoch++) {
            // thread start
            for(DeltaThread thread : threads) {
                thread.train(epoch);
            }

            // waiting for all stop
            while (true) {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    break;
                }
                boolean allStop = true;
                for(DeltaThread thread : threads) {
                    if (thread.isRunning()) {
                        allStop = false;
                        break;
                    }
                }
                if (allStop) {
                    break;
                }
            }

            // update
            for(DeltaThread thread : threads) {
                sgd.mergeParam(thread.getParam(), nrModelReplica);
            }

            logger.info("train done for this iteration-" + epoch);

            /**
             * 1 parameter output
             */
            if(config.isParamOutput() && (0 == (epoch % config.getParamOutputStep()))) {
                SGDPersistableWrite.output(config.getParamOutputPath(), sgd);
            }
           
            /**
             * 2 loss print
             */
            if(!config.isPrintLoss()) {
                continue;
            }
            if (0 != (epoch % config.getLossCalStep())) {
                continue;
            }

            // sum loss
            for (int i = 0; i < nrModelReplica; i++) {
                loss_threads.get(i).sumLoss(threads.get(i).getSamples());
            }

            // waiting for all stop
            while (true) {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    break;
                }
                boolean allStop = true;
                for(LossThread thread : loss_threads) {
                    if (thread.isRunning()) {
                        allStop = false;
                        break;
                    }
                }
                if (allStop) {
                    break;
                }
            }

            // sum up
            double totalError = 0;
            for(LossThread thread : loss_threads) {
                totalError += thread.getError();
            }
            totalError /= xy_n;
            logger.info("iteration-" + epoch + " done, total error is " + totalError);
            if (totalError <= config.getMinLoss()) {
                break;
            }
        }
    }

  • 相关阅读:
    用SecureCRT来上传和下载文件
    Linux指令--tar,gzip
    Linux指令--文件和目录属性
    Linux指令--which,whereis,locate,find
    Linux指令--head,tail
    Linux指令--more,less
    Linux指令--nl
    Linux指令--cat,tac
    Linux指令--touch
    Linux指令--cp
  • 原文地址:https://www.cnblogs.com/huiwq1990/p/3930144.html
Copyright © 2020-2023  润新知