diff --git a/returnn/torch/frontend/bridge.py b/returnn/torch/frontend/bridge.py index 1b000104b2..0279dbaf37 100644 --- a/returnn/torch/frontend/bridge.py +++ b/returnn/torch/frontend/bridge.py @@ -121,5 +121,12 @@ def _apply(self, fn): # Update the corresponding RF Parameter. for name, rf_param in self._rf_module.named_parameters(recurse=False): pt_param = getattr(self, name) - assert isinstance(pt_param, torch.nn.Parameter) + if rf_param.auxiliary: + assert isinstance(pt_param, torch.Tensor) # but not torch.nn.Parameter + # See similar logic in torch.nn.Module._apply. + pt_param = torch.nn.Parameter(pt_param, pt_param.requires_grad) + else: + assert isinstance(pt_param, torch.nn.Parameter), ( + f"{self}.{name} is not a Parameter" f" but {type(pt_param).__name__}" + ) rf_param.raw_tensor = pt_param