Skip to content

Commit

Permalink
merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 9, 2023
1 parent e4e2890 commit 472d349
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 163 deletions.
169 changes: 85 additions & 84 deletions algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,91 +39,7 @@ def _instance_norm2d(x, axes, epsilon=1e-5):
return y


class ConvBlock(nn.Module):
"""A Convolutional Block.
out_channels: Number of channels in the output.
dropout_rate: Dropout probability.
"""
out_channels: int
dropout_rate: float
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x, train=True):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
x: Input 4D tensor of shape `(N, H, W, in_channels)`.
train: deterministic or not (use init2winit naming).
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
# DO NOT SUBMIT check that this comment edit is correct
# InstanceNorm2d was run with no learnable params in reference code
# so this is a simple normalization along spatial dims.
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
return x


class TransposeConvBlock(nn.Module):
"""A Transpose Convolutional Block.
out_channels: Number of channels in the output.
"""
out_channels: int
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x):
"""Forward function.
Args:
x: Input 4D tensor of shape `(N, H, W, in_channels)`.
Returns:
jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
"""
x = nn.ConvTranspose(
self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
return x


class UNet(nn.Module):
Expand Down Expand Up @@ -213,3 +129,88 @@ def __call__(self, x, train=True):
out_channels = 1
output = nn.Conv(out_channels, kernel_size=(1, 1), strides=(1, 1))(output)
return output.squeeze(-1)
class ConvBlock(nn.Module):
"""A Convolutional Block.
out_channels: Number of channels in the output.
dropout_rate: Dropout probability.
"""
out_channels: int
dropout_rate: float
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x, train=True):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
x: Input 4D tensor of shape `(N, H, W, in_channels)`.
train: deterministic or not (use init2winit naming).
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
# DO NOT SUBMIT check that this comment edit is correct
# InstanceNorm2d was run with no learnable params in reference code
# so this is a simple normalization along spatial dims.
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
return x


class TransposeConvBlock(nn.Module):
"""A Transpose Convolutional Block.
out_channels: Number of channels in the output.
"""
out_channels: int
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x):
"""Forward function.
Args:
x: Input 4D tensor of shape `(N, H, W, in_channels)`.
Returns:
jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
"""
x = nn.ConvTranspose(
self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
return x
120 changes: 62 additions & 58 deletions algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,6 @@
from algorithmic_efficiency import init_utils


class ConvBlock(nn.Module):
# A Convolutional Block that consists of two convolution layers each
# followed by instance normalization, LeakyReLU activation and dropout_rate.

def __init__(self,
in_chans: int,
out_chans: int,
dropout_rate: float,
use_tanh: bool,
use_layer_norm: bool) -> None:
super().__init__()

if use_layer_norm:
norm_layer = nn.LayerNorm
else:
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh(inplace=True)
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_layers = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
)

def forward(self, x: Tensor) -> Tensor:
return self.conv_layers(x)


class TransposeConvBlock(nn.Module):
# A Transpose Convolutional Block that consists of one convolution transpose
# layers followed by instance normalization and LeakyReLU activation.

def __init__(self, in_chans: int, out_chans: int):
super().__init__()
if use_layer_norm:
norm_layer = nn.LayerNorm
else:
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh(inplace=True)
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layers = nn.Sequential(
nn.ConvTranspose2d(
in_chans, out_chans, kernel_size=2, stride=2, bias=False),
norm_layer(out_chans),
activation_fn,
)

def forward(self, x: Tensor) -> Tensor:
return self.layers(x)


class UNet(nn.Module):
Expand All @@ -83,6 +25,7 @@ class UNet(nn.Module):

def __init__(self,
in_chans: int = 1,
out_chans: int = 1,
chans: int = 32,
num_pool_layers: int = 4,
dropout_rate: Optional[float] = 0.0,
Expand All @@ -91,6 +34,7 @@ def __init__(self,
super().__init__()

self.in_chans = in_chans
self.out_chans = out_chans
self.chans = chans
self.num_pool_layers = num_pool_layers
if dropout_rate is None:
Expand Down Expand Up @@ -157,3 +101,63 @@ def forward(self, x: Tensor) -> Tensor:
output = conv(output)

return output


class ConvBlock(nn.Module):
# A Convolutional Block that consists of two convolution layers each
# followed by instance normalization, LeakyReLU activation and dropout_rate.

def __init__(self,
in_chans: int,
out_chans: int,
dropout_rate: float,
use_tanh: bool,
use_layer_norm: bool) -> None:
super().__init__()

if use_layer_norm:
norm_layer = nn.LayerNorm
else:
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh(inplace=True)
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_layers = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
)

def forward(self, x: Tensor) -> Tensor:
return self.conv_layers(x)


class TransposeConvBlock(nn.Module):
# A Transpose Convolutional Block that consists of one convolution transpose
# layers followed by instance normalization and LeakyReLU activation.

def __init__(self, in_chans: int, out_chans: int):
super().__init__()
if use_layer_norm:
norm_layer = nn.LayerNorm
else:
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh(inplace=True)
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layers = nn.Sequential(
nn.ConvTranspose2d(
in_chans, out_chans, kernel_size=2, stride=2, bias=False),
norm_layer(out_chans),
activation_fn,
)

def forward(self, x: Tensor) -> Tensor:
return self.layers(x)

This file was deleted.

0 comments on commit 472d349

Please sign in to comment.