脚本:
import numpy as np
import matplotlib.pyplot as plt
# 创建数据
np.random.seed(42)
X_xor = np.random.normal(0, 1, (300, 2))
y_xor = np.logical_xor(X_xor[:, 0] > 0 , X_xor[:, 1] > 0)
y_xor = np.where(y_xor, 1, -1)
# 绘制散点图
plt.scatter(x=X_xor[y_xor==1, 0], # 横轴坐标
y=X_xor[y_xor==1, 1], # 纵轴坐标
color='g', # green
marker='x',
label='1'
)
plt.scatter(x=X_xor[y_xor==-1, 0],
y=X_xor[y_xor==-1, 1],
color='r', # red
marker='s',
label='-1'
)
# 添加水平参考区域
plt.axhspan(ymin=-1,
ymax=1,
ls=':', # line style
facecolor='1',
edgecolor='k',
alpha=0.2,
)
# 添加垂直参区域
plt.axvspan(xmin=-1,
xmax=1,
ls=':',
facecolor='1',
edgecolor='k',
alpha=0.2,
)
# 显示图例
plt.legend(scatterpoints=2) # 图例中标记点的个数
plt.show()
图形: