一、序列化与反序列化
1、torch.save
主要参数:
- obj:对象
- f:输出路径
2、torch.load
主要参数:
- f:文件路径
- map_location:指定存放位置,cpu or gpu
二、模型保存与加载的两种方式
第一种方式:
保存整个Module
torch.save(net,path)
第二种方式:
state_dict = net.state_dict()
torch.save(state_dict , path)
三、模型断点续训练
模型微调Finetune
四、Transfer Learning & Model Finetune
Transfer Learning:机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)
Model Finetune:模型的迁移学习
模型微调的步骤:
1、获取预训练模型参数
2、加载模型(load_state_dict)
3、修改输出层
模型微调训练方法
1、固定预训练的参数(requires_grad = False; lr = 0)
2、Features Extractor较小学习率(params_group)
五、PyTorch中的Finetune
Finetune Resnet-18 用于二分类
蚂蚁蜜蜂二分类数据
训练集:各120~张 验证集:各70~张