传入参数:
1)plt:调用对象中的pyplot。
如:import matplotlib.pyplot as plt
2)predict:调用对象中ML算法的predict函数,用于预测对X,Y构造网格后的预测。
如:clf = neighbors.KNeighborsClassifier(n_neighbors=15, weights='distance')
3)X, Y:绘图的参数,shape:nx1,n,1
4) classes_color:颜色列表,
如:classes_color=['#FFAAAA', '#AAFFAA', '#AAAAFF','#00000']
5)step=0.05 ,网格细分长度
import numpy as np
from matplotlib.colors import ListedColormap
cmpcolor = ['#FFAAAA', '#AAFFAA', '#AAAAFF']
def create_meshgrid_pic(plt, predict, X, Y, classes_color=cmpcolor, step=0.05):
# 确认训练集的边界
x_min, x_max = X[:].min() - 1, X[:].max() + 1
y_min, y_max = Y[:].min() - 1, Y[:].max() + 1
# 生成网格数据,xx:所有网格点的x坐标,形状也是网格性nxm。yy同样
xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
np.arange(y_min, y_max, step))
# xx,yy的扁平化成一串坐标点(密密麻麻的网格点平摊开来)
d = np.c_[xx.ravel(), yy.ravel()]
# 对网格点进行类型预测
Z = predict(d)
# 预测类型后,重新变回网格的样子,因为后面pcolormesh接收网格形式的绘图数据
Z = Z.reshape(xx.shape)
# 获取类型数量
class_size = np.unique(Z).size
if class_size > len(classes_color):
print('颜色列表太少')
return AttributeError
classes_color = classes_color[:class_size]
cmap_light = ListedColormap(classes_color)
# 接收网格化的x,y,z
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
使用:
import matplotlib.pyplot as plt
from sklearn import neighbors
from **** import create_meshgrid_pic
X, Y = ()
clf = neighbors.KNeighborsClassifier(n_neighbors=15, weights='distance')
clf.fit(X, Y)
cmap_light = ['#FFAAAA', '#AAFFAA', '#AF0000']
create_meshgrid_pic(plt, clf, X[:, 0], X[:, 1], cmap_light, 0.02)
plt.show()