• libsvm代码阅读:关于Solver类分析(二)(转)


    如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了。

    下面先贴出它的类定义,一些成员函数的具体实现先忽略。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918  
    2. // Solves:  
    3. //  min 0.5(alpha^T Q alpha) + p^T alpha  
    4. //  
    5. //      y^T alpha = delta  
    6. //      y_i = +1 or -1  
    7. //      0 <= alpha_i <= Cp for y_i = 1  
    8. //      0 <= alpha_i <= Cn for y_i = -1  
    9. //  
    10. // Given:  
    11. //  Q, p, y, Cp, Cn, and an initial feasible point alpha  
    12. //  l is the size of vectors and matrices  
    13. //  eps is the stopping tolerance  
    14. // solution will be put in alpha, objective value will be put in obj  
    15. //  
    16. class Solver {  
    17. public:  
    18.     Solver() {};  
    19.     virtual ~Solver() {};//用虚析构函数的原因是:保证根据实际运行适当的析构函数  
    20.   
    21.     struct SolutionInfo {  
    22.         double obj;  
    23.         double rho;  
    24.         double upper_bound_p;  
    25.         double upper_bound_n;  
    26.         double r;   // for Solver_NU  
    27.     };  
    28.   
    29.     void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
    30.            double *alpha_, double Cp, double Cn, double eps,  
    31.            SolutionInfo* si, int shrinking);  
    32. protected:  
    33.     int active_size;//计算时实际参加运算的样本数目,经过shrink处理后,该数目小于全部样本数  
    34.     schar *y;       //样本所属类别,该值只能取-1或+1。  
    35.     double *G;      // gradient of objective function = (Q alpha + p)  
    36.     enum { LOWER_BOUND, UPPER_BOUND, FREE };  
    37.     char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE   
    38.     double *alpha;      //  
    39.     const QMatrix *Q;     
    40.     const double *QD;  
    41.     double eps;     //误差限  
    42.     double Cp,Cn;  
    43.     double *p;  
    44.     int *active_set;  
    45.     double *G_bar;      // gradient, if we treat free variables as 0  
    46.     int l;  
    47.     bool unshrink;  // XXX  
    48.     //返回对应于样本的C。设置不同的Cp和Cn是为了处理数据的不平衡  
    49.     double get_C(int i)  
    50.     {  
    51.         return (y[i] > 0)? Cp : Cn;  
    52.     }  
    53.   
    54.     void update_alpha_status(int i)  
    55.     {  
    56.         if(alpha[i] >= get_C(i))  
    57.             alpha_status[i] = UPPER_BOUND;  
    58.         else if(alpha[i] <= 0)  
    59.             alpha_status[i] = LOWER_BOUND;  
    60.         else alpha_status[i] = FREE;  
    61.     }  
    62.     bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }  
    63.     bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }  
    64.     bool is_free(int i) { return alpha_status[i] == FREE; }  
    65.     void swap_index(int i, int j);//交换样本i和j的内容,包括申请的内存的地址  
    66.     void reconstruct_gradient();  //重新计算梯度。  
    67.     virtual int select_working_set(int &i, int &j);//选择工作集  
    68.     virtual double calculate_rho();  
    69.     virtual void do_shrinking();//对样本集做缩减。  
    70. private:  
    71.     bool be_shrunk(int i, double Gmax1, double Gmax2);    
    72. };  

    下面我们来看看SMO如何选择工作集(working set B),选择的约束如下:

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. // return i,j such that  
    2. // i: maximizes -y_i * grad(f)_i, i in I_up(alpha)  
    3. // j: minimizes the decrease of obj value  
    4. //    (if quadratic coefficeint <= 0, replace it with tau)  
    5. //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(alpha)  

    论文中的公式如下:

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. int Solver::select_working_set(int &out_i, int &out_j)  
    2. {  
    3.     // return i,j such that  
    4.     // i: maximizes -y_i * grad(f)_i, i in I_up(alpha)  
    5.     // j: minimizes the decrease of obj value  
    6.     //    (if quadratic coefficeint <= 0, replace it with tau)  
    7.     //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(alpha)  
    8. //select i    
    9.     double Gmax = -INF;  
    10.     double Gmax2 = -INF;  
    11.     int Gmax_idx = -1;  
    12.     int Gmin_idx = -1;  
    13.     double obj_diff_min = INF;  
    14.   
    15.     for(int t=0;t<active_size;t++)  
    16.         if(y[t]==+1)    //若类别为1  
    17.         {  
    18.             if(!is_upper_bound(t))//若alpha<C  
    19.                 if(-G[t] >= Gmax)  
    20.                 {  
    21.                     Gmax = -G[t];// -y[t]*G[t]=-1*G[t]  
    22.                     Gmax_idx = t;  
    23.                 }  
    24.         }  
    25.         else  
    26.         {  
    27.             if(!is_lower_bound(t))  
    28.                 if(G[t] >= Gmax)  
    29.                 {  
    30.                     Gmax = G[t];  
    31.                     Gmax_idx = t;  
    32.                 }  
    33.         }  
    34.   
    35.     int i = Gmax_idx;  
    36.     const Qfloat *Q_i = NULL;  
    37.     if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1  
    38.         Q_i = Q->get_Q(i,active_size);  
    39. //select j  
    40.     for(int j=0;j<active_size;j++)  
    41.     {  
    42.         if(y[j]==+1)  
    43.         {  
    44.             if (!is_lower_bound(j))  
    45.             {  
    46.                 double grad_diff=Gmax+G[j];  
    47.                 if (G[j] >= Gmax2)  
    48.                     Gmax2 = G[j];  
    49.                 if (grad_diff > 0)  
    50.                 {  
    51.                     double obj_diff;   
    52.                     double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];  
    53.                     if (quad_coef > 0)  
    54.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
    55.                     else  
    56.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
    57.   
    58.                     if (obj_diff <= obj_diff_min)  
    59.                     {  
    60.                         Gmin_idx=j;  
    61.                         obj_diff_min = obj_diff;  
    62.                     }  
    63.                 }  
    64.             }  
    65.         }  
    66.         else  
    67.         {  
    68.             if (!is_upper_bound(j))  
    69.             {  
    70.                 double grad_diff= Gmax-G[j];  
    71.                 if (-G[j] >= Gmax2)  
    72.                     Gmax2 = -G[j];  
    73.                 if (grad_diff > 0)  
    74.                 {  
    75.                     double obj_diff;   
    76.                     double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];  
    77.                     if (quad_coef > 0)  
    78.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
    79.                     else  
    80.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
    81.   
    82.                     if (obj_diff <= obj_diff_min)  
    83.                     {  
    84.                         Gmin_idx=j;  
    85.                         obj_diff_min = obj_diff;  
    86.                     }  
    87.                 }  
    88.             }  
    89.         }  
    90.     }  
    91.   
    92.     if(Gmax+Gmax2 < eps)  
    93.         return 1;  
    94.   
    95.     out_i = Gmax_idx;  
    96.     out_j = Gmin_idx;  
    97.     return 0;  
    98. }  

    配合上面几个公式看,这段代码还是很清晰了。

    下面来看看它的构造函数,这个构造函数是solver类的核心。这个算法也结合上一篇博文的algorithm2来看。其中要注意的是get_Q是获取核函数。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
    2.            double *alpha_, double Cp, double Cn, double eps,  
    3.            SolutionInfo* si, int shrinking)  
    4. {  
    5.     this->l = l;  
    6.     this->Q = &Q;  
    7.     QD=Q.get_QD();//这个是获取核函数(如果分类的话在SVC_Q中定义)  
    8.   
    9.     clone(p, p_,l);  
    10.     clone(y, y_,l);  
    11.     clone(alpha,alpha_,l);  
    12.   
    13.     this->Cp = Cp;  
    14.     this->Cn = Cn;  
    15.     this->eps = eps;  
    16.     unshrink = false;  
    17.   
    18.     // initialize alpha_status  
    19.     {  
    20.         alpha_status = new char[l];  
    21.         for(int i=0;i<l;i++)  
    22.             update_alpha_status(i);  
    23.     }  
    24.   
    25.     // initialize active set (for shrinking)  
    26.     {  
    27.         active_set = new int[l];  
    28.         for(int i=0;i<l;i++)  
    29.             active_set[i] = i;  
    30.         active_size = l;  
    31.     }  
    32.   
    33.     // initialize gradient  
    34.     {  
    35.         G = new double[l];  
    36.         G_bar = new double[l];  
    37.         int i;  
    38.         for(i=0;i<l;i++)  
    39.         {  
    40.             G[i] = p[i];  
    41.             G_bar[i] = 0;  
    42.         }  
    43.         for(i=0;i<l;i++)  
    44.             if(!is_lower_bound(i))  
    45.             {  
    46.                 const Qfloat *Q_i = Q.get_Q(i,l);  
    47.                 double alpha_i = alpha[i];  
    48.                 int j;  
    49.                 for(j=0;j<l;j++)  
    50.                     G[j] += alpha_i*Q_i[j];  
    51.                 if(is_upper_bound(i))  
    52.                     for(j=0;j<l;j++)  
    53.                         G_bar[j] += get_C(i) * Q_i[j]; //这里见文献LIBSVM: A Library for SVM公式(33)  
    54.             }  
    55.     }  
    56.   
    57.     // optimization step  
    58.   
    59.     int iter = 0;  
    60.     int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);  
    61.     int counter = min(l,1000)+1;  
    62.       
    63.     while(iter < max_iter)  
    64.     {  
    65.         // show progress and do shrinking  
    66.   
    67.         if(--counter == 0)  
    68.         {  
    69.             counter = min(l,1000);  
    70.             if(shrinking) do_shrinking();    
    71.             info(".");  
    72.         }  
    73.   
    74.         int i,j;  
    75.         if(select_working_set(i,j)!=0)  
    76.         {  
    77.             // reconstruct the whole gradient  
    78.             reconstruct_gradient();  
    79.             // reset active set size and check  
    80.             active_size = l;  
    81.             info("*");  
    82.             if(select_working_set(i,j)!=0)  
    83.                 break;  
    84.             else  
    85.                 counter = 1;    // do shrinking next iteration  
    86.         }  
    87.           
    88.         ++iter;  
    89.   
    90.         // update alpha[i] and alpha[j], handle bounds carefully  
    91.           
    92.         const Qfloat *Q_i = Q.get_Q(i,active_size);  
    93.         const Qfloat *Q_j = Q.get_Q(j,active_size);  
    94.   
    95.         double C_i = get_C(i);  
    96.         double C_j = get_C(j);  
    97.   
    98.         double old_alpha_i = alpha[i];  
    99.         double old_alpha_j = alpha[j];  
    100.   
    101.         if(y[i]!=y[j])  
    102.         {  
    103.             double quad_coef = QD[i]+QD[j]+2*Q_i[j];  
    104.             if (quad_coef <= 0)  
    105.                 quad_coef = TAU;  
    106.             double delta = (-G[i]-G[j])/quad_coef;  
    107.             double diff = alpha[i] - alpha[j];  
    108.             alpha[i] += delta;  
    109.             alpha[j] += delta;  
    110.               
    111.             if(diff > 0)  
    112.             {  
    113.                 if(alpha[j] < 0)  
    114.                 {  
    115.                     alpha[j] = 0;  
    116.                     alpha[i] = diff;  
    117.                 }  
    118.             }  
    119.             else  
    120.             {  
    121.                 if(alpha[i] < 0)  
    122.                 {  
    123.                     alpha[i] = 0;  
    124.                     alpha[j] = -diff;  
    125.                 }  
    126.             }  
    127.             if(diff > C_i - C_j)  
    128.             {  
    129.                 if(alpha[i] > C_i)  
    130.                 {  
    131.                     alpha[i] = C_i;  
    132.                     alpha[j] = C_i - diff;  
    133.                 }  
    134.             }  
    135.             else  
    136.             {  
    137.                 if(alpha[j] > C_j)  
    138.                 {  
    139.                     alpha[j] = C_j;  
    140.                     alpha[i] = C_j + diff;  
    141.                 }  
    142.             }  
    143.         }  
    144.         else  
    145.         {  
    146.             double quad_coef = QD[i]+QD[j]-2*Q_i[j];  
    147.             if (quad_coef <= 0)  
    148.                 quad_coef = TAU;  
    149.             double delta = (G[i]-G[j])/quad_coef;  
    150.             double sum = alpha[i] + alpha[j];  
    151.             alpha[i] -= delta;  
    152.             alpha[j] += delta;  
    153.   
    154.             if(sum > C_i)  
    155.             {  
    156.                 if(alpha[i] > C_i)  
    157.                 {  
    158.                     alpha[i] = C_i;  
    159.                     alpha[j] = sum - C_i;  
    160.                 }  
    161.             }  
    162.             else  
    163.             {  
    164.                 if(alpha[j] < 0)  
    165.                 {  
    166.                     alpha[j] = 0;  
    167.                     alpha[i] = sum;  
    168.                 }  
    169.             }  
    170.             if(sum > C_j)  
    171.             {  
    172.                 if(alpha[j] > C_j)  
    173.                 {  
    174.                     alpha[j] = C_j;  
    175.                     alpha[i] = sum - C_j;  
    176.                 }  
    177.             }  
    178.             else  
    179.             {  
    180.                 if(alpha[i] < 0)  
    181.                 {  
    182.                     alpha[i] = 0;  
    183.                     alpha[j] = sum;  
    184.                 }  
    185.             }  
    186.         }  
    187.   
    188.         // update G  
    189.   
    190.         double delta_alpha_i = alpha[i] - old_alpha_i;  
    191.         double delta_alpha_j = alpha[j] - old_alpha_j;  
    192.           
    193.         for(int k=0;k<active_size;k++)  
    194.         {  
    195.             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;  
    196.         }  
    197.   
    198.         // update alpha_status and G_bar  
    199.   
    200.         {  
    201.             bool ui = is_upper_bound(i);  
    202.             bool uj = is_upper_bound(j);  
    203.             update_alpha_status(i);  
    204.             update_alpha_status(j);  
    205.             int k;  
    206.             if(ui != is_upper_bound(i))  
    207.             {  
    208.                 Q_i = Q.get_Q(i,l);  
    209.                 if(ui)  
    210.                     for(k=0;k<l;k++)  
    211.                         G_bar[k] -= C_i * Q_i[k];  
    212.                 else  
    213.                     for(k=0;k<l;k++)  
    214.                         G_bar[k] += C_i * Q_i[k];  
    215.             }  
    216.   
    217.             if(uj != is_upper_bound(j))  
    218.             {  
    219.                 Q_j = Q.get_Q(j,l);  
    220.                 if(uj)  
    221.                     for(k=0;k<l;k++)  
    222.                         G_bar[k] -= C_j * Q_j[k];  
    223.                 else  
    224.                     for(k=0;k<l;k++)  
    225.                         G_bar[k] += C_j * Q_j[k];  
    226.             }  
    227.         }  
    228.     }  
    229.   
    230.     if(iter >= max_iter)  
    231.     {  
    232.         if(active_size < l)  
    233.         {  
    234.             // reconstruct the whole gradient to calculate objective value  
    235.             reconstruct_gradient();  
    236.             active_size = l;  
    237.             info("*");  
    238.         }  
    239.         fprintf(stderr," WARNING: reaching max number of iterations ");  
    240.     }  
    241.   
    242.     // calculate rho  
    243.   
    244.     si->rho = calculate_rho();  
    245.   
    246.     // calculate objective value  
    247.     {  
    248.         double v = 0;  
    249.         int i;  
    250.         for(i=0;i<l;i++)  
    251.             v += alpha[i] * (G[i] + p[i]);  
    252.   
    253.         si->obj = v/2;  
    254.     }  
    255.   
    256.     // put back the solution  
    257.     {  
    258.         for(int i=0;i<l;i++)  
    259.             alpha_[active_set[i]] = alpha[i];  
    260.     }  
    261.   
    262.     // juggle everything back  
    263.     /*{ 
    264.         for(int i=0;i<l;i++) 
    265.             while(active_set[i] != i) 
    266.                 swap_index(i,active_set[i]); 
    267.                 // or Q.swap_index(i,active_set[i]); 
    268.     }*/  
    269.   
    270.     si->upper_bound_p = Cp;  
    271.     si->upper_bound_n = Cn;  
    272.   
    273.     info(" optimization finished, #iter = %d ",iter);  
    274.   
    275.     delete[] p;  
    276.     delete[] y;  
    277.     delete[] alpha;  
    278.     delete[] alpha_status;  
    279.     delete[] active_set;  
    280.     delete[] G;  
    281.     delete[] G_bar;  
    282. }  
  • 相关阅读:
    python-裴波那契数列
    python-装饰器
    登录权限,认证
    sevlet面试题总结
    ssm整合
    配置文件sshmvc
    利用反射来实现动态代理
    springmvc spring hibernate整合
    中期项目总结
    web servlet 网址http://www.cnblogs.com/mengdd/tag/Servlet/
  • 原文地址:https://www.cnblogs.com/Miliery/p/4394149.html
Copyright © 2020-2023  润新知