• gym库中类FilterObservation(ObservationWrapper)的理解


    filter_observation.py模块中类

    FilterObservation(ObservationWrapper) 的理解。

    代码:

    import copy
    
    from gym import spaces
    from gym import ObservationWrapper
    
    
    class FilterObservation(ObservationWrapper):
        """Filter dictionary observations by their keys.
        
        Args:
            env: The environment to wrap.
            filter_keys: List of keys to be included in the observations.
    
        Raises:
            ValueError: If observation keys in not instance of None or
                iterable.
            ValueError: If any of the `filter_keys` are not included in
                the original `env`'s observation space
        
        """
        def __init__(self, env, filter_keys=None):
            super(FilterObservation, self).__init__(env)
    
            wrapped_observation_space = env.observation_space
            assert isinstance(wrapped_observation_space, spaces.Dict), (
                "FilterObservationWrapper is only usable with dict observations.")
    
            observation_keys = wrapped_observation_space.spaces.keys()
    
            if filter_keys is None:
                filter_keys = tuple(observation_keys)
    
            missing_keys = set(
                key for key in filter_keys if key not in observation_keys)
    
            if missing_keys:
                raise ValueError(
                    "All the filter_keys must be included in the "
                    "original obsrevation space.\n"
                    "Filter keys: {filter_keys}\n"
                    "Observation keys: {observation_keys}\n"
                    "Missing keys: {missing_keys}".format(
                        filter_keys=filter_keys,
                        observation_keys=observation_keys,
                        missing_keys=missing_keys,
                    ))
    
            self.observation_space = type(wrapped_observation_space)([
                (name, copy.deepcopy(space))
                for name, space in wrapped_observation_space.spaces.items()
                if name in filter_keys
            ])
    
            self._env = env
            self._filter_keys = tuple(filter_keys)
    
        def observation(self, observation):
            filter_observation = self._filter_observation(observation)
            return filter_observation
    
        def _filter_observation(self, observation):
            observation = type(observation)([
                (name, value)
                for name, value in observation.items()
                if name in self._filter_keys
            ])
            return observation

    该类的一个前提要求是传入的内部env必须是状态空间属于spaces.Dict类的,如下:

            assert isinstance(wrapped_observation_space, spaces.Dict), (
                "FilterObservationWrapper is only usable with dict observations.")

    该类的意思就是将传入的状态空间为spaces.Dict类型的env中的属于filter_keys的key保留下其他的不保留。

    内部的包装类的所有key为env.observation_space.spaces.keys() 。

    如果需要保留下来的key本身不存在与内部包装类中,则记录下来:

            missing_keys = set(
                key for key in filter_keys if key not in observation_keys)

    并报错:

            if missing_keys:
                raise ValueError(
                    "All the filter_keys must be included in the "
                    "original obsrevation space.\n"
                    "Filter keys: {filter_keys}\n"
                    "Observation keys: {observation_keys}\n"
                    "Missing keys: {missing_keys}".format(
                        filter_keys=filter_keys,
                        observation_keys=observation_keys,
                        missing_keys=missing_keys,
                    ))

    每次获得内部类的observation后都按照过滤的key对其进行处理将过滤后的observation向上传递:

        def observation(self, observation):
            filter_observation = self._filter_observation(observation)
            return filter_observation
    
        def _filter_observation(self, observation):
            observation = type(observation)([
                (name, value)
                for name, value in observation.items()
                if name in self._filter_keys
            ])
            return observation

    =========================================

  • 相关阅读:
    C# 基于正则表达式的字符串验证
    Matlab 曲线绘制之线型和颜色 示例
    色彩学基础知识
    Matlab 常用绘图指令(二维图形)
    让机器教人学习更有效:Becoming the Expert
    基于直接最小二乘的椭圆拟合(Direct Least Squares Fitting of Ellipses)
    一维最大熵二值化方法
    基于模板的全自动目标检测跟踪系统的设想
    纽扣检测及其旋转角度
    指针式压力表自动读数:Auto Read the Value of Manometer
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/16035654.html
Copyright © 2020-2023  润新知