Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights

This commit is contained in:
Ross Wightman 2020-07-24 17:59:21 -07:00
parent 7ba5a384d3
commit 08016e839d
4 changed files with 40 additions and 34 deletions

View File

@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module):
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
if feature_location != 'bottleneck': if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:

View File

@ -8,7 +8,7 @@ Hacked together by Ross Wightman
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -30,42 +30,46 @@ class FeatureInfo:
def from_other(self, out_indices: Tuple[int]): def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices) return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i][key] for i in idx]
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None): def channels(self, idx=None):
""" feature channels accessor """ feature channels accessor
if idx == None, returns feature channel count at each output index
if idx is an integer, return feature channel count for that feature module index
""" """
if isinstance(idx, int): return self.get('num_chs', idx)
return self.info[idx]['num_chs']
return [self.info[i]['num_chs'] for i in self.out_indices]
def reduction(self, idx=None): def reduction(self, idx=None):
""" feature reduction (output stride) accessor """ feature reduction (output stride) accessor
if idx == None, returns feature reduction factor at each output index
if idx is an integer, return feature channel count at that feature module index
""" """
if isinstance(idx, int): return self.get('reduction', idx)
return self.info[idx]['reduction']
return [self.info[i]['reduction'] for i in self.out_indices]
def module_name(self, idx=None): def module_name(self, idx=None):
""" feature module name accessor """ feature module name accessor
if idx == None, returns feature module name at each output index
if idx is an integer, return feature module name at that feature module index
""" """
if isinstance(idx, int): return self.get('module', idx)
return self.info[idx]['module']
return [self.info[i]['module'] for i in self.out_indices]
def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
"""
if isinstance(idx, int):
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
def __getitem__(self, item): def __getitem__(self, item):
return self.info[item] return self.info[item]
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
if hasattr(model, 'reset_classifier'): # make sure classifier is removed? if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0) model.reset_classifier(0)
layers['body'] = model layers['body'] = model
hooks.extend(self.feature_info) hooks.extend(self.feature_info.get_dicts())
else: else:
modules = _module_list(model, flatten_sequential=flatten_sequential) modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info} for f in self.feature_info.get_dicts()}
for new_name, old_name, module in modules: for new_name, old_name, module in modules:
layers[new_name] = module layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name): for fn, fm in module.named_modules(prefix=old_name):

View File

@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module):
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None self.feature_hooks = None
if feature_location != 'bottleneck': if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:

View File

@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
default_cfgs = dict( default_cfgs = dict(
xception41=_cfg(url=''), xception41=_cfg(
xception65=_cfg(url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception71=_cfg(url=''), xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
) )
@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
return _xception('xception65', pretrained=pretrained, **model_args) return _xception('xception65', pretrained=pretrained, **model_args)
@register_model @register_model
def xception71(pretrained=False, **kwargs): def xception71(pretrained=False, **kwargs):
""" Modified Aligned Xception-71 """ Modified Aligned Xception-71