- opt_einsum, 爱因斯坦求和
- jax, 算svd时比numpy 包快很多
- flax
之前的一个问题, relu = max(x, 0) 在0点次梯度(不可导)的导数取了什么? 代码中是0。
import jax
from jax.experimental import optimizers, stax
from jax import grad, jacfwd
In [7]: grad(stax.relu)(2.0)
Out[7]: DeviceArray(1., dtype=float32)
In [8]: grad(stax.relu)(0.0)
Out[8]: DeviceArray(0., dtype=float32)
In [10]: grad(stax.relu)(0.001)
Out[10]: DeviceArray(1., dtype=float32)
In [11]: grad(stax.relu)(-0.001)
Out[11]: DeviceArray(0., dtype=float32)