diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 4d1d9d31..bb15a283 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -272,8 +272,8 @@ def checkpoint_filter_fn(state_dict, model): k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k) k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k) k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k) - k = k.replace('downsample.1.0', 'downsample.1.conv') - k = k.replace('downsample.1.1', 'downsample.1.bn') + k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k) + k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k) if k.endswith('bn.weight'): # convert weight from inplace_abn to batchnorm v = v.abs().add(1e-5)