Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nisutte committed Oct 25, 2024
1 parent f52fa79 commit ed6957a
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 7 deletions.
4 changes: 1 addition & 3 deletions serl_launcher/serl_launcher/agents/continuous/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def critic_loss_fn(self, batch, params: Params, rng: PRNGKey):
)
chex.assert_shape(target_q, (batch_size,))

if self.config[
"backup_entropy"
]: # not the same as in original jaxrl_m SAC implementation: https://github.com/dibyaghosh/jaxrl_m/blob/main/examples/mujoco/sac.py
if self.config["backup_entropy"]: # not the same as in original jaxrl_m SAC implementation: https://github.com/dibyaghosh/jaxrl_m/blob/main/examples/mujoco/sac.py
temperature = self.forward_temperature()
# target_q = target_q - temperature * next_actions_log_probs # serl original
target_q = (
Expand Down
1 change: 0 additions & 1 deletion serl_launcher/serl_launcher/networks/actor_critic_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __call__(
inputs = jnp.concatenate([obs_enc, actions], -1)
outputs = self.network(inputs, train)
# train=train throws: "RuntimeWarning: kwargs are not supported in vmap, so "train" is(are) ignored"

if self.init_final is not None:
value = nn.Dense(
1,
Expand Down
3 changes: 1 addition & 2 deletions serl_launcher/serl_launcher/vision/data_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jax
import jax.numpy as jnp
import jax.lax as lax


@partial(jax.jit, static_argnames="padding")
Expand Down Expand Up @@ -98,7 +97,7 @@ def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
radius = int(kernel_size / 2)
kernel_size_ = 2 * radius + 1
x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
blur_filter = jnp.exp(-(x ** 2) / (2.0 * sigma ** 2))
blur_filter = jnp.exp(-(x**2) / (2.0 * sigma**2))
blur_filter = blur_filter / jnp.sum(blur_filter)
blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ur_env.envs.basic_env.config import UR5BasicConfig



class BoxPickingBasicEnv(UR5Env):
def __init__(self, **kwargs):
super().__init__(**kwargs, config=UR5BasicConfig)
Expand Down

0 comments on commit ed6957a

Please sign in to comment.