Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 73b72b1 commit 174f0a4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ def init_model_fn(
dropout_rate=dropout_rate)
variables = jax.jit(self._model.init)({'params': rng}, fake_batch)
params = variables['params']
tabulate_fn = nn.tabulate(
self._model,
jax.random.PRNGKey(0),
console_kwargs={
'force_terminal': False, 'force_jupyter':False, 'width':240
}
)
print(tabulate_fn(fake_batch, train=False))
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
params = jax_utils.replicate(params)
Expand Down
3 changes: 2 additions & 1 deletion tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def sort_key(k):

keys = sorted(sd.keys(), key=sort_key)
c = 0
jax_weights_name = 'kernel'
for idx, k in enumerate(keys):
new_key = []
jax_weights_name = 'kernel'
for idx2, i in enumerate(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
Expand All @@ -44,6 +44,7 @@ def sort_key(k):
i = i.replace('ConvTranspose2d', 'ConvTranspose')
if 'LayerNorm' in i:
jax_weights_name = 'scale'
continue
if 'weight' in i:
i = i.replace('weight', jax_weights_name)
new_key.append(i)
Expand Down

0 comments on commit 174f0a4

Please sign in to comment.