• libsvm使用简介


    libsvm是support vector machine的一种开源实现,采用了smo算法。源代码编写有独到之处,值得一睹。

    常用结构

    svm_node结构

    定义了构成输入特征向量的元素,index为索引(= -1为最后一个元素),value为值,

    public class svm_node implements java.io.Serializable
    {
        public int index;
        public double value;
    }

     借鉴了稀疏矩阵的表示方法。对于一个输入向量,定义为svm_node构成的一维数组

    svm_node[] pa = {pa0, pa1};

     所有输入序列有一个二维数组表示

    svm_node[][] datas = {pa, pb};

    标记序列

    就是一个double数组,对应于输入序列datas的每一维。

    double[] labels = {1.0, -1.0};

    svm_problem结构

    定义了(X, Y)的训练样本结构

    public class svm_problem implements java.io.Serializable
    {
        public int l;
        public double[] y;
        public svm_node[][] x;
    }

    其中l是样本数量。

    svm_parameter结构

    定义了训练时的重要参数

    public class svm_parameter implements Cloneable,java.io.Serializable
    {
        /* svm_type */
        public static final int C_SVC = 0;
        public static final int NU_SVC = 1;
        public static final int ONE_CLASS = 2;
        public static final int EPSILON_SVR = 3;
        public static final int NU_SVR = 4;
    
        /* kernel_type */
        public static final int LINEAR = 0;
        public static final int POLY = 1;
        public static final int RBF = 2;
        public static final int SIGMOID = 3;
        public static final int PRECOMPUTED = 4;
    
        public int svm_type;
        public int kernel_type;
        public int degree;    // for poly
        public double gamma;    // for poly/rbf/sigmoid
        public double coef0;    // for poly/sigmoid
    
        // these are for training only
        public double cache_size; // in MB
        public double eps;    // stopping criteria
        public double C;    // for C_SVC, EPSILON_SVR and NU_SVR
        public int nr_weight;        // for C_SVC
        public int[] weight_label;    // for C_SVC
        public double[] weight;        // for C_SVC
        public double nu;    // for NU_SVC, ONE_CLASS, and NU_SVR
        public double p;    // for EPSILON_SVR
        public int shrinking;    // use the shrinking heuristics
        public int probability; // do probability estimates
    
        public Object clone() 
        {
            try 
            {
                return super.clone();
            } catch (CloneNotSupportedException e) 
            {
                return null;
            }
        }
    
    }

    主要分为两大类参数:分类器的核函数性质和训练算法SMO的一些参数,包括精度啊等等

    训练

    通过调用svm.svm_train()训练模型

    public static svm_model svm_train(svm_problem prob, svm_parameter param)

    返回svm_model类对象表示训练得到的分类器

    预测

    通过svm.svm_predict()利用分类器进行预测

    public static double svm_predict(svm_model model, svm_node[] x)

    返回类别标记

    实例代码如下,输入点pa = (10.0 10.0) ya = 1.0 pb = (-10.0, -10.0) yb = -1.0

    测试点 (-0.1, 0)

     1 import libsvm.svm;
     2 import libsvm.svm_model;
     3 import libsvm.svm_node;
     4 import libsvm.svm_parameter;
     5 import libsvm.svm_problem;
     6 
     7 public class SvmTest {
     8     public static void main(String[] args) {
     9         
    10         svm_node pa0 = new svm_node();
    11         pa0.index = 0;
    12         pa0.value = 10.0;
    13         
    14         svm_node pa1 = new svm_node();
    15         pa1.index = -1;
    16         pa1.value = 10.0;
    17         
    18         svm_node pb0 = new svm_node();
    19         pb0.index = 0;
    20         pb0.value = -10.0;
    21         
    22         svm_node pb1 = new svm_node();
    23         pb1.index = -1;
    24         pb1.value = -10.0;
    25         
    26         svm_node[] pa = {pa0, pa1};
    27         svm_node[] pb = {pb0, pb1};
    28         
    29         svm_node[][] datas = {pa, pb};
    30         
    31         double[] labels = {1.0, -1.0};
    32         
    33         svm_problem problem = new svm_problem();
    34         problem.l = 2;
    35         problem.x = datas;
    36         problem.y = labels;
    37         
    38         svm_parameter param = new svm_parameter();
    39         param.svm_type = svm_parameter.C_SVC;
    40         param.kernel_type = svm_parameter.LINEAR;
    41         param.cache_size = 100;
    42         param.eps = 0.00001;
    43         param.C = 1;
    44         
    45         
    46         System.out.println(svm.svm_check_parameter(problem, param));
    47         svm_model model = svm.svm_train(problem, param);
    48         
    49         svm_node pc0 = new svm_node();
    50         pc0.index = 0;
    51         pc0.value = -0.1;
    52         svm_node pc1 = new svm_node();
    53         pc1.index = -1;
    54         pc1.value = 0;
    55         
    56         svm_node[] pc = {pc0, pc1};
    57         
    58         System.out.println(svm.svm_predict(model, pc));
    59     }
    60 }
  • 相关阅读:
    LeetCode 43. 字符串相乘(Multiply Strings)
    LeetCode 541. 反转字符串 II(Reverse String II)
    枚举类型
    c#字母加密
    汇率兑换Python
    冒泡排序c#
    c#
    HTML
    日历
    Java2
  • 原文地址:https://www.cnblogs.com/zjgtan/p/3305720.html
Copyright © 2020-2023  润新知