Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Action Normalization #57

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

cheng-chi
Copy link

@cheng-chi cheng-chi commented May 2, 2023

As per discussion with @snasiriany, this is my current implementation of action normalization which is required for diffusion policy integration. These code are not fully tested and is meant to be a starting point for discussions.

DO NOT MERGE

Copy link
Collaborator

@snasiriany snasiriany left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good! Left some comments, mainly about naming and doc strings.

Also, should we infer hdf5_normlize_action from the dataset, rather than manually specifying it in the config?

@@ -421,7 +421,7 @@ class RolloutPolicy(object):
"""
Wraps @Algo object to make it easy to run policies in a rollout loop.
"""
def __init__(self, policy, obs_normalization_stats=None):
def __init__(self, policy, obs_normalization_stats=None, action_normalization_stats=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some comments in the function docstring for action_normalization_stats? Similar to how it's already done for obs_normalization_stats

@@ -474,4 +475,7 @@ def __call__(self, ob, goal=None):
if goal is not None:
goal = self._prepare_observation(goal)
ac = self.policy.get_action(obs_dict=ob, goal_dict=goal)
return TensorUtils.to_numpy(ac[0])
ac = TensorUtils.to_numpy(ac)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for changing ac[0] to ac? Can we keep things as ac[0]?

@@ -156,6 +156,8 @@ class has a default implementation that usually doesn't need to be overriden.
# of each observation in each dimension, computed across the training set. See SequenceDataset.normalize_obs
# in utils/dataset.py for more information.
self.train.hdf5_normalize_obs = False

self.train.hdf5_normalize_action = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment to describe the use case (similar to rest of file)

@@ -30,6 +30,7 @@ def __init__(
hdf5_cache_mode=None,
hdf5_use_swmr=True,
hdf5_normalize_obs=False,
hdf5_normalize_action=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function docstring needs comment for this attribute

@@ -499,6 +499,17 @@ def normalize_obs(obs_dict, obs_normalization_stats):

return obs_dict

def normalize_actions(actions, action_normalization_stats):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nitpick: our convention here is to use " rather than '. can you make the style change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, both normalize_actions and unnormalize_actions need docstring

@@ -366,6 +372,99 @@ def get_obs_normalization_stats(self):
assert self.hdf5_normalize_obs, "not using observation normalization!"
return deepcopy(self.obs_normalization_stats)

def normalize_actions(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming of this function may be confused for normalize_actions in ObsUtils. How about renaming this to get_action_normalization_stats?

return obs_traj

ep = self.dataset.demos[0]
obs_traj = get_obs_traj(ep)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming of obs here might be confused for observations that we pass into the policy (eg. images). Can we replace this term to be more general? And all other places where we name things with obs in this function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants