Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add np.ndarray as a recognized type for TB histograms. #1635

Merged
merged 15 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
14 changes: 12 additions & 2 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,18 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, .
else:
self.writer.add_scalar(key, value, step)

if isinstance(value, th.Tensor):
self.writer.add_histogram(key, value, step)
if isinstance(value, (th.Tensor, np.ndarray)):
try:
self.writer.add_histogram(key, value, step)
except TypeError:
warnings.warn(
"A numpy.ndarray was passed to write which threw a "
"TypeError. This is most likely due to an outdated numpy version (<1.24.0) and/or "
"an outdated torch version (<2.0.0). The ndarray will be converted to a torch.Tensor "
"as a workaround. For more information, "
"see https://github.com/DLR-RM/stable-baselines3/pull/1635"
)
self.writer.add_histogram(key, th.from_numpy(value), step)

if isinstance(value, Video):
self.writer.add_video(key, value.frames, step, value.fps)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import importlib.util
import os
import sys
Expand Down Expand Up @@ -250,6 +251,62 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f
writer.close()


_called = None


def get_fail_first_then_pass_fn(fn, exception=Exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should not need that, just record the warnings with pytest and check that the correct warning is there (we have some examples in the tests)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I guess with the current CI setup this warning is always hit. However, this test will then fail for anyone with newer versions of np and/pr torch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to check the version of pytorch to know if a warning should be outputted or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So would you propose a check before the add_histogram to see if a warning and conversion is needed?

_called = False

@functools.wraps(fn)
def _fn(*args, **kwargs):
global _called
if not _called:
_called = True
raise exception()
return fn(*args, **kwargs)

return _fn


@pytest.mark.parametrize(
"histogram,fail_first_write",
[
(th.rand(100), False),
(np.random.rand(100), False),
(np.random.rand(100), True),
],
)
def test_report_histogram_to_tensorboard(tmp_path, read_log, fail_first_write, histogram):
pytest.importorskip("tensorboard")

writer = make_output_format("tensorboard", tmp_path)

if fail_first_write:
writer.writer.add_histogram = get_fail_first_then_pass_fn(writer.writer.add_histogram, TypeError)

writer.write({"data": histogram}, key_excluded={"data": ()})

log = read_log("tensorboard")

assert not log.empty
assert any("data" in f for f in log.lines)
assert any("Histogram" in f for f in log.lines)

writer.close()


@pytest.mark.parametrize("histogram", [list(np.random.rand(100)), tuple(np.random.rand(100)), "1 2 3 4"])
def test_report_unsupported_type_as_histogram_to_tensorboard(tmp_path, read_log, histogram):
pytest.importorskip("tensorboard")

writer = make_output_format("tensorboard", tmp_path)
writer.write({"data": histogram}, key_excluded={"data": ()})
araffin marked this conversation as resolved.
Show resolved Hide resolved

assert all("Histogram" not in f for f in read_log("tensorboard").lines)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment, something like "check that the values were not logged as histogram"
(I'm not sure if all of them are logged btw)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See 383ee76


writer.close()


def test_report_image_to_tensorboard(tmp_path, read_log):
pytest.importorskip("tensorboard")

Expand Down