• SVM回归


    SVM回归任务是限制间隔违规情况下,尽量防止更多的样本在“街道”上。“街道”的宽度由超参数(epsilon)控制
    在随机生成的线性数据上,两个线性SVM回归模型,一个有较大的间隔((epsilon=1.5)),另一个间隔较小((epsilon=0.5)),训练情况如下:
    代码如下:
    造数据与训练:

    np.random.seed(42)
    m = 50
    X = 2 * np.random.randn(m,1)
    y = (4 + 3 * X + np.random.randn(m,1)).ravel()
    
    from sklearn.svm import LinearSVR
    
    svm_reg1 = LinearSVR(epsilon=1.5, random_state=42)
    svm_reg2 = LinearSVR(epsilon=0.5, random_state=42)
    
    svm_reg1.fit(X, y)
    svm_reg2.fit(X,y)
    

    可视化编码

    def find_support_vectors(svm_reg, X, y):
        y_pred = svm_reg.predict(X)
        off_margin = (np.abs(y - y_pred) >= svm_reg.epsilon)
        return np.argwhere(off_margin)
    
    svm_reg1.support_ = find_support_vectors(svm_reg1, X, y)
    svm_reg2.support_ = find_support_vectors(svm_reg2, X, y)
    
    eps_x1 = 1
    eps_y_pred = svm_reg1.predict([[eps_x1]])
    def plot_svm_regression(svm_reg, X, y, axes):
        x1s = np.linspace(axes[0], axes[1], 100).reshape(100, 1)
        y_pred = svm_reg.predict(x1s)
        plt.plot(x1s, y_pred, "k-", linewidth=2, label=r"$hat{y}$")
        plt.plot(x1s, y_pred + svm_reg.epsilon, "k--")
        plt.plot(x1s, y_pred - svm_reg.epsilon, "k--")
        plt.scatter(X[svm_reg.support_], y[svm_reg.support_], s=180, facecolors='#FFAAAA')
        plt.plot(X, y, "bo")
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.legend(loc="upper left", fontsize=18)
        plt.axis(axes)
    
    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_reg1, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    #plt.plot([eps_x1, eps_x1], [eps_y_pred, eps_y_pred - svm_reg1.epsilon], "k-", linewidth=2)
    plt.annotate(
            '', xy=(eps_x1, eps_y_pred), xycoords='data',
            xytext=(eps_x1, eps_y_pred - svm_reg1.epsilon),
            textcoords='data', arrowprops={'arrowstyle': '<->', 'linewidth': 1.5}
        )
    plt.text(0.91, 5.6, r"$epsilon$", fontsize=20)
    plt.subplot(122)
    plot_svm_regression(svm_reg2, X, y, [0, 2, 3, 11])
    plt.title(r"$epsilon = {}$".format(svm_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

    非线性拟合

    造数据

    np.random.seed(42)
    m = 100
    X = 2 * np.random.rand(m, 1) - 1
    y = (0.2 + 0.1 * X + 0.5 * X**2 + np.random.randn(m, 1)/10).ravel()
    
    from sklearn.svm import SVR
    
    from sklearn.svm import SVR
    
    svm_poly_reg1 = SVR(kernel="poly", degree=2, C=100, epsilon=0.1, gamma="auto")
    svm_poly_reg2 = SVR(kernel="poly", degree=2, C=0.01, epsilon=0.1, gamma="auto")
    svm_poly_reg1.fit(X, y)
    svm_poly_reg2.fit(X, y)
    

    可视化编程

    plt.figure(figsize=(9, 4))
    plt.subplot(121)
    plot_svm_regression(svm_poly_reg1, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg1.degree, svm_poly_reg1.C, svm_poly_reg1.epsilon), fontsize=18)
    plt.ylabel(r"$y$", fontsize=18, rotation=0)
    plt.subplot(122)
    plot_svm_regression(svm_poly_reg2, X, y, [-1, 1, 0, 1])
    plt.title(r"$degree={}, C={}, epsilon = {}$".format(svm_poly_reg2.degree, svm_poly_reg2.C, svm_poly_reg2.epsilon), fontsize=18)
    
    plt.show()
    

    可视化展示:

  • 相关阅读:
    D365FO Debug找不到w3cp进程
    D365FO 10.0PU32 开发环境 Data Management导出失败
    一张图看懂项目管理
    用户体验为什么重要?如何提升产品的用户体验?(写给产品小白)
    敏捷考证?你应该知道的敏捷体系认证(最全名单)
    漫画:禅道程序员的一天
    敏捷开发管理--任务分解经验之谈
    漫画:优秀程序员的必备特质有哪些?
    漫画:女生/男生告白攻略
    漫画:程序员脱单秘籍
  • 原文地址:https://www.cnblogs.com/whiteBear/p/13096398.html
Copyright © 2020-2023  润新知