损失函数
Dice
class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1., dims=(-2,-1)):
super(SoftDiceLoss, self).__init__()
self.smooth = smooth
self.dims = dims
def forward(self, x, y):
tp = (x * y).sum(self.dims)
fp = (x * (1 - y)).sum(self.dims)
fn = ((1 - x) * y).sum(self.dims)
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
dc = dc.mean()
return 1 - dc