From 5da49d7e194206e91a54396bf23c17ca959196c6 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 2 Oct 2024 15:06:24 +0100 Subject: [PATCH] feat: use `nnx.Sequential` over `ModuleList` (#5) Part of #1 --- jflux/autoencoder.py | 48 +++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/jflux/autoencoder.py b/jflux/autoencoder.py index 22dbe07..3f1cd9f 100644 --- a/jflux/autoencoder.py +++ b/jflux/autoencoder.py @@ -340,16 +340,13 @@ def __init__( curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult - # FIXME: Use nnx.Sequential instead - self.down = nnx.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): - # FIXME: Use nnx.Sequential instead - block = nnx.ModuleList() + blocks = [] block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): - block.append( + blocks.append( ResnetBlock( in_channels=block_in, out_channels=block_out, @@ -359,17 +356,17 @@ def __init__( ) ) block_in = block_out - down = nnx.Module() - down.block = block if i_level != self.num_resolutions - 1: - down.downsample = Downsample( - in_channels=block_in, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + blocks.append( + Downsample( + in_channels=block_in, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) ) curr_res = curr_res // 2 - self.down.append(down) + self.down = nnx.Sequential(*blocks) # middle self.middle = nnx.Sequential( @@ -520,14 +517,12 @@ def __init__( ) # upsampling - # FIXME: Use nnx.Sequential instead self.up = nnx.ModuleList() for i_level in reversed(range(self.num_resolutions)): - # FIXME: Use nnx.Sequential instead - block = nnx.ModuleList() + blocks = [] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): - block.append( + blocks.append( ResnetBlock( in_channels=block_in, out_channels=block_out, @@ -537,17 +532,20 @@ def __init__( ) ) block_in = block_out - up = nnx.Module() - up.block = block + + upsample_module = [*blocks] if i_level != 0: - up.upsample = Upsample( - in_channels=block_in, - rngs=rngs, - dtype=self.dtype, - param_dtype=self.param_dtype, + upsample_module.append( + Upsample( + in_channels=block_in, + rngs=rngs, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) ) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + + self.up = nnx.Sequential(*upsample_module) # end self.norm_out = nnx.GroupNorm(