Skip to content

Commit

Permalink
add&update network infra to support VICE agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo428 committed Nov 2, 2023
1 parent f02729d commit 0df20d8
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 17 deletions.
3 changes: 0 additions & 3 deletions serl/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from serl.networks.mlp import MLP, default_init
from serl.networks.pixel_multiplexer import PixelMultiplexer
from serl.networks.state_action_value import StateActionValue
from serl.networks.two_pixel_encoder import TwoPixelEncoder
from serl.networks.two_pixel_multiplexer import TwoPixelMultiplexer
from serl.networks.two_encoded_multiplexer import TwoEncodedMultiplexer
from serl.networks.encoded_encoder import EncodedEncoder
from serl.networks.nd_output import NDimOutput
from serl.networks.binary_classifier import BinaryClassifier
47 changes: 47 additions & 0 deletions serl/networks/encoded_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict, Optional, Tuple, Type, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

from serl.networks import default_init
from serl.networks.spatial import SpatialLearnedEmbeddings


class EncodedEncoder(nn.Module):
network_cls: Type[nn.Module]
latent_dim: int
stop_gradient: bool = False
pixel_key: str = "pixels"
dropout_rate: float = 0.1

@nn.compact
def __call__(
self,
observations: Union[FrozenDict, Dict],
training: bool = False,
) -> jnp.ndarray:
observations = FrozenDict(observations)
x = observations[self.pixel_key]

if x.ndim == 3:
x = x[None, :]

x = SpatialLearnedEmbeddings(*(x.shape[1:]), 8)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not training)

if x.shape[0] == 1:
x = x.reshape(-1)
else:
x = x.reshape((x.shape[0], -1))

if self.stop_gradient:
# We do not update conv layers with policy gradients.
x = jax.lax.stop_gradient(x)

x = nn.Dense(512, kernel_init=default_init())(x)
x = nn.LayerNorm()(x)
x = nn.tanh(x)

return self.network_cls()(x, training)
4 changes: 0 additions & 4 deletions serl/networks/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from serl.networks.encoders.d4pg_encoder import D4PGEncoder
from serl.networks.encoders.two_d4pg_encoder import TwoD4PGEncoder
from serl.networks.encoders.ln_resnet_encoder import ResNetV2Encoder, TwoResNetV2Encoder
from serl.networks.encoders.two_resnet_encoder import TwoResNetEncoder
from serl.networks.encoders.resnet_encoder import ResNetEncoder, ResNetAttentionEncoder
from serl.networks.encoders.mobilenet_encoder import MobileNetEncoder
from serl.networks.encoders.two_mobilenet_encoder import TwoMobileNetEncoder
from serl.networks.encoders.pieg_gpt_encoder import PiegGPTEncoder
146 changes: 146 additions & 0 deletions serl/networks/encoders/ln_resnet_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Based on:
# https://github.com/google/flax/blob/main/examples/imagenet/models.py
# and
# https://github.com/google-research/big_transfer/blob/master/bit_jax/models.py
from functools import partial
from typing import Any, Callable, Tuple, Type

import flax.linen as nn
import jax.numpy as jnp
from flax import linen as nn

from serl.networks import default_init


class ResNetV2Block(nn.Module):
"""ResNet block."""

filters: int
conv_cls: Type[nn.Module]
norm_cls: Type[nn.Module]
act: Callable
strides: Tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
residual = x
y = self.norm_cls()(x)
y = self.act(y)
y = self.conv_cls(self.filters, (3, 3), self.strides)(y)
y = self.norm_cls()(y)
y = self.act(y)
y = self.conv_cls(self.filters, (3, 3))(y)

if residual.shape != y.shape:
residual = self.conv_cls(self.filters, (1, 1), self.strides)(residual)

return residual + y


class MyGroupNorm(nn.GroupNorm):
def __call__(self, x):
if x.ndim == 3:
x = x[jnp.newaxis]
x = super().__call__(x)
return x[0]
else:
return super().__call__(x)

class ResNetV2Encoder(nn.Module):
"""ResNetV2."""

stage_sizes: Tuple[int]
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu

