• 机器学习(三):朴素贝叶斯(NB)


    # -- coding: utf-8 --
    from numpy import *
    
    def loadDataSet():
        # 创建单词向量及对应的分类,1代表侮辱性文字,0代表正常言论
        postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
                     ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
                     ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
                     ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
                     ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
                     ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
        classVec = [0,1,0,1,0,1]
        return postingList,classVec
    
    def createVocabList(dataSet):               # 创建一个过滤dataSet重复数据的表
        vocabSet = set([])                      # 创建一个空集
        for document in dataSet:
            vocabSet = vocabSet | set(document) # 创建两个集合的并集
        return list(vocabSet)
    
    def setOfWords2Vec(vocabList, inputSet):    # 将文档转换成特征向量
        returnVec = [0]*len(vocabList)          # 创建一个长度与不重复词表一样的一维数组,元素默认为0
        for word in inputSet:
            if word in vocabList:               # 若词表单词在文档中出现过,则将元素改为1
                returnVec[vocabList.index(word)] = 1
            else: print "the word: %s is not in my Vocabulary!" % word
        return returnVec
    
    def trainNB0(trainMatrix,trainCategory):
        numTrainDocs = len(trainMatrix)         # 计算训练样本数量
        numWords = len(trainMatrix[0])          # 计算不重复词表中单词数量
        pAbusive = sum(trainCategory)/float(numTrainDocs) # 类别为1的训练样本的概率,即P(Y=c1)
        # 创建一个长度与不重复词表一样的一维数组,计算各单词出现次数,初始化为1
        p0Num = ones(numWords); p1Num = ones(numWords)
        p0Denom = 2.0; p1Denom = 2.0            # 将分母(所有单词量)初始化为2
        for i in range(numTrainDocs):
            if trainCategory[i] == 1:
                p1Num += trainMatrix[i]         # 若类别为1,将相应样本列相加,得该单词在全部文档中出现次数
                p1Denom += sum(trainMatrix[i])  # 计算类别为1的样本的所有单词量
            else:
                p0Num += trainMatrix[i]         # 若类别为0,将相应样本列相加,得该单词在全部文档中出现次数
                p0Denom += sum(trainMatrix[i])  # 计算类别为0的样本的所有单词量
        # 在类别为1的条件下,各单词在文档中出现的概率,并求其对数,即log(P(x=xi|Y=c1))
        p1Vect = log(p1Num/p1Denom)
        # 在类别为0的条件下,各单词在文档中出现的概率,并求其对数,即log(P(x=xi|Y=c0))
        p0Vect = log(p0Num/p0Denom)
        return p0Vect,p1Vect,pAbusive
    
    def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):
        # 假设传入的测试样本特征为第1,3,4个
        # 则vec2Classify * p0Vec表示为log(P(x=x1|Y=c0))+log(P(x=x3|Y=c0))+log(P(x=x4|Y=c0))
        # 则vec2Classify * p1Vec表示为log(P(x=x1|Y=c1))+log(P(x=x3|Y=c1))+log(P(x=x4|Y=c1))
    
        # p1=log(P(x=x1|Y=c1))+...+log(P(x=xn|Y=c1))+log(P(Y=c1))
        p1 = sum(vec2Classify * p1Vec) + log(pClass1)
        # p0=log(P(x=x1|Y=c0))+...+log(P(x=xn|Y=c0))+log(P(Y=c0))
        p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)
        # 对比p1和p0的大小,大的对应的值及为最终的分类结果
        if p1 > p0:
            return 1
        else:
            return 0
    
    def testingNB():
        listOPosts,listClasses = loadDataSet()     # 获取单词向量及对应分类
        myVocabList = createVocabList(listOPosts)  # 获取不重复的词表(此时假设每个特征同等重要)
        trainMat=[]
        for postinDoc in listOPosts:
            # 为每个单词构建一个特征
            # 输入某文档,输出文档向量,向量为1或0,分别表示词表myVocabList中的单词在输入文档是否出现
            trainMat.append(setOfWords2Vec(myVocabList, postinDoc))
        p0V,p1V,pAb = trainNB0(array(trainMat),array(listClasses))
        testEntry = ['love', 'my', 'dalmation']
        thisDoc = array(setOfWords2Vec(myVocabList, testEntry))
        print testEntry,'classified as: ',classifyNB(thisDoc,p0V,p1V,pAb)
        testEntry = ['stupid', 'garbage']
        thisDoc = array(setOfWords2Vec(myVocabList, testEntry))
        print testEntry,'classified as: ',classifyNB(thisDoc,p0V,p1V,pAb)
    

      
    以上全部内容参考书籍如下:
    李航《统计学习方法》
    Peter Harrington《Machine Learing in Action》
    《概率论与数理统计》高等教育出版社

  • 相关阅读:
    a[::-1]相当于 a[-1:-len(a)-1:-1],也就是从最后一个元素到第一个元素复制一遍。
    +=
    map 和reduce
    赋值语句
    高阶函数
    函数式编程
    迭代器
    如何判断一个对象是可迭代对象呢?方法是通过collections模块的Iterable类型判断:
    ie11升级的过程中遇到的问题以及解决办法
    .csporj 文件部分节点解析
  • 原文地址:https://www.cnblogs.com/pengfeiz/p/11392640.html
Copyright © 2020-2023  润新知