1 import com.hankcs.hanlp.tokenizer.NLPTokenizer 2 import org.apache.hadoop.io.{LongWritable, Text} 3 import org.apache.hadoop.mapred.TextInputFormat 4 import org.apache.log4j.{Level, Logger} 5 import org.apache.spark.ml.feature.Word2Vec 6 import org.apache.spark.sql.SparkSession 7 8 /** 9 * Created by zhen on 2018/11/20. 10 */ 11 object Word2Vec { 12 Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别 13 def main(args: Array[String]) { 14 val spark = SparkSession.builder() 15 .appName("Word2Vec") 16 .master("local[2]") 17 .getOrCreate() 18 val sc = spark.sparkContext 19 20 val trainDataPath = "E://BDS/newsparkml/src/news_tensite_xml.smarty.dat" 21 // 数据预处理 22 val rdd = sc.hadoopFile(trainDataPath, classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) 23 .map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, "GBK")) 24 .filter(row => row.contains("content")) 25 .map(row => 26 if(row.contains("content")){ 27 row.substring(row.indexOf(">")+1, row.lastIndexOf("<")).trim() 28 }else{ 29 null 30 } 31 ) 32 .filter(row => !row.equals(null)) 33 // 分词 34 val segmentResult = rdd.mapPartitions( row =>{ 35 row.map(word => { 36 val nlpList = NLPTokenizer.segment(word) 37 import scala.collection.JavaConverters._ 38 nlpList.asScala.map(term => { 39 term.word.trim() 40 }) 41 .filter(word => word.length>1) //过滤掉长度小于2的词 42 .mkString(" ") 43 }) 44 }) 45 val regex = """^d+$""".r 46 //val size = 5 47 segmentResult.saveAsTextFile("E:/BDS/newsparkml/src/分词结果") 48 // 加载分词训练数据 49 val input = sc.textFile("E:/BDS/newsparkml/src/分词结果") 50 //.filter(row => row.split(" ").length>=size) 51 .filter(row => regex.findFirstMatchIn(row) == None) //过滤掉无用的数字关键词 52 .map(row => { 53 val split = row.split(" ") 54 val array : Array[String] = new Array[String](split.length) 55 for(i<- 0 until split.length){ 56 array(i) = split(i) 57 } 58 new Tuple1(array) 59 }) 60 61 val dataFrame = spark.sqlContext.createDataFrame(input).toDF("text") 62 dataFrame.foreach(println(_)) 63 //创建Word2Vec对象 64 val word2Vec = new Word2Vec() 65 .setInputCol("text") 66 .setOutputCol("result") 67 .setVectorSize(50) 68 .setNumPartitions(64) 69 //训练模型 70 val model = word2Vec.fit(dataFrame) 71 //缓存模型 72 model.save("E:/BDS/newsparkml/src/Word2VecModel") 73 //保存词向量数据 74 /*val vector = model.getVectors.map{ 75 case (word, vector) => Seq(word, vector) 76 } 77 vector.toJavaRDD.saveAsTextFile("E:/BDS/newsparkml/src/Word2VecData")*/ 78 //预测 79 val like = model.findSynonyms("中国", 10) 80 like.foreach(println(_)) 81 /*for((item, literacy) <- like){ 82 print(s"$item $literacy") 83 }*/ 84 } 85 }
分词结果:
分词结果部分数据:
模型:
结果:
分析:
预测结果与训练集数据紧密相关,Word2Vec会根据训练集中各词之间的紧密程度设置不同的相识度,因此,要想获得较好的预测结果,需要有合适的训练集!