• gym库中from gym.wrappers import FlattenObservation的理解


    看代码的过程中看到有这样的调用:

    from gym.wrappers import FlattenObservation

    if sinstance(env.observation_space, gym.spaces.Dict):
         env = FlattenObservation(env)

    不是很理解这个代码的意思。

    ===============================================

    查看gym源码中类:

    FlattenObservation(ObservationWrapper)

    import numpy as np
    import gym.spaces as spaces
    from gym import ObservationWrapper
    
    
    class FlattenObservation(ObservationWrapper):
        r"""Observation wrapper that flattens the observation."""
        def __init__(self, env):
            super(FlattenObservation, self).__init__(env)
    
            flatdim = spaces.flatdim(env.observation_space)
            self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)
    
        def observation(self, observation):
            return spaces.flatten(self.env.observation_space, observation)

    从gym的状态空间的转换可以看出这个类是要将observation的状态空间进行flatten操作。

    具体的flatten操作调用:

    spaces.flatten(self.env.observation_space, observation)

    查看spaces.flatten源代码:

    def flatten(space, x):
        if isinstance(space, Box):
            return np.asarray(x, dtype=np.float32).flatten()
        elif isinstance(space, Discrete):
            onehot = np.zeros(space.n, dtype=np.float32)
            onehot[x] = 1.0
            return onehot
        elif isinstance(space, Tuple):
            return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
        elif isinstance(space, Dict):
            return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
        elif isinstance(space, MultiBinary):
            return np.asarray(x).flatten()
        elif isinstance(space, MultiDiscrete):
            return np.asarray(x).flatten()
        else:
            raise NotImplementedError

    可以知道如果 env.observation_space属于Box类型,则直接调用np.array的flatten操作。

    如果 env.observation_space属于Discrete类型,则直接进行onehot编码的方法进行flatten操作。

    env.observation_space如果属于多个Box类型或Discrete类型组合而成的,也就是属于Tuple, Dict, 那么需要将其中的每个类型的状态空间都进行flatten操作后在进行拼接操作。

    即:(取出组合空间中的各个子状态空间迭代调用flatten操作从而实现对组合中的各个子observation_space进行flatten)

        elif isinstance(space, Tuple):
            return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
        elif isinstance(space, Dict):
            return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])

    MultiBinary, MultiDiscrete类型直接转为np.array类型的数据再进行flatten操作。

    ===================================================

  • 相关阅读:
    动态加载方法(定时任务)
    安装 asp.net core 出错
    .NET:权限管理
    关于随机数
    博客园首弹
    C# MVC从其他系统获取文件流,显示文件
    Python中操作MySQL步骤
    MySql之_增删改查
    数据库之_SQL注入
    为什么上传到youtube上的视频很模糊
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/16035111.html
Copyright © 2020-2023  润新知