diff --git a/returnn/torch/frontend/bridge.py b/returnn/torch/frontend/bridge.py index 0279dbaf37..d11bd6e2e2 100644 --- a/returnn/torch/frontend/bridge.py +++ b/returnn/torch/frontend/bridge.py @@ -34,15 +34,21 @@ def pt_module_to_wrapped_rf_module(pt_module: torch.nn.Module) -> Optional[rf.Mo return None -def rf_module_to_pt_module(rf_module: rf.Module) -> torch.nn.Module: +def rf_module_to_pt_module(rf_module: rf.Module, *, aux_params_as_buffers: bool = False) -> torch.nn.Module: """ :param rf_module: RF module + :param aux_params_as_buffers: whether to map RF auxiliary parameters to PyTorch buffers, + otherwise to normal parameters, i.e. they occur in model.named_parameters(). + Note that even when they are part of model.named_parameters(), + aux params usually don't have a gradient, and then they are not updated by the optimizer. + Historically, this was False. For now, we keep that default + because optimizer state dicts are not compatible otherwise. :return: torch module """ assert isinstance(rf_module, rf.Module) if isinstance(rf_module, _PTModuleAsRFModule): return rf_module.pt_module - return _RFModuleAsPTModule(rf_module=rf_module) + return _RFModuleAsPTModule(rf_module=rf_module, aux_params_as_buffers=aux_params_as_buffers) class _PTModuleAsRFModule(rf.Module): @@ -83,21 +89,22 @@ def __call__(self, *args, **kwargs): class _RFModuleAsPTModule(torch.nn.Module): - def __init__(self, rf_module: rf.Module): + def __init__(self, rf_module: rf.Module, *, aux_params_as_buffers: bool): super().__init__() self._rf_module = rf_module + self._aux_params_as_buffers = aux_params_as_buffers # recurse=False because param names cannot contain "." for name, rf_param in rf_module.named_parameters(recurse=False): pt_param = rf_param.raw_tensor assert isinstance(pt_param, torch.nn.Parameter) - if rf_param.auxiliary: + if rf_param.auxiliary and aux_params_as_buffers: self.register_buffer(name, pt_param) else: self.register_parameter(name, pt_param) for name, rf_mod in rf_module.named_children(): - pt_mod = rf_module_to_pt_module(rf_mod) + pt_mod = rf_module_to_pt_module(rf_mod, aux_params_as_buffers=aux_params_as_buffers) self.add_module(name, pt_mod) def _get_name(self): @@ -121,7 +128,7 @@ def _apply(self, fn): # Update the corresponding RF Parameter. for name, rf_param in self._rf_module.named_parameters(recurse=False): pt_param = getattr(self, name) - if rf_param.auxiliary: + if rf_param.auxiliary and self._aux_params_as_buffers: assert isinstance(pt_param, torch.Tensor) # but not torch.nn.Parameter # See similar logic in torch.nn.Module._apply. pt_param = torch.nn.Parameter(pt_param, pt_param.requires_grad)