本文简单介绍模型训练时候,使用准确率求解过程,不涉及精确率和召回率计算,
本文给出简要计算方法与代码。
计算方法:
使用top1计算为例(以下以2个batch,3个num_classes举列):
网络预测结果形式:pred=[b,num_classes] ,如pred=[[0.6,0.8,0.9],[0.7,0.4,0.3]]
真实标签形式:label=[b],如batch[1,0]
公式:预测正确/预测个数(即batch)
计算步骤:
步骤1:从pred找到预测最好的分别为[0.9,0.7],可知类别为[2,0]
步骤2:依次比较预测与label是否匹[2,0]--[1,0],可知第二个预测正确,则预测正确为1
步骤3:计算准确率为:1/2=50%
计算代码:
import torch def accuracy(output, target, topk=(1,)): maxk = max(topk) # topk=(1,)取top1准确率,topk=(1,5)取top1和top5准确率 batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) # topk参数中,maxk取得是top1准确率,dim=1是按行取值, largest=1是取最大值 pred = pred.t() # 转置 correct = pred.eq(target.view(1, -1).expand_as(pred)).contiguous() # 比较是否相等 res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(1 / batch_size)) return res import numpy as np if __name__ == '__main__': N=60 # batch C=9 # 类别 pred=np.random.rand(N,C) pred=torch.from_numpy(pred) label=np.random.randint(0,2,N) label=torch.from_numpy(label) r=accuracy(pred, label, topk=(1,5)) # topk=(1,5) 表示求解top1与top5的分类准确率 print(r) # print(label) # print(pred)
结果显示: