Skip to content

Commit

Permalink
Add context managers for temporary attribute and evaluation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Mar 5, 2024
1 parent 6862801 commit 4cba3a3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
30 changes: 30 additions & 0 deletions bnpm/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
import warnings
from typing import Callable, List, Any, Dict
from contextlib import contextmanager


def estimate_array_size(
Expand Down Expand Up @@ -391,6 +392,35 @@ def array_hasher():
return partial(xxhash.xxh64_hexdigest, seed=0)


@contextmanager
def temp_set_attr(obj, attr_name, new_value):
"""
Temporarily set an attribute of an object to a new value within a context
manager / closure.
RH 2024
Args:
obj (object):
Object to toggle attribute for.
attr_name (str):
Attribute to toggle.
new_value (Any):
New value to set attribute to.
Demo:
.. code-block:: python
with temp_set_attr(obj, attr, new_val):
# do something
"""
original_value = getattr(obj, attr_name)
setattr(obj, attr_name, new_value)
try:
yield
finally:
setattr(obj, attr_name, original_value)


#########################################################
############ INTRA-MODULE HELPER FUNCTIONS ##############
#########################################################
Expand Down
32 changes: 32 additions & 0 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import copy
from typing import Union, List, Tuple, Dict, Callable, Optional, Any
from contextlib import contextmanager

import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -50,6 +51,37 @@ def show_all_tensors(
print(string)



@contextmanager
def temp_eval(module):
"""
Temporarily sets the network to evaluation mode within a context manager.
RH 2024
Args:
module (torch.nn.Module):
The network to temporarily set to evaluation mode.
Yields:
(torch.nn.Module):
The network temporarily set to evaluation mode.
Demo:
.. highlight:: python
.. code-block:: python
with temp_eval(model):
y = model(x)
"""
state_train = module.training
module.eval()
try:
yield module
finally:
if state_train:
module.train()


######################################
############ CUDA STUFF ##############
######################################
Expand Down

0 comments on commit 4cba3a3

Please sign in to comment.