Skip to content

Commit

Permalink
Merge pull request #337 from alexhernandezgarcia/extend_common_tests
Browse files Browse the repository at this point in the history
Extend common tests and env base
  • Loading branch information
carriepl authored Aug 22, 2024
2 parents 262cda6 + 65814b5 commit cbac6cc
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 10 deletions.
63 changes: 53 additions & 10 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Base class of GFlowNet environments
"""

import numbers
import uuid
from abc import abstractmethod
from copy import deepcopy
Expand Down Expand Up @@ -814,6 +815,8 @@ def copy(self):

@staticmethod
def equal(state_x, state_y):
if isinstance(state_x, numbers.Number) or isinstance(state_x, str):
return state_x == state_y
if torch.is_tensor(state_x) and torch.is_tensor(state_y):
# Check for nans because (torch.nan == torch.nan) == False
x_nan = torch.isnan(state_x)
Expand All @@ -823,11 +826,36 @@ def equal(state_x, state_y):
return False
return torch.equal(state_x[~x_nan], state_y[~y_nan])
return torch.equal(state_x, state_y)
else:
return state_x == state_y
if isinstance(state_x, dict) and isinstance(state_y, dict):
if len(state_x) != len(state_y):
return False
return all(
[
key_x == key_y and GFlowNetEnv.equal(value_x, value_y)
for (key_x, value_x), (key_y, value_y) in zip(
sorted(state_x.items()), sorted(state_y.items())
)
]
)
if (isinstance(state_x, list) and isinstance(state_y, list)) or (
isinstance(state_x, tuple) and isinstance(state_y, tuple)
):
if len(state_x) != len(state_y):
return False
if len(state_x) == 0:
return True
if isinstance(state_x[0], numbers.Number) or isinstance(state_x[0], str):
value_type = type(state_x[0])
if all([isinstance(sx, value_type) for sx in state_x]) and all(
[isinstance(sy, value_type) for sy in state_y]
):
return state_x == state_y
return all([GFlowNetEnv.equal(sx, sy) for sx, sy in zip(state_x, state_y)])

@staticmethod
def isclose(state_x, state_y, atol=1e-8):
if isinstance(state_x, numbers.Number) or isinstance(state_x, str):
return np.isclose(state_x, state_y, atol=atol)
if torch.is_tensor(state_x) and torch.is_tensor(state_y):
# Check for nans because (torch.nan == torch.nan) == False
x_nan = torch.isnan(state_x)
Expand All @@ -840,15 +868,30 @@ def isclose(state_x, state_y, atol=1e-8):
)
return torch.equal(state_x, state_y)
if isinstance(state_x, dict) and isinstance(state_y, dict):
keys_equal = set(state_x.keys()) == set(state_y.keys())
values_close = np.all(
np.isclose(
sorted(state_x.values()), sorted(state_y.values()), atol=atol
)
if len(state_x) != len(state_y):
return False
return all(
[
key_x == key_y and GFlowNetEnv.isclose(value_x, value_y)
for (key_x, value_x), (key_y, value_y) in zip(
sorted(state_x.items()), sorted(state_y.items())
)
]
)
return keys_equal and values_close
else:
return np.all(np.isclose(state_x, state_y, atol=atol))
if (isinstance(state_x, list) and isinstance(state_y, list)) or (
isinstance(state_x, tuple) and isinstance(state_y, tuple)
):
if len(state_x) != len(state_y):
return False
if len(state_x) == 0:
return True
if isinstance(state_x[0], numbers.Number) or isinstance(state_x[0], str):
value_type = type(state_x[0])
if all([isinstance(sx, value_type) for sx in state_x]) and all(
[isinstance(sy, value_type) for sy in state_y]
):
return np.all(np.isclose(state_x, state_y, atol=atol))
return all([GFlowNetEnv.isclose(sx, sy) for sx, sy in zip(state_x, state_y)])

def get_max_traj_length(self):
return 1e3
Expand Down
16 changes: 16 additions & 0 deletions tests/gflownet/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,22 @@ def test__get_parents_step_get_mask__are_compatible(self, n_repeat=1):
assert p_a in self.env.action_space
assert mask[self.env.action_space.index(p_a)] is False

def test__get_parents__all_parents_are_reached_with_different_actions(
self, n_repeat=1
):
if _get_current_method_name() in self.repeats:
n_repeat = self.repeats[_get_current_method_name()]

for _ in range(n_repeat):
self.env.reset()
while not self.env.done:
# Sample random action
state_next, action, valid = self.env.step_random()
if valid is False:
continue
_, parents_a = self.env.get_parents()
assert len(set(parents_a)) == len(parents_a)

def test__state2readable__is_reversible(self, n_repeat=1):
if _get_current_method_name() in self.repeats:
n_repeat = self.repeats[_get_current_method_name()]
Expand Down
249 changes: 249 additions & 0 deletions tests/gflownet/envs/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from collections import Counter

import pytest

from gflownet.envs.grid import Grid
from gflownet.envs.tetris import Tetris


@pytest.fixture
def grid():
return Grid()


@pytest.mark.parametrize(
"state_x, state_y, is_equal",
[
### Integers
(
0,
0,
True,
),
(
17,
17,
True,
),
(
17,
18,
False,
),
### Floats
(
0.0,
0.0,
True,
),
(
17.8,
17.8,
True,
),
(
17.0,
18.0,
False,
),
(
17.0,
18.0,
False,
),
### Lists
(
[],
[],
True,
),
(
[],
[0],
False,
),
(
[0],
[0],
True,
),
(
[0],
[1],
False,
),
(
[0, 1, -1],
[0, 1, -1],
True,
),
(
[0, 1, 1],
[0, 1, -1],
False,
),
(
[0, 1],
[0, 1, -1],
False,
),
(
[0.0, 1.0, -1.0],
[0.0, 1.0, -1.0],
True,
),
(
[0.0, 1.0, -1.0],
[0.0, 1.0, 1.0],
False,
),
(
["a", "b", -1, 1],
["a", "b", -1, 1],
True,
),
(
["a", "b", -1, 0],
["a", "b", -1, 1],
False,
),
### Lists of lists
(
[[0, 1], ["a", "b", -1, 0]],
[[0, 1], ["a", "b", -1, 0]],
True,
),
(
[[0, 1], ["a", "b", -1, 1]],
[[0, 1], ["a", "b", -1, 0]],
False,
),
(
[[0, 1], ["a", "b", -1, 0], 0.5],
[[0, 1], ["a", "b", -1, 0], 0.5],
True,
),
(
[[0, 1], ["a", "b", -1, 0], 0.5],
[[0, 1], ["a", "b", -1, 0], 1.5],
False,
),
### Dictionaries
(
{0: [1, 2, 3], 1: ["a", "b"]},
{0: [1, 2, 3], 1: ["a", "b"]},
True,
),
# Key is different
(
{0: [1, 2, 3], 1: ["a", "b"]},
{0: [1, 2, 3], 2: ["a", "b"]},
False,
),
# Value is different
(
{0: [1, 2, 3], 1: ["a", "b"]},
{0: [1, 2, 3], 1: ["a", "c"]},
False,
),
# Order of keys are different
(
{0: [1, 2, 3], 1: ["a", "b"]},
{1: ["a", "b"], 0: [1, 2, 3]},
True,
),
### Counters
(
Counter(),
Counter(),
True,
),
(
Counter({1: 1}),
Counter({1: 1}),
True,
),
(
Counter({1: 1}),
Counter({1: 2}),
False,
),
(
Counter({1: 1}),
Counter({2: 1}),
False,
),
(
Counter({1: 1}),
Counter(),
False,
),
### Tuples
(
(),
(),
True,
),
(
(1,),
(1,),
True,
),
(
(1,),
(2,),
False,
),
(
(1, 2),
(1, 2, 3),
False,
),
(
(1, [0, 1], "a", (2, 3)),
(1, [0, 1], "a", (2, 3)),
True,
),
(
(1, [0, 1], "a", (2, 3)),
(1, [0, 1], "b", (2, 3)),
False,
),
(
(1, [0, 1], "a", (2, 3)),
(1, [0, 1], "a", (2, 3, 1)),
False,
),
### List of Counter and tuple (Wyckomposition)
(
[Counter(), ()],
[Counter(), ()],
True,
),
(
[Counter({(1, 1): 2}), ()],
[Counter({(1, 1): 2}), ()],
True,
),
(
[Counter({(1, 1): 2}), ()],
[Counter({(1, 1): 3}), ()],
False,
),
(
[Counter({(1, 1): 2}), ()],
[Counter({(1, 1): 2}), (1, 2)],
False,
),
(
[Counter({(1, 1): 2}), ()],
[Counter(), ()],
False,
),
],
)
def test__equal__behaves_as_expected(grid, state_x, state_y, is_equal):
# The grid is use as a generic environment. Note that the values compared are not
# grid states, but it does not matter for the purposes of this test.
env = grid
assert env.equal(state_x, state_y) == is_equal
1 change: 1 addition & 0 deletions tests/gflownet/envs/test_scrabble.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,6 @@ def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
"test__get_parents__all_parents_are_reached_with_different_actions": 10,
}
self.n_states = {} # TODO: Populate.
1 change: 1 addition & 0 deletions tests/gflownet/envs/test_tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def setup(self, env_full):
self.env = env_full
self.repeats = {
"test__reset__state_is_source": 10,
"test__get_parents__all_parents_are_reached_with_different_actions": 10,
"test__gflownet_minimal_runs": 0,
}
self.n_states = {} # TODO: Populate.
Expand Down

0 comments on commit cbac6cc

Please sign in to comment.