x = torch.tensor(2.0)
x.requires_grad_(True)
y = 2 * x
z = 5 * x
w = y + z.detach()
w.backward()
print(x.grad)
=> 2
本来应该x的梯度为7,但是detach()那一路切段了梯度的传播,导致5没有向后传递
x = torch.tensor(2.0)
x.requires_grad_(True)
y = 2 * x
z = 5 * x
w = y + z.detach()
w.backward()
print(x.grad)
=> 2
本来应该x的梯度为7,但是detach()那一路切段了梯度的传播,导致5没有向后传递