-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Action Normalization #57
base: master
Are you sure you want to change the base?
Changes from 2 commits
abbee68
8662524
ae4bd7d
ed5e8c6
9a92b36
a2c2e47
c169db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -421,7 +421,7 @@ class RolloutPolicy(object): | |
""" | ||
Wraps @Algo object to make it easy to run policies in a rollout loop. | ||
""" | ||
def __init__(self, policy, obs_normalization_stats=None): | ||
def __init__(self, policy, obs_normalization_stats=None, action_normalization_stats=None): | ||
""" | ||
Args: | ||
policy (Algo instance): @Algo object to wrap to prepare for rollouts | ||
|
@@ -433,6 +433,7 @@ def __init__(self, policy, obs_normalization_stats=None): | |
""" | ||
self.policy = policy | ||
self.obs_normalization_stats = obs_normalization_stats | ||
self.action_normalization_stats = action_normalization_stats | ||
|
||
def start_episode(self): | ||
""" | ||
|
@@ -474,4 +475,7 @@ def __call__(self, ob, goal=None): | |
if goal is not None: | ||
goal = self._prepare_observation(goal) | ||
ac = self.policy.get_action(obs_dict=ob, goal_dict=goal) | ||
return TensorUtils.to_numpy(ac[0]) | ||
ac = TensorUtils.to_numpy(ac) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason for changing |
||
if self.action_normalization_stats is not None: | ||
ac = ObsUtils.unnormalize_actions(ac, self.action_normalization_stats) | ||
return ac |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,6 +156,8 @@ class has a default implementation that usually doesn't need to be overriden. | |
# of each observation in each dimension, computed across the training set. See SequenceDataset.normalize_obs | ||
# in utils/dataset.py for more information. | ||
self.train.hdf5_normalize_obs = False | ||
|
||
self.train.hdf5_normalize_action = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a comment to describe the use case (similar to rest of file) |
||
|
||
# if provided, use the list of demo keys under the hdf5 group "mask/@hdf5_filter_key" for training, instead | ||
# of the full dataset. This provides a convenient way to train on only a subset of the trajectories in a dataset. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ def __init__( | |
hdf5_cache_mode=None, | ||
hdf5_use_swmr=True, | ||
hdf5_normalize_obs=False, | ||
hdf5_normalize_action=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function docstring needs comment for this attribute |
||
filter_by_attribute=None, | ||
load_next_obs=True, | ||
): | ||
|
@@ -85,6 +86,7 @@ def __init__( | |
self.hdf5_path = os.path.expanduser(hdf5_path) | ||
self.hdf5_use_swmr = hdf5_use_swmr | ||
self.hdf5_normalize_obs = hdf5_normalize_obs | ||
self.hdf5_normalize_action = hdf5_normalize_action | ||
self._hdf5_file = None | ||
|
||
assert hdf5_cache_mode in ["all", "low_dim", None] | ||
|
@@ -119,6 +121,10 @@ def __init__( | |
self.obs_normalization_stats = None | ||
if self.hdf5_normalize_obs: | ||
self.obs_normalization_stats = self.normalize_obs() | ||
|
||
self.action_normalization_stats = None | ||
if self.hdf5_normalize_action: | ||
self.action_normalization_stats = self.normalize_actions() | ||
|
||
# maybe store dataset in memory for fast access | ||
if self.hdf5_cache_mode in ["all", "low_dim"]: | ||
|
@@ -366,6 +372,99 @@ def get_obs_normalization_stats(self): | |
assert self.hdf5_normalize_obs, "not using observation normalization!" | ||
return deepcopy(self.obs_normalization_stats) | ||
|
||
def normalize_actions(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naming of this function may be confused for |
||
""" | ||
Computes a dataset-wide min, max, mean and standard deviation for the actions | ||
(per dimension) and returns it. | ||
""" | ||
def _compute_traj_stats(traj_obs_dict): | ||
""" | ||
Helper function to compute statistics over a single trajectory of observations. | ||
""" | ||
traj_stats = { k : {} for k in traj_obs_dict } | ||
for k in traj_obs_dict: | ||
traj_stats[k]["n"] = traj_obs_dict[k].shape[0] | ||
traj_stats[k]["mean"] = traj_obs_dict[k].mean(axis=0, keepdims=True) # [1, ...] | ||
traj_stats[k]["sqdiff"] = ((traj_obs_dict[k] - traj_stats[k]["mean"]) ** 2).sum(axis=0, keepdims=True) # [1, ...] | ||
traj_stats[k]["min"] = traj_obs_dict[k].min(axis=0, keepdims=True) | ||
traj_stats[k]["max"] = traj_obs_dict[k].max(axis=0, keepdims=True) | ||
return traj_stats | ||
|
||
def _aggregate_traj_stats(traj_stats_a, traj_stats_b): | ||
""" | ||
Helper function to aggregate trajectory statistics. | ||
See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
for more information. | ||
""" | ||
merged_stats = {} | ||
for k in traj_stats_a: | ||
n_a, avg_a, M2_a, min_a, max_a = traj_stats_a[k]["n"], traj_stats_a[k]["mean"], traj_stats_a[k]["sqdiff"], traj_stats_a[k]["min"], traj_stats_a[k]["max"] | ||
n_b, avg_b, M2_b, min_b, max_b = traj_stats_b[k]["n"], traj_stats_b[k]["mean"], traj_stats_b[k]["sqdiff"], traj_stats_b[k]["min"], traj_stats_b[k]["max"] | ||
n = n_a + n_b | ||
mean = (n_a * avg_a + n_b * avg_b) / n | ||
delta = (avg_b - avg_a) | ||
M2 = M2_a + M2_b + (delta ** 2) * (n_a * n_b) / n | ||
min_ = np.minimum(min_a, min_b) | ||
max_ = np.maximum(max_a, max_b) | ||
merged_stats[k] = dict(n=n, mean=mean, sqdiff=M2, min=min_, max=max_) | ||
return merged_stats | ||
|
||
# Run through all trajectories. For each one, compute minimal observation statistics, and then aggregate | ||
# with the previous statistics. | ||
def get_obs_traj(ep): | ||
obs_traj = dict() | ||
obs_traj['actions'] = self.dataset.hdf5_file["data/{}/actions".format(ep)][()].astype('float32') | ||
return obs_traj | ||
|
||
ep = self.dataset.demos[0] | ||
obs_traj = get_obs_traj(ep) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming of |
||
merged_stats = _compute_traj_stats(obs_traj) | ||
print("SequenceDataset: normalizing actions...") | ||
for ep in LogUtils.custom_tqdm(self.dataset.demos[1:]): | ||
obs_traj = get_obs_traj(ep) | ||
traj_stats = _compute_traj_stats(obs_traj) | ||
merged_stats = _aggregate_traj_stats(merged_stats, traj_stats) | ||
|
||
obs_normalization_stats = { k : {} for k in merged_stats } | ||
for k in merged_stats: | ||
obs_normalization_stats[k]["mean"] = merged_stats[k]["mean"].astype('float32') | ||
obs_normalization_stats[k]["std"] = np.sqrt(merged_stats[k]["sqdiff"] / merged_stats[k]["n"]).astype('float32') | ||
obs_normalization_stats[k]["min"] = merged_stats[k]["min"].astype('float32') | ||
obs_normalization_stats[k]["max"] = merged_stats[k]["max"].astype('float32') | ||
|
||
# convert min and max to scale and offset | ||
stats = obs_normalization_stats['actions'] | ||
range_eps = 1e-4 | ||
input_min = stats['min'] | ||
input_max = stats['max'] | ||
output_min = -1.0 | ||
output_max = 1.0 | ||
|
||
input_range = input_max - input_min | ||
ignore_dim = input_range < range_eps | ||
input_range[ignore_dim] = output_max - output_min | ||
scale = (output_max - output_min) / input_range | ||
offset = output_min - scale * input_min | ||
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] | ||
|
||
action_normalization_stats = { | ||
'scale': scale, | ||
'offset': offset | ||
} | ||
return action_normalization_stats | ||
|
||
def get_action_normalization_stats(self): | ||
""" | ||
Returns dictionary of min, max, mean and std for actions. | ||
|
||
Returns: | ||
action_normalization_stats (dict): a dictionary for action | ||
normalization with a "min", "max", "mean" and "std" of shape (1, ...) where ... is the default | ||
shape for the action. | ||
""" | ||
assert self.hdf5_normalize_action, "not using action normalization!" | ||
return deepcopy(self.action_normalization_stats) | ||
|
||
def get_dataset_for_ep(self, ep, key): | ||
""" | ||
Helper utility to get a dataset for a specific demonstration. | ||
|
@@ -443,6 +542,8 @@ def get_item(self, index): | |
) | ||
if self.hdf5_normalize_obs: | ||
meta["obs"] = ObsUtils.normalize_obs(meta["obs"], obs_normalization_stats=self.obs_normalization_stats) | ||
if self.hdf5_normalize_action: | ||
meta["actions"] = ObsUtils.normalize_actions(meta["actions"], action_normalization_stats=self.action_normalization_stats) | ||
|
||
if self.load_next_obs: | ||
meta["next_obs"] = self.get_obs_sequence_from_demo( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,6 +499,17 @@ def normalize_obs(obs_dict, obs_normalization_stats): | |
|
||
return obs_dict | ||
|
||
def normalize_actions(actions, action_normalization_stats): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. small nitpick: our convention here is to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, both |
||
scale = action_normalization_stats['scale'] | ||
offset = action_normalization_stats['offset'] | ||
actions = actions * scale + offset | ||
return actions | ||
|
||
def unnormalize_actions(actions, action_normalization_stats): | ||
scale = action_normalization_stats['scale'] | ||
offset = action_normalization_stats['offset'] | ||
actions = (actions - offset) / scale | ||
return actions | ||
|
||
def has_modality(modality, obs_keys): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add some comments in the function docstring for
action_normalization_stats
? Similar to how it's already done forobs_normalization_stats