书到用时方恨少,知识到用时才知道没有学精通。函数求导的链式规则,我记得f(g(x))' = f'(g)*g'(x)。我甚至还记得更便于理解记忆的形式:df/dx = df/dg * dg/dx. 没错,看上去我记得很清楚。可是实际工作中碰到的函数是这样的:
已知:
f(g1, g2, ...., gm) // f 是 g1, g2, ..., gm 的函数
g1(u1, u2, ...., un) // g1 是 u1, u2, ...., un 的函数
g2(u1, u2, ...., un)
...
gm(u1, u2, ..., un)
u1(x1, x2, ... xk) // u1是x1, x2, ..., xk 的函数
u2(x1, x2, ... xk)
...
un(x1, x2, ... xk)
求:f'(x1),..., f'(xk) // 求 f 对 x1, x2, ..., xk 的导数
居然一下子傻眼了。这里,一个函数有多个变量,而变量本身又是函数,传递超过两层。折腾了半天,才了解链式规则的真谛:函数f对x的微分,等于f对其自身变量u的微分乘以u对x的微分。如果自身变量不止一个,则要把对所有变量的微分相加。所以,正确的结果是:
f'(x1) = f'(g1)*g1'(x1) + f'(g2)*g2'(x1) + .... + f'(gm)*gm'(x1)
...
f'(xk) = f'(g1)*g1'(xk) + f'(g2)*g2'(xk) + .... + f'(gm)*gm'(xk)
写成易于理解的形式,是:
df/dx1 = df/dg1 * dg1/dx1 + df/dg2 * dg2/dx1 + .... + df/dgm * dgm/dx1
...
df/dxk = df/dg1 * dg1/dxk + df/dg2 * dg2/dxk + ... + df/dgm * dgm/dxk
可是,g1'(x1),也就是dg1/dx1仍然是不知道的啊。答案是继续运用链式规则:
g1'(x1) = g1'(u1)*u1'(x1) + g1'(u2)*u2'(x1) + .... + g1'(un)*un'(x1)
写成易于理解的形式,是:
dg1/dx1 = dg1/du1 * du1/dx1 + dg1/du2 * du2/dx1 + ... + dg1/dun * dun/dx1
到这一步,u1'(x1),也就是dg1/du1就是已知的了。一路回代,就可以写出f'(x1) 的公式了。
后记:整个回代过程如果用程序实现,就是前向自动微分方法(Forward mode automatic differenation)。因为任意复杂的函数最后都可以写成基本数学函数(u1, u2, ..., un)的组合,而基本数学函数的微分公式我们是知道的,可以以硬编码的方式写进去。