Skip to content

Commit

Permalink
Refactor _cost_objective variable
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Oct 27, 2024
1 parent d7135c6 commit 76e7dc9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions mo_gymnasium/envs/mujoco/ant_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MOAntEnv(AntEnv, EzPickle):
def __init__(self, cost_objective=True, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, cost_objective, **kwargs)
self.cost_objetive = cost_objective
self._cost_objetive = cost_objective
self.reward_dim = 3 if cost_objective else 2
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(self.reward_dim,))

Expand All @@ -39,7 +39,7 @@ def step(self, action):
cost = info["reward_ctrl"]
healthy_reward = info["reward_survive"]

if self.cost_objetive:
if self._cost_objetive:
cost /= self._ctrl_cost_weight # Ignore the weight in the original AntEnv
vec_reward = np.array([x_velocity, y_velocity, cost], dtype=np.float32)
else:
Expand Down
4 changes: 2 additions & 2 deletions mo_gymnasium/envs/mujoco/ant_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MOAntEnv(AntEnv, EzPickle):
def __init__(self, cost_objective=True, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, cost_objective, **kwargs)
self.cost_objetive = cost_objective
self._cost_objetive = cost_objective
self.reward_dim = 3 if cost_objective else 2
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(self.reward_dim,))

Expand All @@ -43,7 +43,7 @@ def step(self, action):
cost = info["reward_ctrl"]
healthy_reward = info["reward_survive"]

if self.cost_objetive:
if self._cost_objetive:
cost /= self._ctrl_cost_weight # Ignore the weight in the original AntEnv
vec_reward = np.array([x_velocity, y_velocity, cost], dtype=np.float32)
else:
Expand Down
4 changes: 2 additions & 2 deletions mo_gymnasium/envs/mujoco/hopper_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MOHopperEnv(HopperEnv, EzPickle):
def __init__(self, cost_objective=True, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, cost_objective, **kwargs)
self.cost_objetive = cost_objective
self._cost_objetive = cost_objective
self.reward_dim = 3 if cost_objective else 2
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(self.reward_dim,))

Expand All @@ -53,7 +53,7 @@ def step(self, action):
height = 10 * (z - self.init_qpos[1])
energy_cost = np.sum(np.square(action))

if self.cost_objetive:
if self._cost_objetive:
vec_reward = np.array([x_velocity, height, -energy_cost], dtype=np.float32)
else:
vec_reward = np.array([x_velocity, height], dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion mo_gymnasium/envs/mujoco/hopper_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MOHopperEnv(HopperEnv, EzPickle):
def __init__(self, cost_objective=True, **kwargs):
super().__init__(**kwargs)
EzPickle.__init__(self, cost_objective, **kwargs)
self.cost_objetive = cost_objective
self._cost_objetive = cost_objective
self.reward_dim = 3 if cost_objective else 2
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(self.reward_dim,))

Expand Down

0 comments on commit 76e7dc9

Please sign in to comment.