Skip to content

Commit

Permalink
chore: suggestions from review
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Bonnet <[email protected]>
Co-authored-by: RuanJohn <[email protected]>
  • Loading branch information
3 people authored Mar 13, 2024
1 parent 9e2c6b0 commit 95389a6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ problems.
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| ✏️ Sudoku | Logic | `Sudoku-v0` <br/>`Sudoku-very-easy-v0`| [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🧩 FlatPack (2D Grid filling problem) | Packing | `FlatPack-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/flat_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/flat_pack/) |
| 🧩 FlatPack (2D Grid Filling Problem) | Packing | `FlatPack-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/flat_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/flat_pack/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| ▒ Tetris | Packing | `Tetris-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/tetris/) | [doc](https://instadeepai.github.io/jumanji/environments/tetris/) |
Expand Down
2 changes: 1 addition & 1 deletion docs/environments/flat_pack.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ all blocks that can be placed.


## Action
The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_blocks`,
The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_blocks - 1`,
an index between 0 and 4 (since there are 4 possible rotations), an index between 0 and `num_rows-2`
(the possible row coordinates for placing a block) and an index between 0 and `num_cols-2`
(the possible column coordinates for placing a block). An action thus consists of four pieces of
Expand Down
29 changes: 15 additions & 14 deletions jumanji/environments/packing/flat_pack/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ class FlatPack(Environment[State]):

"""The FlatPack environment with a configurable number of row and column blocks.
Here the goal of an agent is to completely fill an empty grid by placing all
available blocks. It can be thought of as 2D version of the `BinPack`
available blocks. It can be thought of as a discrete 2D version of the `BinPack`
environment.
- observation: `Observation`
- grid: jax array (float) of shape (num_rows, num_cols) with the
current state of the grid.
- blocks: jax array (float) of shape (num_blocks, 3, 3) with the blocks to
be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (float) showing where which blocks can be placed on the grid.
- action_mask: jax array (bool) showing where which blocks can be placed on the grid.
this mask includes all possible rotations and possible placement locations
for each block on the grid.
Expand All @@ -69,26 +69,26 @@ class FlatPack(Environment[State]):
- if the agent has taken `num_blocks` steps in the environment.
- state: `State`
- num_blocks: jax array (float) of shape () with the
- num_blocks: jax array (int32) of shape () with the
number of blocks in the environment.
- blocks: jax array (float) of shape (num_blocks, 3, 3) with the blocks to
- blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to
be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (float) showing where which blocks can be placed on the grid.
- action_mask: jax array (bool) showing where which blocks can be placed on the grid.
this mask includes all possible rotations and possible placement locations
for each block on the grid.
- placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks
have been placed on the grid.
- grid: jax array (float) of shape (num_rows, num_cols) with the
- grid: jax array (int32) of shape (num_rows, num_cols) with the
current state of the grid.
- step_count: jax array (float) of shape () with the number of steps taken
- step_count: jax array (int32) of shape () with the number of steps taken
in the environment.
- key: jax array (float) of shape (2,) with the random key used for board
- key: jax array of shape (2,) with the random key used for board
generation.
```python
from jumanji.environments import FlatPack
env = FlatPack()
key = jax.random.key(0)
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec().generate_value()
Expand All @@ -103,11 +103,12 @@ def __init__(
reward_fn: Optional[RewardFn] = None,
viewer: Optional[Viewer[State]] = None,
):
"""Initializes the environment.
"""Initializes the FlatPack environment.
Args:
generator: Instance generator for the environment.
reward_fn: Reward function for the environment.
generator: Instance generator for the environment, default to `RandomFlatPackGenerator`
with a grid of 5 blocks per row and column.
reward_fn: Reward function for the environment, default to `CellDenseReward`.
viewer: Viewer for rendering the environment.
"""

Expand Down Expand Up @@ -275,15 +276,15 @@ def observation_spec(self) -> specs.Spec[Observation]:
shape=(self.num_rows, self.num_cols),
minimum=0,
maximum=self.num_blocks,
dtype=jnp.float32,
dtype=jnp.int32,
name="grid",
)

blocks = specs.BoundedArray(
shape=(self.num_blocks, 3, 3),
minimum=0,
maximum=self.num_blocks,
dtype=jnp.float32,
dtype=jnp.int32,
name="blocks",
)

Expand Down
22 changes: 11 additions & 11 deletions jumanji/environments/packing/flat_pack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,23 +356,23 @@ def __call__(self, key: chex.PRNGKey) -> State:

solved_grid = jnp.array(
[
[1.0, 1.0, 1.0, 2.0, 2.0],
[1.0, 1.0, 2.0, 2.0, 2.0],
[3.0, 1.0, 4.0, 4.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 4.0],
[3.0, 3.0, 3.0, 4.0, 4.0],
[1, 1, 1, 2, 2],
[1, 1, 2, 2, 2],
[3, 1, 4, 4, 2],
[3, 3, 4, 4, 4],
[3, 3, 3, 4, 4],
],
dtype=jnp.float32,
dtype=jnp.int32,
)

blocks = jnp.array(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 0.0], [0.0, 1.0, 0.0]],
[[0.0, 2.0, 2.0], [2.0, 2.0, 2.0], [0.0, 0.0, 2.0]],
[[3.0, 0.0, 0.0], [3.0, 3.0, 0.0], [3.0, 3.0, 3.0]],
[[4.0, 4.0, 0.0], [4.0, 4.0, 4.0], [0.0, 4.0, 4.0]],
[[1, 1, 1], [1, 1, 0], [0, 1, 0]],
[[0, 2, 2], [2, 2, 2], [0, 0, 2]],
[[3, 0, 0], [3, 3, 0], [3, 3, 3]],
[[4, 4, 0], [4, 4, 4], [0, 4, 4]],
],
dtype=jnp.float32,
dtype=jnp.int32,
)

return State(
Expand Down

0 comments on commit 95389a6

Please sign in to comment.