mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
extend existing unittests
This commit is contained in:
parent
9f5c279bad
commit
68d5a64e45
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import create_act_layer, set_layer_config
|
||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
|
||||
|
||||
import importlib
|
||||
import os
|
||||
@ -76,3 +76,46 @@ def test_hard_swish_grad():
|
||||
def test_hard_mish_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('hard_mish')
|
||||
|
||||
def test_get_act_layer_empty_string():
|
||||
# Empty string should return None
|
||||
assert get_act_layer('') is None
|
||||
|
||||
|
||||
def test_create_act_layer_inplace_error():
|
||||
class NoInplaceAct(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
# Should recover when inplace arg causes TypeError
|
||||
layer = create_act_layer(NoInplaceAct, inplace=True)
|
||||
assert isinstance(layer, NoInplaceAct)
|
||||
|
||||
|
||||
def test_create_act_layer_edge_cases():
|
||||
# Test None input
|
||||
assert create_act_layer(None) is None
|
||||
|
||||
# Test TypeError handling for inplace
|
||||
class CustomAct(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
result = create_act_layer(CustomAct, inplace=True)
|
||||
assert isinstance(result, CustomAct)
|
||||
|
||||
|
||||
def test_get_act_fn_callable():
|
||||
def custom_act(x):
|
||||
return x
|
||||
assert get_act_fn(custom_act) is custom_act
|
||||
|
||||
|
||||
def test_get_act_fn_none():
|
||||
assert get_act_fn(None) is None
|
||||
assert get_act_fn('') is None
|
||||
|
||||
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import TestCase
|
||||
from torch.nn import Parameter
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.optim import create_optimizer_v2, param_groups_layer_decay, param_groups_weight_decay
|
||||
|
||||
import importlib
|
||||
import os
|
||||
@ -741,3 +741,82 @@ def test_lookahead_radam(optimizer):
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||
)
|
||||
|
||||
|
||||
def test_param_groups_layer_decay_with_end_decay():
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 5),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(5, 2)
|
||||
)
|
||||
|
||||
param_groups = param_groups_layer_decay(
|
||||
model,
|
||||
weight_decay=0.05,
|
||||
layer_decay=0.75,
|
||||
end_layer_decay=0.5,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
assert len(param_groups) > 0
|
||||
# Verify layer scaling is applied with end decay
|
||||
for group in param_groups:
|
||||
assert 'lr_scale' in group
|
||||
assert group['lr_scale'] <= 1.0
|
||||
assert group['lr_scale'] >= 0.5
|
||||
|
||||
|
||||
def test_param_groups_layer_decay_with_matcher():
|
||||
class ModelWithMatcher(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = torch.nn.Linear(10, 5)
|
||||
self.layer2 = torch.nn.Linear(5, 2)
|
||||
|
||||
def group_matcher(self, coarse=False):
|
||||
return lambda name: int(name.split('.')[0][-1])
|
||||
|
||||
model = ModelWithMatcher()
|
||||
param_groups = param_groups_layer_decay(
|
||||
model,
|
||||
weight_decay=0.05,
|
||||
layer_decay=0.75,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
assert len(param_groups) > 0
|
||||
# Verify layer scaling is applied
|
||||
for group in param_groups:
|
||||
assert 'lr_scale' in group
|
||||
assert 'weight_decay' in group
|
||||
assert len(group['params']) > 0
|
||||
|
||||
|
||||
def test_param_groups_weight_decay():
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(10, 5),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(5, 2)
|
||||
)
|
||||
weight_decay = 0.01
|
||||
no_weight_decay_list = ['1.weight']
|
||||
|
||||
param_groups = param_groups_weight_decay(
|
||||
model,
|
||||
weight_decay=weight_decay,
|
||||
no_weight_decay_list=no_weight_decay_list
|
||||
)
|
||||
|
||||
assert len(param_groups) == 2
|
||||
assert param_groups[0]['weight_decay'] == 0.0
|
||||
assert param_groups[1]['weight_decay'] == weight_decay
|
||||
|
||||
# Verify parameters are correctly grouped
|
||||
no_decay_params = set(param_groups[0]['params'])
|
||||
decay_params = set(param_groups[1]['params'])
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
||||
assert param in no_decay_params
|
||||
else:
|
||||
assert param in decay_params
|
||||
|
||||
|
@ -2,8 +2,15 @@ from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
import timm
|
||||
import pytest
|
||||
from timm.utils.model import freeze, unfreeze
|
||||
from timm.utils.model import ActivationStatsHook
|
||||
from timm.utils.model import extract_spp_stats
|
||||
|
||||
from timm.utils.model import _freeze_unfreeze
|
||||
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
|
||||
from timm.utils.model import reparameterize_model
|
||||
from timm.utils.model import get_state_dict
|
||||
|
||||
def test_freeze_unfreeze():
|
||||
model = timm.create_model('resnet18')
|
||||
@ -55,3 +62,131 @@ def test_freeze_unfreeze():
|
||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||
unfreeze(model.layer1[0], ['bn1'])
|
||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||||
|
||||
def test_activation_stats_hook_validation():
|
||||
model = timm.create_model('resnet18')
|
||||
|
||||
def test_hook(model, input, output):
|
||||
return output.mean().item()
|
||||
|
||||
# Test error case with mismatched lengths
|
||||
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
|
||||
ActivationStatsHook(
|
||||
model,
|
||||
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
|
||||
hook_fns=[test_hook]
|
||||
)
|
||||
|
||||
|
||||
def test_extract_spp_stats():
|
||||
model = timm.create_model('resnet18')
|
||||
|
||||
def test_hook(model, input, output):
|
||||
return output.mean().item()
|
||||
|
||||
stats = extract_spp_stats(
|
||||
model,
|
||||
hook_fn_locs=['layer1.0.conv1'],
|
||||
hook_fns=[test_hook],
|
||||
input_shape=[2, 3, 32, 32]
|
||||
)
|
||||
|
||||
assert isinstance(stats, dict)
|
||||
assert test_hook.__name__ in stats
|
||||
assert isinstance(stats[test_hook.__name__], list)
|
||||
assert len(stats[test_hook.__name__]) > 0
|
||||
|
||||
def test_freeze_unfreeze_bn_root():
|
||||
import torch.nn as nn
|
||||
from timm.layers import BatchNormAct2d
|
||||
|
||||
# Create batch norm layers
|
||||
bn = nn.BatchNorm2d(10)
|
||||
bn_act = BatchNormAct2d(10)
|
||||
|
||||
# Test with BatchNorm2d as root
|
||||
with pytest.raises(AssertionError):
|
||||
_freeze_unfreeze(bn, mode="freeze")
|
||||
|
||||
# Test with BatchNormAct2d as root
|
||||
with pytest.raises(AssertionError):
|
||||
_freeze_unfreeze(bn_act, mode="freeze")
|
||||
|
||||
|
||||
def test_activation_stats_functions():
|
||||
import torch
|
||||
|
||||
# Create sample input tensor [batch, channels, height, width]
|
||||
x = torch.randn(2, 3, 4, 4)
|
||||
|
||||
# Test avg_sq_ch_mean
|
||||
result1 = avg_sq_ch_mean(None, None, x)
|
||||
assert isinstance(result1, float)
|
||||
|
||||
# Test avg_ch_var
|
||||
result2 = avg_ch_var(None, None, x)
|
||||
assert isinstance(result2, float)
|
||||
|
||||
# Test avg_ch_var_residual
|
||||
result3 = avg_ch_var_residual(None, None, x)
|
||||
assert isinstance(result3, float)
|
||||
|
||||
|
||||
def test_reparameterize_model():
|
||||
import torch.nn as nn
|
||||
|
||||
class FusableModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
|
||||
def fuse(self):
|
||||
return nn.Identity()
|
||||
|
||||
class ModelWithFusable(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fusable = FusableModule()
|
||||
self.normal = nn.Linear(10, 10)
|
||||
|
||||
model = ModelWithFusable()
|
||||
|
||||
# Test with inplace=False (should create a copy)
|
||||
new_model = reparameterize_model(model, inplace=False)
|
||||
assert isinstance(new_model.fusable, nn.Identity)
|
||||
assert isinstance(model.fusable, FusableModule) # Original unchanged
|
||||
|
||||
# Test with inplace=True
|
||||
reparameterize_model(model, inplace=True)
|
||||
assert isinstance(model.fusable, nn.Identity)
|
||||
|
||||
|
||||
def test_get_state_dict_custom_unwrap():
|
||||
import torch.nn as nn
|
||||
|
||||
class CustomModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(10, 10)
|
||||
|
||||
model = CustomModel()
|
||||
|
||||
def custom_unwrap(m):
|
||||
return m
|
||||
|
||||
state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
|
||||
assert 'linear.weight' in state_dict
|
||||
assert 'linear.bias' in state_dict
|
||||
|
||||
|
||||
def test_freeze_unfreeze_string_input():
|
||||
model = timm.create_model('resnet18')
|
||||
|
||||
# Test with string input
|
||||
_freeze_unfreeze(model, 'layer1', mode='freeze')
|
||||
assert model.layer1[0].conv1.weight.requires_grad == False
|
||||
|
||||
# Test unfreezing with string input
|
||||
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
|
||||
assert model.layer1[0].conv1.weight.requires_grad == True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user