From 0df20d8c445261555004db0ff4e6988f6c38e177 Mon Sep 17 00:00:00 2001 From: Zheyuan Hu Date: Thu, 2 Nov 2023 15:04:41 -0700 Subject: [PATCH] add&update network infra to support VICE agent --- serl/networks/__init__.py | 3 - serl/networks/encoded_encoder.py | 47 +++++++ serl/networks/encoders/__init__.py | 4 - serl/networks/encoders/ln_resnet_encoder.py | 146 ++++++++++++++++++++ serl/networks/encoders/mobilenet_encoder.py | 48 +++++++ serl/networks/mlp.py | 11 +- serl/networks/nd_output.py | 22 +++ serl/networks/one_d_output.py | 2 +- 8 files changed, 266 insertions(+), 17 deletions(-) create mode 100644 serl/networks/encoded_encoder.py create mode 100644 serl/networks/encoders/ln_resnet_encoder.py create mode 100644 serl/networks/encoders/mobilenet_encoder.py create mode 100644 serl/networks/nd_output.py diff --git a/serl/networks/__init__.py b/serl/networks/__init__.py index dc61950..2411f88 100644 --- a/serl/networks/__init__.py +++ b/serl/networks/__init__.py @@ -2,9 +2,6 @@ from serl.networks.mlp import MLP, default_init from serl.networks.pixel_multiplexer import PixelMultiplexer from serl.networks.state_action_value import StateActionValue -from serl.networks.two_pixel_encoder import TwoPixelEncoder -from serl.networks.two_pixel_multiplexer import TwoPixelMultiplexer -from serl.networks.two_encoded_multiplexer import TwoEncodedMultiplexer from serl.networks.encoded_encoder import EncodedEncoder from serl.networks.nd_output import NDimOutput from serl.networks.binary_classifier import BinaryClassifier diff --git a/serl/networks/encoded_encoder.py b/serl/networks/encoded_encoder.py new file mode 100644 index 0000000..7480b50 --- /dev/null +++ b/serl/networks/encoded_encoder.py @@ -0,0 +1,47 @@ +from typing import Dict, Optional, Tuple, Type, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from serl.networks import default_init +from serl.networks.spatial import SpatialLearnedEmbeddings + + +class EncodedEncoder(nn.Module): + network_cls: Type[nn.Module] + latent_dim: int + stop_gradient: bool = False + pixel_key: str = "pixels" + dropout_rate: float = 0.1 + + @nn.compact + def __call__( + self, + observations: Union[FrozenDict, Dict], + training: bool = False, + ) -> jnp.ndarray: + observations = FrozenDict(observations) + x = observations[self.pixel_key] + + if x.ndim == 3: + x = x[None, :] + + x = SpatialLearnedEmbeddings(*(x.shape[1:]), 8)(x) + x = nn.Dropout(self.dropout_rate)(x, deterministic=not training) + + if x.shape[0] == 1: + x = x.reshape(-1) + else: + x = x.reshape((x.shape[0], -1)) + + if self.stop_gradient: + # We do not update conv layers with policy gradients. + x = jax.lax.stop_gradient(x) + + x = nn.Dense(512, kernel_init=default_init())(x) + x = nn.LayerNorm()(x) + x = nn.tanh(x) + + return self.network_cls()(x, training) diff --git a/serl/networks/encoders/__init__.py b/serl/networks/encoders/__init__.py index 66d43ae..410bbfe 100644 --- a/serl/networks/encoders/__init__.py +++ b/serl/networks/encoders/__init__.py @@ -1,8 +1,4 @@ -from serl.networks.encoders.d4pg_encoder import D4PGEncoder from serl.networks.encoders.two_d4pg_encoder import TwoD4PGEncoder from serl.networks.encoders.ln_resnet_encoder import ResNetV2Encoder, TwoResNetV2Encoder -from serl.networks.encoders.two_resnet_encoder import TwoResNetEncoder -from serl.networks.encoders.resnet_encoder import ResNetEncoder, ResNetAttentionEncoder from serl.networks.encoders.mobilenet_encoder import MobileNetEncoder from serl.networks.encoders.two_mobilenet_encoder import TwoMobileNetEncoder -from serl.networks.encoders.pieg_gpt_encoder import PiegGPTEncoder diff --git a/serl/networks/encoders/ln_resnet_encoder.py b/serl/networks/encoders/ln_resnet_encoder.py new file mode 100644 index 0000000..42a3d27 --- /dev/null +++ b/serl/networks/encoders/ln_resnet_encoder.py @@ -0,0 +1,146 @@ +# Based on: +# https://github.com/google/flax/blob/main/examples/imagenet/models.py +# and +# https://github.com/google-research/big_transfer/blob/master/bit_jax/models.py +from functools import partial +from typing import Any, Callable, Tuple, Type + +import flax.linen as nn +import jax.numpy as jnp +from flax import linen as nn + +from serl.networks import default_init + + +class ResNetV2Block(nn.Module): + """ResNet block.""" + + filters: int + conv_cls: Type[nn.Module] + norm_cls: Type[nn.Module] + act: Callable + strides: Tuple[int, int] = (1, 1) + + @nn.compact + def __call__(self, x): + residual = x + y = self.norm_cls()(x) + y = self.act(y) + y = self.conv_cls(self.filters, (3, 3), self.strides)(y) + y = self.norm_cls()(y) + y = self.act(y) + y = self.conv_cls(self.filters, (3, 3))(y) + + if residual.shape != y.shape: + residual = self.conv_cls(self.filters, (1, 1), self.strides)(residual) + + return residual + y + + +class MyGroupNorm(nn.GroupNorm): + def __call__(self, x): + if x.ndim == 3: + x = x[jnp.newaxis] + x = super().__call__(x) + return x[0] + else: + return super().__call__(x) + +class ResNetV2Encoder(nn.Module): + """ResNetV2.""" + + stage_sizes: Tuple[int] + num_filters: int = 64 + dtype: Any = jnp.float32 + act: Callable = nn.relu + + @nn.compact + def __call__(self, x): + if x.ndim == 3: + x_0 = x[:x.shape[0] // 2] + x_1 = x[x.shape[0] // 2:] + elif x.ndim == 4: + x_0 = x[:, :x.shape[1] // 2, ...] + x_1 = x[:, x.shape[1] // 2:, ...] + conv_cls = partial( + nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init() + ) + norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype) + + if x.shape[-2] == 224: + x = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x) + x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") + else: + x = conv_cls(self.num_filters, (3, 3))(x) + + for i, block_size in enumerate(self.stage_sizes): + for j in range(block_size): + strides = (2, 2) if i > 0 and j == 0 else (1, 1) + x = ResNetV2Block( + self.num_filters * 2**i, + strides=strides, + conv_cls=conv_cls, + norm_cls=norm_cls, + act=self.act, + )(x) + + x = norm_cls()(x) + x = self.act(x) + return x.reshape((*x.shape[:-3], -1)) + +class TwoResNetV2Encoder(nn.Module): + """ResNetV2.""" + + stage_sizes: Tuple[int] + num_filters: int = 64 + dtype: Any = jnp.float32 + act: Callable = nn.relu + + @nn.compact + def __call__(self, x): + if x.ndim == 3: + x_0 = x[:x.shape[0] // 2] + x_1 = x[x.shape[0] // 2:] + elif x.ndim == 4: + x_0 = x[:, :x.shape[1] // 2, ...] + x_1 = x[:, x.shape[1] // 2:, ...] + + conv_cls = partial( + nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init() + ) + norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype) + if x_0.shape[-2] == 224: + x_0 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_0) + x_0 = nn.max_pool(x_0, (3, 3), strides=(2, 2), padding="SAME") + x_1 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_1) + x_1 = nn.max_pool(x_1, (3, 3), strides=(2, 2), padding="SAME") + else: + x_0 = conv_cls(self.num_filters, (3, 3))(x_0) + x_1 = conv_cls(self.num_filters, (3, 3))(x_1) + + for i, block_size in enumerate(self.stage_sizes): + for j in range(block_size): + strides = (2, 2) if i > 0 and j == 0 else (1, 1) + x_0 = ResNetV2Block( + self.num_filters * 2**i, + strides=strides, + conv_cls=conv_cls, + norm_cls=norm_cls, + act=self.act, + )(x_0) + x_1 = ResNetV2Block( + self.num_filters * 2**i, + strides=strides, + conv_cls=conv_cls, + norm_cls=norm_cls, + act=self.act, + )(x_1) + + x_0 = norm_cls()(x_0) + x_0 = self.act(x_0) + x_1 = norm_cls()(x_1) + x_1 = self.act(x_1) + x_0 = x_0.reshape((*x_0.shape[:-3], -1)) + x_1 = x_1.reshape((*x_1.shape[:-3], -1)) + return x_0, x_1 + \ No newline at end of file diff --git a/serl/networks/encoders/mobilenet_encoder.py b/serl/networks/encoders/mobilenet_encoder.py new file mode 100644 index 0000000..78ace8f --- /dev/null +++ b/serl/networks/encoders/mobilenet_encoder.py @@ -0,0 +1,48 @@ +from typing import Sequence, Callable +import flax.linen as nn +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +import numpy as np +import jax + + +class MobileNetEncoder(nn.Module): + mobilenet: Callable[..., Callable] + params: FrozenDict + stop_gradient: bool = False + + @nn.compact + def __call__(self, x: jnp.ndarray, training=False, divide_by=False, reshape=False) -> jnp.ndarray: + ''' + encode an image using the mobilenet encoder + TODO: it should work for all pretrained encoders, not just mobilenet. + + :param x: input image + :param training: whether the network is in training mode + :param divide_by: whether to divide the image by 255 + :param reshape: whether to reshape the image before passing into encoder + :return: the encoded image + ''' + + mean = jnp.array((0.485, 0.456, 0.406))[None, ...] + std = jnp.array((0.229, 0.224, 0.225))[None, ...] + + if reshape: + x = jnp.reshape(x, (*x.shape[:-2], -1)) + + if divide_by: + x = x.astype(jnp.float32) / 255.0 + x = (x - mean) / std + + if x.ndim == 3: + x = x[None, ...] + x = self.mobilenet.apply(self.params, x, mutable=False, training=False) + elif x.ndim == 4: + x = self.mobilenet.apply(self.params, x, mutable=False, training=False) + else: + raise NotImplementedError('ndim is not 3 or 4') + + if self.stop_gradient: + x = jax.lax.stop_gradient(x) + + return x diff --git a/serl/networks/mlp.py b/serl/networks/mlp.py index 5a7d0c4..6e3d8a0 100644 --- a/serl/networks/mlp.py +++ b/serl/networks/mlp.py @@ -2,7 +2,6 @@ import flax.linen as nn import jax.numpy as jnp -from serl.networks.spectral import SpectralNormalization default_init = nn.initializers.xavier_uniform @@ -21,15 +20,9 @@ def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: for i, size in enumerate(self.hidden_dims): if i + 1 == len(self.hidden_dims) and self.scale_final is not None: - if self.spectral_norm: - x = SpectralNormalization(nn.Dense(size, kernel_init=default_init(self.scale_final)))(x, training=training) - else: - x = nn.Dense(size, kernel_init=default_init(self.scale_final))(x) + x = nn.Dense(size, kernel_init=default_init(self.scale_final))(x) else: - if self.spectral_norm: - x = SpectralNormalization(nn.Dense(size, kernel_init=default_init()))(x, training=training) - else: - x = nn.Dense(size, kernel_init=default_init())(x) + x = nn.Dense(size, kernel_init=default_init())(x) if i + 1 < len(self.hidden_dims) or self.activate_final: if self.dropout_rate is not None and self.dropout_rate > 0: diff --git a/serl/networks/nd_output.py b/serl/networks/nd_output.py new file mode 100644 index 0000000..bfa319b --- /dev/null +++ b/serl/networks/nd_output.py @@ -0,0 +1,22 @@ +import flax.linen as nn +import jax.numpy as jnp + +from serl.networks import default_init + + +class NDimOutput(nn.Module): + n_dim: int + base_cls: nn.Module + spectral_norm: bool = False + + @nn.compact + def __call__( + self, observations: jnp.ndarray, *args, **kwargs + ) -> jnp.ndarray: + if self.base_cls: + outputs = self.base_cls()(observations, *args, **kwargs) + else: + outputs = observations + + value = nn.Dense(self.n_dim, kernel_init=default_init())(outputs) + return value diff --git a/serl/networks/one_d_output.py b/serl/networks/one_d_output.py index 2186754..97ac8d2 100644 --- a/serl/networks/one_d_output.py +++ b/serl/networks/one_d_output.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from serl.networks import default_init -from serl.networks.spectral import SpectralNormalization + class OneDimOutput(nn.Module): base_cls: nn.Module