baselines算法库baselines/common/input.py模块代码:
import numpy as np import tensorflow as tf from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'): ''' Create placeholder to feed observations into of the size appropriate to the observation space Parameters: ---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns: ------- tensorflow placeholder tensor ''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype if dtype == np.int8: dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'): ''' Create placeholder to feed observations into of the size appropriate to the observation space, and add input encoder of the appropriate type. ''' placeholder = observation_placeholder(ob_space, batch_size, name) return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder): ''' Encode input in the way that is appropriate to the observation space Parameters: ---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder ''' if isinstance(ob_space, Discrete): return tf.to_float(tf.one_hot(placeholder, ob_space.n)) elif isinstance(ob_space, Box): return tf.to_float(placeholder) elif isinstance(ob_space, MultiDiscrete): placeholder = tf.cast(placeholder, tf.int32) one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] return tf.concat(one_hots, axis=-1) else: raise NotImplementedError
可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。
可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。
在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。
encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。
根据encoder_observation中的代码:
我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。
如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行 tf.one_hot 操作,即:
tf.one_hot(placeholder, ob_space.n)
如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行 tf.one_hot 操作然后在concat拼接,即:
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] return tf.concat(one_hots, axis=-1)
其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。
其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。
=======================================
observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。
可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。
import gym env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ]) for i in range(env_space.shape[-1]): print(env_space.nvec[i])
可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。
一般,gym的常见observation空间类型有:
gym.space.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
gym.spaces.Tuple
gym.spaces.Dict
其中:
gym.spaces.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
对应的observation均为np.array类型,在这里:
gym.spaces.Discrete
gym.spaces.MultiDiscrete
由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。
由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。
查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:
或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。
===============================================