“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

转自:https://www.meltycriss.com/2018/03/26/tech-gym/

本文首先介绍Gym的核心函数调用链,然后介绍如何创建自定义的Gym环境,最后给出一些使用Gym过程中碰到的问题及其解决方案


01

Gym核心函数调用链
一般来说,使用Gym的代码如下:
# main.pyimport gymdef choose_action(o):  ...env = gym.make('CartPole-v0')o = env.reset()while True:  a = choose_action(o)  o_, r, done, info = env.step(a)  o = o_  if done:    break
可见,关键的函数有:
  • env = gym.make('CartPole-v0')

  • env.reset()

  • env.step(a)

我们先关注env.reset()和env.step(a)。这两个函数是超类Env的成员函数,Env的相关代码如下:
# gym/core.pyclass Env(object):  ...    # Override in ALL subclasses    def _step(self, action): raise NotImplementedError    def _reset(self): raise NotImplementedError    ...    def step(self, action):        return self._step(action)    def reset(self):        return self._reset()    ...
可以看到这两个函数依赖于子类的_reset(self)和_step(self, action)实现,子类CartPoleEnv的相关代码如下:
# gym/envs/classic_control/CartPole.pyclass CartPoleEnv(gym.Env):  ...    def _step(self, action):        ...    def _reset(self):        ...    ...
综上,env.reset()和env.step(a)实际上是调用子类的_reset(self)和_step(self, action)。
下面我们关注gym.make('CartPole-v0'),它的实现如下:
# gym/envs/registration.py# Have a global registryregistry = EnvRegistry()...def make(id):    return registry.make(id)
可以看到gym.make依赖于类EnvRegistry的成员函数make,EnvRegistry的相关代码如下:
# gym/envs/registration.pyclass EnvRegistry(object):    def __init__(self):      # 注册表      # key:  环境名称(e.g., 'CartPole-v0')      # value:类型为EnvSpec,可以暂时理解为环境        self.env_specs = {}    def make(self, id):        ...        # 根据环境名称,通过成员函数找到对应的环境        spec = self.spec(id)        # 实例化环境        env = spec.make()        ...        return env    ...    def spec(self, id):        ...    ...
可见类EnvRegistry的成员函数make依赖于类EnvSpec的成员函数make,EnvSpec的相关代码如下:
# gym/envs/registration.pydef load(name):    ...# EnvSpec与Env之间的关系类似于说明商品规格的订单与商品之间的关系,# 下面用一个例子来说明:# 假设你网购看中了一款衣服,那么你会挑选该款衣服的颜色、码数,然后再下单。# 在这个例子里面,那款衣服就是Env,而说明该款衣服颜色、码数的订单就是EnvSpec。# 这就是为什么EnvRegistry.make(self, id)中,在得到spec之后还要再spec.make(),# 因为EnvSpec并不是Env,正如订单不是衣服。class EnvSpec(object):    def __init__(self, id, entry_point=None, ...):      self.id = id        ...        self._entry_point = entry_point        ...    def make(self):        ...        # 动态加载环境类        # 相当于以下代码        # from self._entry_point import classA        # cls = classA        cls = load(self._entry_point)        # 实例化环境        env = cls(**self._kwargs)        ...        return env    ...
至此,我们对Gym的核心函数调用链有了一个基本的了解:
  • gym.make(id):通过EnvRegistry中的注册表找到对应的EnvSpec,EnvSpec根据entry_point动态import对应的Env,并将其实例化;
  • env.reset()和env.step(a):子类的_reset(self)和_step(self, action)。


02

创建自定义环境
对Gym的核心函数调用链有了基本了解后,我们知道创建自定义环境的关键有两个:
  • 第一个是搭建自己的Env子类FooEnv;
  • 第二个是注册FooEnv(i.e., 将FooEnv添加到registry.env_specs中),使得gym.make(id)可以找到FooEnv。
