mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix out_indices handling breakage, should have left as per vgg approach.
This commit is contained in:
parent
a9f91483a6
commit
ccfeb06936
@ -408,14 +408,13 @@ class CspNet(nn.Module):
|
|||||||
|
|
||||||
def _create_cspnet(variant, pretrained=False, **kwargs):
|
def _create_cspnet(variant, pretrained=False, **kwargs):
|
||||||
cfg_variant = variant.split('_')[0]
|
cfg_variant = variant.split('_')[0]
|
||||||
if 'darknet' in variant:
|
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
||||||
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
|
||||||
kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5))
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
CspNet, variant, pretrained,
|
CspNet, variant, pretrained,
|
||||||
default_cfg=default_cfgs[variant],
|
default_cfg=default_cfgs[variant],
|
||||||
model_cfg=model_cfgs[cfg_variant],
|
model_cfg=model_cfgs[cfg_variant],
|
||||||
feature_cfg=dict(flatten_sequential=True),
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -180,12 +180,12 @@ def _filter_fn(state_dict):
|
|||||||
def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
|
def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
|
||||||
cfg = variant.split('_')[0]
|
cfg = variant.split('_')[0]
|
||||||
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
||||||
kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5))
|
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
VGG, variant, pretrained,
|
VGG, variant, pretrained,
|
||||||
default_cfg=default_cfgs[variant],
|
default_cfg=default_cfgs[variant],
|
||||||
model_cfg=cfgs[cfg],
|
model_cfg=cfgs[cfg],
|
||||||
feature_cfg=dict(flatten_sequential=True),
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||||
pretrained_filter_fn=_filter_fn,
|
pretrained_filter_fn=_filter_fn,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user