Skip to content

Commit

Permalink
Rewrite _reparametrize_module to use contextmanager (pytorch#138203)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#138203
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#136033, pytorch#140604
  • Loading branch information
guilhermeleobas authored and pytorchmergebot committed Dec 20, 2024
1 parent 1c817fe commit 7bf3b7c
Showing 1 changed file with 46 additions and 64 deletions.
110 changes: 46 additions & 64 deletions torch/nn/utils/stateless.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Any, Dict, Optional, Set, Tuple, Union
from typing_extensions import deprecated

Expand Down Expand Up @@ -94,89 +95,70 @@ def _untie_named_tensors_map(
return untied_parameters_and_buffers


class _ReparametrizeModule:
def __init__(
self,
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
):
self.parameters_and_buffers = parameters_and_buffers
self.stack_weights = stack_weights
@contextlib.contextmanager
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
):
parameters_and_buffers = parameters_and_buffers
stack_weights = stack_weights

if tie_weights:
self.untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
self.untied_parameters_and_buffers = parameters_and_buffers
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
untied_parameters_and_buffers = parameters_and_buffers

self.accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = self.accessor.check_keys(
self.untied_parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = accessor.check_keys(
untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
)
if len(missing_keys) > 0:
error_msgs.append(
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
if len(missing_keys) > 0:
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
)

def __enter__(self):
self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.untied_parameters_and_buffers, allow_missing=True
orig_parameters_and_buffers: Dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
)

def __exit__(self, exception_type, exception_value, traceback):
if self.stack_weights:
yield
finally:
if stack_weights:
# When stacking is enabled, we will restore the weights in LIFO order.
self.orig_parameters_and_buffers = dict(
reversed(self.orig_parameters_and_buffers.items())
orig_parameters_and_buffers = dict(
reversed(orig_parameters_and_buffers.items())
)
new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.orig_parameters_and_buffers, allow_missing=True
new_parameters_and_buffers, _ = accessor.swap_tensors_dict(
orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
self.parameters_and_buffers.update(
parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in self.parameters_and_buffers
for k in parameters_and_buffers
if k in new_parameters_and_buffers
}
)


def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
*,
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
) -> _ReparametrizeModule:
return _ReparametrizeModule(
module,
parameters_and_buffers,
tie_weights=tie_weights,
strict=strict,
stack_weights=stack_weights,
)


@deprecated(
"`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 "
"and will be removed in a future version of PyTorch. "
Expand Down

0 comments on commit 7bf3b7c

Please sign in to comment.