下面给出模型训练步骤的思路,在用pytorch(也包括其他框架)编写代码进行网络编写时,建议都按照这几个步骤来进行,形成一个清晰的思路
各模块简要说明
数据:涉及数据的收集、划分、读取及预处理等
模型:根据任务的复杂程度选择简单的线性模型或复杂的神经网络模型等
损失函数:根据任务的不同选择不同的损失函数,比如线性回归模型采用均方差函数,分类任务采用交叉熵函数等
优化器:根据损失函数求得的梯度来更新模型参数
迭代训练: 确立好前四大模块后,进行反复的迭代训练:前向传播;计算loss;反向传播;更新梯度;清空梯度;计算acc及可视化