Skip to content

Commit

Permalink
RF fix incompat optimizer state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 16, 2023
1 parent f7efe87 commit b48fde5
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions returnn/torch/frontend/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@ def pt_module_to_wrapped_rf_module(pt_module: torch.nn.Module) -> Optional[rf.Mo
return None


def rf_module_to_pt_module(rf_module: rf.Module) -> torch.nn.Module:
def rf_module_to_pt_module(rf_module: rf.Module, *, aux_params_as_buffers: bool = False) -> torch.nn.Module:
"""
:param rf_module: RF module
:param aux_params_as_buffers: whether to map RF auxiliary parameters to PyTorch buffers,
otherwise to normal parameters, i.e. they occur in model.named_parameters().
Note that even when they are part of model.named_parameters(),
aux params usually don't have a gradient, and then they are not updated by the optimizer.
Historically, this was False. For now, we keep that default
because optimizer state dicts are not compatible otherwise.
:return: torch module
"""
assert isinstance(rf_module, rf.Module)
if isinstance(rf_module, _PTModuleAsRFModule):
return rf_module.pt_module
return _RFModuleAsPTModule(rf_module=rf_module)
return _RFModuleAsPTModule(rf_module=rf_module, aux_params_as_buffers=aux_params_as_buffers)


class _PTModuleAsRFModule(rf.Module):
Expand Down Expand Up @@ -83,21 +89,22 @@ def __call__(self, *args, **kwargs):


class _RFModuleAsPTModule(torch.nn.Module):
def __init__(self, rf_module: rf.Module):
def __init__(self, rf_module: rf.Module, *, aux_params_as_buffers: bool):
super().__init__()
self._rf_module = rf_module
self._aux_params_as_buffers = aux_params_as_buffers

# recurse=False because param names cannot contain "."
for name, rf_param in rf_module.named_parameters(recurse=False):
pt_param = rf_param.raw_tensor
assert isinstance(pt_param, torch.nn.Parameter)
if rf_param.auxiliary:
if rf_param.auxiliary and aux_params_as_buffers:
self.register_buffer(name, pt_param)
else:
self.register_parameter(name, pt_param)

for name, rf_mod in rf_module.named_children():
pt_mod = rf_module_to_pt_module(rf_mod)
pt_mod = rf_module_to_pt_module(rf_mod, aux_params_as_buffers=aux_params_as_buffers)
self.add_module(name, pt_mod)

def _get_name(self):
Expand All @@ -121,7 +128,7 @@ def _apply(self, fn):
# Update the corresponding RF Parameter.
for name, rf_param in self._rf_module.named_parameters(recurse=False):
pt_param = getattr(self, name)
if rf_param.auxiliary:
if rf_param.auxiliary and self._aux_params_as_buffers:
assert isinstance(pt_param, torch.Tensor) # but not torch.nn.Parameter
# See similar logic in torch.nn.Module._apply.
pt_param = torch.nn.Parameter(pt_param, pt_param.requires_grad)
Expand Down

0 comments on commit b48fde5

Please sign in to comment.