diff --git a/tests/test_layers.py b/tests/test_layers.py index 92f6b683d3..2cc8420abf 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from timm.layers import create_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn import importlib import os @@ -76,3 +76,46 @@ def test_hard_swish_grad(): def test_hard_mish_grad(): for _ in range(100): _run_act_layer_grad('hard_mish') + +def test_get_act_layer_empty_string(): + # Empty string should return None + assert get_act_layer('') is None + + +def test_create_act_layer_inplace_error(): + class NoInplaceAct(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return x + + # Should recover when inplace arg causes TypeError + layer = create_act_layer(NoInplaceAct, inplace=True) + assert isinstance(layer, NoInplaceAct) + + +def test_create_act_layer_edge_cases(): + # Test None input + assert create_act_layer(None) is None + + # Test TypeError handling for inplace + class CustomAct(nn.Module): + def __init__(self, **kwargs): + super().__init__() + def forward(self, x): + return x + + result = create_act_layer(CustomAct, inplace=True) + assert isinstance(result, CustomAct) + + +def test_get_act_fn_callable(): + def custom_act(x): + return x + assert get_act_fn(custom_act) is custom_act + + +def test_get_act_fn_none(): + assert get_act_fn(None) is None + assert get_act_fn('') is None + diff --git a/tests/test_optim.py b/tests/test_optim.py index 38f625fb42..66aaadbf95 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -11,6 +11,8 @@ import torch from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter + +from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay from timm.scheduler import PlateauLRScheduler from timm.optim import create_optimizer_v2 @@ -741,3 +743,82 @@ def test_lookahead_radam(optimizer): lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) ) + +def test_param_groups_layer_decay_with_end_decay(): + model = torch.nn.Sequential( + torch.nn.Linear(10, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 2) + ) + + param_groups = param_groups_layer_decay( + model, + weight_decay=0.05, + layer_decay=0.75, + end_layer_decay=0.5, + verbose=True + ) + + assert len(param_groups) > 0 + # Verify layer scaling is applied with end decay + for group in param_groups: + assert 'lr_scale' in group + assert group['lr_scale'] <= 1.0 + assert group['lr_scale'] >= 0.5 + + +def test_param_groups_layer_decay_with_matcher(): + class ModelWithMatcher(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 5) + self.layer2 = torch.nn.Linear(5, 2) + + def group_matcher(self, coarse=False): + return lambda name: int(name.split('.')[0][-1]) + + model = ModelWithMatcher() + param_groups = param_groups_layer_decay( + model, + weight_decay=0.05, + layer_decay=0.75, + verbose=True + ) + + assert len(param_groups) > 0 + # Verify layer scaling is applied + for group in param_groups: + assert 'lr_scale' in group + assert 'weight_decay' in group + assert len(group['params']) > 0 + + +def test_param_groups_weight_decay(): + model = torch.nn.Sequential( + torch.nn.Linear(10, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 2) + ) + weight_decay = 0.01 + no_weight_decay_list = ['1.weight'] + + param_groups = param_groups_weight_decay( + model, + weight_decay=weight_decay, + no_weight_decay_list=no_weight_decay_list + ) + + assert len(param_groups) == 2 + assert param_groups[0]['weight_decay'] == 0.0 + assert param_groups[1]['weight_decay'] == weight_decay + + # Verify parameters are correctly grouped + no_decay_params = set(param_groups[0]['params']) + decay_params = set(param_groups[1]['params']) + + for name, param in model.named_parameters(): + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + assert param in no_decay_params + else: + assert param in decay_params + diff --git a/tests/test_utils.py b/tests/test_utils.py index b0f890d2fe..1e2126eead 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,15 @@ from torchvision.ops.misc import FrozenBatchNorm2d import timm +import pytest from timm.utils.model import freeze, unfreeze +from timm.utils.model import ActivationStatsHook +from timm.utils.model import extract_spp_stats +from timm.utils.model import _freeze_unfreeze +from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual +from timm.utils.model import reparameterize_model +from timm.utils.model import get_state_dict def test_freeze_unfreeze(): model = timm.create_model('resnet18') @@ -54,4 +61,132 @@ def test_freeze_unfreeze(): freeze(model.layer1[0], ['bn1']) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) unfreeze(model.layer1[0], ['bn1']) - assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + +def test_activation_stats_hook_validation(): + model = timm.create_model('resnet18') + + def test_hook(model, input, output): + return output.mean().item() + + # Test error case with mismatched lengths + with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"): + ActivationStatsHook( + model, + hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'], + hook_fns=[test_hook] + ) + + +def test_extract_spp_stats(): + model = timm.create_model('resnet18') + + def test_hook(model, input, output): + return output.mean().item() + + stats = extract_spp_stats( + model, + hook_fn_locs=['layer1.0.conv1'], + hook_fns=[test_hook], + input_shape=[2, 3, 32, 32] + ) + + assert isinstance(stats, dict) + assert test_hook.__name__ in stats + assert isinstance(stats[test_hook.__name__], list) + assert len(stats[test_hook.__name__]) > 0 + +def test_freeze_unfreeze_bn_root(): + import torch.nn as nn + from timm.layers import BatchNormAct2d + + # Create batch norm layers + bn = nn.BatchNorm2d(10) + bn_act = BatchNormAct2d(10) + + # Test with BatchNorm2d as root + with pytest.raises(AssertionError): + _freeze_unfreeze(bn, mode="freeze") + + # Test with BatchNormAct2d as root + with pytest.raises(AssertionError): + _freeze_unfreeze(bn_act, mode="freeze") + + +def test_activation_stats_functions(): + import torch + + # Create sample input tensor [batch, channels, height, width] + x = torch.randn(2, 3, 4, 4) + + # Test avg_sq_ch_mean + result1 = avg_sq_ch_mean(None, None, x) + assert isinstance(result1, float) + + # Test avg_ch_var + result2 = avg_ch_var(None, None, x) + assert isinstance(result2, float) + + # Test avg_ch_var_residual + result3 = avg_ch_var_residual(None, None, x) + assert isinstance(result3, float) + + +def test_reparameterize_model(): + import torch.nn as nn + + class FusableModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def fuse(self): + return nn.Identity() + + class ModelWithFusable(nn.Module): + def __init__(self): + super().__init__() + self.fusable = FusableModule() + self.normal = nn.Linear(10, 10) + + model = ModelWithFusable() + + # Test with inplace=False (should create a copy) + new_model = reparameterize_model(model, inplace=False) + assert isinstance(new_model.fusable, nn.Identity) + assert isinstance(model.fusable, FusableModule) # Original unchanged + + # Test with inplace=True + reparameterize_model(model, inplace=True) + assert isinstance(model.fusable, nn.Identity) + + +def test_get_state_dict_custom_unwrap(): + import torch.nn as nn + + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + model = CustomModel() + + def custom_unwrap(m): + return m + + state_dict = get_state_dict(model, unwrap_fn=custom_unwrap) + assert 'linear.weight' in state_dict + assert 'linear.bias' in state_dict + + +def test_freeze_unfreeze_string_input(): + model = timm.create_model('resnet18') + + # Test with string input + _freeze_unfreeze(model, 'layer1', mode='freeze') + assert model.layer1[0].conv1.weight.requires_grad == False + + # Test unfreezing with string input + _freeze_unfreeze(model, 'layer1', mode='unfreeze') + assert model.layer1[0].conv1.weight.requires_grad == True +