Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into master-patched
Browse files Browse the repository at this point in the history
  • Loading branch information
xwang233 committed Nov 12, 2024
2 parents 93d03e7 + e31e5d2 commit 817b074
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 2 deletions.
45 changes: 44 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from timm.layers import create_act_layer, set_layer_config
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn

import importlib
import os
Expand Down Expand Up @@ -76,3 +76,46 @@ def test_hard_swish_grad():
def test_hard_mish_grad():
for _ in range(100):
_run_act_layer_grad('hard_mish')

def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None


def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)


def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None

# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x

result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)


def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act


def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None

81 changes: 81 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter

from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler

from timm.optim import create_optimizer_v2
Expand Down Expand Up @@ -741,3 +743,82 @@ def test_lookahead_radam(optimizer):
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
)


def test_param_groups_layer_decay_with_end_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)

param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
end_layer_decay=0.5,
verbose=True
)

assert len(param_groups) > 0
# Verify layer scaling is applied with end decay
for group in param_groups:
assert 'lr_scale' in group
assert group['lr_scale'] <= 1.0
assert group['lr_scale'] >= 0.5


def test_param_groups_layer_decay_with_matcher():
class ModelWithMatcher(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 5)
self.layer2 = torch.nn.Linear(5, 2)

def group_matcher(self, coarse=False):
return lambda name: int(name.split('.')[0][-1])

model = ModelWithMatcher()
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
verbose=True
)

assert len(param_groups) > 0
# Verify layer scaling is applied
for group in param_groups:
assert 'lr_scale' in group
assert 'weight_decay' in group
assert len(group['params']) > 0


def test_param_groups_weight_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
weight_decay = 0.01
no_weight_decay_list = ['1.weight']

param_groups = param_groups_weight_decay(
model,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay_list
)

assert len(param_groups) == 2
assert param_groups[0]['weight_decay'] == 0.0
assert param_groups[1]['weight_decay'] == weight_decay

# Verify parameters are correctly grouped
no_decay_params = set(param_groups[0]['params'])
decay_params = set(param_groups[1]['params'])

for name, param in model.named_parameters():
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
assert param in no_decay_params
else:
assert param in decay_params

137 changes: 136 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from torchvision.ops.misc import FrozenBatchNorm2d

import timm
import pytest
from timm.utils.model import freeze, unfreeze
from timm.utils.model import ActivationStatsHook
from timm.utils.model import extract_spp_stats

from timm.utils.model import _freeze_unfreeze
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
from timm.utils.model import reparameterize_model
from timm.utils.model import get_state_dict

def test_freeze_unfreeze():
model = timm.create_model('resnet18')
Expand Down Expand Up @@ -54,4 +61,132 @@ def test_freeze_unfreeze():
freeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
assert isinstance(model.layer1[0].bn1, BatchNorm2d)

def test_activation_stats_hook_validation():
model = timm.create_model('resnet18')

def test_hook(model, input, output):
return output.mean().item()

# Test error case with mismatched lengths
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
ActivationStatsHook(
model,
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
hook_fns=[test_hook]
)


def test_extract_spp_stats():
model = timm.create_model('resnet18')

def test_hook(model, input, output):
return output.mean().item()

stats = extract_spp_stats(
model,
hook_fn_locs=['layer1.0.conv1'],
hook_fns=[test_hook],
input_shape=[2, 3, 32, 32]
)

assert isinstance(stats, dict)
assert test_hook.__name__ in stats
assert isinstance(stats[test_hook.__name__], list)
assert len(stats[test_hook.__name__]) > 0

def test_freeze_unfreeze_bn_root():
import torch.nn as nn
from timm.layers import BatchNormAct2d

# Create batch norm layers
bn = nn.BatchNorm2d(10)
bn_act = BatchNormAct2d(10)

# Test with BatchNorm2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn, mode="freeze")

# Test with BatchNormAct2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn_act, mode="freeze")


def test_activation_stats_functions():
import torch

# Create sample input tensor [batch, channels, height, width]
x = torch.randn(2, 3, 4, 4)

# Test avg_sq_ch_mean
result1 = avg_sq_ch_mean(None, None, x)
assert isinstance(result1, float)

# Test avg_ch_var
result2 = avg_ch_var(None, None, x)
assert isinstance(result2, float)

# Test avg_ch_var_residual
result3 = avg_ch_var_residual(None, None, x)
assert isinstance(result3, float)


def test_reparameterize_model():
import torch.nn as nn

class FusableModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)

def fuse(self):
return nn.Identity()

class ModelWithFusable(nn.Module):
def __init__(self):
super().__init__()
self.fusable = FusableModule()
self.normal = nn.Linear(10, 10)

model = ModelWithFusable()

# Test with inplace=False (should create a copy)
new_model = reparameterize_model(model, inplace=False)
assert isinstance(new_model.fusable, nn.Identity)
assert isinstance(model.fusable, FusableModule) # Original unchanged

# Test with inplace=True
reparameterize_model(model, inplace=True)
assert isinstance(model.fusable, nn.Identity)


def test_get_state_dict_custom_unwrap():
import torch.nn as nn

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)

model = CustomModel()

def custom_unwrap(m):
return m

state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
assert 'linear.weight' in state_dict
assert 'linear.bias' in state_dict


def test_freeze_unfreeze_string_input():
model = timm.create_model('resnet18')

# Test with string input
_freeze_unfreeze(model, 'layer1', mode='freeze')
assert model.layer1[0].conv1.weight.requires_grad == False

# Test unfreezing with string input
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
assert model.layer1[0].conv1.weight.requires_grad == True

0 comments on commit 817b074

Please sign in to comment.