在求最优解时,前面很多地方都用梯度下降(Gradient Descent)的方法,但由于最优步长很难确定,可能会出现总是在最优解附近徘徊的情况,致使最优解的搜索过程很缓慢。牛顿法(Newton's Method)在最优解的搜索方面有了较大改进,它不仅利用了目标函数的一阶导数,还利用了搜索点处的二阶导数,使得搜索算法能更准确地指向最优解。
我们结合下图所示的一个实例来描述牛顿法的思想。假设我们想要求得参数( heta),使得(f( heta)=0)。算法的描述如下:
- 随机猜测一个解( heta^{(0)}),并令(t=0);
- 在( heta^{(t)})处用一根切线来近似(f( heta));
- 求得切线与横坐标的交点( heta^{(t+1)}),作为下一个可能的解;
- (t=t+1);
- 重复2-4,直到收敛,即(f( heta^{(t)})approx 0)。
那么( heta^{(t+1)})与( heta^{(t)})之间存在怎样的迭代关系呢?由切线的斜率可知 egin{equation} f'( heta)=frac{f( heta)}{vartriangle}Rightarrow vartriangle=frac{f( heta)}{f'( heta)} end{equation}
观察( heta^{(t+1)})与( heta^{(t)})在横坐标上的关系,可知 egin{equation} heta^{(t+1)}= heta^{(t)}-vartriangle= heta^{(t)}-frac{f( heta)}{f'( heta)} end{equation}
牛顿法给出了(f( heta)=0)的求解算法,那么怎样将其运用到求使似然函数(mathcal{L}( heta))最大化的参数上呢?一般最优参数( heta^{star})在(mathcal{L}( heta))的极值点出取得,即(mathcal{L}'( heta^{star})=0)。那么,令上面的(f( heta)=mathcal{L}'( heta)),我们很容易就得出了下列的迭代法则 egin{equation} heta^{(t+1)}= heta^{(t)}-frac{mathcal{L}'( heta^{(t)})}{mathcal{L}''( heta^{(t)})} end{equation} 最终求得使(mathcal{L}'( heta)=0)的参数( heta^star),也就是令似然函数(mathcal{L}( heta))最大的参数。
上面讨论的参数( hetainmathbb{R}),我们现在将牛顿法则推广到(n)维向量( hetainmathbb{R}^n),对应的迭代法则形式如下: egin{equation} heta^{(t+1)}= heta^{(t)}-H^{-1} abla_{ heta}mathcal{L} end{equation} 其中(H)为(mathcal{L})对向量( heta^{(t)})的二阶偏导,称为Hessian矩阵,(H_{ij}=frac{partial^2mathcal{L}}{partial heta^{(t)}_ipartial heta^{(t)}_j})。
接下来,我们从另外一个角度来考察牛顿法。用似然函数(mathcal{L}( heta))的二阶泰勒展开(mathcal{F}( heta))来对其进行逼近。 egin{equation} mathcal{L}( heta)approxmathcal{F}( heta)=mathcal{L}( heta^{(t})+ abla_{ heta^{(t)}}mathcal{L}( heta- heta^{(t)})+frac{1}{2}( heta- heta^{(t)})^TH( heta- heta^{(t)}) end{equation} 令( heta= heta^{(t+1)}),可得 egin{equation} egin{array}{ll} mathcal{F}( heta^{(t+1)})=&mathcal{L}( heta^{(t)})+ abla_{ heta^{(t)}}mathcal{L}^T( heta^{(t+1)}- heta^{(t)})\ &+frac{1}{2}( heta^{(t+1)}- heta^{(t)})^TH( heta^{(t+1)}- heta^{(t)}) end{array} end{equation} 现在,我们的目的是求得使(mathcal{F}( heta^{(t+1)}))最小的参数( heta^{(t+1)})。将上式对( heta^{(t+1)})求导并令导数为0,可得 egin{equation} frac{partialmathcal{F}( heta^{(t+1)})}{partial heta^{(t+1)}}= abla_{ heta^{(t)}}mathcal{L}+H( heta^{(t+1)}- heta^{(t)})=0 end{equation} 等式两侧同时左乘(H^{-1}),化简得 egin{equation} heta^{(t+1)}= heta^{(t)}-H^{-1} abla_{ heta}mathcal{L} end{equation}
我们用的是二阶泰勒展开式(mathcal{F}( heta))逼近似然函数(mathcal{L}( heta))。如果(mathcal{L}( heta))确实为二次函数,那么(mathcal{F}( heta))就是(mathcal{L}( heta))的准确展开式,利用牛顿法一步就可以直接求得最优解。一般情况下,(mathcal{L}( heta))并非二次函数,那么(mathcal{F}( heta))也就存在逼近误差,使得一次迭代不能求得最优解,当(mathcal{L}( heta))的次数很高时,往往要经历很多次迭代。一般而言,因为牛顿法利用了二阶导数来修正搜索方向和步长,收敛速度很更快。但是这同样也是要付出代价的,相比梯度下降而言,我们需要额外计算Hessian矩阵并求其逆,这两步的计算代价都很大。只要参数( heta)的维度(n)不是很大,可以考虑用牛顿迭代。另外还有一点,如果目标函数不是严格的凸函数,Hessian矩阵(H)很可能是奇异矩阵,也就是存在特征值为0的情况,那么它的逆矩阵是不存在的,也就无法用牛顿法。
今年有一道面试题是要求我们写出一段程序,求解(sqrt{n})。如果把牛顿法用上去,问题就迎刃而解了。我们设定目标函数为(f(x)=x^2-n),那么令(f(x)=0)的解很显然就是(pmsqrt{n})。要注意的是,我们要选择合理的迭代起始点,如果我们从正数开始迭代,求得的是(sqrt{n});如果从负数开始迭代,求得的就是(-sqrt{n});如果从0开始迭代,会出现未定义的计算(0作为除数)。我们根据前面讲的牛顿迭代法则,直接给出该题的迭代法则 egin{equation} x^{(t+1)}=x^{(t)}-frac{f(x^{(t)})}{f'(x^{(t)})}=x^{(t)}-frac{(x^{(t)})^2-n}{2x^{(t)}}=frac{1}{2}left(x^{(t)}+frac{n}{x^{(t)}} ight) end{equation} 下面是由该算法写出的一段精简的code,浓缩了牛顿算法的精髓
1 double mysqrt1(double n) 2 { 3 if (n<0) return -1; 4 if(n==0) return 0; 5 double eps=1e-5; 6 double x=0.1;//start from a positive value 7 while(fabs(x*x-n)>=eps) 8 x=(x+n/x)/2;//Newton's method 9 return x; 10 }
这道题我还想了另外一个算法,算法的启发点来源于((x-1)(x+1)+1=x^2=n)。用这个算法,我们的迭代起始点可以是0。算法的基本思想如下:给定一个初始步长step,从起始点开始每次向前走一个步长,直到超过了(sqrt{n});一旦超过了(sqrt{n}),就要开始慢慢向最终解靠近,每次前进或后退的步长都缩减为以前的一半。很明显,这个算法没有牛顿迭代法快。我只用了少数几个测试用例,两段程序的计算结果都和sqrt库函数的计算结果一致。代码如下:
1 double mysqrt2(double n) 2 { 3 if (n<0) return -1; 4 double x=0; 5 double step=10; 6 int threshold=0; 7 double eps=1e-5; 8 double res=x*x; 9 while(fabs(res-n)>=eps) 10 { 11 if(res<n) 12 { 13 //once we have passed the solution,we must walk forward slowly 14 if(threshold) step/=2; 15 x+=step;//walking forward 16 } 17 else//walk 18 { 19 threshold=1;//indicating we have passed the solution 20 step/=2;//reducing the step size to its half 21 x-=step;//walking back 22 } 23 res=x*x;//compute x*x to estimate real n 24 }//end while 25 return x; 26 }