Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights

This commit is contained in:
Ross Wightman 2020-07-24 17:59:21 -07:00
parent 7ba5a384d3
commit 08016e839d
4 changed files with 40 additions and 34 deletions

View File

@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module):
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
if feature_location != 'bottleneck': if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:

View File

@ -8,7 +8,7 @@ Hacked together by Ross Wightman
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,42 +30,46 @@ class FeatureInfo:
def from_other(self, out_indices: Tuple[int]): def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices) return FeatureInfo(deepcopy(self.info), out_indices)
def channels(self, idx=None): def get(self, key, idx=None):
""" feature channels accessor """ Get value by key at specified index (indices)
if idx == None, returns feature channel count at each output index if idx == None, returns value for key at each output index
if idx is an integer, return feature channel count for that feature module index if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
""" """
if isinstance(idx, int): if idx is None:
return self.info[idx]['num_chs'] return [self.info[i][key] for i in self.out_indices]
return [self.info[i]['num_chs'] for i in self.out_indices] if isinstance(idx, (tuple, list)):
return [self.info[i][key] for i in idx]
else:
return self.info[idx][key]
def reduction(self, idx=None): def get_dicts(self, keys=None, idx=None):
""" feature reduction (output stride) accessor """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
if idx == None, returns feature reduction factor at each output index
if idx is an integer, return feature channel count at that feature module index
""" """
if isinstance(idx, int): if idx is None:
return self.info[idx]['reduction']
return [self.info[i]['reduction'] for i in self.out_indices]
def module_name(self, idx=None):
""" feature module name accessor
if idx == None, returns feature module name at each output index
if idx is an integer, return feature module name at that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['module']
return [self.info[i]['module'] for i in self.out_indices]
def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
"""
if isinstance(idx, int):
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
if keys is None: if keys is None:
return [self.info[i] for i in self.out_indices] return [self.info[i] for i in self.out_indices]
else: else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices] return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
""" feature channels accessor
"""
return self.get('num_chs', idx)
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
"""
return self.get('reduction', idx)
def module_name(self, idx=None):
""" feature module name accessor
"""
return self.get('module', idx)
def __getitem__(self, item): def __getitem__(self, item):
return self.info[item] return self.info[item]
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
if hasattr(model, 'reset_classifier'): # make sure classifier is removed? if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0) model.reset_classifier(0)
layers['body'] = model layers['body'] = model
hooks.extend(self.feature_info) hooks.extend(self.feature_info.get_dicts())
else: else:
modules = _module_list(model, flatten_sequential=flatten_sequential) modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info} for f in self.feature_info.get_dicts()}
for new_name, old_name, module in modules: for new_name, old_name, module in modules:
layers[new_name] = module layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name): for fn, fm in module.named_modules(prefix=old_name):

View File

@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module):
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
if feature_location != 'bottleneck': if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:

View File

@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
default_cfgs = dict( default_cfgs = dict(
xception41=_cfg(url=''), xception41=_cfg(
xception65=_cfg(url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception71=_cfg(url=''), xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
) )
@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
return _xception('xception65', pretrained=pretrained, **model_args) return _xception('xception65', pretrained=pretrained, **model_args)
@register_model @register_model
def xception71(pretrained=False, **kwargs): def xception71(pretrained=False, **kwargs):
""" Modified Aligned Xception-71 """ Modified Aligned Xception-71