Skip to content

Commit

Permalink
fix bug in logging timers (#189)
Browse files Browse the repository at this point in the history
This PR:

* Adds an overrideable name to the memory TableMemoryProxy
* Exposes that name as a property
* Extends the proxy base to allowed reading properties from the inner
classes
* Uses the above property for logging
  • Loading branch information
tgolsson authored Nov 10, 2023
1 parent daa87ec commit 1bf9da7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
27 changes: 24 additions & 3 deletions emote/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(
table: Table,
minimum_length_threshold: Optional[int] = None,
use_terminal: bool = False,
*,
name: str = "default",
):
self._store: Dict[AgentId, Episode] = {}
self._table = table
Expand All @@ -80,6 +82,11 @@ def __init__(
self._completed_episodes: set[AgentId] = set()
self._term_states = [EpisodeState.TERMINAL, EpisodeState.INTERRUPTED]
self._use_terminal = use_terminal
self._name = name

@property
def name(self):
return self._name

def size(self):
return self._table.size()
Expand Down Expand Up @@ -168,14 +175,28 @@ def __init__(self, inner: "MemoryProxyWrapper" | MemoryProxy, **kwargs):
super().__init__(**kwargs)
self._inner = inner

def _lookup_class_attr(self, name):
cls_attr = getattr(self._inner.__class__, name, None)
if cls_attr is None:
if isinstance(self._inner, MemoryProxyWrapper):
return self._inner._lookup_class_attr(name)

return None

return cls_attr

def __getattr__(self, name):
# get the attribute from inner.
# if it does not exist, exception will be raised.
#
# we look up the class attr to check if it is a property. Properties on the instance only
# resolve to the value, which would be string for example.
cls_attr = self._lookup_class_attr(name)
attr = getattr(self._inner, name)

# for some safety, make sure it is an method.
# we only want the memory proxy wrapper to forward methods.
if not inspect.ismethod(attr):
if not inspect.ismethod(attr) and not isinstance(cls_attr, property):
# NOTE: In python >= 3.10 we should specify
# 'obj' and 'name' on the AttributeError so Python can provide hints to the user.
raise AttributeError(
Expand Down Expand Up @@ -303,8 +324,8 @@ def _end_cycle(self):
self.log_scalar("episode/completed", self.completed_episodes)

for name, (mean, var) in self.timers().stats().items():
self.log_scalar(f"memory/{self._target_memory_name}/{name}/timing/mean", mean)
self.log_scalar(f"memory/{self._target_memory_name}/{name}/timing/var", var)
self.log_scalar(f"memory/{self.name}/{name}/timing/mean", mean)
self.log_scalar(f"memory/{self.name}/{name}/timing/var", var)

if "episode/reward" in self.windowed_scalar:
rewards = self.windowed_scalar["episode/reward"]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_memory_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,31 @@ def test_get_report(table_proxy, tmpdir):
assert out["histogram:ones"] == 1
assert out["one"] == 1 and out["one/cumulative"] == 3
assert out_lists["three"] == [3, 3]


def test_end_cycle(table_proxy, tmpdir):
proxy = LoggingProxyWrapper(
table_proxy,
SummaryWriter(
log_dir=tmpdir,
),
2,
)

state = EpisodeState.INITIAL
for s in range(10):
proxy.add(
{
0: DictObservation(
episode_state=state,
array_data={"obs": [1.0]},
rewards={"reward": None},
metadata=MetaData(info={"episode/reward": 10.0}, info_lists={}),
)
},
{0: DictResponse({"actions": [0.0]}, {})} if s < 9 else {},
)

state = EpisodeState.RUNNING if s < 8 else EpisodeState.TERMINAL

proxy._end_cycle()
19 changes: 19 additions & 0 deletions tests/test_memory_proxy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def __init__(self):
def say_hello(self):
return "hello world"

@property
def a_property(self):
return "a property"


class EmptyMemoryProxyWrapper(MemoryProxyWrapper):
pass
Expand Down Expand Up @@ -77,3 +81,18 @@ def test_wrapper_disallows_accessing_non_existing_attribute():

with pytest.raises(AttributeError):
wrapper.i_do_not_exist


def test_wrapper_allows_accessing_property():
dummy = DummyMemoryProxy()
wrapper = EmptyMemoryProxyWrapper(dummy)

wrapper.a_property


def test_wrapper_allows_accessing_property_nested():
dummy = DummyMemoryProxy()
wrapper = EmptyMemoryProxyWrapper(dummy)
wrapper = SayGoodbyeProxyWrapper(wrapper)

wrapper.a_property

0 comments on commit 1bf9da7

Please sign in to comment.