From 2b251fb291e7c2b62eb0177710318b79d6828b3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 6 Jan 2025 11:28:39 -0800 Subject: [PATCH 1/2] Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override --- timm/layers/__init__.py | 3 +- timm/layers/config.py | 18 +++++++- timm/models/_features.py | 3 +- timm/models/_manipulate.py | 55 ++++++++++++++++++------ timm/models/beit.py | 3 +- timm/models/densenet.py | 5 +-- timm/models/efficientnet.py | 3 +- timm/models/eva.py | 2 +- timm/models/focalnet.py | 3 +- timm/models/gcvit.py | 3 +- timm/models/hiera.py | 3 +- timm/models/mobilenetv3.py | 3 +- timm/models/mvitv2.py | 4 +- timm/models/pvt_v2.py | 2 +- timm/models/swin_transformer_v2.py | 4 +- timm/models/swin_transformer_v2_cr.py | 3 +- timm/models/tnt.py | 13 +++--- timm/models/vision_transformer.py | 2 - timm/models/vision_transformer_relpos.py | 3 +- timm/models/vision_transformer_sam.py | 3 +- timm/models/volo.py | 2 +- timm/models/xcit.py | 5 +-- 22 files changed, 91 insertions(+), 54 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 5ec03219e8..c71ff30c82 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -8,7 +8,8 @@ from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ - set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn + set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \ + set_reentrant_ckpt, use_reentrant_ckpt from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn diff --git a/timm/layers/config.py b/timm/layers/config.py index 47d5d0a341..f69f380317 100644 --- a/timm/layers/config.py +++ b/timm/layers/config.py @@ -8,7 +8,8 @@ __all__ = [ 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', - 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn', + 'set_reentrant_ckpt', 'use_reentrant_ckpt' ] # Set to True if prefer to have layers with no jit optimization (includes activations) @@ -34,6 +35,12 @@ _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) +if 'TIMM_REENTRANT_CKPT' in os.environ: + _USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT']) +else: + _USE_REENTRANT_CKPT = False # defaults to disabled (off) + + def is_no_jit(): return _NO_JIT @@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False): _USE_FUSED_ATTN = 1 else: _USE_FUSED_ATTN = 0 + + +def use_reentrant_ckpt() -> bool: + return _USE_REENTRANT_CKPT + + +def set_reentrant_ckpt(enable: bool = True): + global _USE_REENTRANT_CKPT + _USE_REENTRANT_CKPT = enable \ No newline at end of file diff --git a/timm/models/_features.py b/timm/models/_features.py index 14d174f5d4..08ff3aa1f4 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -15,10 +15,9 @@ import torch import torch.nn as nn -from torch.utils.checkpoint import checkpoint from timm.layers import Format, _assert - +from ._manipulate import checkpoint __all__ = [ 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet', diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index e689b39276..f40ff9ac47 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -3,14 +3,17 @@ import re from collections import defaultdict from itertools import chain -from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union import torch +import torch.utils.checkpoint from torch import nn as nn -from torch.utils.checkpoint import checkpoint + +from timm.layers import use_reentrant_ckpt + __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', - 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] + 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint'] def model_parameters(model: nn.Module, exclude_head: bool = False): @@ -183,13 +186,35 @@ def flatten_modules( yield name, module +def checkpoint( + function, + *args, + use_reentrant: Optional[bool] = None, + **kwargs, +): + """ checkpoint wrapper fn + + A thin wrapper around torch.utils.checkpoint.checkpoint to default + use_reentrant to False + """ + if use_reentrant is None: + use_reentrant = use_reentrant_ckpt() + + return torch.utils.checkpoint.checkpoint( + function, + *args, + use_reentrant=use_reentrant, + **kwargs, + ) + + def checkpoint_seq( functions, x, - every=1, - flatten=False, - skip_last=False, - preserve_rng_state=True + every: int = 1, + flatten: bool = False, + skip_last: bool = False, + use_reentrant: Optional[bool] = None, ): r"""A helper function for checkpointing sequential models. @@ -215,10 +240,9 @@ def checkpoint_seq( functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. x: A Tensor that is input to :attr:`functions` every: checkpoint every-n functions (default: 1) - flatten (bool): flatten nn.Sequential of nn.Sequentials - skip_last (bool): skip checkpointing the last function in the sequence if True - preserve_rng_state (bool, optional, default=True): Omit stashing and restoring - the RNG state during each checkpoint. + flatten: flatten nn.Sequential of nn.Sequentials + skip_last: skip checkpointing the last function in the sequence if True + use_reentrant: Use re-entrant checkpointing Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` @@ -227,6 +251,9 @@ def checkpoint_seq( >>> model = nn.Sequential(...) >>> input_var = checkpoint_seq(model, input_var, every=2) """ + if use_reentrant is None: + use_reentrant = use_reentrant_ckpt() + def run_function(start, end, functions): def forward(_x): for j in range(start, end + 1): @@ -247,7 +274,11 @@ def forward(_x): end = -1 for start in range(0, num_checkpointed, every): end = min(start + every - 1, num_checkpointed - 1) - x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) + x = torch.utils.checkpoint.checkpoint( + run_function(start, end, functions), + x, + use_reentrant=use_reentrant, + ) if skip_last: return run_function(end + 1, len(functions) - 1, functions)(x) return x diff --git a/timm/models/beit.py b/timm/models/beit.py index c47ea395e8..5123a60627 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -44,15 +44,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid - from ._builder import build_model_with_cfg from ._features import feature_take_indices +from ._manipulate import checkpoint from ._registry import generate_default_cfgs, register_model __all__ = ['Beit'] diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 31d1f73f9c..d522965907 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -8,13 +8,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier from ._builder import build_model_with_cfg -from ._manipulate import MATCH_PREV_GROUP +from ._manipulate import MATCH_PREV_GROUP, checkpoint from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['DenseNet'] @@ -60,7 +59,7 @@ def call_checkpoint_bottleneck(self, x): def closure(*xs): return self.bottleneck_fn(xs) - return cp.checkpoint(closure, *x) + return checkpoint(closure, *x) @torch.jit._overload_method # noqa: F811 def forward(self, x): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 07bb250c84..b5bc35c036 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -41,7 +41,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \ @@ -51,7 +50,7 @@ from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks, feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint_seq, checkpoint from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['EfficientNet', 'EfficientNetFeatures'] diff --git a/timm/models/eva.py b/timm/models/eva.py index 552965947b..26c5278aa5 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -30,7 +30,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ @@ -39,6 +38,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices +from ._manipulate import checkpoint from ._registry import generate_default_cfgs, register_model __all__ = ['Eva'] diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index f747001c73..7a5e7401da 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -22,12 +22,11 @@ import torch import torch.nn as nn -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from ._builder import build_model_with_cfg -from ._manipulate import named_apply +from ._manipulate import named_apply, checkpoint from ._registry import generate_default_cfgs, register_model __all__ = ['FocalNet'] diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 44660a3f6c..b31b5768bd 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -25,14 +25,13 @@ import torch import torch.nn as nn -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function -from ._manipulate import named_apply +from ._manipulate import named_apply, checkpoint from ._registry import register_model, generate_default_cfgs __all__ = ['GlobalContextVit'] diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 34d6670fbe..bd38bf8866 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -29,7 +29,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ @@ -39,7 +38,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import named_apply +from ._manipulate import named_apply, checkpoint __all__ = ['Hiera'] diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 79a51f7729..08dcb064fa 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -12,7 +12,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer @@ -21,7 +20,7 @@ from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks, feature_take_indices -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint_seq, checkpoint from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 167ebb9e82..f790fd0d13 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -20,7 +20,6 @@ from typing import Union, List, Tuple, Optional import torch -import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -28,7 +27,8 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._registry import register_model, register_model_deprecations, generate_default_cfgs +from ._manipulate import checkpoint +from ._registry import register_model, generate_default_cfgs __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 9200bbd451..dd7011b042 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -21,11 +21,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from ._builder import build_model_with_cfg +from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs __all__ = ['PyramidVisionTransformerV2'] diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 2174b4840f..652efa3bdb 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -18,14 +18,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\ resample_patch_embed, ndgrid, get_act_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function +from ._manipulate import checkpoint from ._registry import generate_default_cfgs, register_model, register_model_deprecations __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index dceb3d5040..d8d247cdeb 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -34,14 +34,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import named_apply +from ._manipulate import named_apply, checkpoint from ._registry import generate_default_cfgs, register_model __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 9e37770ac7..d97cfaae3e 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -11,13 +11,13 @@ import torch import torch.nn as nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple +from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed from ._builder import build_model_with_cfg +from ._manipulate import checkpoint from ._registry import register_model -from .vision_transformer import resize_pos_embed + __all__ = ['TNT'] # model_registry will add each entrypoint fn to this @@ -340,8 +340,11 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" if state_dict['patch_pos'].shape != model.patch_pos.shape: - state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], - model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + state_dict['patch_pos'] = resample_abs_pos_embed( + state_dict['patch_pos'], + new_size=model.pixel_embed.grid_size, + num_prefix_tokens=1, + ) return state_dict diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2368d353b3..9ddf8eb37e 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -37,7 +37,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ @@ -1019,7 +1018,6 @@ def _n2p(_w, t=True, idx=None): else: pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: - old_shape = pos_embed_w.shape num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 234195973f..030c24dc69 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -17,13 +17,12 @@ import torch import torch.nn as nn from torch.jit import Final -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply +from ._manipulate import named_apply, checkpoint from ._registry import generate_default_cfgs, register_model from .vision_transformer import get_init_weights_vit diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index a57c166d75..3fb0b59f1f 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn @@ -25,7 +24,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq +from ._manipulate import checkpoint_seq, checkpoint from ._registry import generate_default_cfgs, register_model # model_registry will add each entrypoint fn to this diff --git a/timm/models/volo.py b/timm/models/volo.py index 0d273180fb..56fc73997c 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -26,12 +26,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._features import feature_take_indices +from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 1e902ac23f..e6cf87b789 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -17,16 +17,15 @@ import torch import torch.nn as nn -from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn +from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module +from ._manipulate import checkpoint from ._registry import register_model, generate_default_cfgs, register_model_deprecations from .cait import ClassAttn -from .vision_transformer import Mlp __all__ = ['Xcit'] # model_registry will add each entrypoint fn to this From 155f6e7fea5e5d24446cb4d34cfd63b5b46653b4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 6 Jan 2025 13:09:15 -0800 Subject: [PATCH 2/2] Update README, few minor fixups. --- README.md | 3 +++ timm/layers/config.py | 2 +- timm/models/__init__.py | 2 +- timm/models/vision_transformer_sam.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index afccc02d73..7b1f06f06e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,9 @@ ## What's New +## Jan 6, 2025 +* Add `torch.utils.checkpoint.checkpoint()` wrapper in `timm.models` that defaults `use_reentrant=False`, unless `TIMM_REENTRANT_CKPT=1` is set in env. + ## Dec 31, 2024 * `convnext_nano` 384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384 * Add AIM-v2 encoders from https://github.com/apple/ml-aim, see on Hub: https://huggingface.co/models?search=timm%20aimv2 diff --git a/timm/layers/config.py b/timm/layers/config.py index f69f380317..e2a23b5ad7 100644 --- a/timm/layers/config.py +++ b/timm/layers/config.py @@ -162,4 +162,4 @@ def use_reentrant_ckpt() -> bool: def set_reentrant_ckpt(enable: bool = True): global _USE_REENTRANT_CKPT - _USE_REENTRANT_CKPT = enable \ No newline at end of file + _USE_REENTRANT_CKPT = enable diff --git a/timm/models/__init__.py b/timm/models/__init__.py index c5b1984f20..3db5af6049 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -91,7 +91,7 @@ from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ - group_modules, group_parameters, checkpoint_seq, adapt_input_conv + group_modules, group_parameters, checkpoint_seq, checkpoint, adapt_input_conv from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg from ._prune import adapt_model_from_string from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \ diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 3fb0b59f1f..3979b3b4a1 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -24,7 +24,7 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq, checkpoint +from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model # model_registry will add each entrypoint fn to this