Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 8, 2023
1 parent 2d08326 commit 66430c8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class UNet(nn.Module):
"""
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
use_tanh: bool = False
use_layer_norm: bool = False
Expand Down Expand Up @@ -125,8 +126,7 @@ def __call__(self, x, train=True):
output = jnp.concatenate((output, downsample_layer), axis=-1)
output = conv(output, train)

out_channels = 1
output = nn.Conv(out_channels, kernel_size=(1, 1), strides=(1, 1))(output)
output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))(output)
return output.squeeze(-1)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.fastmri.fastmri_jax import models
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.models import Unet
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import ssim
from algorithmic_efficiency.workloads.fastmri.workload import \
BaseFastMRIWorkload
Expand All @@ -27,7 +27,7 @@ def init_model_fn(
"""aux_dropout_rate is unused."""
del aux_dropout_rate
fake_batch = jnp.zeros((13, 320, 320))
self._model = models.UNet(
self._model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
Expand Down Expand Up @@ -165,7 +165,6 @@ def _eval_model_on_split(self,


class FastMRIModelSizeWorkload(FastMRIWorkload):

@property
def num_pool_layers(self) -> bool:
"""Whether or not to use tanh activations in the model."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self,
self.up_conv.append(
nn.Sequential(
ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm),
nn.Conv2d(ch, 1, kernel_size=1, stride=1),
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
))

for m in self.modules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from algorithmic_efficiency import pytorch_utils
from algorithmic_efficiency import spec
import algorithmic_efficiency.random_utils as prng
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch import models
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.models import Unet
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import ssim
from algorithmic_efficiency.workloads.fastmri.workload import \
BaseFastMRIWorkload
Expand Down Expand Up @@ -113,7 +113,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
model = models.UNet(
model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
Expand Down

0 comments on commit 66430c8

Please sign in to comment.