看代码的过程中看到有这样的调用:
from gym.wrappers import FlattenObservation
if sinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env)
不是很理解这个代码的意思。
===============================================
查看gym源码中类:
FlattenObservation(ObservationWrapper)
import numpy as np import gym.spaces as spaces from gym import ObservationWrapper class FlattenObservation(ObservationWrapper): r"""Observation wrapper that flattens the observation.""" def __init__(self, env): super(FlattenObservation, self).__init__(env) flatdim = spaces.flatdim(env.observation_space) self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32) def observation(self, observation): return spaces.flatten(self.env.observation_space, observation)
从gym的状态空间的转换可以看出这个类是要将observation的状态空间进行flatten操作。
具体的flatten操作调用:
spaces.flatten(self.env.observation_space, observation)
查看spaces.flatten源代码:
def flatten(space, x): if isinstance(space, Box): return np.asarray(x, dtype=np.float32).flatten() elif isinstance(space, Discrete): onehot = np.zeros(space.n, dtype=np.float32) onehot[x] = 1.0 return onehot elif isinstance(space, Tuple): return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) elif isinstance(space, Dict): return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()]) elif isinstance(space, MultiBinary): return np.asarray(x).flatten() elif isinstance(space, MultiDiscrete): return np.asarray(x).flatten() else: raise NotImplementedError
可以知道如果 env.observation_space属于Box类型,则直接调用np.array的flatten操作。
如果 env.observation_space属于Discrete类型,则直接进行onehot编码的方法进行flatten操作。
env.observation_space如果属于多个Box类型或Discrete类型组合而成的,也就是属于Tuple, Dict, 那么需要将其中的每个类型的状态空间都进行flatten操作后在进行拼接操作。
即:(取出组合空间中的各个子状态空间迭代调用flatten操作从而实现对组合中的各个子observation_space进行flatten)
elif isinstance(space, Tuple): return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)]) elif isinstance(space, Dict): return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
MultiBinary, MultiDiscrete类型直接转为np.array类型的数据再进行flatten操作。
===================================================