forked from datamllab/rlcard
-
Notifications
You must be signed in to change notification settings - Fork 1
/
registration.py
89 lines (71 loc) · 2.62 KB
/
registration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import importlib
# Default Config
DEFAULT_CONFIG = {
'allow_step_back': False,
'seed': None,
}
class EnvSpec(object):
''' A specification for a particular instance of the environment.
'''
def __init__(self, env_id, entry_point=None):
''' Initilize
Args:
env_id (string): The name of the environent
entry_point (string): A string the indicates the location of the envronment class
'''
self.env_id = env_id
mod_name, class_name = entry_point.split(':')
self._entry_point = getattr(importlib.import_module(mod_name), class_name)
def make(self, config=DEFAULT_CONFIG):
''' Instantiates an instance of the environment
Returns:
env (Env): An instance of the environemnt
config (dict): A dictionary of the environment settings
'''
env = self._entry_point(config)
return env
class EnvRegistry(object):
''' Register an environment (game) by ID
'''
def __init__(self):
''' Initilize
'''
self.env_specs = {}
def register(self, env_id, entry_point):
''' Register an environment
Args:
env_id (string): The name of the environent
entry_point (string): A string the indicates the location of the envronment class
'''
if env_id in self.env_specs:
raise ValueError('Cannot re-register env_id: {}'.format(env_id))
self.env_specs[env_id] = EnvSpec(env_id, entry_point)
def make(self, env_id, config=DEFAULT_CONFIG):
''' Create and environment instance
Args:
env_id (string): The name of the environment
config (dict): A dictionary of the environment settings
'''
if env_id not in self.env_specs:
raise ValueError('Cannot find env_id: {}'.format(env_id))
return self.env_specs[env_id].make(config)
# Have a global registry
registry = EnvRegistry()
def register(env_id, entry_point):
''' Register an environment
Args:
env_id (string): The name of the environent
entry_point (string): A string the indicates the location of the envronment class
'''
return registry.register(env_id, entry_point)
def make(env_id, config={}):
''' Create and environment instance
Args:
env_id (string): The name of the environment
config (dict): A dictionary of the environment settings
env_num (int): The number of environments
'''
_config = DEFAULT_CONFIG.copy()
for key in config:
_config[key] = config[key]
return registry.make(env_id, _config)