mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix some broken tests for ResNetV2 BiT models
This commit is contained in:
parent
fd9061dbf7
commit
20516abc18
@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
|
||||
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
||||
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
|
||||
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*']
|
||||
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*', '*in21k', '*152x4_bitm']
|
||||
else:
|
||||
EXCLUDE_FILTERS = ['vit_*']
|
||||
MAX_FWD_SIZE = 384
|
||||
|
@ -331,7 +331,7 @@ def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, nor
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
stem['pad'] = nn.ConstantPad2d(1, 0)
|
||||
stem['pad'] = nn.ConstantPad2d(1, 0.)
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
elif 'same' in stem_type:
|
||||
# full, input size based 'SAME' padding, used in ViT Hybrid model
|
||||
@ -421,7 +421,12 @@ class ResNetV2(nn.Module):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
|
||||
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
||||
if self.stem.conv.weight.shape[1] == 1:
|
||||
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
||||
# FIXME handle > 3 in_chans?
|
||||
else:
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
|
Loading…
x
Reference in New Issue
Block a user