Skip to content

Commit

Permalink
RF bridge, fix aux param handling
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 15, 2023
1 parent c68855b commit 35adf52
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion returnn/torch/frontend/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,12 @@ def _apply(self, fn):
# Update the corresponding RF Parameter.
for name, rf_param in self._rf_module.named_parameters(recurse=False):
pt_param = getattr(self, name)
assert isinstance(pt_param, torch.nn.Parameter)
if rf_param.auxiliary:
assert isinstance(pt_param, torch.Tensor) # but not torch.nn.Parameter
# See similar logic in torch.nn.Module._apply.
pt_param = torch.nn.Parameter(pt_param, pt_param.requires_grad)
else:
assert isinstance(pt_param, torch.nn.Parameter), (
f"{self}.{name} is not a Parameter" f" but {type(pt_param).__name__}"
)
rf_param.raw_tensor = pt_param

0 comments on commit 35adf52

Please sign in to comment.