• libsvm代码阅读:关于svm_train函数分析(转)


    在svm中,训练是一个十分重要的步骤,下面我们来看看svm的train部分。

    在libsvm中的svm_train中分别有回归和分类两部分,我只对其中分类做介绍。

    分类的步骤如下:

    • 统计类别总数,同时记录类别的标号,统计每个类的样本数目
    • 将属于相同类的样本分组,连续存放
    • 计算权重C
    • 训练n(n-1)/2 个模型
      • 初始化nozero数组,便于统计SV
      • //初始化概率数组
      • 训练过程中,需要重建子数据集,样本的特征不变,但样本的类别要改为+1/-1
      • //如有必要,先调用svm_binary_svc_probability
      • 训练子数据集svm_train_one
      • 统计一下nozero,如果nozero已经是真,就不变,否则改为真
    • 输出模型
      • 主要是填充svm_model
    • 清除内存

    函数中调用过程如下:

    svm_train-->svm_train_one-->solve_c_svc(for example)-->s.Solve

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. //  
    2. // Interface functions  
    3. //重点函数:svm训练函数  
    4. //根据选择的算法,来组织参加训练的分样本,以及进行训练结果的保存。其中会对样本进行初步的统计。  
    5. svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)  
    6. {  
    7.     svm_model *model = Malloc(svm_model,1);//#define Malloc(type,n) (type *)malloc((n)*sizeof(type))  
    8.     model->param = *param;  
    9.     model->free_sv = 0;  // XXX  
    10.   
    11.     if(param->svm_type == ONE_CLASS ||  
    12.        param->svm_type == EPSILON_SVR ||  
    13.        param->svm_type == NU_SVR)  
    14.     {  
    15.         // regression or one-class-svm  
    16.         model->nr_class = 2;  
    17.         model->label = NULL;  
    18.         model->nSV = NULL;  
    19.         model->probA = NULL; model->probB = NULL;  
    20.         model->sv_coef = Malloc(double *,1);  
    21.   
    22.         if(param->probability &&   
    23.            (param->svm_type == EPSILON_SVR ||  
    24.             param->svm_type == NU_SVR))  
    25.         {  
    26.             model->probA = Malloc(double,1);  
    27.             model->probA[0] = svm_svr_probability(prob,param);  
    28.         }  
    29.   
    30.         decision_function f = svm_train_one(prob,param,0,0);  
    31.         model->rho = Malloc(double,1);  
    32.         model->rho[0] = f.rho;  
    33.   
    34.         int nSV = 0;  
    35.         int i;  
    36.         for(i=0;i<prob->l;i++)  
    37.             if(fabs(f.alpha[i]) > 0) ++nSV;  
    38.         model->l = nSV;  
    39.         model->SV = Malloc(svm_node *,nSV);  
    40.         model->sv_coef[0] = Malloc(double,nSV);  
    41.         model->sv_indices = Malloc(int,nSV);  
    42.         int j = 0;  
    43.         for(i=0;i<prob->l;i++)  
    44.             if(fabs(f.alpha[i]) > 0)  
    45.             {  
    46.                 model->SV[j] = prob->x[i];  
    47.                 model->sv_coef[0][j] = f.alpha[i];  
    48.                 model->sv_indices[j] = i+1;  
    49.                 ++j;  
    50.             }         
    51.         free(f.alpha);  
    52.     }  
    53.     else  
    54.     {  
    55.         // classification  
    56.         int l = prob->l;  
    57.         int nr_class;  
    58.         int *label = NULL;  
    59.         int *start = NULL;  
    60.         int *count = NULL;  
    61.         int *perm = Malloc(int,l);  
    62.   
    63.         // group training data of the same class对训练样本进行处理,同类整合到一起  
    64.         svm_group_classes(prob,&nr_class,&label,&start,&count,perm);  
    65.         if(nr_class == 1)   
    66.             info("WARNING: training data in only one class. See README for details. ");  
    67.           
    68.         svm_node **x = Malloc(svm_node *,l);  
    69.         int i;  
    70.         for(i=0;i<l;i++)  
    71.             x[i] = prob->x[perm[i]];  
    72.   
    73.         // calculate weighted C  
    74.   
    75.         double *weighted_C = Malloc(double, nr_class);  
    76.         for(i=0;i<nr_class;i++)  
    77.             weighted_C[i] = param->C;  
    78.         for(i=0;i<param->nr_weight;i++)  
    79.         {     
    80.             int j;  
    81.             for(j=0;j<nr_class;j++)  
    82.                 if(param->weight_label[i] == label[j])  
    83.                     break;  
    84.             if(j == nr_class)  
    85.                 fprintf(stderr,"WARNING: class label %d specified in weight is not found ", param->weight_label[i]);  
    86.             else  
    87.                 weighted_C[j] *= param->weight[i];  
    88.         }  
    89.   
    90.         // train k*(k-1)/2 models  
    91.           
    92.         bool *nonzero = Malloc(bool,l);  
    93.         for(i=0;i<l;i++)  
    94.             nonzero[i] = false;  
    95.         decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);  
    96.   
    97.         double *probA=NULL,*probB=NULL;  
    98.         if (param->probability)  
    99.         {  
    100.             probA=Malloc(double,nr_class*(nr_class-1)/2);  
    101.             probB=Malloc(double,nr_class*(nr_class-1)/2);  
    102.         }  
    103.   
    104.         int p = 0;  
    105.         for(i=0;i<nr_class;i++)  
    106.             for(int j=i+1;j<nr_class;j++)  
    107.             {  
    108.                 svm_problem sub_prob;  
    109.                 int si = start[i], sj = start[j];  
    110.                 int ci = count[i], cj = count[j];  
    111.                 sub_prob.l = ci+cj;  
    112.                 sub_prob.x = Malloc(svm_node *,sub_prob.l);  
    113.                 sub_prob.y = Malloc(double,sub_prob.l);  
    114.                 int k;  
    115.                 for(k=0;k<ci;k++)  
    116.                 {  
    117.                     sub_prob.x[k] = x[si+k];  
    118.                     sub_prob.y[k] = +1;  
    119.                 }  
    120.                 for(k=0;k<cj;k++)  
    121.                 {  
    122.                     sub_prob.x[ci+k] = x[sj+k];  
    123.                     sub_prob.y[ci+k] = -1;  
    124.                 }  
    125.   
    126.                 if(param->probability)  
    127.                     svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);  
    128.   
    129.                 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);  
    130.                 for(k=0;k<ci;k++)  
    131.                     if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)  
    132.                         nonzero[si+k] = true;  
    133.                 for(k=0;k<cj;k++)  
    134.                     if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)  
    135.                         nonzero[sj+k] = true;  
    136.                 free(sub_prob.x);  
    137.                 free(sub_prob.y);  
    138.                 ++p;  
    139.             }  
    140.   
    141.         // build output  
    142.   
    143.         model->nr_class = nr_class;  
    144.           
    145.         model->label = Malloc(int,nr_class);  
    146.         for(i=0;i<nr_class;i++)  
    147.             model->label[i] = label[i];  
    148.           
    149.         model->rho = Malloc(double,nr_class*(nr_class-1)/2);  
    150.         for(i=0;i<nr_class*(nr_class-1)/2;i++)  
    151.             model->rho[i] = f[i].rho;  
    152.   
    153.         if(param->probability)  
    154.         {  
    155.             model->probA = Malloc(double,nr_class*(nr_class-1)/2);  
    156.             model->probB = Malloc(double,nr_class*(nr_class-1)/2);  
    157.             for(i=0;i<nr_class*(nr_class-1)/2;i++)  
    158.             {  
    159.                 model->probA[i] = probA[i];  
    160.                 model->probB[i] = probB[i];  
    161.             }  
    162.         }  
    163.         else  
    164.         {  
    165.             model->probA=NULL;  
    166.             model->probB=NULL;  
    167.         }  
    168.   
    169.         int total_sv = 0;  
    170.         int *nz_count = Malloc(int,nr_class);  
    171.         model->nSV = Malloc(int,nr_class);  
    172.         for(i=0;i<nr_class;i++)  
    173.         {  
    174.             int nSV = 0;  
    175.             for(int j=0;j<count[i];j++)  
    176.                 if(nonzero[start[i]+j])  
    177.                 {     
    178.                     ++nSV;  
    179.                     ++total_sv;  
    180.                 }  
    181.             model->nSV[i] = nSV;  
    182.             nz_count[i] = nSV;  
    183.         }  
    184.           
    185.         info("Total nSV = %d ",total_sv);  
    186.   
    187.         model->l = total_sv;  
    188.         model->SV = Malloc(svm_node *,total_sv);  
    189.         model->sv_indices = Malloc(int,total_sv);  
    190.         p = 0;  
    191.         for(i=0;i<l;i++)  
    192.             if(nonzero[i])  
    193.             {  
    194.                 model->SV[p] = x[i];  
    195.                 model->sv_indices[p++] = perm[i] + 1;  
    196.             }  
    197.   
    198.         int *nz_start = Malloc(int,nr_class);  
    199.         nz_start[0] = 0;  
    200.         for(i=1;i<nr_class;i++)  
    201.             nz_start[i] = nz_start[i-1]+nz_count[i-1];  
    202.   
    203.         model->sv_coef = Malloc(double *,nr_class-1);  
    204.         for(i=0;i<nr_class-1;i++)  
    205.             model->sv_coef[i] = Malloc(double,total_sv);  
    206.   
    207.         p = 0;  
    208.         for(i=0;i<nr_class;i++)  
    209.             for(int j=i+1;j<nr_class;j++)  
    210.             {  
    211.                 // classifier (i,j): coefficients with  
    212.                 // i are in sv_coef[j-1][nz_start[i]...],  
    213.                 // j are in sv_coef[i][nz_start[j]...]  
    214.   
    215.                 int si = start[i];  
    216.                 int sj = start[j];  
    217.                 int ci = count[i];  
    218.                 int cj = count[j];  
    219.                   
    220.                 int q = nz_start[i];  
    221.                 int k;  
    222.                 for(k=0;k<ci;k++)  
    223.                     if(nonzero[si+k])  
    224.                         model->sv_coef[j-1][q++] = f[p].alpha[k];  
    225.                 q = nz_start[j];  
    226.                 for(k=0;k<cj;k++)  
    227.                     if(nonzero[sj+k])  
    228.                         model->sv_coef[i][q++] = f[p].alpha[ci+k];  
    229.                 ++p;  
    230.             }  
    231.           
    232.         free(label);  
    233.         free(probA);  
    234.         free(probB);  
    235.         free(count);  
    236.         free(perm);  
    237.         free(start);  
    238.         free(x);  
    239.         free(weighted_C);  
    240.         free(nonzero);  
    241.         for(i=0;i<nr_class*(nr_class-1)/2;i++)  
    242.             free(f[i].alpha);  
    243.         free(f);  
    244.         free(nz_count);  
    245.         free(nz_start);  
    246.     }  
    247.     return model;  
    248. }  
  • 相关阅读:
    判断客户端类型
    关于element-ui select组件change事件只要数据变化就会触发的解决办法
    log4net的基本配置及用法
    WCF自定义扩展,以实现aop!
    继承IDbConnection连接不同数据库
    MVC中你必须知道的13个扩展点
    Sql导出数据报错-->SQL Server 阻止了对组件 'Ad Hoc Distributed Queries' 的 STATEMENT'OpenRowset/OpenDatasource' 的访问
    "当前方法的代码已经过优化,无法计算表达式的值"的这个错误的解决方案!!!
    WCF中安全的那些事!!!
    Linq to sql 的语法
  • 原文地址:https://www.cnblogs.com/Miliery/p/4394151.html
Copyright © 2020-2023  润新知