• MLLib实践Naive Bayes


    引言

    本文基于Spark (1.5.0) ml库提供的pipeline完整地实践一次文本分类。pipeline将串联单词分割(tokenize)、单词频数统计(TF),特征向量计算(TF-IDF),朴素贝叶斯(Naive Bayes)模型训练等。
    本文将基于“20 NewsGroups” 数据集训练并测试Naive Bayes模型。这二十个新闻组数据集合是收集大约20,000新闻组文档,均匀的分布在20个不同的集合。我将使用'20news-bydate.tar.gz'文件,因为该数据集中已经将数据划分为两类:train和test,非常方便我们对模型进行训练和评价。

    20news-bydate.tar.gz - 20 Newsgroups sorted by date; duplicates and some headers removed (18846 documents)

    Naive Bayes算法介绍

    NB算法属于有监督分类算法,对输入数据: _M_表示输入样本容量,我们的目标是将其对号入座到某一个分类结果:
    我们将选择可能性最大的那个分类结果,或者说概率最大的那个分类:
    (y^j= ext{arg}max_y{P(y|X^j)},jin{1,...,M})

    根据贝叶斯公式:
    (egin{align}P(y|(x_i,...,x_N))&=frac{P(x_1,...,x_N|y)P(y)}{P(x_1,...,x_N)}end{align})

    我们在分类时只需要考虑分子上的两项乘积,并由此可以得出结论:后验概率∝似然概率✖️先验概率(最大后验概率问题转化为最大似然问题)。
    进一步地,Naive Bayes模型假设了似然函数的计算时简单地假设_X_的各维度之间独立,这样可以简化似然概率计算公式为:
    即给定分类下某个输入_X_出现的概率等于该分类下输入_X_各个维度分别出现的概率乘积。
    综上,在naive bayes算法框架下,对于某个输入_X_:

    • 如果_X_属于某个分类_y_的概率的概率大于属于其它分类的概率,则判定该输入属于分类_y_;
    • _X_属于某个分类_y_的概率正比于分类_y_自身出现的概率✖️该分类_y_条件下_X_各个维度出现的概率的乘积。

    那么,模型训练的目标就很明朗了,我们需要基于给定的训练样本计算出:

    • 各个分类的先验概率:
    • 训练样本中,每个分类条件下,输入各个维度出现的似然概率:

    模型用于分类新数据的计算:


    spark mllib中算法流程

    spark中对NaiveBayes算法的实现非常清晰明了,算法通过combineByKey计算每个分类下:
    (p_k=frac{sum_{j=1}^Mmathbb{I}(y^j=y_k)}{M})( heta(,k)=frac{sum_{j=1}^Mmathbb{I}(y^j=y_k)cdot{X^j}+alpha}{sum_{j=1}^M{X^j}+alphacdot{M}})

    20 newsgroups实践

    数据集分为train和test两组,分别用于训练和测试。每组数据都分为20类,每类数据存放在各自子文件下:

    .
    ├── 20news-bydate-test
    │   ├── alt.atheism
    │   ├── comp.graphics
    │   ├── comp.os.ms-windows.misc
    │   ├── comp.sys.ibm.pc.hardware
    │   ├── comp.sys.mac.hardware
    │   ├── comp.windows.x
    │   ├── misc.forsale
    │   ├── rec.autos
    │   ├── rec.motorcycles
    │   ├── rec.sport.baseball
    │   ├── rec.sport.hockey
    │   ├── sci.crypt
    │   ├── sci.electronics
    │   ├── sci.med
    │   ├── sci.space
    │   ├── soc.religion.christian
    │   ├── talk.politics.guns
    │   ├── talk.politics.mideast
    │   ├── talk.politics.misc
    │   └── talk.religion.misc
    └── 20news-bydate-train
        ├── alt.atheism
        ├── comp.graphics
        ├── comp.os.ms-windows.misc
        ├── comp.sys.ibm.pc.hardware
        ├── comp.sys.mac.hardware
        ├── comp.windows.x
        ├── misc.forsale
        ├── rec.autos
        ├── rec.motorcycles
        ├── rec.sport.baseball
        ├── rec.sport.hockey
        ├── sci.crypt
        ├── sci.electronics
        ├── sci.med
        ├── sci.space
        ├── soc.religion.christian
        ├── talk.politics.guns
        ├── talk.politics.mideast
        ├── talk.politics.misc
        └── talk.religion.misc
    

    原始文档将经过如下流程训练得到NaiveBayes模型:

    代码中的几点注解:

    • 各类数据根据所在的子文件夹来分类,我们在写代码时需要利用子文件夹名称,这时可以通过调用sc.wholeTextFiles(...)函数得到RDD(String,String)类型的原始数据,_1表示文件的绝对路径,_2表示该文件的内容。我们进一步从_1中截取出子文件夹的名称f.split("/").takeRight(2).head.
    • pipeline框架基于DataFrame,所有我们需要将RDD转为DataFrame:
    import sqlContext.implicits._
    labelNameAndData.toDF("id", "sentence").cache()```
    - 所有的转换都使用ml提供的类,未做任何定制或改动,当前模型在测试集上的准确度为82%。
    
    代码:
    ```scala
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.NaiveBayes
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
    import org.apache.spark.{Logging, SparkConf, SparkContext}
    
    
    object NBTest extends App with Logging {
      def createRawDf(s: String) = {
        //sc.setLogLevel("INFO")
        val fileNameData = sc.wholeTextFiles(s)
    
        val uniqueLabels = Array("alt.atheism", "comp.graphics", "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware", "comp.windows.x", "misc.forsale", "rec.autos", "rec.motorcycles", "rec.sport.baseball", "rec.sport.hockey", "sci.crypt", "sci.electronics", "sci.med", "sci.space", "soc.religion.christian", "talk.politics.guns", "talk.politics.mideast", "talk.politics.misc", "talk.religion.misc")
        val uniqueLabelsBc = sc.broadcast(uniqueLabels)
    
        val labelNameAndData = fileNameData
          .map { case (f, data) => (f.split("/").takeRight(2).head, data) }
          .mapPartitions {
            itrs =>
              val labelIdMap = uniqueLabelsBc.value.zipWithIndex.toMap
              itrs.map {
                case (labelName, data) => (labelIdMap(labelName), data)
              }
          }
    
        import sqlContext.implicits._
        labelNameAndData.toDF("id", "sentence").cache()
    
      }
    
      def createTrainPpline() = {
        val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
    
        val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures")
    
        val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    
        //val vecAssembler = new VectorAssembler().setInputCols(Array("features")).setOutputCol("id")
    
        val nb = new NaiveBayes().setFeaturesCol("features").setLabelCol("id")
    
        new Pipeline().setStages(Array(tokenizer, hashingTF, idf, nb))
      }
    
      val conf = new SparkConf().setMaster("local[2]").setAppName("nb")
        .set("spark.ui.enabled", "false")
      val sc = new SparkContext(conf)
      val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    
      val training = createRawDf("file:////root/work/test/20news-bydate-train/*")
    
      val ppline = createTrainPpline()
      val nbModel = ppline.fit(training)
    
      val test = createRawDf("file:////root/work/test/20news-bydate-test/*")
      val testRes = nbModel.transform(test)
    
      val evaluator = new MulticlassClassificationEvaluator().setLabelCol("id")
      val accuracy = evaluator.evaluate(testRes)
      println("Test Error = " + (1.0 - accuracy))
    
    }
    
  • 相关阅读:
    selenium的持续问题解决
    为什么使用Nginx
    [转]性能测试场景设计深度解析
    [转]CentOS7修改时区的正确姿势
    [转]利用Fiddler模拟恶劣网络环境
    [转]什么是微服务
    [转] WebSocket 教程
    [转] Python实现简单的Web服务器
    shell修改配置文件参数
    [转] linux shell 字符串操作(长度,查找,替换)详解
  • 原文地址:https://www.cnblogs.com/luweiseu/p/7699011.html
Copyright © 2020-2023  润新知