Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Aug 2, 2024
1 parent 880bd2d commit e46a6d5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 55 deletions.
4 changes: 2 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a7 (WIP)
Release 2.4.0a8 (WIP)
--------------------------

.. note::
Expand All @@ -19,6 +19,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -282,7 +283,6 @@ New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
13 changes: 2 additions & 11 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,8 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, .
self.writer.add_scalar(key, value, step)

if isinstance(value, (th.Tensor, np.ndarray)):
try:
self.writer.add_histogram(key, value, step)
except TypeError:
warnings.warn(
"A numpy.ndarray was passed to write which threw a "
"TypeError. This is most likely due to an outdated numpy version (<1.24.0) and/or "
"an outdated torch version (<2.0.0). The ndarray will be converted to a torch.Tensor "
"as a workaround. For more information, "
"see https://github.com/DLR-RM/stable-baselines3/pull/1635"
)
self.writer.add_histogram(key, th.from_numpy(value), step)
# Convert to Torch so it works with numpy<1.24 and torch<2.0
self.writer.add_histogram(key, th.as_tensor(value), step)

if isinstance(value, Video):
self.writer.add_video(key, value.frames, step, value.fps)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a7
2.4.0a8
74 changes: 33 additions & 41 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import importlib.util
import os
import sys
Expand Down Expand Up @@ -45,6 +44,7 @@
"f": np.array(1),
"g": np.array([[[1]]]),
"h": 'this ", ;is a \n tes:,t',
"i": th.ones(3),
}

KEY_EXCLUDED = {}
Expand Down Expand Up @@ -177,6 +177,9 @@ def test_main(tmp_path):
logger.record_mean("b", -22.5)
logger.record_mean("b", -44.4)
logger.record("a", 5.5)
# Converted to string:
logger.record("hist1", th.ones(2))
logger.record("hist2", np.ones(2))
logger.dump()

logger.record("a", "longasslongasslongasslongasslongasslongassvalue")
Expand Down Expand Up @@ -242,7 +245,7 @@ def is_moviepy_installed():


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_video_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -252,52 +255,41 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f
writer.close()


_called = None


def get_fail_first_then_pass_fn(fn, exception=Exception):
_called = False

@functools.wraps(fn)
def _fn(*args, **kwargs):
global _called
if not _called:
_called = True
raise exception()
return fn(*args, **kwargs)

return _fn


@pytest.mark.parametrize(
"histogram,fail_first_write",
"histogram",
[
(th.rand(100), False),
(np.random.rand(100), False),
(np.random.rand(100), True),
th.rand(100),
np.random.rand(100),
np.ones(1),
np.ones(1, dtype="int"),
],
)
def test_report_histogram_to_tensorboard(tmp_path, read_log, fail_first_write, histogram):
def test_log_histogram(tmp_path, read_log, histogram):
pytest.importorskip("tensorboard")

writer = make_output_format("tensorboard", tmp_path)

if fail_first_write:
writer.writer.add_histogram = get_fail_first_then_pass_fn(writer.writer.add_histogram, TypeError)

writer.write({"data": histogram}, key_excluded={"data": ()})

log = read_log("tensorboard")

assert not log.empty
assert any("data" in f for f in log.lines)
assert any("Histogram" in f for f in log.lines)
assert any("data" in line for line in log.lines)
assert any("Histogram" in line for line in log.lines)

writer.close()


@pytest.mark.parametrize("histogram", [list(np.random.rand(100)), tuple(np.random.rand(100)), "1 2 3 4"])
def test_report_unsupported_type_as_histogram_to_tensorboard(tmp_path, read_log, histogram):
@pytest.mark.parametrize(
"histogram",
[
list(np.random.rand(100)),
tuple(np.random.rand(100)),
"1 2 3 4",
np.ones(1).item(),
th.ones(1).item(),
],
)
def test_unsupported_type_histogram(tmp_path, read_log, histogram):
"""
Check that other types aren't accidentally logged as a Histogram
"""
Expand All @@ -306,7 +298,7 @@ def test_report_unsupported_type_as_histogram_to_tensorboard(tmp_path, read_log,
writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})

assert all("Histogram" not in f for f in read_log("tensorboard").lines)
assert all("Histogram" not in line for line in read_log("tensorboard").lines)

writer.close()

Expand All @@ -323,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log):


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_image_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -347,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log):


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_figure_format(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand All @@ -360,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format):
def test_unsupported_hparam(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
Expand Down Expand Up @@ -479,9 +471,9 @@ def test_fps_no_div_zero(algo):
model.learn(total_timesteps=100)


def test_human_output_format_no_crash_on_same_keys_different_tags():
o = HumanOutputFormat(sys.stdout, max_length=60)
o.write(
def test_human_output_same_keys_different_tags():
human_out = HumanOutputFormat(sys.stdout, max_length=60)
human_out.write(
{"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"},
{"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None},
)
Expand All @@ -499,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):


@pytest.mark.parametrize("base_class", [object, TextIOBase])
def test_human_output_format_custom_test_io(base_class):
def test_human_out_custom_text_io(base_class):
class DummyTextIO(base_class):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -591,7 +583,7 @@ def step(self, action):
return self.observation_space.sample(), 0.0, False, truncated, info


def test_rollout_success_rate_on_policy_algorithm(tmp_path):
def test_rollout_success_rate_onpolicy_algo(tmp_path):
"""
Test if the rollout/success_rate information is correctly logged with on policy algorithms
Expand Down

0 comments on commit e46a6d5

Please sign in to comment.