Skip to content

Commit

Permalink
Merge pull request #116 from Bam4d/environment_generator_generator
Browse files Browse the repository at this point in the history
Environment generator generator
  • Loading branch information
Bam4d authored Jun 27, 2021
2 parents 2257850 + c8b6d2e commit fdb1151
Show file tree
Hide file tree
Showing 38 changed files with 1,501 additions and 197 deletions.
1 change: 1 addition & 0 deletions bindings/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ PYBIND11_MODULE(python_griddly, m) {
observer_type.value("BLOCK_2D", ObserverType::BLOCK_2D);
observer_type.value("ISOMETRIC", ObserverType::ISOMETRIC);
observer_type.value("VECTOR", ObserverType::VECTOR);
observer_type.value("ASCII", ObserverType::ASCII);

py::class_<NumpyWrapper<uint8_t>, std::shared_ptr<NumpyWrapper<uint8_t>>>(m, "Observation", py::buffer_protocol())
.def_buffer([](NumpyWrapper<uint8_t> &m) -> py::buffer_info {
Expand Down
Binary file modified python/examples/griddlyrts/griddly_rts_global.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified python/examples/griddlyrts/griddly_rts_p1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified python/examples/griddlyrts/griddly_rts_p2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions python/examples/snippet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

def make_env(name):
wrapper = GymWrapperFactory()
wrapper.build_gym_from_yaml(name, 'Single-Player/Mini-Grid/minigrid-spiders.yaml',
wrapper.build_gym_from_yaml(name, 'Single-Player/GVGAI/spider-nest.yaml',
player_observer_type=gd.ObserverType.SPRITE_2D,
global_observer_type=gd.ObserverType.BLOCK_2D,
global_observer_type=gd.ObserverType.ASCII,
level=0,
max_steps=200)

Expand Down Expand Up @@ -38,8 +38,8 @@ def make_env(name):

frames += 1
obs, reward, done, info = env.step(action)
#env.render()
env.render(observer='global')
env.render()
print(env.render(observer='global'))

if frames % 1000 == 0:
end = timer()
Expand Down
43 changes: 40 additions & 3 deletions python/griddly/GymWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GymWrapper(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array']}

def __init__(self, yaml_file=None, level=0, global_observer_type=gd.ObserverType.VECTOR,
def __init__(self, yaml_file=None, yaml_string=None, level=0, global_observer_type=gd.ObserverType.VECTOR,
player_observer_type=gd.ObserverType.VECTOR, max_steps=None, gdy_path=None, image_path=None,
shader_path=None,
gdy=None, game=None, **kwargs):
Expand All @@ -29,10 +29,14 @@ def __init__(self, yaml_file=None, level=0, global_observer_type=gd.ObserverType
self._renderWindow = {}

# If we are loading a yaml file
if yaml_file is not None:
if yaml_file is not None or yaml_string is not None:
self._is_clone = False
loader = GriddlyLoader(gdy_path, image_path, shader_path)
self.gdy = loader.load(yaml_file)
if yaml_file is not None:
self.gdy = loader.load(yaml_file)
else:
self.gdy = loader.load_string(yaml_string)

self.game = self.gdy.create_game(global_observer_type)

if max_steps is not None:
Expand Down Expand Up @@ -204,10 +208,29 @@ def render(self, mode='human', observer=0):
observation = np.array(self.game.observe(), copy=False)
if self._global_observer_type == gd.ObserverType.VECTOR:
observation = self._vector2rgb.convert(observation)
if self._global_observer_type == gd.ObserverType.ASCII:
observation = observation \
.swapaxes(2, 0) \
.reshape(-1, observation.shape[0] * observation.shape[1]) \
.view('c')
ascii_string = ''.join(np.column_stack(
(observation, np.repeat(['\n'], observation.shape[0]))
).flatten().tolist())
return ascii_string

else:
observation = self._player_last_observation[observer]
if self._player_observer_type[observer] == gd.ObserverType.VECTOR:
observation = self._vector2rgb.convert(observation)
if self._player_observer_type[observer] == gd.ObserverType.ASCII:
observation = observation \
.swapaxes(2, 0) \
.reshape(-1, observation.shape[0] * observation.shape[1]) \
.view('c')
ascii_string = ''.join(np.column_stack(
(observation, np.repeat(['\n'], observation.shape[0]))
).flatten().tolist())
return ascii_string

if mode == 'human':
if self._renderWindow.get(observer) is None:
Expand Down Expand Up @@ -314,3 +337,17 @@ def build_gym_from_yaml(self, environment_name, yaml_file, global_observer_type=
'player_observer_type': player_observer_type
}
)

def build_gym_from_yaml_string(self, environment_name, yaml_string, global_observer_type=gd.ObserverType.SPRITE_2D,
player_observer_type=gd.ObserverType.SPRITE_2D, level=None, max_steps=None):
register(
id=f'GDY-{environment_name}-v0',
entry_point='griddly:GymWrapper',
kwargs={
'yaml_string': yaml_string,
'level': level,
'max_steps': max_steps,
'global_observer_type': global_observer_type,
'player_observer_type': player_observer_type
}
)
139 changes: 139 additions & 0 deletions python/griddly/util/environment_generator_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os

import gym
import numpy as np
import yaml

from griddly import GymWrapper, gd, GymWrapperFactory
from griddly.RenderTools import VideoRecorder
from griddly.util.wrappers import ValidActionSpaceWrapper


class EnvironmentGeneratorGenerator():

def __init__(self, gdy_path=None, yaml_file=None):
module_path = os.path.dirname(os.path.realpath(__file__))
self._gdy_path = os.path.realpath(
os.path.join(module_path, '../', 'resources', 'games')) if gdy_path is None else gdy_path
self._input_yaml_file = self._get_full_path(yaml_file)

def _get_full_path(self, gdy_path):
# Assume the file is relative first and if not, try to find it in the pre-defined games
fullpath = gdy_path if os.path.exists(gdy_path) else os.path.join(self._gdy_path, gdy_path)
# (for debugging only) look in parent directory resources because we might not have built the latest version
fullpath = fullpath if os.path.exists(fullpath) else os.path.realpath(
os.path.join(self._gdy_path + '../../../../resources/games', gdy_path))
return fullpath

def generate_env_yaml(self, level_shape):
level_generator_gdy = {}
with open(self._input_yaml_file, 'r') as fs:
self._gdy = yaml.load(fs, Loader=yaml.FullLoader)

objects = [o for o in self._gdy['Objects'] if 'MapCharacter' in o]
environment = self._gdy['Environment']

# Create the placement actions
actions = []
for obj in objects:
object_name = obj["Name"]
place_action = {
'InputMapping': {
'Inputs': {
'1': {'Description': f'Places objects of type \"{object_name}\"'}
}
},
'Name': f'place_{object_name.lower()}',
'Behaviours': [{
'Src': {
'Object': '_empty',
'Commands': [
{'spawn': object_name}
]
}
}]

}
actions.append(place_action)

level_generator_gdy['Actions'] = actions

# Copy the Objects
level_generator_gdy['Objects'] = [{
'Name': o['Name'],
'MapCharacter': o['MapCharacter'],
'Observers': o['Observers']
} for o in objects]

# Generate a default empty level
empty_level = np.empty(level_shape, dtype='str')
empty_level[:] = '.'

level_0_string = '\n'.join([' '.join(list(r)) for r in empty_level])

# Create the environment template
level_generator_gdy['Environment'] = {
'Name': f'{environment["Name"]} Generator',
'Description': f'Level Generator environment for {environment["Name"]}',
'Observers': {k: v for k, v in environment['Observers'].items() if k in ['Sprite2D', 'Isometric']},
'Player': {
'Observer': {
'TrackAvatar': False,
'Height': level_shape[1],
'Width': level_shape[0],
'OffsetX': 0,
'OffsetY': 0,
}
},
'Levels': [level_0_string],
}

return yaml.dump(level_generator_gdy)

def generate_env(self, size, **env_kwargs):
env_yaml = self.generate_env_yaml(size)

env_args = {
**env_kwargs,
'yaml_string': env_yaml,
}

return GymWrapper(*env_args)


if __name__ == '__main__':
wrapper_factory = GymWrapperFactory()
yaml_file = 'Single-Player/GVGAI/sokoban.yaml'

egg = EnvironmentGeneratorGenerator(yaml_file=yaml_file)

for i in range(100):
generator_yaml = egg.generate_env_yaml((10, 10))

env_name = f'test_{i}'
wrapper_factory.build_gym_from_yaml_string(
env_name,
yaml_string=generator_yaml,
# TODO: Change this to ASCII observer when its ready
global_observer_type=gd.ObserverType.VECTOR,
player_observer_type=gd.ObserverType.VECTOR,
)

env = gym.make(f'GDY-{env_name}-v0')
env.reset()
#env = ValidActionSpaceWrapper(env)

# visualization = env.render(observer=0, mode='rgb_array')
# video_recorder = VideoRecorder()
# video_recorder.start('generator_video_test.mp4', visualization.shape)

# Place 10 Random Objects
for i in range(0, 100):
action = env.action_space.sample()
obs, reward, done, info = env.step(action)

#state = env.get_state()

#visual = env.render(observer=0, mode='rgb_array')
# video_recorder.add_frame(visual)

43 changes: 43 additions & 0 deletions python/tests/egg_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import gym
import pytest

from griddly import GymWrapperFactory, gd
from griddly.util.environment_generator_generator import EnvironmentGeneratorGenerator


@pytest.fixture
def test_name(request):
return request.node.name

def build_generator(test_name, yaml_file):
wrapper_factory = GymWrapperFactory()
egg = EnvironmentGeneratorGenerator(yaml_file=yaml_file)
generator_yaml = egg.generate_env_yaml((10,10))

wrapper_factory.build_gym_from_yaml_string(
test_name,
yaml_string=generator_yaml,
global_observer_type=gd.ObserverType.ASCII,
player_observer_type=gd.ObserverType.ASCII,
)

env = gym.make(f'GDY-{test_name}-v0')
env.reset()
return env

def test_spider_nest_generator(test_name):

yaml_file = 'Single-Player/GVGAI/spider-nest.yaml'

for i in range(10):
genv = build_generator(test_name+f'{i}', yaml_file)

# Place 10 Random Objects
for i in range(0, 100):
action = genv.action_space.sample()
obs, reward, done, info = genv.step(action)

player_ascii_string = genv.render(observer=0)
global_ascii_string = genv.render(observer='global')

assert player_ascii_string == global_ascii_string
25 changes: 20 additions & 5 deletions src/Griddly/Core/GDY/Actions/Action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,35 @@ std::shared_ptr<Object> Action::getSourceObject() const {
if (sourceObject_ != nullptr) {
return sourceObject_;
} else {
return grid_->getObject(sourceLocation_);
auto srcObject = grid_->getObject(sourceLocation_);
if (srcObject != nullptr) {
return srcObject;
}

return grid_->getPlayerDefaultObject(playerId_);
}
}

std::shared_ptr<Object> Action::getDestinationObject() const {
switch (actionMode_) {
case ActionMode::SRC_LOC_DST_LOC:
case ActionMode::SRC_OBJ_DST_LOC:
return grid_->getObject(destinationLocation_);
case ActionMode::SRC_OBJ_DST_LOC: {
auto dstObject = grid_->getObject(destinationLocation_);
if (dstObject != nullptr) {
return dstObject;
}
return grid_->getPlayerDefaultObject(playerId_);
}
case ActionMode::SRC_OBJ_DST_OBJ:
return destinationObject_;
case ActionMode::SRC_OBJ_DST_VEC:
case ActionMode::SRC_OBJ_DST_VEC: {
auto destinationLocation = (getSourceLocation() + vectorToDest_);
return grid_->getObject(destinationLocation);
auto dstObject = grid_->getObject(destinationLocation);
if (dstObject != nullptr) {
return dstObject;
}
return grid_->getPlayerDefaultObject(playerId_);
}
}
}

Expand Down
Loading

0 comments on commit fdb1151

Please sign in to comment.