Skip to content

Commit

Permalink
Add drowing lines, change color to green
Browse files Browse the repository at this point in the history
  • Loading branch information
tomekster committed Jan 18, 2024
1 parent f5e87cf commit 8c842e0
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 4 deletions.
Binary file added mo_gymnasium/envs/fruit_tree/assets/agent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mo_gymnasium/envs/fruit_tree/assets/node.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 122 additions & 4 deletions mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle
import pygame
from os import path
from typing import Optional


FRUITS = {
Expand Down Expand Up @@ -264,17 +267,21 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""

def __init__(self, depth=6):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
window_padding = 15
font_size = 12

def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)

self.render_mode = render_mode
self.reward_dim = 6
self.tree_depth = depth # zero based depth
branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim))
fruits = np.array(FRUITS[str(depth)])
# fruits = np.random.randn(2**self.tree_depth, self.reward_dim)
# fruits = np.abs(fruits) / np.linalg.norm(fruits, 2, 1, True)
# print(fruits*10)
fruits = np.array(FRUITS[str(depth)])
self.tree = np.concatenate([branches, fruits])

self.max_reward = 10.0
Expand All @@ -288,8 +295,24 @@ def __init__(self, depth=6):
self.current_state = np.array([0, 0], dtype=np.int32)
self.terminal = False

# pygame
self.row_height = 20
self.window_size = (1200, self.row_height * self.tree_depth + 150)
self.pix_square_size = np.array([10, 10], dtype=np.int32)


self.window = None
self.clock = None
self.node_img = None
self.agent_img = None

def get_ind(self, pos):
return int(2 ** pos[0] - 1) + pos[1]

def ind_to_state(self, ind):
x = int(np.log2(ind + 1))
y = ind - 2 ** x + 1
return np.array([x, y], dtype=np.int32)

def get_tree_value(self, pos):
return np.array(self.tree[self.get_ind(pos)], dtype=np.float32)
Expand Down Expand Up @@ -325,5 +348,100 @@ def step(self, action):
reward = self.get_tree_value(self.current_state)
if self.current_state[0] == self.tree_depth:
self.terminal = True

return self.current_state.copy(), reward, self.terminal, False, {}

def get_pos_in_window(self, row, index_in_row):
pos_x = self.window_padding + (index_in_row + 0.5) * (self.window_size[0] - 2 * self.window_padding) / (2 ** (row))
pos_y = row * self.row_height
return np.array([pos_x, pos_y])

def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You can specify the render_mode at initialization, "
f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")'
)
return

if self.window is None:
pygame.init()

if self.render_mode == "human":
pygame.display.init()
pygame.display.set_caption("Fruit Tree")
self.window = pygame.display.set_mode(self.window_size)
else:
self.window = pygame.Surface(self.window_size)

if self.clock is None:
self.clock = pygame.time.Clock()

if self.node_img is None:
filename = path.join(path.dirname(__file__), "assets", "node.png")
self.node_img = pygame.transform.scale(pygame.image.load(filename), self.pix_square_size)
self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False)
if self.agent_img is None:
filename = path.join(path.dirname(__file__), "assets", "agent.png")
self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.pix_square_size)

# self.font = pygame.font.Font(path.join(path.dirname(__file__), "assets", "Minecraft.ttf"), 20)
self.font = pygame.font.Font(None, self.font_size)

canvas = pygame.Surface(self.window_size)
canvas.fill((0, 0, 0))

self.window.blit(canvas, (0, 0))

for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)

node_pos = self.get_pos_in_window(row, index_in_row)

# Get childerns' positions for drawing branches
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)

half_square = self.pix_square_size / 2

if (row, index_in_row) == tuple(self.current_state):
img = self.agent_img
font_color = (255, 0, 0) # Red digits for agent node
else:
img = self.node_img
font_color = (0, 255, 0) # Green digits for non-agent nodes

if row != self.tree_depth:
pygame.draw.line(self.window, (255,255,255), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(self.window, (255,255,255), node_pos + half_square, child2_pos + half_square, 1)


self.window.blit(img, np.array(node_pos))
if row == self.tree_depth:
values_imgs = [self.font.render(f'{val:.2f}', True, font_color) for val in node]
for i, val_img in enumerate(values_imgs):
self.window.blit(val_img, node_pos + np.array([- 5, (i+1)*self.font_size]))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return np.transpose(np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2))

if __name__ == "__main__":
import mo_gymnasium as mo_gym
import time

env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
terminated = False
env.reset()
while True:
env.render()
obs, r, terminated, truncated, info = env.step(env.action_space.sample())
if terminated or truncated:
env.render()
time.sleep(2)
env.reset()

0 comments on commit 8c842e0

Please sign in to comment.