准备工作
使用colab训练crnn模型
将训练代码和数据集上传至colab,注意图片要定长,不然可能会出现loss nan
安装pytorch1.2.0
pip uninstall torch
pip install torch===1.2.0 torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html
import torch
print(torch.__version__)
将colab连接到google driver
!apt-get install opam
!opam init
!opam update
!opam install depext
!opam depext google-drive-ocamlfuse
!opam install google-drive-ocamlfuse
#进行授权操作
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!/root/.opam/system/bin/google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | /root/.opam/system/bin/google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}
#!!!注意,里面的/root/.opam/system/bin/google-drive-ocamlfuse换成你自己的路径,一般来说你也会得到和我一样的结果
# 指定Google Drive云端硬盘的根目录,名为drive
!mkdir -p drive
!/root/.opam/system/bin/google-drive-ocamlfuse drive
切换到对应目录
%cd /content/drive/colab/crnn
%ls
安装需要的包
pip install tensorboardX
运行训练代码
修改yaml文件中的相关参数,例如图片路径,batchsize,开始结束epoch等
!python train.py --cfg lib/config/OWN_config.yaml
测试模型效果
!python demonew.py --image_path 1_en.jpg
这里需要将demo.py的代码进行修改,第一个代码段中的路径修改为自己的路径
parser.add_argument('--checkpoint', type=str, default='output/checkpoints/base.pth',
help='the path to your checkpoints')
model = crnn.get_crnn(config).to(device)
print('loading pretrained model from {0}'.format(args.checkpoint))
##model.load_state_dict(torch.load(args.checkpoint)) ##作者训练的模型
checkpoint = torch.load(args.checkpoint, map_location='cpu') ##针对训练时新保存的模型
model.load_state_dict(checkpoint['state_dict'])