Skip to content

Commit

Permalink
feat: use nnx.Sequential over ModuleList (#5)
Browse files Browse the repository at this point in the history
Part of #1
  • Loading branch information
SauravMaheshkar authored Oct 2, 2024
1 parent e0cec84 commit 5da49d7
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions jflux/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,13 @@ def __init__(
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# FIXME: Use nnx.Sequential instead
self.down = nnx.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
# FIXME: Use nnx.Sequential instead
block = nnx.ModuleList()
blocks = []
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
blocks.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
Expand All @@ -359,17 +356,17 @@ def __init__(
)
)
block_in = block_out
down = nnx.Module()
down.block = block
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(
in_channels=block_in,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
blocks.append(
Downsample(
in_channels=block_in,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
)
curr_res = curr_res // 2
self.down.append(down)
self.down = nnx.Sequential(*blocks)

# middle
self.middle = nnx.Sequential(
Expand Down Expand Up @@ -520,14 +517,12 @@ def __init__(
)

# upsampling
# FIXME: Use nnx.Sequential instead
self.up = nnx.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
# FIXME: Use nnx.Sequential instead
block = nnx.ModuleList()
blocks = []
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
blocks.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
Expand All @@ -537,17 +532,20 @@ def __init__(
)
)
block_in = block_out
up = nnx.Module()
up.block = block

upsample_module = [*blocks]
if i_level != 0:
up.upsample = Upsample(
in_channels=block_in,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
upsample_module.append(
Upsample(
in_channels=block_in,
rngs=rngs,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

self.up = nnx.Sequential(*upsample_module)

# end
self.norm_out = nnx.GroupNorm(
Expand Down

0 comments on commit 5da49d7

Please sign in to comment.