mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix a few issues that came up in tests
This commit is contained in:
parent
d23a2697d0
commit
d0113f9cdb
@ -44,6 +44,7 @@ class MaxPool2dSame(nn.MaxPool2d):
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
|
||||
kernel_size = tup_pair(kernel_size)
|
||||
stride = tup_pair(stride)
|
||||
dilation = tup_pair(dilation)
|
||||
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -13,7 +13,7 @@ default_cfgs = {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.875,
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
|
@ -24,7 +24,7 @@ default_cfgs = {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.875,
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
|
@ -521,20 +521,23 @@ class ResNet(nn.Module):
|
||||
|
||||
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
|
||||
assert isinstance(default_cfg, dict)
|
||||
load_strict, features = True, False
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
load_strict, features = False, True
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
|
||||
model = ResNet(**kwargs)
|
||||
model.default_cfg = copy.deepcopy(default_cfg)
|
||||
if kwargs.pop('pruned', False):
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices=out_indices)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user