From 42d8a0bd5d527a7d7097586a4e8bc21813fcd5dd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 28 Nov 2023 18:05:31 +0000 Subject: [PATCH] fastmri variant isolating --- .../workloads/fastmri/fastmri_jax/models.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py index 3f8359550..c0685d419 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py @@ -12,6 +12,7 @@ Data: github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ +import functools from typing import Optional import flax.linen as nn @@ -65,6 +66,8 @@ class UNet(nn.Module): channels: int = 32 num_pool_layers: int = 4 dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + use_tanh: bool = False + use_layer_norm: bool = False @nn.compact def __call__(self, x, train=True): @@ -134,6 +137,8 @@ class ConvBlock(nn.Module): """ out_channels: int dropout_rate: float + use_tanh: bool + use_layer_norm: bool @nn.compact def __call__(self, x, train=True): @@ -149,29 +154,35 @@ def __call__(self, x, train=True): features=self.out_channels, kernel_size=(3, 3), strides=(1, 1), - use_bias=False)( - x) - # InstanceNorm2d was run with no learnable params in reference code - # so this is a simple normalization along channels - x = _simple_instance_norm2d(x, (1, 2)) - x = jax.nn.leaky_relu(x, negative_slope=0.2) + use_bias=False)(x) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + else: + # DO NOT SUBMIT check that this comment edit is correct + # InstanceNorm2d was run with no learnable params in reference code + # so this is a simple normalization along spatial dims. + x = _instance_norm2d(x, (1, 2)) + if self.use_tanh: + activation_fn = nn.tanh + else: + activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) + x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x) x = nn.Conv( features=self.out_channels, kernel_size=(3, 3), strides=(1, 1), - use_bias=False)( - x) - x = _simple_instance_norm2d(x, (1, 2)) - x = jax.nn.leaky_relu(x, negative_slope=0.2) + use_bias=False)(x) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + else: + x = simple_instance_norm2d(x, (1, 2)) + x = activation_fn(x) x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) - + self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x) return x @@ -180,6 +191,8 @@ class TransposeConvBlock(nn.Module): out_channels: Number of channels in the output. """ out_channels: int + use_tanh: bool + use_layer_norm: bool @nn.compact def __call__(self, x): @@ -192,7 +205,13 @@ def __call__(self, x): x = nn.ConvTranspose( self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( x) - x = _simple_instance_norm2d(x, (1, 2)) - x = jax.nn.leaky_relu(x, negative_slope=0.2) - + if self.use_layer_norm: + x = nn.LayerNorm()(x) + else: + x = simple_instance_norm2d(x, (1, 2)) + if self.use_tanh: + activation_fn = nn.tanh + else: + activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) + x = activation_fn(x) return x