filter_observation.py模块中类
FilterObservation(ObservationWrapper) 的理解。
代码:
import copy from gym import spaces from gym import ObservationWrapper class FilterObservation(ObservationWrapper): """Filter dictionary observations by their keys. Args: env: The environment to wrap. filter_keys: List of keys to be included in the observations. Raises: ValueError: If observation keys in not instance of None or iterable. ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space """ def __init__(self, env, filter_keys=None): super(FilterObservation, self).__init__(env) wrapped_observation_space = env.observation_space assert isinstance(wrapped_observation_space, spaces.Dict), ( "FilterObservationWrapper is only usable with dict observations.") observation_keys = wrapped_observation_space.spaces.keys() if filter_keys is None: filter_keys = tuple(observation_keys) missing_keys = set( key for key in filter_keys if key not in observation_keys) if missing_keys: raise ValueError( "All the filter_keys must be included in the " "original obsrevation space.\n" "Filter keys: {filter_keys}\n" "Observation keys: {observation_keys}\n" "Missing keys: {missing_keys}".format( filter_keys=filter_keys, observation_keys=observation_keys, missing_keys=missing_keys, )) self.observation_space = type(wrapped_observation_space)([ (name, copy.deepcopy(space)) for name, space in wrapped_observation_space.spaces.items() if name in filter_keys ]) self._env = env self._filter_keys = tuple(filter_keys) def observation(self, observation): filter_observation = self._filter_observation(observation) return filter_observation def _filter_observation(self, observation): observation = type(observation)([ (name, value) for name, value in observation.items() if name in self._filter_keys ]) return observation
该类的一个前提要求是传入的内部env必须是状态空间属于spaces.Dict类的,如下:
assert isinstance(wrapped_observation_space, spaces.Dict), ( "FilterObservationWrapper is only usable with dict observations.")
该类的意思就是将传入的状态空间为spaces.Dict类型的env中的属于filter_keys的key保留下其他的不保留。
内部的包装类的所有key为env.observation_space.spaces.keys() 。
如果需要保留下来的key本身不存在与内部包装类中,则记录下来:
missing_keys = set( key for key in filter_keys if key not in observation_keys)
并报错:
if missing_keys: raise ValueError( "All the filter_keys must be included in the " "original obsrevation space.\n" "Filter keys: {filter_keys}\n" "Observation keys: {observation_keys}\n" "Missing keys: {missing_keys}".format( filter_keys=filter_keys, observation_keys=observation_keys, missing_keys=missing_keys, ))
每次获得内部类的observation后都按照过滤的key对其进行处理将过滤后的observation向上传递:
def observation(self, observation): filter_observation = self._filter_observation(observation) return filter_observation def _filter_observation(self, observation): observation = type(observation)([ (name, value) for name, value in observation.items() if name in self._filter_keys ]) return observation
=========================================