mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights
This commit is contained in:
parent
7ba5a384d3
commit
08016e839d
@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
|
@ -8,7 +8,7 @@ Hacked together by Ross Wightman
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -30,42 +30,46 @@ class FeatureInfo:
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i][key] for i in idx]
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
if idx == None, returns feature channel count at each output index
|
||||
if idx is an integer, return feature channel count for that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.info[idx]['num_chs']
|
||||
return [self.info[i]['num_chs'] for i in self.out_indices]
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
if idx == None, returns feature reduction factor at each output index
|
||||
if idx is an integer, return feature channel count at that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.info[idx]['reduction']
|
||||
return [self.info[i]['reduction'] for i in self.out_indices]
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
if idx == None, returns feature module name at each output index
|
||||
if idx is an integer, return feature module name at that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.info[idx]['module']
|
||||
return [self.info[i]['module'] for i in self.out_indices]
|
||||
|
||||
def get_by_key(self, idx=None, keys=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
return self.get('module', idx)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.info[item]
|
||||
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info)
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info}
|
||||
for f in self.feature_info.get_dicts()}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
|
@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module):
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
|
@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
xception41=_cfg(url=''),
|
||||
xception65=_cfg(url=''),
|
||||
xception71=_cfg(url=''),
|
||||
xception41=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
|
||||
xception65=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
|
||||
xception71=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
|
||||
)
|
||||
|
||||
|
||||
@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
|
||||
return _xception('xception65', pretrained=pretrained, **model_args)
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def xception71(pretrained=False, **kwargs):
|
||||
""" Modified Aligned Xception-71
|
||||
|
Loading…
x
Reference in New Issue
Block a user