From e46a6d5edab8ee467b733385451f86fe06d06179 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 2 Aug 2024 10:46:59 +0200 Subject: [PATCH] Cleanup --- docs/misc/changelog.rst | 4 +- stable_baselines3/common/logger.py | 13 +----- stable_baselines3/version.txt | 2 +- tests/test_logger.py | 74 +++++++++++++----------------- 4 files changed, 38 insertions(+), 55 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e5383c0ed..9c461f6ae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a7 (WIP) +Release 2.4.0a8 (WIP) -------------------------- .. note:: @@ -19,6 +19,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) +- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) Bug Fixes: ^^^^^^^^^^ @@ -282,7 +283,6 @@ New Features: ^^^^^^^^^^^^^ - Added Python 3.11 support - Added Gymnasium 0.29 support (@pseudo-rnd-thoughts) -- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 35fea945f..8ceda71ed 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -413,17 +413,8 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, . self.writer.add_scalar(key, value, step) if isinstance(value, (th.Tensor, np.ndarray)): - try: - self.writer.add_histogram(key, value, step) - except TypeError: - warnings.warn( - "A numpy.ndarray was passed to write which threw a " - "TypeError. This is most likely due to an outdated numpy version (<1.24.0) and/or " - "an outdated torch version (<2.0.0). The ndarray will be converted to a torch.Tensor " - "as a workaround. For more information, " - "see https://github.com/DLR-RM/stable-baselines3/pull/1635" - ) - self.writer.add_histogram(key, th.from_numpy(value), step) + # Convert to Torch so it works with numpy<1.24 and torch<2.0 + self.writer.add_histogram(key, th.as_tensor(value), step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f5230e413..ee717ba15 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a7 +2.4.0a8 diff --git a/tests/test_logger.py b/tests/test_logger.py index 90ae0c9f4..bc18bf2ce 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,4 +1,3 @@ -import functools import importlib.util import os import sys @@ -45,6 +44,7 @@ "f": np.array(1), "g": np.array([[[1]]]), "h": 'this ", ;is a \n tes:,t', + "i": th.ones(3), } KEY_EXCLUDED = {} @@ -177,6 +177,9 @@ def test_main(tmp_path): logger.record_mean("b", -22.5) logger.record_mean("b", -44.4) logger.record("a", 5.5) + # Converted to string: + logger.record("hist1", th.ones(2)) + logger.record("hist2", np.ones(2)) logger.dump() logger.record("a", "longasslongasslongasslongasslongasslongassvalue") @@ -242,7 +245,7 @@ def is_moviepy_installed(): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_video_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -252,52 +255,41 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f writer.close() -_called = None - - -def get_fail_first_then_pass_fn(fn, exception=Exception): - _called = False - - @functools.wraps(fn) - def _fn(*args, **kwargs): - global _called - if not _called: - _called = True - raise exception() - return fn(*args, **kwargs) - - return _fn - - @pytest.mark.parametrize( - "histogram,fail_first_write", + "histogram", [ - (th.rand(100), False), - (np.random.rand(100), False), - (np.random.rand(100), True), + th.rand(100), + np.random.rand(100), + np.ones(1), + np.ones(1, dtype="int"), ], ) -def test_report_histogram_to_tensorboard(tmp_path, read_log, fail_first_write, histogram): +def test_log_histogram(tmp_path, read_log, histogram): pytest.importorskip("tensorboard") writer = make_output_format("tensorboard", tmp_path) - - if fail_first_write: - writer.writer.add_histogram = get_fail_first_then_pass_fn(writer.writer.add_histogram, TypeError) - writer.write({"data": histogram}, key_excluded={"data": ()}) log = read_log("tensorboard") assert not log.empty - assert any("data" in f for f in log.lines) - assert any("Histogram" in f for f in log.lines) + assert any("data" in line for line in log.lines) + assert any("Histogram" in line for line in log.lines) writer.close() -@pytest.mark.parametrize("histogram", [list(np.random.rand(100)), tuple(np.random.rand(100)), "1 2 3 4"]) -def test_report_unsupported_type_as_histogram_to_tensorboard(tmp_path, read_log, histogram): +@pytest.mark.parametrize( + "histogram", + [ + list(np.random.rand(100)), + tuple(np.random.rand(100)), + "1 2 3 4", + np.ones(1).item(), + th.ones(1).item(), + ], +) +def test_unsupported_type_histogram(tmp_path, read_log, histogram): """ Check that other types aren't accidentally logged as a Histogram """ @@ -306,7 +298,7 @@ def test_report_unsupported_type_as_histogram_to_tensorboard(tmp_path, read_log, writer = make_output_format("tensorboard", tmp_path) writer.write({"data": histogram}, key_excluded={"data": ()}) - assert all("Histogram" not in f for f in read_log("tensorboard").lines) + assert all("Histogram" not in line for line in read_log("tensorboard").lines) writer.close() @@ -323,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_image_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -347,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_figure_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -360,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_ @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_hparam(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -479,9 +471,9 @@ def test_fps_no_div_zero(algo): model.learn(total_timesteps=100) -def test_human_output_format_no_crash_on_same_keys_different_tags(): - o = HumanOutputFormat(sys.stdout, max_length=60) - o.write( +def test_human_output_same_keys_different_tags(): + human_out = HumanOutputFormat(sys.stdout, max_length=60) + human_out.write( {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"}, {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None}, ) @@ -499,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size): @pytest.mark.parametrize("base_class", [object, TextIOBase]) -def test_human_output_format_custom_test_io(base_class): +def test_human_out_custom_text_io(base_class): class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() @@ -591,7 +583,7 @@ def step(self, action): return self.observation_space.sample(), 0.0, False, truncated, info -def test_rollout_success_rate_on_policy_algorithm(tmp_path): +def test_rollout_success_rate_onpolicy_algo(tmp_path): """ Test if the rollout/success_rate information is correctly logged with on policy algorithms