官方文档推荐的自定义环境目录结构如下:
gym-foo/  README.md  setup.py       #将gym_foo这个package加到系统环境变量中  gym_foo/       #核心部分    __init__.py   #注册FooEnv    envs/      __init__.py      foo_env.py   #实现FooEnv
实现FooEnv没什么特别的,就是根据自己的需求,实现_step(self, action)、_reset(self)等函数。
值得一提的是注册FooEnv,我们无需自己实现注册环境的代码,因为Gym已经有现成的注册环境API,我们只需要调用该API即可。在我们的自定义环境中,负责注册FooEnv的文件为gym-foo/gym_foo/__init__.py,它的内容如下:
# gym-foo/gym_foo/__init__.pyfrom gym.envs.registration import registerregister(    id='foo-v0',   # 环境名    entry_point='gym_foo.envs:FooEnv',   # 环境类,之后就根据这个路径动态import环境)
可见,注册的关键是register函数,而register函数的实现如下:
# gym/envs/registration.py# Have a global registryregistry = EnvRegistry()# Gym的注册环境APIdef register(id, **kwargs):    return registry.register(id, **kwargs)def make(id):    return registry.make(id)
可以看到register的实现依赖于类EnvRegistry的成员函数register,其相关代码如下:
# gym/envs/registration.pyclass EnvRegistry(object):  ...    def register(self, id, **kwargs):        ...        # 将FooEnv对应的“订单”写到“注册表”上        self.env_specs[id] = EnvSpec(id, **kwargs)
综上,我们可以通过API函数register注册自定义的环境FooEnv。

03

注意事项

3.1 server render

假如你通过ssh连接server,在server上运行(i.e., python main.py)以下代码(关键点在使用env.render()保存录像):
# main.pyimport gymfrom gym import wrappersenv = gym.make('CartPole-v0')env = wrappers.Monitor(env, 'video')for i_episode in range(20):    observation = env.reset()    for t in range(100):        env.render()        action = env.action_space.sample()        observation, reward, done, info = env.step(action)        if done:            break
那么你会得到一个报错,报错的信息大概是pyglet.canvas.xlib.NoSuchDisplayException: Cannot connect to "None"。
原因大概是env.render()需要图形界面(就是弹出来的那个框框),而当你使用ssh连接server时是没有图形界面的。因此我们需要一个虚拟的图形界面,而xvfb-run就是一个提供虚拟图形界面的工具。
所以我们需要使用xvfb-run -a -s "-screen 0 1400x900x24 +extension RANDR" -- python main.py来运行我们的代码。
一般来说,运行上述指令是会报错的,报错的信息大概是pyglet requires an X server with GLX,主要原因在于显卡驱动以及cuda的安装有问题,没有加--no-opengl的flag。解决方案可以参考这里和这里

3.2 保存每一段episode的录像

wrappers.Monitor默认不会保存所有episode的录像,但我们可以通过以下代码来设置保存所有episode的录像:
env = wrappers.Monitor(env, 'video', video_callable=lambda episode_id: True)

3.3 动态修改episode的最大step

env._max_episode_steps = xxx。注意,这仅当env的类型为TimeLimit时可用。

3.4 关于wrapper

  • 相同的两个wrapper不能叠加(e.g., Monitor不可以和Monitor叠加,但是Monitor可以和TimeLimit叠加),否则会报double wrapper的错。
  • 在注册FooEnv时,加不加max_episode_steps=xxx会影响返回的Env的类型。假如加了,返回的是TimeLimit类型的wrapper;假如不加,返回的就是裸的FooEnv。
  • Monitor里面有两个recorder,一个是stat_recorder,用于保存数据(reward之类的);另一个是video_recorder,用于录像。Monitor会在每一次调用env.reset和env.step之后调用render

3.5 屏蔽log信息

# main.pyimport logging# suppress INFO level logging 'Making new env: ...'logging.getLogger('gym.envs.registration').setLevel(logging.WARNING)# suppress INFO level logging 'Starting new video recorder writing to ...'logging.getLogger('gym.monitoring.video_recorder').setLevel(logging.WARNING)# suppress INFO level logging 'Creating monitor directory ...'logging.getLogger('gym.wrappers.monitoring').setLevel(logging.WARNING)



他山之历史文章


更多他山之石专栏文章,

请点击文章底部“阅读原文”查看


Image
Image


分享、点赞、在看,给个三连击呗!