▶ EM 算法的引入,三硬币问题,体验一下不同初始值对收敛点的影响
● 代码
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.patches import Rectangle 4 5 dataSize = 1000 6 trainDataRatio = 0.3 7 defaultTurn = 20 8 epsilon = 1E-10 9 randomSeed = 103 10 11 def dataSplit(dataY, part): # 将数据集分割为训练集和测试集 12 return dataY[:part], dataY[part:] 13 14 def createData(realA, realB, realC, count = dataSize): # 创建数据 15 np.random.seed(randomSeed) 16 a = (np.random.rand(count) > realA).astype(int) 17 b = (np.random.rand(count) > realB).astype(int) 18 c = (np.random.rand(count) > realC).astype(int) 19 return b * (1 - a) + c * a 20 21 def em(dataY, initialA, initialB, initialC, turn = defaultTurn):# 迭代计算 22 count = len(dataY) 23 sumY = np.sum(dataY) 24 a = initialA 25 b = initialB 26 c = initialC 27 for i in range(turn): 28 p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) ) 29 sumP = np.sum(p) 30 a = sumP / count 31 b = np.sum(p * dataY) / sumP 32 c = (sumY - np.sum(p * dataY)) / (count - sumP) 33 return a, b, c 34 35 def test(realA, realB, realC, initialA, initialB, initialC): # 单次测试 36 Y = createData(realA, realB, realC) 37 38 para = em(Y, initialA, initialB, initialC) 39 40 print( "real=(%.3f, %.3f, %.3f),initial=(%.3f,%.3f,%.3f),train=(%.3f,%.3f,%.3f)"%(realA, realB, realC, initialA,initialB,initialC,para[0],para[1],para[2]) ) 41 42 if __name__ == '__main__': 43 test(0.5, 0.5, 0.5, 0.5, 0.5, 0.5) 44 test(0.5, 0.5, 0.5, epsilon, epsilon, epsilon) 45 test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon) 46 test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon) 47 test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5) 48 test(0.5, 0.5, 0.5, 1.0 - epsilon, epsilon, epsilon) 49 test(0.5, 0.5, 0.5, epsilon, 1.0 - epsilon, epsilon) 50 test(0.5, 0.5, 0.5, epsilon, epsilon, 1.0 - epsilon) 51 test(0.5, 0.5, 0.5, 1.0 - epsilon, 1.0 - epsilon, 1.0 - epsilon) 52 53 test(0.4, 0.5, 0.6, 0.4, 0.5, 0.6) 54 test(0.4, 0.5, 0.6, epsilon, epsilon, epsilon) 55 test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon) 56 test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon) 57 test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5)
● 输出结果,从不同的真实值和初始值得到不同的收敛点
real=(0.500, 0.500, 0.500),initial=(0.500,0.500,0.500),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.000),train=(0.000,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516) real=(0.500, 0.500, 0.500),initial=(1.000,0.000,0.000),train=(1.000,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,1.000,0.000),train=(0.258,1.000,0.348) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,1.000),train=(0.242,0.000,0.681) real=(0.500, 0.500, 0.500),initial=(1.000,1.000,1.000),train=(1.000,0.516,0.516) real=(0.400, 0.500, 0.600),initial=(0.400,0.500,0.600),train=(0.409,0.406,0.506) real=(0.400, 0.500, 0.600),initial=(0.000,0.000,0.000),train=(0.000,0.465,0.465) real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516)
● 画图,散点位置表示初始取值,散点颜色 RGB 值表示收敛点取值。各图依次为:(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 20 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 100 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.05,迭代 20 次),(真实值 ( 0.3,0.6,0.8 ),初始间隔 0.1,迭代 20 次)。可见:① 迭代 20 次以后就基本稳定了,更多次数迭代没有明显影响;② 随着初始点的连续移动,收敛点的取值耶连续漂移,没有出现明显断层;③ 图中色彩饱和度较高的散点存在,说明收敛点并不能向真实值点明显靠拢,甚至有可能保持极端取值;④ 真实值点对收敛点在整个空间上的取值有影响(废话)
● 画图脚本
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.patches import Rectangle 4 from mpl_toolkits.mplot3d import Axes3D 5 from mpl_toolkits.mplot3d.art3d import Poly3DCollection 6 7 dataSize = 1000 8 trainDataRatio = 0.3 9 defaultTurn = 20 10 epsilon = 1E-10 11 randomSeed = 103 12 13 def dataSplit(dataY, part): 14 return dataY[:part], dataY[part:] 15 16 def myColor(x): 17 r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0]) 18 g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0]) 19 b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0]) 20 return [r,g,b] 21 22 def createData(realA, realB, realC, count = dataSize): 23 np.random.seed(randomSeed) 24 a = (np.random.rand(count) > realA).astype(int) 25 b = (np.random.rand(count) > realB).astype(int) 26 c = (np.random.rand(count) > realC).astype(int) 27 return b * (1 - a) + c * a 28 29 def em(dataY, initialA, initialB, initialC, turn = defaultTurn): 30 count = len(dataY) 31 sumY = np.sum(dataY) 32 a = initialA 33 b = initialB 34 c = initialC 35 for i in range(turn): 36 p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) ) 37 sumP = np.sum(p) 38 a = sumP / count 39 b = np.sum(p * dataY) / sumP 40 c = (sumY - np.sum(p * dataY)) / (count - sumP) 41 return a, b, c 42 43 def test(realA, realB, realC): 44 dataY = createData(realA, realB, realC) 45 XX, YY, ZZ = np.meshgrid(np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1)) 46 #XX, YY = np.meshgrid(np.arange(0.05,1.00,0.05), np.arange(0.05,1.00,0.05)) # 一个斜截平面 47 #ZZ = ( 9 - 5 * XX - 4 * YY ) / 12 48 49 fig = plt.figure(figsize=(10, 8)) 50 ax = Axes3D(fig) 51 ax.set_xlim3d(0.0, 1.0) 52 ax.set_ylim3d(0.0, 1.0) 53 ax.set_zlim3d(0.0, 1.0) 54 ax.set_xlabel('X', fontdict={'size': 15, 'color': 'r'}) 55 ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'g'}) 56 ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'b'}) 57 58 for xyz in zip(XX.flatten(),YY.flatten(),ZZ.flatten()): 59 para = em(dataY, xyz[0], xyz[1], xyz[2]) 60 para = np.minimum(np.maximum(np.array(para),0),1) 61 ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = list(para), s = 20, label = "P") 62 #ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = myColor( np.sum((np.array(para) - np.array([realA,realB,realC]))**2)), s = 20, label = "P") 63 64 fig.savefig("R:\(" + str(round(realA,3)) + "," + str(round(realB,3)) + "," + str(round(realC,3)) + ").png") 65 plt.close() 66 67 if __name__ == '__main__': 68 test(0.5, 0.5, 0.5)
● EM 算法用于高斯混合模型,代码
1 import numpy as np 2 import scipy as sp 3 import matplotlib.pyplot as plt 4 from matplotlib.patches import Rectangle 5 6 dataSize = 1000 7 trainDataRatio = 0.3 8 defaultTurn = 100 9 epsilon = 1E-5 10 randomSeed = 103 11 12 def dataSplit(dataY, part): # 将数据集分割为训练集和测试集 13 return dataY[:part], dataY[part:] 14 15 def normalCDF(x, μList, σList): 16 return np.exp(-(x - μList)**2 / (2 * σList**2)) / (np.sqrt(2 * np.pi) * σList) 17 18 def createData(ndistributionCount, count = dataSize): # 创建数据 19 np.random.seed(randomSeed) 20 X = np.random.randn(count, ndistributionCount) 21 μ = np.cumsum(np.random.rand(ndistributionCount)) 22 σ = np.random.rand(ndistributionCount) 23 α = np.random.rand(ndistributionCount) 24 α /= np.sum(α) 25 return np.sum(α * (σ * X + μ), 1), μ, σ, α 26 27 def em(dataY, ndistributionCount, turn = defaultTurn): # 迭代计算 28 count = len(dataY) 29 Y = np.tile(dataY,[ndistributionCount,1]).T 30 μ = np.random.rand(ndistributionCount) 31 σ = np.random.rand(ndistributionCount) 32 α = np.random.rand(ndistributionCount) 33 α /= np.sum(α) 34 35 for i in range(turn): 36 p = np.mat(α * normalCDF(Y, μ, σ)) 37 p = np.array( p / np.sum(p, 1) ) 38 sumP = np.sum(p, 0) 39 μ = np.sum( p * Y ,0) / sumP 40 σ = np.sqrt(np.sum( p * (Y - μ)**2 , 0) / sumP) 41 α = sumP / count 42 return μ, σ, α 43 44 def test(ndistributionCount): # 单次测试 45 dataY, μ, σ, α = createData(ndistributionCount) 46 47 μOut, σOut, αOut = em(dataY, ndistributionCount) 48 49 print("ndistributionCount = " + str(ndistributionCount)) 50 print("originμ = ", μ) 51 print("train μ = ", μOut) 52 print("originσ = ", σ) 53 print("train σ = ", σOut) 54 print("originα = ", α) 55 print("train α = ", αOut) 56 57 if __name__ == '__main__': 58 test(1) 59 test(2) 60 test(3) 61 test(4) 62 test(5)
● 输出结果,似乎只有一元的时候收敛,代码有点问题【坑】
ndistributionCount = 1 originμ = [0.67175814] train μ = [0.67842327] originσ = [0.14955569] train σ = [0.14782499] originα = [1.] train α = [1.] ndistributionCount = 2 originμ = [0.41731305 1.03497633] train μ = [0.71810584 0.80326904] originσ = [0.8746775 0.54866726] train σ = [0.60543201 0.46134878] originα = [0.4609127 0.5390873] train α = [0.3065564 0.6934436] ndistributionCount = 3 originμ = [0.56854648 1.11932014 1.58967158] train μ = [1.1911365 1.14824762 1.20370434] originσ = [0.20615474 0.13178869 0.09097129] train σ = [0.07795558 0.08028857 0.07615437] originα = [0.21194169 0.38980503 0.39825329] train α = [0.71143915 0.01148389 0.27707696] ndistributionCount = 4 originμ = [0.06525055 0.76467489 1.36233954 2.27413522] train μ = [0.89375387 1.32709678 1.28927068 0.99696908] originσ = [0.04627714 0.05849647 0.88877231 0.57707149] train σ = [0.248415 0.26329343 0.25491932 0.12993855] originα = [0.0825736 0.39384943 0.2717186 0.25185837] train α = [0.08450359 0.41897111 0.44427471 0.05225059] ndistributionCount = 5 originμ = [0.55679028 1.13666279 1.88269851 2.33842668 2.65599906] train μ = [1.71404135 1.14266766 1.61215492 0.9602133 1.58427791] originσ = [0.51241842 0.49056236 0.14953623 0.57604303 0.98916623] train σ = [0.22736403 0.14251132 0.26561206 0.02144093 0.2710761 ] originα = [0.22618106 0.29204021 0.10173233 0.21445053 0.16559587] train α = [0.05202092 0.03445172 0.59535521 0.0073266 0.31084555]