之前在tensorflow的mnist例程中看到了使用 absl.flags的方法来载入和解析参数的,出于学习的目的,就自己试验了一下,
代码如下:
1 # *_*coding:utf-8 *_* 2 # athor:auto 3 4 import sys, os 5 from absl import app 6 from absl import flags 7 from official.utils.flags import core as flags_core 8 9 10 FLAGS = flags.FLAGS 11 flags.DEFINE_string('gpu', None, 'comma separated list of GPU to use.') 12 13 14 def flagtest(argv): 15 del argv 16 if FLAGS.gpu: 17 print("gpu is %s" % FLAGS.gpu) 18 os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 19 else: 20 print('Please assign GPUs.') 21 exit() 22 23 def main(argv): 24 flags_core.define_base() 25 flags_core.define_performance(num_parallel_calls=False) 26 flags_core.define_image() 27 flags.adopt_module_key_flags(flags_core) 28 29 if __name__ == '__main__': 30 app.run(flagtest)
其中main中的几个调用都是源自于tensorflow的model/official,里面的函数大多是model/official/utils/flags/core.py内定义好的一些默认参数。
在mnist例子中还可以这样添加自定义项:
flags_core.set_defaults(data_dir='./tmp/mnist_data', model_dir='./tmp/mnist_model', batch_size=100, train_epochs=40, stop_threshold=0.998)
参考:
https://blog.csdn.net/faith_binyang/article/details/80551941