记录一个刚学习到的gym使用的点,就是gym.ObservationWrapper使用时的注意点——reset和step函数可以覆盖observation函数。
给出代码:
import gym class Wrapper(gym.ObservationWrapper): def __init__(self, env): super(Wrapper, self).__init__(env) def reset(self): obs = self.env.reset() print(obs) return obs # return self.observation(obs) def step(self, action): obs, reward, is_done, info = self.env.step(action) print(obs) return obs, reward, is_done, info # return self.observation(obs), reward, is_done, info def observation(self, observation): observation += 100 return observation env=gym.make("CartPole-v0") env = Wrapper(env) print("reset:") print(env.reset()) print("step:") print(env.step(0)[0])
运行:
可以看到,继承gym.ObservationWrapper类后,如果重写reset函数或step函数,那么对应的返回的observation也不会被该类的observation函数所处理。
如果我们把reset函数和step函数注释掉,再次运行可以看到运行结果:
=======================================
那么在继承gym.ObservationWrapper类后,如果重写reset函数或step函数,同时又希望对应的返回的observation被该类的observation函数所处理,那么我们可以做如下的修改:
import gym class Wrapper(gym.ObservationWrapper): def __init__(self, env): super(Wrapper, self).__init__(env) def reset(self): obs = self.env.reset() print(obs) # return obs return self.observation(obs) def step(self, action): obs, reward, is_done, info = self.env.step(action) print(obs) # return obs, reward, is_done, info return self.observation(obs), reward, is_done, info def observation(self, observation): observation += 100 return observation env=gym.make("CartPole-v0") env = Wrapper(env) print("reset:") print(env.reset()) print("step:") print(env.step(0)[0])
运行结果:
---------------------------------------------------------------------------
继承gym.ObservationWrapper类后reset和step函数的使用可以具体看下ObservationWrapper类的实现: