extend existing unittests

This commit is contained in:
Tal 2024-11-10 06:57:39 +00:00 committed by Ross Wightman
parent 9f5c279bad
commit 68d5a64e45
3 changed files with 260 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.layers import create_act_layer, set_layer_config from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
import importlib import importlib
import os import os
@ -76,3 +76,46 @@ def test_hard_swish_grad():
def test_hard_mish_grad(): def test_hard_mish_grad():
for _ in range(100): for _ in range(100):
_run_act_layer_grad('hard_mish') _run_act_layer_grad('hard_mish')
def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None
def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)
def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None
# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x
result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)
def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act
def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None

View File

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter from torch.nn import Parameter
from timm.scheduler import PlateauLRScheduler from timm.scheduler import PlateauLRScheduler
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2, param_groups_layer_decay, param_groups_weight_decay
import importlib import importlib
import os import os
@ -741,3 +741,82 @@ def test_lookahead_radam(optimizer):
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
) )
def test_param_groups_layer_decay_with_end_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
end_layer_decay=0.5,
verbose=True
)
assert len(param_groups) > 0
# Verify layer scaling is applied with end decay
for group in param_groups:
assert 'lr_scale' in group
assert group['lr_scale'] <= 1.0
assert group['lr_scale'] >= 0.5
def test_param_groups_layer_decay_with_matcher():
class ModelWithMatcher(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 5)
self.layer2 = torch.nn.Linear(5, 2)
def group_matcher(self, coarse=False):
return lambda name: int(name.split('.')[0][-1])
model = ModelWithMatcher()
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
verbose=True
)
assert len(param_groups) > 0
# Verify layer scaling is applied
for group in param_groups:
assert 'lr_scale' in group
assert 'weight_decay' in group
assert len(group['params']) > 0
def test_param_groups_weight_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
weight_decay = 0.01
no_weight_decay_list = ['1.weight']
param_groups = param_groups_weight_decay(
model,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay_list
)
assert len(param_groups) == 2
assert param_groups[0]['weight_decay'] == 0.0
assert param_groups[1]['weight_decay'] == weight_decay
# Verify parameters are correctly grouped
no_decay_params = set(param_groups[0]['params'])
decay_params = set(param_groups[1]['params'])
for name, param in model.named_parameters():
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
assert param in no_decay_params
else:
assert param in decay_params

View File

@ -2,8 +2,15 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
import timm import timm
import pytest
from timm.utils.model import freeze, unfreeze from timm.utils.model import freeze, unfreeze
from timm.utils.model import ActivationStatsHook
from timm.utils.model import extract_spp_stats
from timm.utils.model import _freeze_unfreeze
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
from timm.utils.model import reparameterize_model
from timm.utils.model import get_state_dict
def test_freeze_unfreeze(): def test_freeze_unfreeze():
model = timm.create_model('resnet18') model = timm.create_model('resnet18')
@ -55,3 +62,131 @@ def test_freeze_unfreeze():
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1']) unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d) assert isinstance(model.layer1[0].bn1, BatchNorm2d)
def test_activation_stats_hook_validation():
model = timm.create_model('resnet18')
def test_hook(model, input, output):
return output.mean().item()
# Test error case with mismatched lengths
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
ActivationStatsHook(
model,
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
hook_fns=[test_hook]
)
def test_extract_spp_stats():
model = timm.create_model('resnet18')
def test_hook(model, input, output):
return output.mean().item()
stats = extract_spp_stats(
model,
hook_fn_locs=['layer1.0.conv1'],
hook_fns=[test_hook],
input_shape=[2, 3, 32, 32]
)
assert isinstance(stats, dict)
assert test_hook.__name__ in stats
assert isinstance(stats[test_hook.__name__], list)
assert len(stats[test_hook.__name__]) > 0
def test_freeze_unfreeze_bn_root():
import torch.nn as nn
from timm.layers import BatchNormAct2d
# Create batch norm layers
bn = nn.BatchNorm2d(10)
bn_act = BatchNormAct2d(10)
# Test with BatchNorm2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn, mode="freeze")
# Test with BatchNormAct2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn_act, mode="freeze")
def test_activation_stats_functions():
import torch
# Create sample input tensor [batch, channels, height, width]
x = torch.randn(2, 3, 4, 4)
# Test avg_sq_ch_mean
result1 = avg_sq_ch_mean(None, None, x)
assert isinstance(result1, float)
# Test avg_ch_var
result2 = avg_ch_var(None, None, x)
assert isinstance(result2, float)
# Test avg_ch_var_residual
result3 = avg_ch_var_residual(None, None, x)
assert isinstance(result3, float)
def test_reparameterize_model():
import torch.nn as nn
class FusableModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def fuse(self):
return nn.Identity()
class ModelWithFusable(nn.Module):
def __init__(self):
super().__init__()
self.fusable = FusableModule()
self.normal = nn.Linear(10, 10)
model = ModelWithFusable()
# Test with inplace=False (should create a copy)
new_model = reparameterize_model(model, inplace=False)
assert isinstance(new_model.fusable, nn.Identity)
assert isinstance(model.fusable, FusableModule) # Original unchanged
# Test with inplace=True
reparameterize_model(model, inplace=True)
assert isinstance(model.fusable, nn.Identity)
def test_get_state_dict_custom_unwrap():
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
model = CustomModel()
def custom_unwrap(m):
return m
state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
assert 'linear.weight' in state_dict
assert 'linear.bias' in state_dict
def test_freeze_unfreeze_string_input():
model = timm.create_model('resnet18')
# Test with string input
_freeze_unfreeze(model, 'layer1', mode='freeze')
assert model.layer1[0].conv1.weight.requires_grad == False
# Test unfreezing with string input
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
assert model.layer1[0].conv1.weight.requires_grad == True