- Sponsors
- What's New
- Introduction
- Models
- Features
- Results
- Getting Started (Documentation)
- Train, Validation, Inference Scripts
- Awesome PyTorch Resources
- Licenses
- Citing
A big thank you to my GitHub Sponsors for their support!
In addition to the sponsors at the link above, I've received hardware and/or cloud resources from
- Nvidia (https://www.nvidia.com/en-us/)
- TFRC (https://www.tensorflow.org/tfrc)
I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs.
- Merge
norm_norm_norm
. IMPORTANT this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch0.5.x
or a previous 0.5.x release can be used if stability is required. - Significant weights update (all TPU trained) as described in this release
regnety_040
- 82.3 @ 224, 82.96 @ 288regnety_064
- 83.0 @ 224, 83.65 @ 288regnety_080
- 83.17 @ 224, 83.86 @ 288regnetv_040
- 82.44 @ 224, 83.18 @ 288 (timm pre-act)regnetv_064
- 83.1 @ 224, 83.71 @ 288 (timm pre-act)regnetz_040
- 83.67 @ 256, 84.25 @ 320regnetz_040h
- 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)resnetv2_50d_gn
- 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)resnetv2_50d_evos
80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)regnetz_c16_evos
- 81.9 @ 256, 82.64 @ 320 (EvoNormS)regnetz_d8_evos
- 83.42 @ 256, 84.04 @ 320 (EvoNormS)xception41p
- 82 @ 299 (timm pre-act)xception65
- 83.17 @ 299xception65p
- 83.14 @ 299 (timm pre-act)resnext101_64x4d
- 82.46 @ 224, 83.16 @ 288seresnext101_32x8d
- 83.57 @ 224, 84.270 @ 288resnetrs200
- 83.85 @ 256, 84.44 @ 320
- HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
- SwinTransformer-V2 implementation added. Submitted by Christoph Reich. Training experiments and model changes by myself are ongoing so expect compat breaks.
- MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets (
- PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
- VOLO models w/ weights adapted from https://github.com/sail-sg/volo
- Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
- Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception
- Grouped conv support added to EfficientNet family
- Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler
- Gradient checkpointing support added to many models
forward_head(x, pre_logits=False)
fn added to all models to allow separate calls offorward_features
+forward_head
- All vision transformer and vision MLP models update to return non-pooled / non-token selected features from
foward_features
, for consistency with CNN models, token selection or pooling now applied inforward_head
- Chris Hughes posted an exhaustive run through of
timm
on his blog yesterday. Well worth a read. Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide - I'm currently prepping to merge the
norm_norm_norm
branch back to master (ver 0.6.x) in next week or so.- The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware
pip install git+https://github.com/rwightman/pytorch-image-models
installs! 0.5.x
releases and a0.5.x
branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable.
- The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware
- Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
- Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
- Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
mnasnet_small
- 65.6 top-1mobilenetv2_050
- 65.9lcnet_100/075/050
- 72.1 / 68.8 / 63.1semnasnet_075
- 73fbnetv3_b/d/g
- 79.1 / 79.7 / 82.0
- TinyNet models added by rsomani95
- LCNet added via MobileNetV3 architecture
- A number of updated weights anew new model defs
eca_halonext26ts
- 79.5 @ 256resnet50_gn
(new) - 80.1 @ 224, 81.3 @ 288resnet50
- 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, weights)resnext50_32x4d
- 81.1 @ 224, 82.0 @ 288sebotnet33ts_256
(new) - 81.2 @ 224lamhalobotnet50ts_256
- 81.5 @ 256halonet50ts
- 81.7 @ 256halo2botnet50ts_256
- 82.0 @ 256resnet101
- 82.0 @ 224, 82.8 @ 288resnetv2_101
(new) - 82.1 @ 224, 83.0 @ 288resnet152
- 82.8 @ 224, 83.5 @ 288regnetz_d8
(new) - 83.5 @ 256, 84.0 @ 320regnetz_e8
(new) - 84.5 @ 256, 85.0 @ 320
vit_base_patch8_224
(85.8 top-1) &in21k
variant weights added thanks Martins Bruveris- Groundwork in for FX feature extraction thanks to Alexander Soare
- models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
- ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
- BCE loss and Repeated Augmentation support for RSB paper
- 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
- Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
- Halo (https://arxiv.org/abs/2103.12731)
- Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
- LambdaNetworks (https://arxiv.org/abs/2102.08602)
- A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
- ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
- freeze/unfreeze helpers by Alexander Soare
- Optimizer bonanza!
- Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/
timm bits
branch) - Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
- Some cleanup on all optimizers and factory. No more
.data
, a bit more consistency, unit tests for all! - SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
- Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/
- EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
- Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
- Add XCiT models from official facebook impl. Contributed by Alexander Soare
- Add
efficientnetv2_rw_t
weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)- top-1 82.34 @ 288x288 and 82.54 @ 320x320
- Add SAM pretrained in1k weight for ViT B/16 (
vit_base_patch16_sam_224
) and B/32 (vit_base_patch32_sam_224
) models. - Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official Flax impl. Contributed by Alexander Soare.
jx_nest_base
- 83.534,jx_nest_small
- 83.120,jx_nest_tiny
- 81.426
- Reproduce gMLP model training,
gmlp_s16_224
trained to 79.6 top-1, matching paper. Hparams for this and other recent MLP training here
- Release Vision Transformer 'AugReg' weights from How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
- .npz weight loading support added, can load any of the 50K+ weights from the AugReg series
- See example notebook from official impl for navigating the augreg weights
- Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
- Highlights:
vit_large_patch16_384
(87.1 top-1),vit_large_r50_s32_384
(86.2 top-1),vit_base_patch16_384
(86.0 top-1)
- Highlights:
vit_deit_*
renamed to justdeit_*
- Remove my old small model, replace with DeiT compatible small w/ AugReg weights
- Add 1st training of my
gmixer_24_224
MLP /w GLU, 78.1 top-1 w/ 25M params. - Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
- Add
eca_nfnet_l2
weights from my 'lightweight' series. 84.7 top-1 at 384x384. - Add distilled BiT 50x1 student and 152x2 Teacher weights from Knowledge distillation: A good teacher is patient and consistent
- NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
- weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
- eps values adjusted, will be slight differences but should be quite close
- Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
- Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
- Please report any regressions, this PR touched quite a few models.
- Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
- Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
- NFNet inspired block layout with quad layer stem and no maxpool
- Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
- Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models
- Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl
- Fix a number of torchscript issues with various vision transformer models
- Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models
- More flexible pos embedding resize (non-square) for ViT and TnT. Thanks Alexander Soare
- Add
efficientnetv2_rw_m
model and weights (started training before official code). 84.8 top-1, 53M params.
- Add EfficientNet-V2 official model defs w/ ported weights from official Tensorflow/Keras impl.
- 1k trained variants:
tf_efficientnetv2_s/m/l
- 21k trained variants:
tf_efficientnetv2_s/m/l_in21k
- 21k pretrained -> 1k fine-tuned:
tf_efficientnetv2_s/m/l_in21ft1k
- v2 models w/ v1 scaling:
tf_efficientnetv2_b0
throughb3
- Rename my prev V2 guess
efficientnet_v2s
->efficientnetv2_rw_s
- Some blank
efficientnetv2_*
models in-place for future native PyTorch training
- 1k trained variants:
- Add MLP-Mixer models and port pretrained weights from Google JAX impl
- Add CaiT models and pretrained weights from FB
- Add ResNet-RS models and weights from TF. Thanks Aman Arora
- Add CoaT models and weights. Thanks Mohammed Rizin
- Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks mrT
- Add GhostNet models and weights. Thanks Kai Han
- Update ByoaNet attention modules
- Improve SA module inits
- Hack together experimental stand-alone Swin based attn module and
swinnet
- Consistent '26t' model defs for experiments.
- Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
- WandB logging support
- Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
- Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
- Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
- Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
- Lambda Networks - https://arxiv.org/abs/2102.08602
- Bottleneck Transformers - https://arxiv.org/abs/2101.11605
- Halo Nets - https://arxiv.org/abs/2103.12731
- Adabelief optimizer contributed by Juntang Zhuang
- Add snazzy
benchmark.py
script for bulktimm
model benchmarking of train and/or inference - Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
- Merged distilled variant into main for torchscript compatibility
- Some
timm
cleanup/style tweaks and weights have hub download support
- Cleanup Vision Transformer (ViT) models
- Merge distilled (DeiT) model into main so that torchscript can work
- Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
- Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
- Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
- nn.Sequential for block stack (does not break downstream compat)
- TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
- Add RegNetY-160 weights from DeiT teacher model
- Add new NFNet-L0 w/ SE attn (rename
nfnet_l0b
->nfnet_l0
) weights 82.75 top-1 @ 288x288 - Some fixes/improvements for TFDS dataset wrapper
- Add new ECA-NFNet-L0 (rename
nfnet_l0c
->eca_nfnet_l0
) weights trained by myself.- 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
- Uses SiLU activation, approx 2x faster than
dm_nfnet_f0
and 50% faster thannfnet_f0s
w/ 1/3 param count
- Integrate Hugging Face model hub into timm create_model and default_cfg handling for pretrained weight and config sharing (more on this soon!)
- Merge HardCoRe NAS models contributed by https://github.com/yoniaflalo
- Merge PyTorch trained EfficientNet-EL and pruned ES/EL variants contributed by DeGirum
- First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
- Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
- Tested with PyTorch 1.8 release. Updated CI to use 1.8.
- Benchmarked several arch on RTX 3090, Titan RTX, and V100 across 1.7.1, 1.8, NGC 20.12, and 21.02. Some interesting performance variations to take note of https://gist.github.com/rwightman/bb59f9e245162cee0e38bd66bd8cd77f
- Add pretrained weights and model variants for NFNet-F* models from DeepMind Haiku impl.
- Models are prefixed with
dm_
. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. - These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized
s
variants. - Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
- Matching the original pre-processing as closely as possible I get these results:
dm_nfnet_f6
- 86.352dm_nfnet_f5
- 86.100dm_nfnet_f4
- 85.834dm_nfnet_f3
- 85.676dm_nfnet_f2
- 85.178dm_nfnet_f1
- 84.696dm_nfnet_f0
- 83.464
- Models are prefixed with
- Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
- AGC w/ default clipping factor
--clip-grad .01 --clip-mode agc
- PyTorch global norm of 1.0 (old behaviour, always norm),
--clip-grad 1.0
- PyTorch value clipping of 10,
--clip-grad 10. --clip-mode value
- AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
- AGC w/ default clipping factor
- Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
- First Normalization-Free model training experiments done,
- nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
- nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256
- More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
- GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in
byobnet.py
- RepVGG (https://github.com/DingXiaoH/RepVGG), impl in
byobnet.py
- classic VGG (from torchvision, impl in
vgg.py
)
- GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in
- Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
- Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with
--channels-last
and--torchscript
model training, APEX does not. - Fix a few bugs introduced since last pypi release
- Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
ecaresnet26t
- 79.88 top-1 @ 320x320, 79.08 @ 256x256ecaresnet50t
- 82.35 top-1 @ 320x320, 81.52 @ 256x256ecaresnet269d
- 84.93 top-1 @ 352x352, 84.87 @ 320x320
- Remove separate tiered (
t
) vs tiered_narrow (tn
) ResNet model defs, alltn
changed tot
andt
models removed (seresnext26t_32x4d
only model w/ weights that was removed). - Support model default_cfgs with separate train vs test resolution
test_input_size
and remove extra_320
suffix ResNet model defs that were just for test.
- Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on paper
- Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
- Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
- ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
- NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
- Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
- Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
- Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
- Ex:
train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2
- Ex:
- Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
- Run validation on full ImageNet-21k directly from tar w/ BiT model:
validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp
- Run validation on full ImageNet-21k directly from tar w/ BiT model:
- Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
- Add SE-ResNet-152D weights
- 256x256 val, 0.94 crop top-1 - 83.75
- 320x320 val, 1.0 crop - 84.36
- Update results files
PyTorch Image Models (timm
) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example training hparams to get you started.
A full version of the list below with source links can be found in the documentation.
- Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
- BEiT - https://arxiv.org/abs/2106.08254
- Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
- Bottleneck Transformers - https://arxiv.org/abs/2101.11605
- CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
- CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
- ConvNeXt - https://arxiv.org/abs/2201.03545
- ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
- CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
- DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
- DenseNet - https://arxiv.org/abs/1608.06993
- DLA - https://arxiv.org/abs/1707.06484
- DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
- EfficientNet (MBConvNet Family)
- EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
- EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
- EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946
- EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
- EfficientNet V2 - https://arxiv.org/abs/2104.00298
- FBNet-C - https://arxiv.org/abs/1812.03443
- MixNet - https://arxiv.org/abs/1907.09595
- MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
- MobileNet-V2 - https://arxiv.org/abs/1801.04381
- Single-Path NAS - https://arxiv.org/abs/1904.02877
- TinyNet - https://arxiv.org/abs/2010.14819
- GhostNet - https://arxiv.org/abs/1911.11907
- gMLP - https://arxiv.org/abs/2105.08050
- GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
- Halo Nets - https://arxiv.org/abs/2103.12731
- HRNet - https://arxiv.org/abs/1908.07919
- Inception-V3 - https://arxiv.org/abs/1512.00567
- Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
- Lambda Networks - https://arxiv.org/abs/2102.08602
- LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
- MLP-Mixer - https://arxiv.org/abs/2105.01601
- MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
- FBNet-V3 - https://arxiv.org/abs/2006.02049
- HardCoRe-NAS - https://arxiv.org/abs/2102.11646
- LCNet - https://arxiv.org/abs/2109.15099
- NASNet-A - https://arxiv.org/abs/1707.07012
- NesT - https://arxiv.org/abs/2105.12723
- NFNet-F - https://arxiv.org/abs/2102.06171
- NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
- PNasNet - https://arxiv.org/abs/1712.00559
- Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
- RegNet - https://arxiv.org/abs/2003.13678
- RepVGG - https://arxiv.org/abs/2101.03697
- ResMLP - https://arxiv.org/abs/2105.03404
- ResNet/ResNeXt
- ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
- ResNeXt - https://arxiv.org/abs/1611.05431
- 'Bag of Tricks' / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187
- Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932
- Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546
- ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4
- Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507
- ResNet-RS - https://arxiv.org/abs/2103.07579
- Res2Net - https://arxiv.org/abs/1904.01169
- ResNeSt - https://arxiv.org/abs/2004.08955
- ReXNet - https://arxiv.org/abs/2007.00992
- SelecSLS - https://arxiv.org/abs/1907.00837
- Selective Kernel Networks - https://arxiv.org/abs/1903.06586
- Swin Transformer - https://arxiv.org/abs/2103.14030
- Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
- TResNet - https://arxiv.org/abs/2003.13630
- Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
- Visformer - https://arxiv.org/abs/2104.12533
- Vision Transformer - https://arxiv.org/abs/2010.11929
- VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
- Xception - https://arxiv.org/abs/1610.02357
- Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
- Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
- XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681
Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
- All models have a common default configuration interface and API for
- accessing/changing the classifier -
get_classifier
andreset_classifier
- doing a forward pass on just the features -
forward_features
(see documentation) - these makes it easy to write consistent network wrappers that work with any of the models
- accessing/changing the classifier -
- All models support multi-scale feature map extraction (feature pyramids) via create_model (see documentation)
create_model(name, features_only=True, out_indices=..., output_stride=...)
out_indices
creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to theC(i + 1)
feature level.output_stride
creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.- feature map channel counts, reduction level (stride) can be queried AFTER model creation via the
.feature_info
member
- All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
- High performance reference training, validation, and inference scripts that work in several process/GPU modes:
- NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
- PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
- PyTorch w/ single GPU single process (AMP optional)
- A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
- A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
- Learning rate schedulers
- Ideas adopted from
- AllenNLP schedulers
- FAIRseq lr_scheduler
- SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
- Schedulers include
step
,cosine
w/ restarts,tanh
w/ restarts,plateau
- Ideas adopted from
- Optimizers:
rmsprop_tf
adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behaviour.radam
by Liyuan Liu (https://arxiv.org/abs/1908.03265)novograd
by Masashi Kimura (https://arxiv.org/abs/1905.11286)lookahead
adapted from impl by Liam (https://arxiv.org/abs/1907.08610)fused<name>
optimizers by name with NVIDIA Apex installedadamp
andsgdp
by Naver ClovAI (https://arxiv.org/abs/2006.08217)adafactor
adapted from FAIRSeq impl (https://arxiv.org/abs/1804.04235)adahessian
by David Samuel (https://arxiv.org/abs/2006.00719)
- Random Erasing from Zhun Zhong (https://arxiv.org/abs/1708.04896)
- Mixup (https://arxiv.org/abs/1710.09412)
- CutMix (https://arxiv.org/abs/1905.04899)
- AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
- AugMix w/ JSD loss (https://arxiv.org/abs/1912.02781), JSD w/ clean + augmented mixing support works with AutoAugment and RandAugment as well
- SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
- DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
- DropBlock (https://arxiv.org/abs/1810.12890)
- Blur Pooling (https://arxiv.org/abs/1904.11486)
- Space-to-Depth by mrT23 (https://arxiv.org/abs/1801.04590) -- original paper?
- Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
- An extensive selection of channel and/or spatial attention modules:
- Bottleneck Transformer - https://arxiv.org/abs/2101.11605
- CBAM - https://arxiv.org/abs/1807.06521
- Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
- Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
- Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
- Global Context (GC) - https://arxiv.org/abs/1904.11492
- Halo - https://arxiv.org/abs/2103.12731
- Involution - https://arxiv.org/abs/2103.06255
- Lambda Layer - https://arxiv.org/abs/2102.08602
- Non-Local (NL) - https://arxiv.org/abs/1711.07971
- Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
- Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
- Split (SPLAT) - https://arxiv.org/abs/2004.08955
- Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
Model validation results can be found in the documentation and in the results tables
My current documentation for timm
covers the basics.
Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide by Chris Hughes is an extensive blog post covering many aspects of timm
in detail.
timmdocs is quickly becoming a much more comprehensive set of documentation for timm
. A big thanks to Aman Arora for his efforts creating timmdocs.
paperswithcode is a good resource for browsing the models within timm
.
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See documentation for some basics and training hparams for some train examples that produce SOTA ImageNet results.
One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
- Detectron2 - https://github.com/facebookresearch/detectron2
- Segmentation Models (Semantic) - https://github.com/qubvel/segmentation_models.pytorch
- EfficientDet (Obj Det, Semantic soon) - https://github.com/rwightman/efficientdet-pytorch
- Albumentations - https://github.com/albumentations-team/albumentations
- Kornia - https://github.com/kornia/kornia
- RepDistiller - https://github.com/HobbitLong/RepDistiller
- torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
- PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning
- fastai - https://github.com/fastai/fastai
The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
@misc{rw2019timm,
author = {Ross Wightman},
title = {PyTorch Image Models},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
doi = {10.5281/zenodo.4414861},
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}