From 4cba3a3d4899b26793f2a16d0e58d3de82a67eba Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Mon, 4 Mar 2024 19:15:40 -0500 Subject: [PATCH] Add context managers for temporary attribute and evaluation mode --- bnpm/misc.py | 30 ++++++++++++++++++++++++++++++ bnpm/torch_helpers.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/bnpm/misc.py b/bnpm/misc.py index 26f9a12..dde4450 100644 --- a/bnpm/misc.py +++ b/bnpm/misc.py @@ -5,6 +5,7 @@ from pathlib import Path import warnings from typing import Callable, List, Any, Dict +from contextlib import contextmanager def estimate_array_size( @@ -391,6 +392,35 @@ def array_hasher(): return partial(xxhash.xxh64_hexdigest, seed=0) +@contextmanager +def temp_set_attr(obj, attr_name, new_value): + """ + Temporarily set an attribute of an object to a new value within a context + manager / closure. + RH 2024 + + Args: + obj (object): + Object to toggle attribute for. + attr_name (str): + Attribute to toggle. + new_value (Any): + New value to set attribute to. + + Demo: + .. code-block:: python + + with temp_set_attr(obj, attr, new_val): + # do something + """ + original_value = getattr(obj, attr_name) + setattr(obj, attr_name, new_value) + try: + yield + finally: + setattr(obj, attr_name, original_value) + + ######################################################### ############ INTRA-MODULE HELPER FUNCTIONS ############## ######################################################### diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index 9a1f2d2..10c8e3c 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -2,6 +2,7 @@ import gc import copy from typing import Union, List, Tuple, Dict, Callable, Optional, Any +from contextlib import contextmanager import torch from torch.utils.data import Dataset @@ -50,6 +51,37 @@ def show_all_tensors( print(string) + +@contextmanager +def temp_eval(module): + """ + Temporarily sets the network to evaluation mode within a context manager. + RH 2024 + + Args: + module (torch.nn.Module): + The network to temporarily set to evaluation mode. + + Yields: + (torch.nn.Module): + The network temporarily set to evaluation mode. + + Demo: + .. highlight:: python + .. code-block:: python + + with temp_eval(model): + y = model(x) + """ + state_train = module.training + module.eval() + try: + yield module + finally: + if state_train: + module.train() + + ###################################### ############ CUDA STUFF ############## ######################################