-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add&update network infra to support VICE agent
- Loading branch information
Showing
8 changed files
with
266 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Dict, Optional, Tuple, Type, Union | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
from flax.core.frozen_dict import FrozenDict | ||
|
||
from serl.networks import default_init | ||
from serl.networks.spatial import SpatialLearnedEmbeddings | ||
|
||
|
||
class EncodedEncoder(nn.Module): | ||
network_cls: Type[nn.Module] | ||
latent_dim: int | ||
stop_gradient: bool = False | ||
pixel_key: str = "pixels" | ||
dropout_rate: float = 0.1 | ||
|
||
@nn.compact | ||
def __call__( | ||
self, | ||
observations: Union[FrozenDict, Dict], | ||
training: bool = False, | ||
) -> jnp.ndarray: | ||
observations = FrozenDict(observations) | ||
x = observations[self.pixel_key] | ||
|
||
if x.ndim == 3: | ||
x = x[None, :] | ||
|
||
x = SpatialLearnedEmbeddings(*(x.shape[1:]), 8)(x) | ||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not training) | ||
|
||
if x.shape[0] == 1: | ||
x = x.reshape(-1) | ||
else: | ||
x = x.reshape((x.shape[0], -1)) | ||
|
||
if self.stop_gradient: | ||
# We do not update conv layers with policy gradients. | ||
x = jax.lax.stop_gradient(x) | ||
|
||
x = nn.Dense(512, kernel_init=default_init())(x) | ||
x = nn.LayerNorm()(x) | ||
x = nn.tanh(x) | ||
|
||
return self.network_cls()(x, training) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,4 @@ | ||
from serl.networks.encoders.d4pg_encoder import D4PGEncoder | ||
from serl.networks.encoders.two_d4pg_encoder import TwoD4PGEncoder | ||
from serl.networks.encoders.ln_resnet_encoder import ResNetV2Encoder, TwoResNetV2Encoder | ||
from serl.networks.encoders.two_resnet_encoder import TwoResNetEncoder | ||
from serl.networks.encoders.resnet_encoder import ResNetEncoder, ResNetAttentionEncoder | ||
from serl.networks.encoders.mobilenet_encoder import MobileNetEncoder | ||
from serl.networks.encoders.two_mobilenet_encoder import TwoMobileNetEncoder | ||
from serl.networks.encoders.pieg_gpt_encoder import PiegGPTEncoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Based on: | ||
# https://github.com/google/flax/blob/main/examples/imagenet/models.py | ||
# and | ||
# https://github.com/google-research/big_transfer/blob/master/bit_jax/models.py | ||
from functools import partial | ||
from typing import Any, Callable, Tuple, Type | ||
|
||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
from flax import linen as nn | ||
|
||
from serl.networks import default_init | ||
|
||
|
||
class ResNetV2Block(nn.Module): | ||
"""ResNet block.""" | ||
|
||
filters: int | ||
conv_cls: Type[nn.Module] | ||
norm_cls: Type[nn.Module] | ||
act: Callable | ||
strides: Tuple[int, int] = (1, 1) | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
residual = x | ||
y = self.norm_cls()(x) | ||
y = self.act(y) | ||
y = self.conv_cls(self.filters, (3, 3), self.strides)(y) | ||
y = self.norm_cls()(y) | ||
y = self.act(y) | ||
y = self.conv_cls(self.filters, (3, 3))(y) | ||
|
||
if residual.shape != y.shape: | ||
residual = self.conv_cls(self.filters, (1, 1), self.strides)(residual) | ||
|
||
return residual + y | ||
|
||
|
||
class MyGroupNorm(nn.GroupNorm): | ||
def __call__(self, x): | ||
if x.ndim == 3: | ||
x = x[jnp.newaxis] | ||
x = super().__call__(x) | ||
return x[0] | ||
else: | ||
return super().__call__(x) | ||
|
||
class ResNetV2Encoder(nn.Module): | ||
"""ResNetV2.""" | ||
|
||
stage_sizes: Tuple[int] | ||
num_filters: int = 64 | ||
dtype: Any = jnp.float32 | ||
act: Callable = nn.relu | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
if x.ndim == 3: | ||
x_0 = x[:x.shape[0] // 2] | ||
x_1 = x[x.shape[0] // 2:] | ||
elif x.ndim == 4: | ||
x_0 = x[:, :x.shape[1] // 2, ...] | ||
x_1 = x[:, x.shape[1] // 2:, ...] | ||
conv_cls = partial( | ||
nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init() | ||
) | ||
norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype) | ||
|
||
if x.shape[-2] == 224: | ||
x = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x) | ||
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") | ||
else: | ||
x = conv_cls(self.num_filters, (3, 3))(x) | ||
|
||
for i, block_size in enumerate(self.stage_sizes): | ||
for j in range(block_size): | ||
strides = (2, 2) if i > 0 and j == 0 else (1, 1) | ||
x = ResNetV2Block( | ||
self.num_filters * 2**i, | ||
strides=strides, | ||
conv_cls=conv_cls, | ||
norm_cls=norm_cls, | ||
act=self.act, | ||
)(x) | ||
|
||
x = norm_cls()(x) | ||
x = self.act(x) | ||
return x.reshape((*x.shape[:-3], -1)) | ||
|
||
class TwoResNetV2Encoder(nn.Module): | ||
"""ResNetV2.""" | ||
|
||
stage_sizes: Tuple[int] | ||
num_filters: int = 64 | ||
dtype: Any = jnp.float32 | ||
act: Callable = nn.relu | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
if x.ndim == 3: | ||
x_0 = x[:x.shape[0] // 2] | ||
x_1 = x[x.shape[0] // 2:] | ||
elif x.ndim == 4: | ||
x_0 = x[:, :x.shape[1] // 2, ...] | ||
x_1 = x[:, x.shape[1] // 2:, ...] | ||
|
||
conv_cls = partial( | ||
nn.Conv, use_bias=False, dtype=self.dtype, kernel_init=default_init() | ||
) | ||
norm_cls = partial(MyGroupNorm, num_groups=4, epsilon=1e-5, dtype=self.dtype) | ||
if x_0.shape[-2] == 224: | ||
x_0 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_0) | ||
x_0 = nn.max_pool(x_0, (3, 3), strides=(2, 2), padding="SAME") | ||
x_1 = conv_cls(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)])(x_1) | ||
x_1 = nn.max_pool(x_1, (3, 3), strides=(2, 2), padding="SAME") | ||
else: | ||
x_0 = conv_cls(self.num_filters, (3, 3))(x_0) | ||
x_1 = conv_cls(self.num_filters, (3, 3))(x_1) | ||
|
||
for i, block_size in enumerate(self.stage_sizes): | ||
for j in range(block_size): | ||
strides = (2, 2) if i > 0 and j == 0 else (1, 1) | ||
x_0 = ResNetV2Block( | ||
self.num_filters * 2**i, | ||
strides=strides, | ||
conv_cls=conv_cls, | ||
norm_cls=norm_cls, | ||
act=self.act, | ||
)(x_0) | ||
x_1 = ResNetV2Block( | ||
self.num_filters * 2**i, | ||
strides=strides, | ||
conv_cls=conv_cls, | ||
norm_cls=norm_cls, | ||
act=self.act, | ||
)(x_1) | ||
|
||
x_0 = norm_cls()(x_0) | ||
x_0 = self.act(x_0) | ||
x_1 = norm_cls()(x_1) | ||
x_1 = self.act(x_1) | ||
x_0 = x_0.reshape((*x_0.shape[:-3], -1)) | ||
x_1 = x_1.reshape((*x_1.shape[:-3], -1)) | ||
return x_0, x_1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Sequence, Callable | ||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
from flax.core.frozen_dict import FrozenDict | ||
import numpy as np | ||
import jax | ||
|
||
|
||
class MobileNetEncoder(nn.Module): | ||
mobilenet: Callable[..., Callable] | ||
params: FrozenDict | ||
stop_gradient: bool = False | ||
|
||
@nn.compact | ||
def __call__(self, x: jnp.ndarray, training=False, divide_by=False, reshape=False) -> jnp.ndarray: | ||
''' | ||
encode an image using the mobilenet encoder | ||
TODO: it should work for all pretrained encoders, not just mobilenet. | ||
:param x: input image | ||
:param training: whether the network is in training mode | ||
:param divide_by: whether to divide the image by 255 | ||
:param reshape: whether to reshape the image before passing into encoder | ||
:return: the encoded image | ||
''' | ||
|
||
mean = jnp.array((0.485, 0.456, 0.406))[None, ...] | ||
std = jnp.array((0.229, 0.224, 0.225))[None, ...] | ||
|
||
if reshape: | ||
x = jnp.reshape(x, (*x.shape[:-2], -1)) | ||
|
||
if divide_by: | ||
x = x.astype(jnp.float32) / 255.0 | ||
x = (x - mean) / std | ||
|
||
if x.ndim == 3: | ||
x = x[None, ...] | ||
x = self.mobilenet.apply(self.params, x, mutable=False, training=False) | ||
elif x.ndim == 4: | ||
x = self.mobilenet.apply(self.params, x, mutable=False, training=False) | ||
else: | ||
raise NotImplementedError('ndim is not 3 or 4') | ||
|
||
if self.stop_gradient: | ||
x = jax.lax.stop_gradient(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import flax.linen as nn | ||
import jax.numpy as jnp | ||
|
||
from serl.networks import default_init | ||
|
||
|
||
class NDimOutput(nn.Module): | ||
n_dim: int | ||
base_cls: nn.Module | ||
spectral_norm: bool = False | ||
|
||
@nn.compact | ||
def __call__( | ||
self, observations: jnp.ndarray, *args, **kwargs | ||
) -> jnp.ndarray: | ||
if self.base_cls: | ||
outputs = self.base_cls()(observations, *args, **kwargs) | ||
else: | ||
outputs = observations | ||
|
||
value = nn.Dense(self.n_dim, kernel_init=default_init())(outputs) | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters