mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights
This commit is contained in:
parent
7ba5a384d3
commit
08016e839d
@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
self.feature_hooks = None
|
self.feature_hooks = None
|
||||||
if feature_location != 'bottleneck':
|
if feature_location != 'bottleneck':
|
||||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
|
@ -8,7 +8,7 @@ Hacked together by Ross Wightman
|
|||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Tuple, Any
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -30,42 +30,46 @@ class FeatureInfo:
|
|||||||
def from_other(self, out_indices: Tuple[int]):
|
def from_other(self, out_indices: Tuple[int]):
|
||||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||||
|
|
||||||
|
def get(self, key, idx=None):
|
||||||
|
""" Get value by key at specified index (indices)
|
||||||
|
if idx == None, returns value for key at each output index
|
||||||
|
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||||
|
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||||
|
"""
|
||||||
|
if idx is None:
|
||||||
|
return [self.info[i][key] for i in self.out_indices]
|
||||||
|
if isinstance(idx, (tuple, list)):
|
||||||
|
return [self.info[i][key] for i in idx]
|
||||||
|
else:
|
||||||
|
return self.info[idx][key]
|
||||||
|
|
||||||
|
def get_dicts(self, keys=None, idx=None):
|
||||||
|
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||||
|
"""
|
||||||
|
if idx is None:
|
||||||
|
if keys is None:
|
||||||
|
return [self.info[i] for i in self.out_indices]
|
||||||
|
else:
|
||||||
|
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||||
|
if isinstance(idx, (tuple, list)):
|
||||||
|
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||||
|
else:
|
||||||
|
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||||
|
|
||||||
def channels(self, idx=None):
|
def channels(self, idx=None):
|
||||||
""" feature channels accessor
|
""" feature channels accessor
|
||||||
if idx == None, returns feature channel count at each output index
|
|
||||||
if idx is an integer, return feature channel count for that feature module index
|
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
return self.get('num_chs', idx)
|
||||||
return self.info[idx]['num_chs']
|
|
||||||
return [self.info[i]['num_chs'] for i in self.out_indices]
|
|
||||||
|
|
||||||
def reduction(self, idx=None):
|
def reduction(self, idx=None):
|
||||||
""" feature reduction (output stride) accessor
|
""" feature reduction (output stride) accessor
|
||||||
if idx == None, returns feature reduction factor at each output index
|
|
||||||
if idx is an integer, return feature channel count at that feature module index
|
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
return self.get('reduction', idx)
|
||||||
return self.info[idx]['reduction']
|
|
||||||
return [self.info[i]['reduction'] for i in self.out_indices]
|
|
||||||
|
|
||||||
def module_name(self, idx=None):
|
def module_name(self, idx=None):
|
||||||
""" feature module name accessor
|
""" feature module name accessor
|
||||||
if idx == None, returns feature module name at each output index
|
|
||||||
if idx is an integer, return feature module name at that feature module index
|
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
return self.get('module', idx)
|
||||||
return self.info[idx]['module']
|
|
||||||
return [self.info[i]['module'] for i in self.out_indices]
|
|
||||||
|
|
||||||
def get_by_key(self, idx=None, keys=None):
|
|
||||||
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
|
|
||||||
"""
|
|
||||||
if isinstance(idx, int):
|
|
||||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
|
||||||
if keys is None:
|
|
||||||
return [self.info[i] for i in self.out_indices]
|
|
||||||
else:
|
|
||||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.info[item]
|
return self.info[item]
|
||||||
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
|
|||||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||||
model.reset_classifier(0)
|
model.reset_classifier(0)
|
||||||
layers['body'] = model
|
layers['body'] = model
|
||||||
hooks.extend(self.feature_info)
|
hooks.extend(self.feature_info.get_dicts())
|
||||||
else:
|
else:
|
||||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||||
for f in self.feature_info}
|
for f in self.feature_info.get_dicts()}
|
||||||
for new_name, old_name, module in modules:
|
for new_name, old_name, module in modules:
|
||||||
layers[new_name] = module
|
layers[new_name] = module
|
||||||
for fn, fm in module.named_modules(prefix=old_name):
|
for fn, fm in module.named_modules(prefix=old_name):
|
||||||
|
@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module):
|
|||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
self.feature_hooks = None
|
self.feature_hooks = None
|
||||||
if feature_location != 'bottleneck':
|
if feature_location != 'bottleneck':
|
||||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
|
@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
default_cfgs = dict(
|
default_cfgs = dict(
|
||||||
xception41=_cfg(url=''),
|
xception41=_cfg(
|
||||||
xception65=_cfg(url=''),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
|
||||||
xception71=_cfg(url=''),
|
xception65=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
|
||||||
|
xception71=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
|
|||||||
return _xception('xception65', pretrained=pretrained, **model_args)
|
return _xception('xception65', pretrained=pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def xception71(pretrained=False, **kwargs):
|
def xception71(pretrained=False, **kwargs):
|
||||||
""" Modified Aligned Xception-71
|
""" Modified Aligned Xception-71
|
||||||
|
Loading…
x
Reference in New Issue
Block a user