@nn.compact
def __call__(self, x):
if x.ndim == 3:
x_0 = x[:x.shape[0] // 2]
x_1 = x[x.shape[0] // 2:]
elif x.ndim == 4:
x_0 = x[:, :x.shape[1] // 2, ...]
x_1 = x[:, x.shape[1] // 2:, ...]
conv_cls = partial(
nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init()
)
norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype)

if x.shape[-2] == 224:
x = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
else:
x = conv_cls(self.num_filters, (3, 3))(x)

for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = ResNetV2Block(
self.num_filters * 2**i,
strides=strides,
conv_cls=conv_cls,
norm_cls=norm_cls,
act=self.act,
)(x)

x = norm_cls()(x)
x = self.act(x)
return x.reshape((*x.shape[:-3], -1))

class TwoResNetV2Encoder(nn.Module):
"""ResNetV2."""

stage_sizes: Tuple[int]
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu

@nn.compact
def __call__(self, x):
if x.ndim == 3:
x_0 = x[:x.shape[0] // 2]
x_1 = x[x.shape[0] // 2:]
elif x.ndim == 4:
x_0 = x[:, :x.shape[1] // 2, ...]
x_1 = x[:, x.shape[1] // 2:, ...]

conv_cls = partial(
nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init()
)
norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype)
if x_0.shape[-2] == 224:
x_0 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_0)
x_0 = nn.max_pool(x_0, (3, 3), strides=(2, 2), padding="SAME")
x_1 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_1)
x_1 = nn.max_pool(x_1, (3, 3), strides=(2, 2), padding="SAME")
else:
x_0 = conv_cls(self.num_filters, (3, 3))(x_0)
x_1 = conv_cls(self.num_filters, (3, 3))(x_1)

for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x_0 = ResNetV2Block(
self.num_filters * 2**i,
strides=strides,
conv_cls=conv_cls,
norm_cls=norm_cls,
act=self.act,
)(x_0)
x_1 = ResNetV2Block(
self.num_filters * 2**i,
strides=strides,
conv_cls=conv_cls,
norm_cls=norm_cls,
act=self.act,
)(x_1)

x_0 = norm_cls()(x_0)
x_0 = self.act(x_0)
x_1 = norm_cls()(x_1)
x_1 = self.act(x_1)
x_0 = x_0.reshape((*x_0.shape[:-3], -1))
x_1 = x_1.reshape((*x_1.shape[:-3], -1))
return x_0, x_1

48 changes: 48 additions & 0 deletions serl/networks/encoders/mobilenet_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Sequence, Callable
import flax.linen as nn
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
import numpy as np
import jax


class MobileNetEncoder(nn.Module):
mobilenet: Callable[..., Callable]
params: FrozenDict
stop_gradient: bool = False

@nn.compact
def __call__(self, x: jnp.ndarray, training=False, divide_by=False, reshape=False) -> jnp.ndarray:
'''
encode an image using the mobilenet encoder
TODO: it should work for all pretrained encoders, not just mobilenet.
:param x: input image
:param training: whether the network is in training mode
:param divide_by: whether to divide the image by 255
:param reshape: whether to reshape the image before passing into encoder
:return: the encoded image
'''

mean = jnp.array((0.485, 0.456, 0.406))[None, ...]
std = jnp.array((0.229, 0.224, 0.225))[None, ...]

if reshape:
x = jnp.reshape(x, (*x.shape[:-2], -1))

if divide_by:
x = x.astype(jnp.float32) / 255.0
x = (x - mean) / std

if x.ndim == 3:
x = x[None, ...]
x = self.mobilenet.apply(self.params, x, mutable=False, training=False)
elif x.ndim == 4:
x = self.mobilenet.apply(self.params, x, mutable=False, training=False)
else:
raise NotImplementedError('ndim is not 3 or 4')

if self.stop_gradient:
x = jax.lax.stop_gradient(x)

return x
11 changes: 2 additions & 9 deletions serl/networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import flax.linen as nn
import jax.numpy as jnp
from serl.networks.spectral import SpectralNormalization

default_init = nn.initializers.xavier_uniform

Expand All @@ -21,15 +20,9 @@ def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:

for i, size in enumerate(self.hidden_dims):
if i + 1 == len(self.hidden_dims) and self.scale_final is not None:
if self.spectral_norm:
x = SpectralNormalization(nn.Dense(size, kernel_init=default_init(self.scale_final)))(x, training=training)
else:
x = nn.Dense(size, kernel_init=default_init(self.scale_final))(x)
x = nn.Dense(size, kernel_init=default_init(self.scale_final))(x)
else:
if self.spectral_norm:
x = SpectralNormalization(nn.Dense(size, kernel_init=default_init()))(x, training=training)
else:
x = nn.Dense(size, kernel_init=default_init())(x)
x = nn.Dense(size, kernel_init=default_init())(x)

if i + 1 < len(self.hidden_dims) or self.activate_final:
if self.dropout_rate is not None and self.dropout_rate > 0:
Expand Down
22 changes: 22 additions & 0 deletions serl/networks/nd_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import flax.linen as nn
import jax.numpy as jnp

from serl.networks import default_init


class NDimOutput(nn.Module):
n_dim: int
base_cls: nn.Module
spectral_norm: bool = False

@nn.compact
def __call__(
self, observations: jnp.ndarray, *args, **kwargs
) -> jnp.ndarray:
if self.base_cls:
outputs = self.base_cls()(observations, *args, **kwargs)
else:
outputs = observations

value = nn.Dense(self.n_dim, kernel_init=default_init())(outputs)
return value
2 changes: 1 addition & 1 deletion serl/networks/one_d_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp

from serl.networks import default_init
from serl.networks.spectral import SpectralNormalization


class OneDimOutput(nn.Module):
base_cls: nn.Module
Expand Down

0 comments on commit 0df20d8

Please sign in to comment.