Skip to content

Commit

Permalink
fastmri variant isolating
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 28, 2023
1 parent 6885e69 commit 42d8a0b
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Data:
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
"""
import functools
from typing import Optional

import flax.linen as nn
Expand Down Expand Up @@ -65,6 +66,8 @@ class UNet(nn.Module):
channels: int = 32
num_pool_layers: int = 4
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
use_tanh: bool = False
use_layer_norm: bool = False

@nn.compact
def __call__(self, x, train=True):
Expand Down Expand Up @@ -134,6 +137,8 @@ class ConvBlock(nn.Module):
"""
out_channels: int
dropout_rate: float
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x, train=True):
Expand All @@ -149,29 +154,35 @@ def __call__(self, x, train=True):
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(
x)
# InstanceNorm2d was run with no learnable params in reference code
# so this is a simple normalization along channels
x = _simple_instance_norm2d(x, (1, 2))
x = jax.nn.leaky_relu(x, negative_slope=0.2)
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
# DO NOT SUBMIT check that this comment edit is correct
# InstanceNorm2d was run with no learnable params in reference code
# so this is a simple normalization along spatial dims.
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
use_bias=False)(
x)
x = _simple_instance_norm2d(x, (1, 2))
x = jax.nn.leaky_relu(x, negative_slope=0.2)
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = simple_instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = nn.Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)

self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(x)
return x


Expand All @@ -180,6 +191,8 @@ class TransposeConvBlock(nn.Module):
out_channels: Number of channels in the output.
"""
out_channels: int
use_tanh: bool
use_layer_norm: bool

@nn.compact
def __call__(self, x):
Expand All @@ -192,7 +205,13 @@ def __call__(self, x):
x = nn.ConvTranspose(
self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
x)
x = _simple_instance_norm2d(x, (1, 2))
x = jax.nn.leaky_relu(x, negative_slope=0.2)

if self.use_layer_norm:
x = nn.LayerNorm()(x)
else:
x = simple_instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
else:
activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
x = activation_fn(x)
return x

0 comments on commit 42d8a0b

Please sign in to comment.