Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into master-patched
Browse files Browse the repository at this point in the history
  • Loading branch information
xwang233 committed Jun 18, 2024
2 parents a92242b + 427b3e4 commit 933a611
Show file tree
Hide file tree
Showing 20 changed files with 81 additions and 37 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

| model |top1 |top1_err|top5 |top5_err|param_count|img_size|
|--------------------------------------------------------------------------------------------------|------|--------|------|--------|-----------|--------|
| [mobilenetv4_hybrid_large.e600_r384_in1k](http://hf.co/timm/mobilenetv4_hybrid_large.e600_r384_in1k) |84.266|15.734 |96.936 |3.064 |37.76 |448 |
| [mobilenetv4_hybrid_large.e600_r384_in1k](http://hf.co/timm/mobilenetv4_hybrid_large.e600_r384_in1k) |83.800|16.200 |96.770 |3.230 |37.76 |384 |
| [mobilenetv4_conv_large.e600_r384_in1k](http://hf.co/timm/mobilenetv4_conv_large.e600_r384_in1k) |83.392|16.608 |96.622 |3.378 |32.59 |448 |
| [mobilenetv4_conv_large.e600_r384_in1k](http://hf.co/timm/mobilenetv4_conv_large.e600_r384_in1k) |82.952|17.048 |96.266 |3.734 |32.59 |384 |
| [mobilenetv4_conv_large.e500_r256_in1k](http://hf.co/timm/mobilenetv4_conv_large.e500_r256_in1k) |82.674|17.326 |96.31 |3.69 |32.59 |320 |
Expand All @@ -43,7 +45,9 @@
| [mobilenetv4_conv_medium.e500_r224_in1k](http://hf.co/timm/mobilenetv4_conv_medium.e500_r224_in1k) |79.808|20.192 |95.186|4.814 |9.72 |256 |
| [mobilenetv4_conv_blur_medium.e500_r224_in1k](http://hf.co/timm/mobilenetv4_conv_blur_medium.e500_r224_in1k) |79.438|20.562 |94.932|5.068 |9.72 |224 |
| [mobilenetv4_conv_medium.e500_r224_in1k](http://hf.co/timm/mobilenetv4_conv_medium.e500_r224_in1k) |79.094|20.906 |94.77 |5.23 |9.72 |224 |
| [mobilenetv4_conv_small.e2400_r224_in1k](http://hf.co/timm/mobilenetv4_conv_small.e2400_r224_in1k) |74.616|25.384 |92.072|7.928 |3.77 |256 |
| [mobilenetv4_conv_small.e1200_r224_in1k](http://hf.co/timm/mobilenetv4_conv_small.e1200_r224_in1k) |74.292|25.708 |92.116|7.884 |3.77 |256 |
| [mobilenetv4_conv_small.e2400_r224_in1k](http://hf.co/timm/mobilenetv4_conv_small.e2400_r224_in1k) |73.756|26.244 |91.422|8.578 |3.77 |224 |
| [mobilenetv4_conv_small.e1200_r224_in1k](http://hf.co/timm/mobilenetv4_conv_small.e1200_r224_in1k) |73.454|26.546 |91.34 |8.66 |3.77 |224 |

* Apple MobileCLIP (https://arxiv.org/pdf/2311.17049, FastViT and ViT-B) image tower model support & weights added (part of OpenCLIP support).
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
# torchvision
pyyaml
huggingface_hub
safetensors==0.3.2 # 0.3.3 doesn't have aarch64 pre-built binary from pypi, causing ARM build to fail, https://github.com/huggingface/safetensors/issues/346
safetensors>=0.2
numpy<2.0
2 changes: 1 addition & 1 deletion timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.classifier

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/inception_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
6 changes: 3 additions & 3 deletions timm/models/metaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(
# if using MlpHead, dropout is handled by MlpHead
if num_classes > 0:
if self.use_mlp_head:
# FIXME hidden size
# FIXME not actually returning mlp hidden state right now as pre-logits.
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
self.head_hidden_size = self.num_features
else:
Expand Down Expand Up @@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
Expand Down
7 changes: 4 additions & 3 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,12 +1024,13 @@ def _cfg(url: str = '', **kwargs):
'mobilenetv4_hybrid_medium.e500_r224_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 256, 256), test_crop_pct=1.0, interpolation='bicubic'),
'mobilenetv4_hybrid_large.e600_r384_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12),
crop_pct=0.95, test_input_size=(3, 448, 448), test_crop_pct=1.0, interpolation='bicubic'),
'mobilenetv4_hybrid_large.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_large.r384': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),

# experimental
'mobilenetv4_conv_aa_medium.untrained': _cfg(
Expand Down
2 changes: 1 addition & 1 deletion timm/models/nasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/pnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_intermediates(
Expand Down
3 changes: 2 additions & 1 deletion timm/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from functools import partial
from math import ceil
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/selecsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.last_linear

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
Expand Down
48 changes: 38 additions & 10 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._forward(x)


def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x

if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'

return x


class VisionTransformer(nn.Module):
""" Vision Transformer
Expand All @@ -400,7 +425,7 @@ def __init__(
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
Expand Down Expand Up @@ -459,10 +484,10 @@ def __init__(
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU

Expand Down Expand Up @@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool = None) -> None:
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
Expand Down Expand Up @@ -756,13 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
return x
pool_type = self.global_pool if pool_type is None else pool_type
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
return x

def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.pool(x)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_relpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
21 changes: 15 additions & 6 deletions timm/models/vovnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""

from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -134,9 +134,17 @@ def __init__(
else:
drop_path = None
blocks += [OsaBlock(
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
]
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=residual and i > 0,
depthwise=depthwise,
attn=attn if last_block else '',
norm_layer=norm_layer,
act_layer=act_layer,
drop_path=drop_path
)]
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)

Expand Down Expand Up @@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/xception_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_features(self, x):
Expand Down

0 comments on commit 933a611

Please sign in to comment.