• Mixtures of Gaussians and the EM algorithms


    Acknowledgement to Stanford CS229.

    Generative modeling is itself a kind of unsupervised learning task[1]. Given unlabelled data, 

    To estimate the parameters, we can write the likelihood as 

    which is also

    The EM algorithm can solve this pdf estimation iteratively.

    An example is provided here. The data points are drawn from 2 gaussian distributions. 

     1 import numpy as np
     2 import operator
     3 np.random.seed(0)
     4 x0=np.random.normal(0,1,50)
     5 x0=np.concatenate((x0,np.random.normal(2,1,50)),
     6                  axis=0)
     7 
     8 mus0=np.array([
     9     0,1
    10 ])
    11 sigmas0=np.array([
    12     2,2
    13 ])
    14 def gauss(x,mu,sigma):
    15     """
    16 
    17     :param x:
    18     :param mu:
    19     :param sigma:
    20     :return: pdf(x)
    21     """
    22     # if np.abs((x-mu)/sigma)<1e-5:
    23     #     return
    24     # numerator=np.exp(
    25     #     -(x-mu)**2/(2*sigma**2)
    26     # )
    27     numerator=np.exp(
    28         -0.5*((x-mu)/sigma)**2
    29     )
    30     denominator=np.sqrt(2*np.pi*sigma**2)
    31     return numerator/denominator
    32 def e_step(mus=mus0,sigmas=sigmas0,x=x0,priors=np.ones(len(mus0))/len(mus0)):
    33     """
    34 
    35     :param mus: gaussian centers, an array of shape (m,)
    36     :param sigmas: gaussian standard deviations, an array of shape (m,)
    37     :param x: n samples with no labels
    38     :return: m by n array, where m is # classes
    39     """
    40     assert len(mus)==len(sigmas),"mus and sigmas doesn't have the same length"
    41     m=len(mus)
    42     n=len(x)
    43     w=np.zeros(shape=(m,n))
    44     for j in range(m):
    45         for i in range(n):
    46             w[j][i]=gauss(x=x[i],mu=mus[j],sigma=sigmas[j])*priors[j]
    47     w_sum_wrt_j=np.sum(w,axis=0)#note j is the row index
    48     for j in range(m):
    49         w[j,:]=w[j,:]/w_sum_wrt_j
    50     return w
    51 def m_step(w,current_mus,x=x0):
    52     """
    53 
    54     :param w: m by n array, where m is # classes
    55     :return: mus: gaussian centers, an array of shape (m,)
    56              sigmas: gaussian standard deviations, an array of shape (m,)
    57     """
    58     m,n=w.shape
    59     mus=np.zeros(shape=(m))
    60     sigmas=np.zeros(shape=(m))
    61     for j in range(m):
    62         mus[j]=np.dot(
    63             w[j,:],x
    64         )
    65     mus/=np.sum(w,axis=1)
    66     for j in range(m):
    67         sigmas[j]=np.sqrt(np.dot(
    68             w[j, :], (x-current_mus[j])**2
    69         ))
    70     sigmas/=np.sqrt(np.sum(w,axis=1))
    71 
    72     priors=np.zeros(shape=(len(mus)))
    73     for i in range(n):
    74         tmp=list(map(
    75             gauss,[x[i]]*m,mus,sigmas
    76         ))
    77         tmpmaxindex,tmpmax=max(
    78             enumerate(tmp),key=operator.itemgetter(1)
    79         )
    80         # print(tmp)
    81         # print(tmpmaxindex)
    82         priors[tmpmaxindex]+=1/n
    83     return mus,sigmas,priors
    84 def solve(x=x0,priors=np.ones(len(mus0))/len(mus0)):
    85     # print("priors={}".format(priors))
    86     mus=mus0
    87     sigmas=sigmas0
    88     for k in range(500):
    89         w=e_step(mus=mus,sigmas=sigmas,x=x,priors=priors)
    90         mus,sigmas,priors=m_step(w,current_mus=mus,x=x0)
    91         print("k={},mus={},sigmas={},priors={}".format(k,mus,sigmas,priors))
    92 
    93 if __name__ == '__main__':
    94     solve()

    After 100 iterations, we get an approximation of the real model.

      1 /usr/local/bin/python3.5 /home/csdl/review/fulcrum/gmm/gmm.py
      2 k=0,mus=[ 0.81734343  1.27122747],sigmas=[ 1.60216343  1.33905931],priors=[ 0.32  0.68]
      3 k=1,mus=[ 0.73393663  1.21263431],sigmas=[ 1.48989073  1.27140643],priors=[ 0.35  0.65]
      4 k=2,mus=[ 0.72025041  1.24207148],sigmas=[ 1.47760392  1.25840835],priors=[ 0.36  0.64]
      5 k=3,mus=[ 0.69405554  1.2656654 ],sigmas=[ 1.47453155  1.2480128 ],priors=[ 0.36  0.64]
      6 k=4,mus=[ 0.65993336  1.28545741],sigmas=[ 1.47454151  1.238417  ],priors=[ 0.36  0.64]
      7 k=5,mus=[ 0.62512005  1.30527053],sigmas=[ 1.4739642   1.22830782],priors=[ 0.36  0.64]
      8 k=6,mus=[ 0.59009448  1.32522573],sigmas=[ 1.47230468  1.21788024],priors=[ 0.36  0.64]
      9 k=7,mus=[ 0.55504913  1.34523959],sigmas=[ 1.46932309  1.20732602],priors=[ 0.36  0.64]
     10 k=8,mus=[ 0.52016003  1.36521637],sigmas=[ 1.46489424  1.19678812],priors=[ 0.36  0.64]
     11 k=9,mus=[ 0.4855794   1.38507002],sigmas=[ 1.45897578  1.18636451],priors=[ 0.36  0.64]
     12 k=10,mus=[ 0.45142496  1.4047313 ],sigmas=[ 1.45158393  1.17611449],priors=[ 0.36  0.64]
     13 k=11,mus=[ 0.41777707  1.42414967],sigmas=[ 1.44277296  1.16606539],priors=[ 0.36  0.64]
     14 k=12,mus=[ 0.38468177  1.44329208],sigmas=[ 1.43261873  1.15621962],priors=[ 0.37  0.63]
     15 k=13,mus=[ 0.36595587  1.46867892],sigmas=[ 1.41990091  1.14409883],priors=[ 0.37  0.63]
     16 k=14,mus=[ 0.33654056  1.48870368],sigmas=[ 1.4082571   1.13343039],priors=[ 0.37  0.63]
     17 k=15,mus=[ 0.30597566  1.50763142],sigmas=[ 1.39543174  1.12335036],priors=[ 0.37  0.63]
     18 k=16,mus=[ 0.27568252  1.52609593],sigmas=[ 1.38137316  1.11360905],priors=[ 0.37  0.63]
     19 k=17,mus=[ 0.24588996  1.54419117],sigmas=[ 1.36625212  1.10407072],priors=[ 0.37  0.63]
     20 k=18,mus=[ 0.21664299  1.56192385],sigmas=[ 1.35022455  1.09464684],priors=[ 0.37  0.63]
     21 k=19,mus=[ 0.18796432  1.57928065],sigmas=[ 1.33342219  1.0852798 ],priors=[ 0.37  0.63]
     22 k=20,mus=[ 0.1598861   1.59623644],sigmas=[ 1.31596643  1.07593506],priors=[ 0.37  0.63]
     23 k=21,mus=[ 0.13245872  1.61275414],sigmas=[ 1.2979812   1.06659735],priors=[ 0.39  0.61]
     24 k=22,mus=[ 0.13549936  1.64420575],sigmas=[ 1.28163006  1.05238461],priors=[ 0.4  0.6]
     25 k=23,mus=[ 0.13362832  1.67212388],sigmas=[ 1.26763647  1.03761078],priors=[ 0.4  0.6]
     26 k=24,mus=[ 0.11750175  1.69125511],sigmas=[ 1.25330002  1.02525138],priors=[ 0.41  0.59]
     27 k=25,mus=[ 0.11224826  1.71494289],sigmas=[ 1.23950504  1.01230038],priors=[ 0.42  0.58]
     28 k=26,mus=[ 0.11153847  1.7395662 ],sigmas=[ 1.22728752  0.99889453],priors=[ 0.42  0.58]
     29 k=27,mus=[ 0.0999276   1.75644918],sigmas=[ 1.21474604  0.98770556],priors=[ 0.43  0.57]
     30 k=28,mus=[ 0.09911993  1.77770261],sigmas=[ 1.20375615  0.97601043],priors=[ 0.43  0.57]
     31 k=29,mus=[ 0.08991339  1.79234269],sigmas=[ 1.19274093  0.96620904],priors=[ 0.43  0.57]
     32 k=30,mus=[ 0.07854133  1.80401995],sigmas=[ 1.18163507  0.95803992],priors=[ 0.43  0.57]
     33 k=31,mus=[ 0.06708472  1.81391145],sigmas=[ 1.1708709  0.9510143],priors=[ 0.43  0.57]
     34 k=32,mus=[ 0.05629168  1.82248392],sigmas=[ 1.16077864  0.94483468],priors=[ 0.43  0.57]
     35 k=33,mus=[ 0.04644144  1.8299628 ],sigmas=[ 1.15153082  0.93934709],priors=[ 0.43  0.57]
     36 k=34,mus=[ 0.03761987  1.83648519],sigmas=[ 1.14319449  0.93447086],priors=[ 0.43  0.57]
     37 k=35,mus=[ 0.02982246  1.84215374],sigmas=[ 1.13577403  0.93015559],priors=[ 0.43  0.57]
     38 k=36,mus=[ 0.02299928  1.84705693],sigmas=[ 1.12923679  0.92636035],priors=[ 0.43  0.57]
     39 k=37,mus=[ 0.01707735  1.85127645],sigmas=[ 1.12352817  0.92304533],priors=[ 0.43  0.57]
     40 k=38,mus=[ 0.01197298  1.85488949],sigmas=[ 1.11858109  0.92016939],priors=[ 0.43  0.57]
     41 k=39,mus=[ 0.00759925  1.85796875],sigmas=[ 1.11432244  0.91769023],priors=[ 0.43  0.57]
     42 k=40,mus=[ 0.00387068  1.86058202],sigmas=[ 1.11067765  0.91556544],priors=[ 0.43  0.57]
     43 k=41,mus=[  7.06082311e-04   1.86279150e+00],sigmas=[ 1.10757393  0.91375369],priors=[ 0.43  0.57]
     44 k=42,mus=[-0.00196965  1.86465346],sigmas=[ 1.10494246  0.91221583],priors=[ 0.43  0.57]
     45 k=43,mus=[-0.00422464  1.86621814],sigmas=[ 1.10271971  0.91091554],priors=[ 0.43  0.57]
     46 k=44,mus=[-0.00611974  1.86752982],sigmas=[ 1.10084819  0.9098198 ],priors=[ 0.43  0.57]
     47 k=45,mus=[-0.00770859  1.86862714],sigmas=[ 1.09927671  0.90889909],priors=[ 0.43  0.57]
     48 k=46,mus=[-0.00903796  1.86954354],sigmas=[ 1.09796019  0.90812732],priors=[ 0.43  0.57]
     49 k=47,mus=[-0.01014832  1.87030773],sigmas=[ 1.09685943  0.90748172],priors=[ 0.43  0.57]
     50 k=48,mus=[-0.01107441  1.87094421],sigmas=[ 1.09594057  0.90694261],priors=[ 0.43  0.57]
     51 k=49,mus=[-0.01184586  1.87147378],sigmas=[ 1.09517461  0.90649307],priors=[ 0.43  0.57]
     52 k=50,mus=[-0.01248783  1.87191401],sigmas=[ 1.09453685  0.90611867],priors=[ 0.43  0.57]
     53 k=51,mus=[-0.01302159  1.87227973],sigmas=[ 1.09400634  0.90580718],priors=[ 0.43  0.57]
     54 k=52,mus=[-0.01346505  1.87258336],sigmas=[ 1.09356541  0.90554823],priors=[ 0.43  0.57]
     55 k=53,mus=[-0.01383328  1.87283531],sigmas=[ 1.09319917  0.90533313],priors=[ 0.43  0.57]
     56 k=54,mus=[-0.01413888  1.87304431],sigmas=[ 1.09289515  0.90515454],priors=[ 0.43  0.57]
     57 k=55,mus=[-0.0143924   1.87321761],sigmas=[ 1.09264288  0.90500635],priors=[ 0.43  0.57]
     58 k=56,mus=[-0.01460264  1.87336127],sigmas=[ 1.09243365  0.90488343],priors=[ 0.43  0.57]
     59 k=57,mus=[-0.01477693  1.87348033],sigmas=[ 1.09226016  0.9047815 ],priors=[ 0.43  0.57]
     60 k=58,mus=[-0.01492139  1.87357899],sigmas=[ 1.09211635  0.90469701],priors=[ 0.43  0.57]
     61 k=59,mus=[-0.0150411   1.87366073],sigmas=[ 1.09199717  0.90462698],priors=[ 0.43  0.57]
     62 k=60,mus=[-0.01514028  1.87372844],sigmas=[ 1.09189842  0.90456896],priors=[ 0.43  0.57]
     63 k=61,mus=[-0.01522245  1.87378452],sigmas=[ 1.09181661  0.90452088],priors=[ 0.43  0.57]
     64 k=62,mus=[-0.01529051  1.87383097],sigmas=[ 1.09174884  0.90448106],priors=[ 0.43  0.57]
     65 k=63,mus=[-0.01534687  1.87386944],sigmas=[ 1.0916927   0.90444807],priors=[ 0.43  0.57]
     66 k=64,mus=[-0.01539356  1.87390129],sigmas=[ 1.09164621  0.90442075],priors=[ 0.43  0.57]
     67 k=65,mus=[-0.01543222  1.87392767],sigmas=[ 1.09160771  0.90439813],priors=[ 0.43  0.57]
     68 k=66,mus=[-0.01546423  1.87394951],sigmas=[ 1.09157583  0.90437939],priors=[ 0.43  0.57]
     69 k=67,mus=[-0.01549074  1.87396759],sigmas=[ 1.09154943  0.90436388],priors=[ 0.43  0.57]
     70 k=68,mus=[-0.01551269  1.87398257],sigmas=[ 1.09152757  0.90435103],priors=[ 0.43  0.57]
     71 k=69,mus=[-0.01553086  1.87399496],sigmas=[ 1.09150947  0.9043404 ],priors=[ 0.43  0.57]
     72 k=70,mus=[-0.0155459   1.87400523],sigmas=[ 1.09149449  0.90433159],priors=[ 0.43  0.57]
     73 k=71,mus=[-0.01555836  1.87401373],sigmas=[ 1.09148208  0.9043243 ],priors=[ 0.43  0.57]
     74 k=72,mus=[-0.01556868  1.87402076],sigmas=[ 1.09147181  0.90431826],priors=[ 0.43  0.57]
     75 k=73,mus=[-0.01557722  1.87402659],sigmas=[ 1.0914633   0.90431327],priors=[ 0.43  0.57]
     76 k=74,mus=[-0.01558428  1.87403141],sigmas=[ 1.09145626  0.90430913],priors=[ 0.43  0.57]
     77 k=75,mus=[-0.01559014  1.8740354 ],sigmas=[ 1.09145043  0.9043057 ],priors=[ 0.43  0.57]
     78 k=76,mus=[-0.01559498  1.87403871],sigmas=[ 1.09144561  0.90430287],priors=[ 0.43  0.57]
     79 k=77,mus=[-0.01559899  1.87404144],sigmas=[ 1.09144161  0.90430052],priors=[ 0.43  0.57]
     80 k=78,mus=[-0.01560232  1.87404371],sigmas=[ 1.0914383   0.90429857],priors=[ 0.43  0.57]
     81 k=79,mus=[-0.01560506  1.87404558],sigmas=[ 1.09143556  0.90429696],priors=[ 0.43  0.57]
     82 k=80,mus=[-0.01560734  1.87404714],sigmas=[ 1.0914333   0.90429563],priors=[ 0.43  0.57]
     83 k=81,mus=[-0.01560923  1.87404842],sigmas=[ 1.09143142  0.90429453],priors=[ 0.43  0.57]
     84 k=82,mus=[-0.01561079  1.87404948],sigmas=[ 1.09142987  0.90429362],priors=[ 0.43  0.57]
     85 k=83,mus=[-0.01561208  1.87405037],sigmas=[ 1.09142858  0.90429286],priors=[ 0.43  0.57]
     86 k=84,mus=[-0.01561315  1.8740511 ],sigmas=[ 1.09142751  0.90429223],priors=[ 0.43  0.57]
     87 k=85,mus=[-0.01561403  1.8740517 ],sigmas=[ 1.09142663  0.90429172],priors=[ 0.43  0.57]
     88 k=86,mus=[-0.01561476  1.8740522 ],sigmas=[ 1.0914259   0.90429129],priors=[ 0.43  0.57]
     89 k=87,mus=[-0.01561537  1.87405261],sigmas=[ 1.0914253   0.90429093],priors=[ 0.43  0.57]
     90 k=88,mus=[-0.01561587  1.87405295],sigmas=[ 1.0914248   0.90429064],priors=[ 0.43  0.57]
     91 k=89,mus=[-0.01561629  1.87405324],sigmas=[ 1.09142438  0.90429039],priors=[ 0.43  0.57]
     92 k=90,mus=[-0.01561663  1.87405347],sigmas=[ 1.09142404  0.90429019],priors=[ 0.43  0.57]
     93 k=91,mus=[-0.01561692  1.87405367],sigmas=[ 1.09142376  0.90429003],priors=[ 0.43  0.57]
     94 k=92,mus=[-0.01561715  1.87405383],sigmas=[ 1.09142352  0.90428989],priors=[ 0.43  0.57]
     95 k=93,mus=[-0.01561735  1.87405396],sigmas=[ 1.09142333  0.90428977],priors=[ 0.43  0.57]
     96 k=94,mus=[-0.01561751  1.87405407],sigmas=[ 1.09142317  0.90428968],priors=[ 0.43  0.57]
     97 k=95,mus=[-0.01561764  1.87405416],sigmas=[ 1.09142303  0.9042896 ],priors=[ 0.43  0.57]
     98 k=96,mus=[-0.01561775  1.87405424],sigmas=[ 1.09142292  0.90428954],priors=[ 0.43  0.57]
     99 k=97,mus=[-0.01561785  1.8740543 ],sigmas=[ 1.09142283  0.90428948],priors=[ 0.43  0.57]
    100 k=98,mus=[-0.01561792  1.87405435],sigmas=[ 1.09142276  0.90428944],priors=[ 0.43  0.57]
    101 k=99,mus=[-0.01561799  1.8740544 ],sigmas=[ 1.09142269  0.9042894 ],priors=[ 0.43  0.57]
    102 
    103 Process finished with exit code 0

      In addition, a scikit-learn example can be found at http://scikit-learn.org/stable/modules/mixture.html

    [1] Ian Goodfellow. https://www.quora.com/Why-could-generative-models-help-with-unsupervised-learning/answer/Ian-Goodfellow?srid=hTUVm

  • 相关阅读:
    Sharepoint 2007 Forms认证与File Not Found错误
    完全控制SharePoint站点菜单(Get full control of SharePoint ActionMenus) Part 1
    从WSS 3.0到MOSS 2007
    如何备份sharepoint中的文档库?
    图片与文本的对齐方式
    backgroundimage 背景图片的设置
    css中三种隐藏方式
    font(字体)所使用的属性
    display属性
    margin中的bug解决方法
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8313163.html
Copyright © 2020-2023  润新知