def relu_forward(x):
out = x * (x > 0) # * 对于 np.ndarray 而言表示 handmard 积,x > 0 得到的 0和1 构成的矩阵
return out, x
def relu_backward(dout, cache):
x = cache
dx = dout * (x >= 0)
return dx
- 传递回去的 x 作为反向传递时会用到的中间变量也即 cache