如题:
def arg_parser(): """ Create an empty argparse.ArgumentParser. """ import argparse parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str) parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2') parser.add_argument('--num_timesteps', type=float, default=1e6), parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None) parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None) parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int) parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float) parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str) parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str) parser.add_argument('--play', default=False, action='store_true') return parser.parse_known_args() def parse_unknown_args(args): """ Parse arguments not consumed by arg parser into a dictionary """ retval = {} preceded_by_key = False for arg in args: if arg.startswith('--'): if '=' in arg: key = arg.split('=')[0][2:] value = arg.split('=')[1] retval[key] = value else: key = arg[2:] preceded_by_key = True elif preceded_by_key: retval[key] = arg preceded_by_key = False return retval def parse_cmdline_kwargs(args, unknown_args): ''' convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible ''' def parse(v): assert isinstance(v, str) try: return eval(v) except (NameError, SyntaxError): return v args.__dict__.update({k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}) return args args, unknown_args = arg_parser() print(args) args = parse_cmdline_kwargs(args, unknown_args) print(args)
运行:
python test.py --aaa=me --xxx=11.11 --abc=True --cde=1+99
解析结果:
Namespace(alg='ppo2', env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None)
Namespace(aaa='me', abc=True, alg='ppo2', cde=100, env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None, xxx=11.11)
=======================================
比较规范的运行参数解析的代码,方便后续代码中对参数的调用。