• 网络训练至某个epoch,参数 问题


    1 start_epoch = params.start_epoch
    2   stop_epoch = params.stop_epoch
    3   if params.resume != '':
    4     resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)  #get_resume_file函数得到epoch.tar文件
    5     if resume_file is not None:
    6       tmp = torch.load(resume_file)
    7       start_epoch = tmp['epoch']+1
    8       model.load_state_dict(tmp['state'])
    9       print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))

    对tar文件进行加载,并且选取其中需要的字典权重

     1 tmp = torch.load(modelfile)   # load parameter file:400.tar
     2   try:
     3     state = tmp['state']
     4   except KeyError:
     5     state = tmp['model_state']
     6   except:
     7     raise
     8   state_keys = list(state.keys())  #列举字典中的key
     9   for i, key in enumerate(state_keys):
    10     if "feature." in key and not 'gamma' in key and not 'beta' in key:
    11       newkey = key.replace("feature.","")
    12       state[newkey] = state.pop(key)  #删除该key并返回对应的值,不影响上面的训练
    13     else:
    14       state.pop(key)
    15 
    16   model.load_state_dict(state) 
  • 相关阅读:
    coredump分析
    Sword LRU算法
    C++ STL迭代器失效问题
    Sword DB主从一致性的解决方法
    Sword CRC算法原理
    C语言 按位异或实现加法
    Linux 等待信号(sigsuspend)
    C语言 宏定义之可变参数
    Linux shell字符串操作
    C++ *和&
  • 原文地址:https://www.cnblogs.com/stepping/p/13403741.html
Copyright © 2020-2023  润新知