• baselines算法库baselines/common/input.py模块分析


    baselines算法库baselines/common/input.py模块代码:

    import numpy as np
    import tensorflow as tf
    from gym.spaces import Discrete, Box, MultiDiscrete
    
    def observation_placeholder(ob_space, batch_size=None, name='Ob'):
        '''
        Create placeholder to feed observations into of the size appropriate to the observation space
    
        Parameters:
        ----------
    
        ob_space: gym.Space     observation space
    
        batch_size: int         size of the batch to be fed into input. Can be left None in most cases.
    
        name: str               name of the placeholder
    
        Returns:
        -------
    
        tensorflow placeholder tensor
        '''
    
        assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
            'Can only deal with Discrete and Box observation spaces for now'
    
        dtype = ob_space.dtype
        if dtype == np.int8:
            dtype = np.uint8
    
        return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
    
    
    def observation_input(ob_space, batch_size=None, name='Ob'):
        '''
        Create placeholder to feed observations into of the size appropriate to the observation space, and add input
        encoder of the appropriate type.
        '''
    
        placeholder = observation_placeholder(ob_space, batch_size, name)
        return placeholder, encode_observation(ob_space, placeholder)
    
    def encode_observation(ob_space, placeholder):
        '''
        Encode input in the way that is appropriate to the observation space
    
        Parameters:
        ----------
    
        ob_space: gym.Space             observation space
    
        placeholder: tf.placeholder     observation input placeholder
        '''
        if isinstance(ob_space, Discrete):
            return tf.to_float(tf.one_hot(placeholder, ob_space.n))
        elif isinstance(ob_space, Box):
            return tf.to_float(placeholder)
        elif isinstance(ob_space, MultiDiscrete):
            placeholder = tf.cast(placeholder, tf.int32)
            one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
            return tf.concat(one_hots, axis=-1)
        else:
            raise NotImplementedError

    可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input

    可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。

    在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:

    return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)

    也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。

    encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。

    根据encoder_observation中的代码:

    我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。

    如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行  tf.one_hot  操作,即:

    tf.one_hot(placeholder, ob_space.n)

    如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行  tf.one_hot  操作然后在concat拼接,即:

    one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
    return tf.concat(one_hots, axis=-1)

    其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。

    其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。

    =======================================

    observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。

    可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。

    import gym
    
    env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ])
    
    for i in range(env_space.shape[-1]):
        print(env_space.nvec[i])

    可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。

    一般,gym的常见observation空间类型有:

    gym.space.Box

    gym.spaces.Discrete

    gym.spaces.MultiDiscrete

    gym.spaces.MultiBinary

    gym.spaces.Tuple

    gym.spaces.Dict

    其中:

    gym.spaces.Box

    gym.spaces.Discrete

    gym.spaces.MultiDiscrete

    gym.spaces.MultiBinary

    对应的observation均为np.array类型,在这里:

    gym.spaces.Discrete

    gym.spaces.MultiDiscrete

    由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。

    由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。

    查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:

     或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。

    ===============================================

  • 相关阅读:
    Python 内置数据结构之 set
    python 在指定的文件夹下生成随机的测验试卷文件
    Python 的 import 语句
    Python 和 R 中的一数多图
    Python 语法练习题
    python 3.x 和 2.x 不同
    R 的 plyr 包
    Python 和 R语言 中的折线图
    设置 Jupyter notebook 默认路径、启动快捷键、打开浏览器
    Python 虚拟环境
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/16080501.html
Copyright © 2020-2023  润新知