Skip to content

Commit

Permalink
fix: timestep extras default value (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Jan 8, 2024
1 parent e8d51e0 commit a695150
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import field
from typing import TYPE_CHECKING, Dict, Generic, Optional, Sequence, TypeVar, Union

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
Expand Down Expand Up @@ -71,14 +72,14 @@ class TimeStep(Generic[Observation]):
extras: environment metric(s) or information returned by the environment but
not observed by the agent (hence not in the observation). For example, it
could be whether an invalid action was taken. In most environments, extras
is None.
is an empty dictionary.
"""

step_type: StepType
reward: Array
discount: Array
observation: Observation
extras: Optional[Dict] = None
extras: Dict = field(default_factory=dict)

def first(self) -> Array:
return self.step_type == StepType.FIRST
Expand Down Expand Up @@ -110,6 +111,7 @@ def restart(
Returns:
TimeStep identified as a reset.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
Expand Down Expand Up @@ -144,6 +146,7 @@ def transition(
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
reward=reward,
Expand Down Expand Up @@ -175,6 +178,7 @@ def termination(
Returns:
TimeStep identified as the termination of an episode.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
reward=reward,
Expand Down Expand Up @@ -208,6 +212,7 @@ def truncation(
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
reward=reward,
Expand Down

0 comments on commit a695150

Please sign in to comment.