mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup, refactoring of Feature extraction code, add tests, fix tests, non hook feature extraction working with torchscript
This commit is contained in:
parent
6eec3fb4a4
commit
4e61c6a12d
@ -106,3 +106,26 @@ def test_model_forward_torchscript(model_name, batch_size):
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
EXCLUDE_FEAT_FILTERS = [
|
||||
'hrnet*', '*pruned*', # hopefully fix at some point
|
||||
'legacy*', # not going to bother
|
||||
]
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
model = create_model(model_name, pretrained=False, features_only=True)
|
||||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
assert len(expected_channels) == len(outputs)
|
||||
for e, o in zip(expected_channels, outputs):
|
||||
assert e == o.shape[1]
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
@ -75,8 +75,14 @@ class FeatureInfo:
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
||||
redesign for torcscript.
|
||||
"""
|
||||
|
||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
@ -92,7 +98,6 @@ class FeatureHooks:
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
self.out_as_dict = out_as_dict
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
@ -100,11 +105,8 @@ class FeatureHooks:
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][hook_id] = x
|
||||
|
||||
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript?
|
||||
if self.out_as_dict:
|
||||
output = self._feature_outputs[device]
|
||||
else:
|
||||
output = list(self._feature_outputs[device].values())
|
||||
def get_output(self, device) -> Dict[str, torch.tensor]:
|
||||
output = self._feature_outputs[device]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
||||
@ -123,160 +125,8 @@ def _module_list(module, flatten_sequential=False):
|
||||
return ml
|
||||
|
||||
|
||||
class LayerGetterHooks(nn.ModuleDict):
|
||||
""" LayerGetterHooks
|
||||
TODO
|
||||
"""
|
||||
|
||||
def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None,
|
||||
default_hook_type='forward'):
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
super(LayerGetterHooks, self).__init__(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
|
||||
|
||||
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
return self.hooks.get_output(x.device)
|
||||
|
||||
|
||||
class LayerGetterDict(nn.ModuleDict):
|
||||
"""
|
||||
Module wrapper that returns intermediate layers from a model as a dictionary
|
||||
|
||||
Originally based on concepts from IntermediateLayerGetter at
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
It has a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||
in the forward if you want this to work.
|
||||
|
||||
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
|
||||
long as `features` is a sequential container assigned to the model).
|
||||
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model on which we will extract the features
|
||||
return_layers (Dict[name, new_name]): a dict containing the names
|
||||
of the modules for which the activations will be returned as
|
||||
the key of the dict, and the value of the dict is the name
|
||||
of the returned activation (which the user can specify).
|
||||
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||
self.return_layers = {}
|
||||
self.concat = concat
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
self.return_layers[new_name] = return_layers[old_name]
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
super(LayerGetterDict, self).__init__(layers)
|
||||
|
||||
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
|
||||
class LayerGetterList(nn.Sequential):
|
||||
"""
|
||||
Module wrapper that returns intermediate layers from a model as a list
|
||||
|
||||
Originally based on concepts from IntermediateLayerGetter at
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
It has a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||
in the forward if you want this to work.
|
||||
|
||||
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
|
||||
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
|
||||
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model on which we will extract the features
|
||||
return_layers (Dict[name, new_name]): a dict containing the names
|
||||
of the modules for which the activations will be returned as
|
||||
the key of the dict, and the value of the dict is the name
|
||||
of the returned activation (which the user can specify).
|
||||
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||
super(LayerGetterList, self).__init__()
|
||||
self.return_layers = {}
|
||||
self.concat = concat
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
for new_name, orig_name, module in modules:
|
||||
self.add_module(new_name, module)
|
||||
if orig_name in remaining:
|
||||
self.return_layers[new_name] = return_layers[orig_name]
|
||||
remaining.remove(orig_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
out = []
|
||||
for name, module in self.named_children():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out.append(torch.cat(x, 1) if self.concat else x[0])
|
||||
else:
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_feature_info(net, out_indices, feature_info=None):
|
||||
if feature_info is None:
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
def _get_feature_info(net, out_indices):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
@ -293,51 +143,135 @@ def _get_return_layers(feature_info, out_map):
|
||||
return return_layers
|
||||
|
||||
|
||||
class FeatureNet(nn.Module):
|
||||
""" FeatureNet
|
||||
class FeatureDictNet(nn.ModuleDict):
|
||||
""" Feature extractor with OrderedDict return
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network
|
||||
is partially re-built from contained modules using the LayerGetters.
|
||||
Wrap a model and extract features as specified by the out indices, the network is
|
||||
partially re-built from contained modules.
|
||||
|
||||
Please read the docstrings of the LayerGetter classes, they will not work on all models.
|
||||
There is a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
||||
trivial modules like `self.relu = nn.ReLU`.
|
||||
|
||||
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model from which we will extract the features
|
||||
out_indices (tuple[int]): model output indices to extract features for
|
||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
||||
otherwise str(index) is used
|
||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
"""
|
||||
def __init__(
|
||||
self, net,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False,
|
||||
feature_info=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureNet, self).__init__()
|
||||
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
||||
if use_hooks:
|
||||
self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
|
||||
else:
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
|
||||
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.return_layers = {}
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
# return id has to be consistently str type for torchscript
|
||||
self.return_layers[new_name] = str(return_layers[old_name])
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.body(x)
|
||||
return output
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
def forward(self, x) -> Dict[str, torch.Tensor]:
|
||||
return self._collect(x)
|
||||
|
||||
|
||||
class FeatureHookNet(nn.Module):
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureListNet, self).__init__(
|
||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
|
||||
|
||||
class FeatureHookNet(nn.ModuleDict):
|
||||
""" FeatureHookNet
|
||||
|
||||
Wrap a model and extract features specified by the out indices.
|
||||
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
||||
|
||||
Features are extracted via hooks without modifying the underlying network in any way. If only
|
||||
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
|
||||
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
|
||||
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
||||
network in any way.
|
||||
|
||||
If `no_rewrite` is False, the model will be re-written as in the
|
||||
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
||||
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self, net,
|
||||
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None,
|
||||
feature_info=None, feature_concat=False):
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
||||
super(FeatureHookNet, self).__init__()
|
||||
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
||||
self.body = net
|
||||
self.hooks = FeatureHooks(
|
||||
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
assert not flatten_sequential
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info)
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def forward(self, x):
|
||||
self.body(x)
|
||||
return self.hooks.get_output(x.device)
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
||||
|
@ -252,7 +252,7 @@ class Xception65(nn.Module):
|
||||
def _create_gluon_xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(use_hooks=True), **kwargs)
|
||||
feature_cfg=dict(feature_cls='hook'), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
from .features import FeatureNet, FeatureHookNet
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .layers import Conv2dSame
|
||||
|
||||
|
||||
@ -234,15 +234,15 @@ def build_model_with_cfg(
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
|
||||
if features:
|
||||
feature_cls = feature_cfg.pop('feature_cls', FeatureNet)
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
if feature_cls == 'hook' or feature_cls == 'featurehooknet':
|
||||
feature_cls = FeatureHookNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'):
|
||||
model.reset_classifier(0)
|
||||
feature_cls = FeatureListNet
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
|
||||
return model
|
||||
|
@ -211,7 +211,8 @@ class MobileNetV3Features(nn.Module):
|
||||
return features
|
||||
else:
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
out = self.feature_hooks.get_output(x.device)
|
||||
return list(out.values())
|
||||
|
||||
|
||||
def _create_mnv3(model_kwargs, variant, pretrained=False):
|
||||
@ -220,6 +221,7 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_kwargs.pop('head_bias', None)
|
||||
model_cls = MobileNetV3Features
|
||||
else:
|
||||
load_strict = True
|
||||
|
@ -554,7 +554,7 @@ class NASNetALarge(nn.Module):
|
||||
def _create_nasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
|
||||
|
||||
|
@ -337,7 +337,7 @@ class PNASNet5Large(nn.Module):
|
||||
def _create_pnasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
|
||||
|
||||
|
@ -74,6 +74,29 @@ class SequentialList(nn.Sequential):
|
||||
return x
|
||||
|
||||
|
||||
class SelectSeq(nn.Module):
|
||||
def __init__(self, mode='index', index=0):
|
||||
super(SelectSeq, self).__init__()
|
||||
self.mode = mode
|
||||
self.index = index
|
||||
|
||||
@torch.jit._overload_method # noqa: F811
|
||||
def forward(self, x):
|
||||
# type: (List[torch.Tensor]) -> (torch.Tensor)
|
||||
pass
|
||||
|
||||
@torch.jit._overload_method # noqa: F811
|
||||
def forward(self, x):
|
||||
# type: (Tuple[torch.Tensor]) -> (torch.Tensor)
|
||||
pass
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
if self.mode == 'index':
|
||||
return x[self.index]
|
||||
else:
|
||||
return torch.cat(x, dim=1)
|
||||
|
||||
|
||||
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
|
||||
if padding is None:
|
||||
padding = ((stride - 1) + dilation * (k - 1)) // 2
|
||||
@ -137,8 +160,10 @@ class SelecSLS(nn.Module):
|
||||
|
||||
self.stem = conv_bn(in_chans, 32, stride=2)
|
||||
self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']])
|
||||
self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way
|
||||
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
|
||||
self.num_features = cfg['num_features']
|
||||
self.feature_info = cfg['feature_info']
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
@ -165,7 +190,7 @@ class SelecSLS(nn.Module):
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.features(x)
|
||||
x = self.head(x[0])
|
||||
x = self.head(self.from_seq(x))
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
@ -297,6 +322,7 @@ def _create_selecsls(variant, pretrained, model_kwargs):
|
||||
])
|
||||
else:
|
||||
raise ValueError('Invalid net configuration ' + variant + ' !!!')
|
||||
cfg['feature_info'] = feature_info
|
||||
|
||||
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
|
||||
return build_model_with_cfg(
|
||||
|
@ -160,6 +160,9 @@ class Bottleneck(nn.Module):
|
||||
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
||||
aa_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
||||
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
|
||||
|
||||
self.conv3 = conv2d_iabn(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
||||
|
||||
@ -167,9 +170,6 @@ class Bottleneck(nn.Module):
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
||||
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
@ -225,8 +225,8 @@ class TResNet(nn.Module):
|
||||
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
|
||||
dict(num_chs=self.planes, reduction=4, module='body.layer1'),
|
||||
dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'),
|
||||
dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'),
|
||||
dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'),
|
||||
dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
|
||||
dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
|
||||
]
|
||||
|
||||
# head
|
||||
|
@ -228,7 +228,7 @@ class Xception(nn.Module):
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
Xception, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(use_hooks=True), **kwargs)
|
||||
feature_cfg=dict(feature_cls='hook'), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -174,7 +174,7 @@ class XceptionAligned(nn.Module):
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs)
|
||||
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -100,6 +100,8 @@ def validate(args):
|
||||
# might as well try to validate something
|
||||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
@ -119,8 +121,6 @@ def validate(args):
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
|
||||
if args.torchscript:
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user