• Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3


    Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3

    http://blog.csdn.net/sunbow0

    第二章Deep Belief Network (深度信念网络)

    3实例

    3.1 測试数据

    依照上例数据,或者新建图片识别数据。

    3.2 DBN实例

    //****************2(读取固定样本:来源于经典优化算法測试函数Sphere Model***********//

        //2 读取样本数据

        Logger.getRootLogger.setLevel(Level.WARN)

        valdata_path ="/user/huangmeiling/deeplearn/data1"

        valexamples =sc.textFile(data_path).cache()

        valtrain_d1 =examples.map { line =>

          valf1 = line.split(" ")

          valf =f1.map(f =>f.toDouble)

          valid =f(0)

          valy = Array(f(1))

          valx =f.slice(2,f.length)

          (id, new BDM(1,y.length,y),new BDM(1,x.length,x))

        }

        valtrain_d =train_d1.map(f => (f._2, f._3))

        valopts = Array(100.0,20.0,0.0) 

        //3 设置训练參数,建立DBN模型

        valDBNmodel =new DBN().

          setSize(Array(5, 7)).

          setLayer(2).

          setMomentum(0.1).

          setAlpha(1.0).

          DBNtrain(train_d, opts) 

        //4 DBN模型转化为NN模型

        valmynn =DBNmodel.dbnunfoldtonn(1)

        valnnopts = Array(100.0,50.0,0.0)

        valnumExamples =train_d.count()

        println(s"numExamples = $numExamples.")

        println(mynn._2)

        for (i <-0 tomynn._1.length -1) {

          print(mynn._1(i) +" ")

        }

        println()

        println("mynn_W1")

        valtmpw1 =mynn._3(0)

        for (i <-0 totmpw1.rows -1) {

          for (j <-0 totmpw1.cols -1) {

            print(tmpw1(i,j) +" ")

          }

          println()

        }

        valNNmodel =new NeuralNet().

          setSize(mynn._1).

          setLayer(mynn._2).

          setActivation_function("sigm").

          setOutput_function("sigm").

          setInitW(mynn._3).

          NNtrain(train_d, nnopts) 

        //5 NN模型測试

        valNNforecast =NNmodel.predict(train_d)

        valNNerror =NNmodel.Loss(NNforecast)

        println(s"NNerror = $NNerror.")

        valprintf1 =NNforecast.map(f => (f.label.data(0), f.predict_label.data(0))).take(200)

        println("预測结果——实际值:预測值:误差")

        for (i <-0 untilprintf1.length)

          println(printf1(i)._1 +" " +printf1(i)._2 +" " + (printf1(i)._2 -printf1(i)._1)) 

    转载请注明出处:

    http://blog.csdn.net/sunbow0

  • 相关阅读:
    BZOJ3473: 字符串
    BZOJ1088: [SCOI2005]扫雷Mine
    跪啃SAM
    BZOJ3932: [CQOI2015]任务查询系统
    BZOJ3545: [ONTAK2010]Peaks
    06.约束
    05.数据表的创建与简单操作
    04.数据库的创建
    安卓6.0后运行时权限封装
    OkGo使用缓存
  • 原文地址:https://www.cnblogs.com/mengfanrong/p/5228232.html
Copyright © 2020-2023  润新知